mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-20 21:06:54 +08:00
Compare commits
202 Commits
12979a3f21
...
nightly
| Author | SHA1 | Date | |
|---|---|---|---|
| 55c0468ac9 | |||
| eeb36a5ce7 | |||
| aceca266ff | |||
| d82e502a71 | |||
| 0494b92371 | |||
| 8683a5b1b7 | |||
| 4cbe470089 | |||
| 6cd1824a77 | |||
| 2844700dc4 | |||
| f8fd1ea7e1 | |||
| 57edc215d7 | |||
| 7a4044b05f | |||
| e84d5412bc | |||
| 151480dc85 | |||
| 2331b3a270 | |||
| 5cd1a678c8 | |||
| cc9546b761 | |||
| a63dcfed6f | |||
| 4dd8cdc38b | |||
| 1a4822d6be | |||
| ce161f09cc | |||
| 672958a192 | |||
| 3820de916c | |||
| ef44979b5c | |||
| d38f8a1562 | |||
| 8e4d011b15 | |||
| 7baa67dfe8 | |||
| e58271ef76 | |||
| 4fd4a41e7c | |||
| 82d4e5fb87 | |||
| d16643a53d | |||
| 93ca1e0b91 | |||
| 4046bffaf1 | |||
| 03f9be7cbb | |||
| 5e05f43c3d | |||
| 205a6483f5 | |||
| 2595644dfd | |||
| 30019dab9f | |||
| 4d46726eb7 | |||
| 0e8b9588ba | |||
| 344a106eba | |||
| bccad7b4a8 | |||
| f7926724aa | |||
| 5bba562048 | |||
| 49c74d08e8 | |||
| 1112b6291b | |||
| ef5d1d4b74 | |||
| a98887d4ca | |||
| 7ca3e11566 | |||
| a2e080c2d3 | |||
| ad6f7fd4b0 | |||
| 2a0f835ffe | |||
| 13d8241eee | |||
| 1ddd11f045 | |||
| 81eb03d230 | |||
| 7d23c3aed0 | |||
| 6be0338aa0 | |||
| 44dec89f1f | |||
| 2b260901df | |||
| 948bc93786 | |||
| 0f0fb53256 | |||
| 0fcb1680fd | |||
| 50715ba332 | |||
| f9510edbbc | |||
| 6560388f2b | |||
| e37aea5f81 | |||
| 7db9045b74 | |||
| a6bd765a02 | |||
| 74afb8d710 | |||
| ea4a5cd665 | |||
| 22a51a3868 | |||
| e9710b7aa9 | |||
| bd0eff2954 | |||
| e3cfe8e848 | |||
| c610bb605a | |||
| a6afb7dfe2 | |||
| 7b96113d4c | |||
| 8370bc61b7 | |||
| 74eb894453 | |||
| 34d29d7e8b | |||
| badf33e3b9 | |||
| 3cb72377d7 | |||
| ab4b62031f | |||
| 80f3ccf1ac | |||
| a1164b9c89 | |||
| fd7e55b23d | |||
| f128a1fa9e | |||
| 65a5a56d95 | |||
| ca2d6f3301 | |||
| a94b3b9df2 | |||
| 30377319d8 | |||
| 07dca37ef0 | |||
| 036b29f084 | |||
| 9863862348 | |||
| bb6022477e | |||
| 28bc87c5e2 | |||
| c51e6b2a58 | |||
| 481192300d | |||
| 1777620ea5 | |||
| f3a03b06b2 | |||
| dd046be976 | |||
| 5c9672a265 | |||
| 09a3854ed8 | |||
| 43f51baa96 | |||
| 5a2011e687 | |||
| 7dd9ce0b5f | |||
| b66881a371 | |||
| 4d7934061e | |||
| 660fa8888b | |||
| 3285f09c92 | |||
| 51ec708c58 | |||
| 9b8971a9de | |||
| 6546f86b4e | |||
| 8de6b97806 | |||
| e4e0a88053 | |||
| 7719fd6350 | |||
| 15ef6dd72f | |||
| 5b5f19cbc1 | |||
| ea38e12d42 | |||
| 885eb2eab9 | |||
| 6587acef88 | |||
| ad03ede7cd | |||
| 468e4042c2 | |||
| af1344033d | |||
| 4012d65b3c | |||
| e2bc1a3478 | |||
| 6c2c447a72 | |||
| e7022db9a4 | |||
| ca4a0ee1b2 | |||
| 27b0550876 | |||
| 797e03f843 | |||
| b4e06237ef | |||
| 751a13fb64 | |||
| fa7b857aa9 | |||
| 257af75ece | |||
| cbdacf21f6 | |||
| b1f3130519 | |||
| 3c224c817b | |||
| a3c9402218 | |||
| a7d40e9132 | |||
| 648342b62f | |||
| 4870d42949 | |||
| caaf7043cc | |||
| 237a66913b | |||
| 3c50c7d3ac | |||
| b44e65a12e | |||
| e3f40db963 | |||
| b5ad7b7062 | |||
| 6fc7def562 | |||
| c8f608b2dd | |||
| 5c81e01de5 | |||
| 83fac6d0a0 | |||
| a6681d6366 | |||
| 1388c4420d | |||
| 962bd5f5df | |||
| 627c11c429 | |||
| 4ba17361e9 | |||
| c946858328 | |||
| ba6e2af5fd | |||
| 2ffe6f7439 | |||
| e3987e21b9 | |||
| a713f54732 | |||
| 519f03097e | |||
| 299c655e39 | |||
| b8c0fb4572 | |||
| d1e172171f | |||
| 81ae6cf78d | |||
| 1120575021 | |||
| 221947acc4 | |||
| 21d8ffca56 | |||
| 41cff3e09e | |||
| b6c4722687 | |||
| 6ea4248bdc | |||
| 88a28212b3 | |||
| 9d0309aedc | |||
| 9a8ce9d3e2 | |||
| 7499608a8b | |||
| 0ebbb60102 | |||
| 80f6d22d2a | |||
| 088b049b4c | |||
| fa9b7b259c | |||
| 14616cf845 | |||
| d2915f6984 | |||
| ccce8beeeb | |||
| 3d2e0f1a1b | |||
| 918d5a9ff8 | |||
| 7d05d4ced7 | |||
| dbdda0fbab | |||
| cf7fdd274b | |||
| 982ed233a2 | |||
| 1f96c95b42 | |||
| 8604c4f57c | |||
| a674338c21 | |||
| 89d82ff031 | |||
| c71d25f744 | |||
| f57f32cf3a | |||
| b6314164c5 | |||
| 856201c0f2 | |||
| 9d8b96c1d0 | |||
| 7c3c185038 | |||
| a9259917c6 | |||
| 8c28587821 |
1
.github/copilot-instructions.md
vendored
Normal file
1
.github/copilot-instructions.md
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refer to [AGENTS.MD](../AGENTS.md) for all repo instructions.
|
||||||
21
.github/workflows/release.yml
vendored
21
.github/workflows/release.yml
vendored
@ -3,11 +3,12 @@ name: release
|
|||||||
on:
|
on:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 13 * * *' # This schedule runs every 13:00:00Z(21:00:00+08:00)
|
- cron: '0 13 * * *' # This schedule runs every 13:00:00Z(21:00:00+08:00)
|
||||||
|
# https://github.com/orgs/community/discussions/26286?utm_source=chatgpt.com#discussioncomment-3251208
|
||||||
|
# "The create event does not support branch filter and tag filter."
|
||||||
# The "create tags" trigger is specifically focused on the creation of new tags, while the "push tags" trigger is activated when tags are pushed, including both new tag creations and updates to existing tags.
|
# The "create tags" trigger is specifically focused on the creation of new tags, while the "push tags" trigger is activated when tags are pushed, including both new tag creations and updates to existing tags.
|
||||||
create:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- "v*.*.*" # normal release
|
- "v*.*.*" # normal release
|
||||||
- "nightly" # the only one mutable tag
|
|
||||||
|
|
||||||
# https://docs.github.com/en/actions/using-jobs/using-concurrency
|
# https://docs.github.com/en/actions/using-jobs/using-concurrency
|
||||||
concurrency:
|
concurrency:
|
||||||
@ -21,9 +22,9 @@ jobs:
|
|||||||
- name: Ensure workspace ownership
|
- name: Ensure workspace ownership
|
||||||
run: echo "chown -R ${USER} ${GITHUB_WORKSPACE}" && sudo chown -R ${USER} ${GITHUB_WORKSPACE}
|
run: echo "chown -R ${USER} ${GITHUB_WORKSPACE}" && sudo chown -R ${USER} ${GITHUB_WORKSPACE}
|
||||||
|
|
||||||
# https://github.com/actions/checkout/blob/v3/README.md
|
# https://github.com/actions/checkout/blob/v6/README.md
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.GITHUB_TOKEN }} # Use the secret as an environment variable
|
token: ${{ secrets.GITHUB_TOKEN }} # Use the secret as an environment variable
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
@ -31,12 +32,12 @@ jobs:
|
|||||||
|
|
||||||
- name: Prepare release body
|
- name: Prepare release body
|
||||||
run: |
|
run: |
|
||||||
if [[ ${GITHUB_EVENT_NAME} == "create" ]]; then
|
if [[ ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||||
RELEASE_TAG=${GITHUB_REF#refs/tags/}
|
RELEASE_TAG=${GITHUB_REF#refs/tags/}
|
||||||
if [[ ${RELEASE_TAG} == "nightly" ]]; then
|
if [[ ${RELEASE_TAG} == v* ]]; then
|
||||||
PRERELEASE=true
|
|
||||||
else
|
|
||||||
PRERELEASE=false
|
PRERELEASE=false
|
||||||
|
else
|
||||||
|
PRERELEASE=true
|
||||||
fi
|
fi
|
||||||
echo "Workflow triggered by create tag: ${RELEASE_TAG}"
|
echo "Workflow triggered by create tag: ${RELEASE_TAG}"
|
||||||
else
|
else
|
||||||
@ -55,7 +56,7 @@ jobs:
|
|||||||
git fetch --tags
|
git fetch --tags
|
||||||
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
||||||
# Determine if a given tag exists and matches a specific Git commit.
|
# Determine if a given tag exists and matches a specific Git commit.
|
||||||
# actions/checkout@v4 fetch-tags doesn't work when triggered by schedule
|
# actions/checkout@v6 fetch-tags doesn't work when triggered by schedule
|
||||||
if [ "$(git rev-parse -q --verify "refs/tags/${RELEASE_TAG}")" = "${GITHUB_SHA}" ]; then
|
if [ "$(git rev-parse -q --verify "refs/tags/${RELEASE_TAG}")" = "${GITHUB_SHA}" ]; then
|
||||||
echo "mutable tag ${RELEASE_TAG} exists and matches ${GITHUB_SHA}"
|
echo "mutable tag ${RELEASE_TAG} exists and matches ${GITHUB_SHA}"
|
||||||
else
|
else
|
||||||
@ -88,7 +89,7 @@ jobs:
|
|||||||
- name: Build and push image
|
- name: Build and push image
|
||||||
run: |
|
run: |
|
||||||
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
|
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
sudo docker build --build-arg NEED_MIRROR=1 -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
|
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
|
||||||
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
|
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
|
||||||
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
|
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
|
||||||
sudo docker push infiniflow/ragflow:latest
|
sudo docker push infiniflow/ragflow:latest
|
||||||
|
|||||||
29
.github/workflows/tests.yml
vendored
29
.github/workflows/tests.yml
vendored
@ -1,4 +1,6 @@
|
|||||||
name: tests
|
name: tests
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -12,7 +14,7 @@ on:
|
|||||||
# The only difference between pull_request and pull_request_target is the context in which the workflow runs:
|
# The only difference between pull_request and pull_request_target is the context in which the workflow runs:
|
||||||
# — pull_request_target workflows use the workflow files from the default branch, and secrets are available.
|
# — pull_request_target workflows use the workflow files from the default branch, and secrets are available.
|
||||||
# — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable.
|
# — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable.
|
||||||
pull_request_target:
|
pull_request:
|
||||||
types: [ synchronize, ready_for_review ]
|
types: [ synchronize, ready_for_review ]
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- 'docs/**'
|
- 'docs/**'
|
||||||
@ -31,12 +33,9 @@ jobs:
|
|||||||
name: ragflow_tests
|
name: ragflow_tests
|
||||||
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
|
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
|
||||||
# https://github.com/orgs/community/discussions/26261
|
# https://github.com/orgs/community/discussions/26261
|
||||||
if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable == true) }}
|
if: ${{ github.event_name != 'pull_request' || (github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'ci')) }}
|
||||||
runs-on: [ "self-hosted", "ragflow-test" ]
|
runs-on: [ "self-hosted", "ragflow-test" ]
|
||||||
steps:
|
steps:
|
||||||
# https://github.com/hmarr/debug-action
|
|
||||||
#- uses: hmarr/debug-action@v2
|
|
||||||
|
|
||||||
- name: Ensure workspace ownership
|
- name: Ensure workspace ownership
|
||||||
run: |
|
run: |
|
||||||
echo "Workflow triggered by ${{ github.event_name }}"
|
echo "Workflow triggered by ${{ github.event_name }}"
|
||||||
@ -44,7 +43,7 @@ jobs:
|
|||||||
|
|
||||||
# https://github.com/actions/checkout/issues/1781
|
# https://github.com/actions/checkout/issues/1781
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ (github.event_name == 'pull_request' || github.event_name == 'pull_request_target') && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.sha }}
|
ref: ${{ (github.event_name == 'pull_request' || github.event_name == 'pull_request_target') && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.sha }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
@ -53,7 +52,7 @@ jobs:
|
|||||||
- name: Check workflow duplication
|
- name: Check workflow duplication
|
||||||
if: ${{ !cancelled() && !failure() }}
|
if: ${{ !cancelled() && !failure() }}
|
||||||
run: |
|
run: |
|
||||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
if [[ ${GITHUB_EVENT_NAME} != "pull_request" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||||
HEAD=$(git rev-parse HEAD)
|
HEAD=$(git rev-parse HEAD)
|
||||||
# Find a PR that introduced a given commit
|
# Find a PR that introduced a given commit
|
||||||
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
|
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
|
||||||
@ -78,7 +77,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then
|
elif [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then
|
||||||
PR_NUMBER=${{ github.event.pull_request.number }}
|
PR_NUMBER=${{ github.event.pull_request.number }}
|
||||||
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
||||||
# Calculate the hash of the current workspace content
|
# Calculate the hash of the current workspace content
|
||||||
@ -98,7 +97,7 @@ jobs:
|
|||||||
- name: Check comments of changed Python files
|
- name: Check comments of changed Python files
|
||||||
if: ${{ false }}
|
if: ${{ false }}
|
||||||
run: |
|
run: |
|
||||||
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
|
if [[ ${{ github.event_name }} == 'pull_request' || ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
||||||
| grep -E '\.(py)$' || true)
|
| grep -E '\.(py)$' || true)
|
||||||
|
|
||||||
@ -127,13 +126,21 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Run unit test
|
||||||
|
run: |
|
||||||
|
uv sync --python 3.12 --group test --frozen
|
||||||
|
source .venv/bin/activate
|
||||||
|
which pytest || echo "pytest not in PATH"
|
||||||
|
echo "Start to run unit test"
|
||||||
|
python3 run_tests.py
|
||||||
|
|
||||||
- name: Build ragflow:nightly
|
- name: Build ragflow:nightly
|
||||||
run: |
|
run: |
|
||||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||||
RAGFLOW_IMAGE=infiniflow/ragflow:${GITHUB_RUN_ID}
|
RAGFLOW_IMAGE=infiniflow/ragflow:${GITHUB_RUN_ID}
|
||||||
echo "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> ${GITHUB_ENV}
|
echo "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> ${GITHUB_ENV}
|
||||||
sudo docker pull ubuntu:22.04
|
sudo docker pull ubuntu:22.04
|
||||||
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
||||||
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
||||||
export HTTP_API_TEST_LEVEL=p3
|
export HTTP_API_TEST_LEVEL=p3
|
||||||
else
|
else
|
||||||
@ -193,7 +200,7 @@ jobs:
|
|||||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
||||||
|
|
||||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||||
uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
|
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
|
||||||
|
|
||||||
- name: Run sdk tests against Elasticsearch
|
- name: Run sdk tests against Elasticsearch
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -195,3 +195,6 @@ ragflow_cli.egg-info
|
|||||||
|
|
||||||
# Default backup dir
|
# Default backup dir
|
||||||
backup
|
backup
|
||||||
|
|
||||||
|
|
||||||
|
.hypothesis
|
||||||
110
AGENTS.md
Normal file
110
AGENTS.md
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
# RAGFlow Project Instructions for GitHub Copilot
|
||||||
|
|
||||||
|
This file provides context, build instructions, and coding standards for the RAGFlow project.
|
||||||
|
It is structured to follow GitHub Copilot's [customization guidelines](https://docs.github.com/en/copilot/concepts/prompting/response-customization).
|
||||||
|
|
||||||
|
## 1. Project Overview
|
||||||
|
RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding. It is a full-stack application with a Python backend and a React/TypeScript frontend.
|
||||||
|
|
||||||
|
- **Backend**: Python 3.10+ (Flask/Quart)
|
||||||
|
- **Frontend**: TypeScript, React, UmiJS
|
||||||
|
- **Architecture**: Microservices based on Docker.
|
||||||
|
- `api/`: Backend API server.
|
||||||
|
- `rag/`: Core RAG logic (indexing, retrieval).
|
||||||
|
- `deepdoc/`: Document parsing and OCR.
|
||||||
|
- `web/`: Frontend application.
|
||||||
|
|
||||||
|
## 2. Directory Structure
|
||||||
|
- `api/`: Backend API server (Flask/Quart).
|
||||||
|
- `apps/`: API Blueprints (Knowledge Base, Chat, etc.).
|
||||||
|
- `db/`: Database models and services.
|
||||||
|
- `rag/`: Core RAG logic.
|
||||||
|
- `llm/`: LLM, Embedding, and Rerank model abstractions.
|
||||||
|
- `deepdoc/`: Document parsing and OCR modules.
|
||||||
|
- `agent/`: Agentic reasoning components.
|
||||||
|
- `web/`: Frontend application (React + UmiJS).
|
||||||
|
- `docker/`: Docker deployment configurations.
|
||||||
|
- `sdk/`: Python SDK.
|
||||||
|
- `test/`: Backend tests.
|
||||||
|
|
||||||
|
## 3. Build Instructions
|
||||||
|
|
||||||
|
### Backend (Python)
|
||||||
|
The project uses **uv** for dependency management.
|
||||||
|
|
||||||
|
1. **Setup Environment**:
|
||||||
|
```bash
|
||||||
|
uv sync --python 3.12 --all-extras
|
||||||
|
uv run download_deps.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Run Server**:
|
||||||
|
- **Pre-requisite**: Start dependent services (MySQL, ES/Infinity, Redis, MinIO).
|
||||||
|
```bash
|
||||||
|
docker compose -f docker/docker-compose-base.yml up -d
|
||||||
|
```
|
||||||
|
- **Launch**:
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate
|
||||||
|
export PYTHONPATH=$(pwd)
|
||||||
|
bash docker/launch_backend_service.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend (TypeScript/React)
|
||||||
|
Located in `web/`.
|
||||||
|
|
||||||
|
1. **Install Dependencies**:
|
||||||
|
```bash
|
||||||
|
cd web
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Run Dev Server**:
|
||||||
|
```bash
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
Runs on port 8000 by default.
|
||||||
|
|
||||||
|
### Docker Deployment
|
||||||
|
To run the full stack using Docker:
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.yml up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. Testing Instructions
|
||||||
|
|
||||||
|
### Backend Tests
|
||||||
|
- **Run All Tests**:
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
- **Run Specific Test**:
|
||||||
|
```bash
|
||||||
|
uv run pytest test/test_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend Tests
|
||||||
|
- **Run Tests**:
|
||||||
|
```bash
|
||||||
|
cd web
|
||||||
|
npm run test
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. Coding Standards & Guidelines
|
||||||
|
- **Python Formatting**: Use `ruff` for linting and formatting.
|
||||||
|
```bash
|
||||||
|
ruff check
|
||||||
|
ruff format
|
||||||
|
```
|
||||||
|
- **Frontend Linting**:
|
||||||
|
```bash
|
||||||
|
cd web
|
||||||
|
npm run lint
|
||||||
|
```
|
||||||
|
- **Pre-commit**: Ensure pre-commit hooks are installed.
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on d
|
|||||||
### Backend Development
|
### Backend Development
|
||||||
```bash
|
```bash
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
uv sync --python 3.10 --all-extras
|
uv sync --python 3.12 --all-extras
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
|
|
||||||
|
|||||||
42
Dockerfile
42
Dockerfile
@ -1,5 +1,5 @@
|
|||||||
# base stage
|
# base stage
|
||||||
FROM ubuntu:22.04 AS base
|
FROM ubuntu:24.04 AS base
|
||||||
USER root
|
USER root
|
||||||
SHELL ["/bin/bash", "-c"]
|
SHELL ["/bin/bash", "-c"]
|
||||||
|
|
||||||
@ -10,7 +10,6 @@ WORKDIR /ragflow
|
|||||||
# Copy models downloaded via download_deps.py
|
# Copy models downloaded via download_deps.py
|
||||||
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
|
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
|
||||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
|
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
|
||||||
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
|
|
||||||
tar --exclude='.*' -cf - \
|
tar --exclude='.*' -cf - \
|
||||||
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
|
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
|
||||||
/huggingface.co/InfiniFlow/deepdoc \
|
/huggingface.co/InfiniFlow/deepdoc \
|
||||||
@ -34,36 +33,41 @@ ENV DEBIAN_FRONTEND=noninteractive
|
|||||||
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
|
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
|
||||||
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
|
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
|
||||||
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||||
|
apt update && \
|
||||||
|
apt --no-install-recommends install -y ca-certificates; \
|
||||||
if [ "$NEED_MIRROR" == "1" ]; then \
|
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
sed -i 's|http://ports.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
|
sed -i 's|http://archive.ubuntu.com/ubuntu|https://mirrors.tuna.tsinghua.edu.cn/ubuntu|g' /etc/apt/sources.list.d/ubuntu.sources; \
|
||||||
sed -i 's|http://archive.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
|
sed -i 's|http://security.ubuntu.com/ubuntu|https://mirrors.tuna.tsinghua.edu.cn/ubuntu|g' /etc/apt/sources.list.d/ubuntu.sources; \
|
||||||
fi; \
|
fi; \
|
||||||
rm -f /etc/apt/apt.conf.d/docker-clean && \
|
rm -f /etc/apt/apt.conf.d/docker-clean && \
|
||||||
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
|
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
|
||||||
chmod 1777 /tmp && \
|
chmod 1777 /tmp && \
|
||||||
apt update && \
|
apt update && \
|
||||||
apt --no-install-recommends install -y ca-certificates && \
|
|
||||||
apt update && \
|
|
||||||
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
|
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
|
||||||
apt install -y pkg-config libicu-dev libgdiplus && \
|
apt install -y pkg-config libicu-dev libgdiplus && \
|
||||||
apt install -y default-jdk && \
|
apt install -y default-jdk && \
|
||||||
apt install -y libatk-bridge2.0-0 && \
|
apt install -y libatk-bridge2.0-0 && \
|
||||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||||
apt install -y libjemalloc-dev && \
|
apt install -y libjemalloc-dev && \
|
||||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
apt install -y nginx unzip curl wget git vim less && \
|
||||||
apt install -y ghostscript && \
|
apt install -y ghostscript && \
|
||||||
apt install -y pandoc && \
|
apt install -y pandoc && \
|
||||||
apt install -y texlive
|
apt install -y texlive && \
|
||||||
|
apt install -y fonts-freefont-ttf fonts-noto-cjk
|
||||||
|
|
||||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
# Install uv
|
||||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
||||||
pip3 config set global.trusted-host pypi.tuna.tsinghua.edu.cn; \
|
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
mkdir -p /etc/uv && \
|
mkdir -p /etc/uv && \
|
||||||
echo "[[index]]" > /etc/uv/uv.toml && \
|
echo 'python-install-mirror = "https://registry.npmmirror.com/-/binary/python-build-standalone/"' > /etc/uv/uv.toml && \
|
||||||
|
echo '[[index]]' >> /etc/uv/uv.toml && \
|
||||||
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
|
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
|
||||||
echo "default = true" >> /etc/uv/uv.toml; \
|
echo 'default = true' >> /etc/uv/uv.toml; \
|
||||||
fi; \
|
fi; \
|
||||||
pipx install uv
|
tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \
|
||||||
|
&& cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \
|
||||||
|
&& rm -rf uv-x86_64-unknown-linux-gnu \
|
||||||
|
&& uv python install 3.11
|
||||||
|
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||||
ENV PATH=/root/.local/bin:$PATH
|
ENV PATH=/root/.local/bin:$PATH
|
||||||
@ -79,12 +83,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
|||||||
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
||||||
RUN apt update && apt install -y curl build-essential \
|
RUN apt update && apt install -y curl build-essential \
|
||||||
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
# Use TUNA mirrors for rustup/rust dist files
|
# Use TUNA mirrors for rustup/rust dist files \
|
||||||
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
||||||
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
||||||
echo "Using TUNA mirrors for Rustup."; \
|
echo "Using TUNA mirrors for Rustup."; \
|
||||||
fi; \
|
fi; \
|
||||||
# Force curl to use HTTP/1.1
|
# Force curl to use HTTP/1.1 \
|
||||||
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
||||||
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
||||||
|
|
||||||
@ -101,10 +105,10 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
|||||||
apt update && \
|
apt update && \
|
||||||
arch="$(uname -m)"; \
|
arch="$(uname -m)"; \
|
||||||
if [ "$arch" = "arm64" ] || [ "$arch" = "aarch64" ]; then \
|
if [ "$arch" = "arm64" ] || [ "$arch" = "aarch64" ]; then \
|
||||||
# ARM64 (macOS/Apple Silicon or Linux aarch64)
|
# ARM64 (macOS/Apple Silicon or Linux aarch64) \
|
||||||
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql18; \
|
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql18; \
|
||||||
else \
|
else \
|
||||||
# x86_64 or others
|
# x86_64 or others \
|
||||||
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17; \
|
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17; \
|
||||||
fi || \
|
fi || \
|
||||||
{ echo "Failed to install ODBC driver"; exit 1; }
|
{ echo "Failed to install ODBC driver"; exit 1; }
|
||||||
@ -148,7 +152,7 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
|
|||||||
else \
|
else \
|
||||||
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
|
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
|
||||||
fi; \
|
fi; \
|
||||||
uv sync --python 3.10 --frozen
|
uv sync --python 3.12 --frozen
|
||||||
|
|
||||||
COPY web web
|
COPY web web
|
||||||
COPY docs docs
|
COPY docs docs
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
FROM scratch
|
FROM scratch
|
||||||
|
|
||||||
# Copy resources downloaded via download_deps.py
|
# Copy resources downloaded via download_deps.py
|
||||||
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb /
|
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz /
|
||||||
|
|
||||||
COPY nltk_data /nltk_data
|
COPY nltk_data /nltk_data
|
||||||
|
|
||||||
|
|||||||
@ -194,7 +194,7 @@ releases! 🌟
|
|||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.22.1
|
||||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||||
|
|
||||||
# Use CPU for DeepDoc tasks:
|
# Use CPU for DeepDoc tasks:
|
||||||
$ docker compose -f docker-compose.yml up -d
|
$ docker compose -f docker-compose.yml up -d
|
||||||
@ -207,7 +207,7 @@ releases! 🌟
|
|||||||
> Note: Prior to `v0.22.0`, we provided both images with embedding models and slim images without embedding models. Details as follows:
|
> Note: Prior to `v0.22.0`, we provided both images with embedding models and slim images without embedding models. Details as follows:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
@ -314,7 +314,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|||||||
@ -207,7 +207,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
> Catatan: Sebelum `v0.22.0`, kami menyediakan image dengan model embedding dan image slim tanpa model embedding. Detailnya sebagai berikut:
|
> Catatan: Sebelum `v0.22.0`, kami menyediakan image dengan model embedding dan image slim tanpa model embedding. Detailnya sebagai berikut:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|||||||
@ -187,7 +187,7 @@
|
|||||||
> 注意:`v0.22.0` より前のバージョンでは、embedding モデルを含むイメージと、embedding モデルを含まない slim イメージの両方を提供していました。詳細は以下の通りです:
|
> 注意:`v0.22.0` より前のバージョンでは、embedding モデルを含むイメージと、embedding モデルを含まない slim イメージの両方を提供していました。詳細は以下の通りです:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|||||||
@ -189,7 +189,7 @@
|
|||||||
> 참고: `v0.22.0` 이전 버전에서는 embedding 모델이 포함된 이미지와 embedding 모델이 포함되지 않은 slim 이미지를 모두 제공했습니다. 자세한 내용은 다음과 같습니다:
|
> 참고: `v0.22.0` 이전 버전에서는 embedding 모델이 포함된 이미지와 embedding 모델이 포함되지 않은 slim 이미지를 모두 제공했습니다. 자세한 내용은 다음과 같습니다:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
@ -283,7 +283,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|||||||
@ -207,7 +207,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
> Nota: Antes da `v0.22.0`, fornecíamos imagens com modelos de embedding e imagens slim sem modelos de embedding. Detalhes a seguir:
|
> Nota: Antes da `v0.22.0`, fornecíamos imagens com modelos de embedding e imagens slim sem modelos de embedding. Detalhes a seguir:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
@ -305,7 +305,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
uv sync --python 3.10 # instala os módulos Python dependentes do RAGFlow
|
uv sync --python 3.12 # instala os módulos Python dependentes do RAGFlow
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|||||||
@ -206,7 +206,7 @@
|
|||||||
> 注意:在 `v0.22.0` 之前的版本,我們會同時提供包含 embedding 模型的映像和不含 embedding 模型的 slim 映像。具體如下:
|
> 注意:在 `v0.22.0` 之前的版本,我們會同時提供包含 embedding 模型的映像和不含 embedding 模型的 slim 映像。具體如下:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
@ -315,7 +315,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|||||||
@ -207,7 +207,7 @@
|
|||||||
> 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下:
|
> 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
@ -315,7 +315,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||||
uv run download_deps.py
|
uv run download_deps.py
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|||||||
@ -6,7 +6,7 @@ Use this section to tell people about which versions of your project are
|
|||||||
currently being supported with security updates.
|
currently being supported with security updates.
|
||||||
|
|
||||||
| Version | Supported |
|
| Version | Supported |
|
||||||
| ------- | ------------------ |
|
|---------|--------------------|
|
||||||
| <=0.7.0 | :white_check_mark: |
|
| <=0.7.0 | :white_check_mark: |
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|||||||
@ -351,7 +351,7 @@ class AdminCLI(Cmd):
|
|||||||
def verify_admin(self, arguments: dict, single_command: bool):
|
def verify_admin(self, arguments: dict, single_command: bool):
|
||||||
self.host = arguments['host']
|
self.host = arguments['host']
|
||||||
self.port = arguments['port']
|
self.port = arguments['port']
|
||||||
print(f"Attempt to access ip: {self.host}, port: {self.port}")
|
print("Attempt to access server for admin login")
|
||||||
url = f"http://{self.host}:{self.port}/api/v1/admin/login"
|
url = f"http://{self.host}:{self.port}/api/v1/admin/login"
|
||||||
|
|
||||||
attempt_count = 3
|
attempt_count = 3
|
||||||
@ -390,7 +390,7 @@ class AdminCLI(Cmd):
|
|||||||
print(f"Bad response,status: {response.status_code}, password is wrong")
|
print(f"Bad response,status: {response.status_code}, password is wrong")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
print(f"Can't access {self.host}, port: {self.port}")
|
print("Can't access server for admin login (connection failed)")
|
||||||
|
|
||||||
def _format_service_detail_table(self, data):
|
def _format_service_detail_table(self, data):
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
@ -674,7 +674,7 @@ class AdminCLI(Cmd):
|
|||||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
user_name: str = user_name_tree.children[0].strip("'\"")
|
||||||
password_tree: Tree = command['password']
|
password_tree: Tree = command['password']
|
||||||
password: str = password_tree.children[0].strip("'\"")
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
print(f"Alter user: {user_name}, password: {password}")
|
print(f"Alter user: {user_name}, password: ******")
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password'
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password'
|
||||||
response = self.session.put(url, json={'new_password': encrypt(password)})
|
response = self.session.put(url, json={'new_password': encrypt(password)})
|
||||||
res_json = response.json()
|
res_json = response.json()
|
||||||
@ -689,7 +689,7 @@ class AdminCLI(Cmd):
|
|||||||
password_tree: Tree = command['password']
|
password_tree: Tree = command['password']
|
||||||
password: str = password_tree.children[0].strip("'\"")
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
role: str = command['role']
|
role: str = command['role']
|
||||||
print(f"Create user: {user_name}, password: {password}, role: {role}")
|
print(f"Create user: {user_name}, password: ******, role: {role}")
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
response = self.session.post(
|
response = self.session.post(
|
||||||
url,
|
url,
|
||||||
@ -951,7 +951,7 @@ def main():
|
|||||||
|
|
||||||
args = cli.parse_connection_args(sys.argv)
|
args = cli.parse_connection_args(sys.argv)
|
||||||
if 'error' in args:
|
if 'error' in args:
|
||||||
print(f"Error: {args['error']}")
|
print("Error: Invalid connection arguments")
|
||||||
return
|
return
|
||||||
|
|
||||||
if 'command' in args:
|
if 'command' in args:
|
||||||
@ -960,7 +960,7 @@ def main():
|
|||||||
return
|
return
|
||||||
if cli.verify_admin(args, single_command=True):
|
if cli.verify_admin(args, single_command=True):
|
||||||
command: str = args['command']
|
command: str = args['command']
|
||||||
print(f"Run single command: {command}")
|
# print(f"Run single command: {command}")
|
||||||
cli.run_single_command(command)
|
cli.run_single_command(command)
|
||||||
else:
|
else:
|
||||||
if cli.verify_admin(args, single_command=False):
|
if cli.verify_admin(args, single_command=False):
|
||||||
|
|||||||
@ -5,7 +5,7 @@ description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow
|
|||||||
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||||
license = { text = "Apache License, Version 2.0" }
|
license = { text = "Apache License, Version 2.0" }
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10,<3.13"
|
requires-python = ">=3.12,<3.15"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"requests>=2.30.0,<3.0.0",
|
"requests>=2.30.0,<3.0.0",
|
||||||
"beartype>=0.20.0,<1.0.0",
|
"beartype>=0.20.0,<1.0.0",
|
||||||
|
|||||||
@ -176,11 +176,11 @@ def login_verify(f):
|
|||||||
"message": "Access denied",
|
"message": "Access denied",
|
||||||
"data": None
|
"data": None
|
||||||
}), 200
|
}), 200
|
||||||
except Exception as e:
|
except Exception:
|
||||||
error_msg = str(e)
|
logging.exception("An error occurred during admin login verification.")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"message": error_msg
|
"message": "An internal server error occurred."
|
||||||
}), 200
|
}), 200
|
||||||
|
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|||||||
@ -13,6 +13,3 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from beartype.claw import beartype_this_package
|
|
||||||
beartype_this_package()
|
|
||||||
|
|||||||
221
agent/canvas.py
221
agent/canvas.py
@ -13,7 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import inspect
|
||||||
|
import binascii
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -25,7 +28,10 @@ from typing import Any, Union, Tuple
|
|||||||
|
|
||||||
from agent.component import component_class
|
from agent.component import component_class
|
||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
|
from common.constants import LLMType
|
||||||
from common.misc_utils import get_uuid, hash_str2int
|
from common.misc_utils import get_uuid, hash_str2int
|
||||||
from common.exceptions import TaskCanceledException
|
from common.exceptions import TaskCanceledException
|
||||||
from rag.prompts.generator import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
@ -79,14 +85,12 @@ class Graph:
|
|||||||
self.dsl = json.loads(dsl)
|
self.dsl = json.loads(dsl)
|
||||||
self._tenant_id = tenant_id
|
self._tenant_id = tenant_id
|
||||||
self.task_id = task_id if task_id else get_uuid()
|
self.task_id = task_id if task_id else get_uuid()
|
||||||
|
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
self.components = self.dsl["components"]
|
self.components = self.dsl["components"]
|
||||||
cpn_nms = set([])
|
cpn_nms = set([])
|
||||||
for k, cpn in self.components.items():
|
|
||||||
cpn_nms.add(cpn["obj"]["component_name"])
|
|
||||||
|
|
||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
cpn_nms.add(cpn["obj"]["component_name"])
|
cpn_nms.add(cpn["obj"]["component_name"])
|
||||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||||
@ -156,7 +160,7 @@ class Graph:
|
|||||||
return self._tenant_id
|
return self._tenant_id
|
||||||
|
|
||||||
def get_value_with_variable(self,value: str) -> Any:
|
def get_value_with_variable(self,value: str) -> Any:
|
||||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||||
out_parts = []
|
out_parts = []
|
||||||
last = 0
|
last = 0
|
||||||
|
|
||||||
@ -281,6 +285,7 @@ class Canvas(Graph):
|
|||||||
"sys.conversation_turns": 0,
|
"sys.conversation_turns": 0,
|
||||||
"sys.files": []
|
"sys.files": []
|
||||||
}
|
}
|
||||||
|
self.variables = {}
|
||||||
super().__init__(dsl, tenant_id, task_id)
|
super().__init__(dsl, tenant_id, task_id)
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
@ -295,6 +300,10 @@ class Canvas(Graph):
|
|||||||
"sys.conversation_turns": 0,
|
"sys.conversation_turns": 0,
|
||||||
"sys.files": []
|
"sys.files": []
|
||||||
}
|
}
|
||||||
|
if "variables" in self.dsl:
|
||||||
|
self.variables = self.dsl["variables"]
|
||||||
|
else:
|
||||||
|
self.variables = {}
|
||||||
|
|
||||||
self.retrieval = self.dsl["retrieval"]
|
self.retrieval = self.dsl["retrieval"]
|
||||||
self.memory = self.dsl.get("memory", [])
|
self.memory = self.dsl.get("memory", [])
|
||||||
@ -311,8 +320,9 @@ class Canvas(Graph):
|
|||||||
self.history = []
|
self.history = []
|
||||||
self.retrieval = []
|
self.retrieval = []
|
||||||
self.memory = []
|
self.memory = []
|
||||||
|
print(self.variables)
|
||||||
for k in self.globals.keys():
|
for k in self.globals.keys():
|
||||||
if k.startswith("sys.") or k.startswith("env."):
|
if k.startswith("sys."):
|
||||||
if isinstance(self.globals[k], str):
|
if isinstance(self.globals[k], str):
|
||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
elif isinstance(self.globals[k], int):
|
elif isinstance(self.globals[k], int):
|
||||||
@ -325,9 +335,31 @@ class Canvas(Graph):
|
|||||||
self.globals[k] = {}
|
self.globals[k] = {}
|
||||||
else:
|
else:
|
||||||
self.globals[k] = None
|
self.globals[k] = None
|
||||||
|
if k.startswith("env."):
|
||||||
|
key = k[4:]
|
||||||
|
if key in self.variables:
|
||||||
|
variable = self.variables[key]
|
||||||
|
if variable["value"]:
|
||||||
|
self.globals[k] = variable["value"]
|
||||||
|
else:
|
||||||
|
if variable["type"] == "string":
|
||||||
|
self.globals[k] = ""
|
||||||
|
elif variable["type"] == "number":
|
||||||
|
self.globals[k] = 0
|
||||||
|
elif variable["type"] == "boolean":
|
||||||
|
self.globals[k] = False
|
||||||
|
elif variable["type"] == "object":
|
||||||
|
self.globals[k] = {}
|
||||||
|
elif variable["type"].startswith("array"):
|
||||||
|
self.globals[k] = []
|
||||||
|
else:
|
||||||
|
self.globals[k] = ""
|
||||||
|
else:
|
||||||
|
self.globals[k] = ""
|
||||||
|
|
||||||
async def run(self, **kwargs):
|
async def run(self, **kwargs):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
self.message_id = get_uuid()
|
self.message_id = get_uuid()
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
self.add_user_input(kwargs.get("query"))
|
self.add_user_input(kwargs.get("query"))
|
||||||
@ -336,14 +368,19 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
if kwargs.get("webhook_payload"):
|
if kwargs.get("webhook_payload"):
|
||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
if self.components[k]["obj"].component_name.lower() == "webhook":
|
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||||
for kk, vv in kwargs["webhook_payload"].items():
|
payload = kwargs.get("webhook_payload", {})
|
||||||
|
if "input" in payload:
|
||||||
|
self.components[k]["obj"].set_input_value("request", payload["input"])
|
||||||
|
for kk, vv in payload.items():
|
||||||
|
if kk == "input":
|
||||||
|
continue
|
||||||
self.components[k]["obj"].set_output(kk, vv)
|
self.components[k]["obj"].set_output(kk, vv)
|
||||||
|
|
||||||
for k in kwargs.keys():
|
for k in kwargs.keys():
|
||||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||||
if k == "files":
|
if k == "files":
|
||||||
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
|
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
|
||||||
else:
|
else:
|
||||||
self.globals[f"sys.{k}"] = kwargs[k]
|
self.globals[f"sys.{k}"] = kwargs[k]
|
||||||
if not self.globals["sys.conversation_turns"] :
|
if not self.globals["sys.conversation_turns"] :
|
||||||
@ -373,19 +410,27 @@ class Canvas(Graph):
|
|||||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||||
|
|
||||||
def _run_batch(f, t):
|
async def _run_batch(f, t):
|
||||||
if self.is_canceled():
|
if self.is_canceled():
|
||||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
raise TaskCanceledException(msg)
|
raise TaskCanceledException(msg)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
loop = asyncio.get_running_loop()
|
||||||
thr = []
|
tasks = []
|
||||||
|
|
||||||
|
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||||
|
return asyncio.run(coro_func(**call_kwargs))
|
||||||
|
|
||||||
i = f
|
i = f
|
||||||
while i < t:
|
while i < t:
|
||||||
cpn = self.get_component_obj(self.path[i])
|
cpn = self.get_component_obj(self.path[i])
|
||||||
|
task_fn = None
|
||||||
|
call_kwargs = None
|
||||||
|
|
||||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
call_kwargs = {"inputs": kwargs.get("inputs", {})}
|
||||||
|
task_fn = cpn.invoke
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
for _, ele in cpn.get_input_elements().items():
|
for _, ele in cpn.get_input_elements().items():
|
||||||
@ -394,10 +439,21 @@ class Canvas(Graph):
|
|||||||
t -= 1
|
t -= 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
call_kwargs = cpn.get_input()
|
||||||
|
task_fn = cpn.invoke
|
||||||
i += 1
|
i += 1
|
||||||
for t in thr:
|
|
||||||
t.result()
|
if task_fn is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
invoke_async = getattr(cpn, "invoke_async", None)
|
||||||
|
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||||
|
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||||
|
else:
|
||||||
|
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def _node_finished(cpn_obj):
|
def _node_finished(cpn_obj):
|
||||||
return decorate("node_finished",{
|
return decorate("node_finished",{
|
||||||
@ -414,6 +470,7 @@ class Canvas(Graph):
|
|||||||
self.error = ""
|
self.error = ""
|
||||||
idx = len(self.path) - 1
|
idx = len(self.path) - 1
|
||||||
partials = []
|
partials = []
|
||||||
|
tts_mdl = None
|
||||||
while idx < len(self.path):
|
while idx < len(self.path):
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
@ -424,35 +481,70 @@ class Canvas(Graph):
|
|||||||
"component_type": self.get_component_type(self.path[i]),
|
"component_type": self.get_component_type(self.path[i]),
|
||||||
"thoughts": self.get_component_thoughts(self.path[i])
|
"thoughts": self.get_component_thoughts(self.path[i])
|
||||||
})
|
})
|
||||||
_run_batch(idx, to)
|
await _run_batch(idx, to)
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
# post processing of components invocation
|
# post-processing of components invocation
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
cpn_obj = self.get_component_obj(self.path[i])
|
cpn_obj = self.get_component_obj(self.path[i])
|
||||||
if cpn_obj.component_name.lower() == "message":
|
if cpn_obj.component_name.lower() == "message":
|
||||||
|
if cpn_obj.get_param("auto_play"):
|
||||||
|
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
_m = ""
|
_m = ""
|
||||||
for m in cpn_obj.output("content")():
|
buff_m = ""
|
||||||
|
stream = cpn_obj.output("content")()
|
||||||
|
async def _process_stream(m):
|
||||||
|
nonlocal buff_m, _m, tts_mdl
|
||||||
if not m:
|
if not m:
|
||||||
continue
|
return
|
||||||
if m == "<think>":
|
if m == "<think>":
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
return decorate("message", {"content": "", "start_to_think": True})
|
||||||
|
|
||||||
elif m == "</think>":
|
elif m == "</think>":
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
return decorate("message", {"content": "", "end_to_think": True})
|
||||||
else:
|
|
||||||
yield decorate("message", {"content": m})
|
buff_m += m
|
||||||
_m += m
|
_m += m
|
||||||
|
|
||||||
|
if len(buff_m) > 16:
|
||||||
|
ev = decorate(
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": m,
|
||||||
|
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
buff_m = ""
|
||||||
|
return ev
|
||||||
|
|
||||||
|
return decorate("message", {"content": m})
|
||||||
|
|
||||||
|
if inspect.isasyncgen(stream):
|
||||||
|
async for m in stream:
|
||||||
|
ev= await _process_stream(m)
|
||||||
|
if ev:
|
||||||
|
yield ev
|
||||||
|
else:
|
||||||
|
for m in stream:
|
||||||
|
ev= await _process_stream(m)
|
||||||
|
if ev:
|
||||||
|
yield ev
|
||||||
|
if buff_m:
|
||||||
|
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||||
|
buff_m = ""
|
||||||
cpn_obj.set_output("content", _m)
|
cpn_obj.set_output("content", _m)
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||||
else:
|
else:
|
||||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||||
|
|
||||||
if isinstance(cpn_obj.output("attachment"), tuple):
|
message_end = {}
|
||||||
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
|
if isinstance(cpn_obj.output("attachment"), dict):
|
||||||
|
message_end["attachment"] = cpn_obj.output("attachment")
|
||||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
if cite:
|
||||||
|
message_end["reference"] = self.get_reference()
|
||||||
|
yield decorate("message_end", message_end)
|
||||||
|
|
||||||
while partials:
|
while partials:
|
||||||
_cpn_obj = self.get_component_obj(partials[0])
|
_cpn_obj = self.get_component_obj(partials[0])
|
||||||
@ -473,7 +565,7 @@ class Canvas(Graph):
|
|||||||
else:
|
else:
|
||||||
self.error = cpn_obj.error()
|
self.error = cpn_obj.error()
|
||||||
|
|
||||||
if cpn_obj.component_name.lower() != "iteration":
|
if cpn_obj.component_name.lower() not in ("iteration","loop"):
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
if self.error:
|
if self.error:
|
||||||
cpn_obj.set_output("content", None)
|
cpn_obj.set_output("content", None)
|
||||||
@ -498,14 +590,16 @@ class Canvas(Graph):
|
|||||||
for cpn_id in cpn_ids:
|
for cpn_id in cpn_ids:
|
||||||
_append_path(cpn_id)
|
_append_path(cpn_id)
|
||||||
|
|
||||||
if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
|
if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end():
|
||||||
iter = cpn_obj.get_parent()
|
iter = cpn_obj.get_parent()
|
||||||
yield _node_finished(iter)
|
yield _node_finished(iter)
|
||||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||||
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
|
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
|
||||||
_extend_path(cpn_obj.output("_next"))
|
_extend_path(cpn_obj.output("_next"))
|
||||||
elif cpn_obj.component_name.lower() == "iteration":
|
elif cpn_obj.component_name.lower() in ("iteration", "loop"):
|
||||||
_append_path(cpn_obj.get_start())
|
_append_path(cpn_obj.get_start())
|
||||||
|
elif cpn_obj.component_name.lower() == "exitloop" and cpn_obj.get_parent().component_name.lower() == "loop":
|
||||||
|
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||||
elif not cpn["downstream"] and cpn_obj.get_parent():
|
elif not cpn["downstream"] and cpn_obj.get_parent():
|
||||||
_append_path(cpn_obj.get_parent().get_start())
|
_append_path(cpn_obj.get_parent().get_start())
|
||||||
else:
|
else:
|
||||||
@ -561,6 +655,50 @@ class Canvas(Graph):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def tts(self,tts_mdl, text):
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
if not tts_mdl or not text:
|
||||||
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
bin = b""
|
||||||
|
try:
|
||||||
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
def get_history(self, window_size):
|
def get_history(self, window_size):
|
||||||
convs = []
|
convs = []
|
||||||
if window_size <= 0:
|
if window_size <= 0:
|
||||||
@ -590,21 +728,30 @@ class Canvas(Graph):
|
|||||||
def get_component_input_elements(self, cpnnm):
|
def get_component_input_elements(self, cpnnm):
|
||||||
return self.components[cpnnm]["obj"].get_input_elements()
|
return self.components[cpnnm]["obj"].get_input_elements()
|
||||||
|
|
||||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
|
||||||
from api.db.services.file_service import FileService
|
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
def image_to_base64(file):
|
def image_to_base64(file):
|
||||||
return "data:{};base64,{}".format(file["mime_type"],
|
return "data:{};base64,{}".format(file["mime_type"],
|
||||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||||
exe = ThreadPoolExecutor(max_workers=5)
|
loop = asyncio.get_running_loop()
|
||||||
threads = []
|
tasks = []
|
||||||
for file in files:
|
for file in files:
|
||||||
if file["mime_type"].find("image") >=0:
|
if file["mime_type"].find("image") >=0:
|
||||||
threads.append(exe.submit(image_to_base64, file))
|
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||||
continue
|
continue
|
||||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||||
return [th.result() for th in threads]
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper for get_files_async, used by sync component invoke paths.
|
||||||
|
"""
|
||||||
|
loop = getattr(self, "_loop", None)
|
||||||
|
if loop and loop.is_running():
|
||||||
|
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
|
||||||
|
|
||||||
|
return asyncio.run(self.get_files_async(files))
|
||||||
|
|
||||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||||
agent_ids = agent_id.split("-->")
|
agent_ids = agent_id.split("-->")
|
||||||
|
|||||||
@ -13,10 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -28,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
||||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
@ -137,8 +138,34 @@ class Agent(LLM, ToolBase):
|
|||||||
res.update(cpn.get_input_form())
|
res.update(cpn.get_input_form())
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
def _get_output_schema(self):
|
||||||
|
try:
|
||||||
|
cand = self._param.outputs.get("structured")
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(cand, dict):
|
||||||
|
if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0:
|
||||||
|
return cand
|
||||||
|
for k in ("schema", "structured"):
|
||||||
|
if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0:
|
||||||
|
return cand[k]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
|
||||||
|
fmt_msgs = [
|
||||||
|
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
|
||||||
|
{"role": "user", "content": text},
|
||||||
|
]
|
||||||
|
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
|
||||||
|
return await self._generate_async(fmt_msgs)
|
||||||
|
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|
||||||
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||||
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("Agent processing"):
|
if self.check_if_canceled("Agent processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -157,20 +184,25 @@ class Agent(LLM, ToolBase):
|
|||||||
if not self.tools:
|
if not self.tools:
|
||||||
if self.check_if_canceled("Agent processing"):
|
if self.check_if_canceled("Agent processing"):
|
||||||
return
|
return
|
||||||
return LLM._invoke(self, **kwargs)
|
return await LLM._invoke_async(self, **kwargs)
|
||||||
|
|
||||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||||
|
output_schema = self._get_output_schema()
|
||||||
|
schema_prompt = ""
|
||||||
|
if output_schema:
|
||||||
|
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||||
|
schema_prompt = structured_output_prompt(schema)
|
||||||
|
|
||||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
ex = self.exception_handler()
|
ex = self.exception_handler()
|
||||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
|
||||||
return
|
return
|
||||||
|
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
use_tools = []
|
use_tools = []
|
||||||
ans = ""
|
ans = ""
|
||||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||||
if self.check_if_canceled("Agent processing"):
|
if self.check_if_canceled("Agent processing"):
|
||||||
return
|
return
|
||||||
ans += delta_ans
|
ans += delta_ans
|
||||||
@ -183,16 +215,38 @@ class Agent(LLM, ToolBase):
|
|||||||
self.set_output("_ERROR", ans)
|
self.set_output("_ERROR", ans)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if output_schema:
|
||||||
|
error = ""
|
||||||
|
for _ in range(self._param.max_retries + 1):
|
||||||
|
try:
|
||||||
|
def clean_formated_answer(ans: str) -> str:
|
||||||
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||||
|
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||||
|
obj = json_repair.loads(clean_formated_answer(ans))
|
||||||
|
self.set_output("structured", obj)
|
||||||
|
if use_tools:
|
||||||
|
self.set_output("use_tools", use_tools)
|
||||||
|
return obj
|
||||||
|
except Exception:
|
||||||
|
error = "The answer cannot be parsed as JSON"
|
||||||
|
ans = await self._force_format_to_schema_async(ans, schema_prompt)
|
||||||
|
if ans.find("**ERROR**") >= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.set_output("_ERROR", error)
|
||||||
|
return
|
||||||
|
|
||||||
self.set_output("content", ans)
|
self.set_output("content", ans)
|
||||||
if use_tools:
|
if use_tools:
|
||||||
self.set_output("use_tools", use_tools)
|
self.set_output("use_tools", use_tools)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
answer_without_toolcall = ""
|
answer_without_toolcall = ""
|
||||||
use_tools = []
|
use_tools = []
|
||||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -210,39 +264,23 @@ class Agent(LLM, ToolBase):
|
|||||||
if use_tools:
|
if use_tools:
|
||||||
self.set_output("use_tools", use_tools)
|
self.set_output("use_tools", use_tools)
|
||||||
|
|
||||||
def _gen_citations(self, text):
|
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||||
retrievals = self._canvas.get_reference()
|
|
||||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
|
||||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
|
||||||
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
|
||||||
{"role": "user", "content": text}
|
|
||||||
]):
|
|
||||||
yield delta_ans
|
|
||||||
|
|
||||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
|
|
||||||
token_count = 0
|
token_count = 0
|
||||||
tool_metas = self.tool_meta
|
tool_metas = self.tool_meta
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
last_calling = ""
|
last_calling = ""
|
||||||
if len(hist) > 3:
|
if len(hist) > 3:
|
||||||
st = timer()
|
st = timer()
|
||||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||||
else:
|
else:
|
||||||
user_request = history[-1]["content"]
|
user_request = history[-1]["content"]
|
||||||
|
|
||||||
def use_tool(name, args):
|
async def use_tool_async(name, args):
|
||||||
nonlocal hist, use_tools, token_count,last_calling,user_request
|
nonlocal hist, use_tools, last_calling
|
||||||
logging.info(f"{last_calling=} == {name=}")
|
logging.info(f"{last_calling=} == {name=}")
|
||||||
# Summarize of function calling
|
|
||||||
#if all([
|
|
||||||
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
|
|
||||||
# last_calling,
|
|
||||||
# last_calling != name
|
|
||||||
#]):
|
|
||||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
|
||||||
last_calling = name
|
last_calling = name
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||||
use_tools.append({
|
use_tools.append({
|
||||||
"name": name,
|
"name": name,
|
||||||
"arguments": args,
|
"arguments": args,
|
||||||
@ -253,12 +291,16 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
return name, tool_response
|
return name, tool_response
|
||||||
|
|
||||||
def complete():
|
async def complete():
|
||||||
nonlocal hist
|
nonlocal hist
|
||||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||||
|
if schema_prompt:
|
||||||
|
need2cite = False
|
||||||
cited = False
|
cited = False
|
||||||
if hist[0]["role"] == "system" and need2cite:
|
if hist and hist[0]["role"] == "system":
|
||||||
if len(hist) < 7:
|
if schema_prompt:
|
||||||
|
hist[0]["content"] += "\n" + schema_prompt
|
||||||
|
if need2cite and len(hist) < 7:
|
||||||
hist[0]["content"] += citation_prompt()
|
hist[0]["content"] += citation_prompt()
|
||||||
cited = True
|
cited = True
|
||||||
yield "", token_count
|
yield "", token_count
|
||||||
@ -267,7 +309,7 @@ class Agent(LLM, ToolBase):
|
|||||||
if len(hist) > 12:
|
if len(hist) > 12:
|
||||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||||
entire_txt = ""
|
entire_txt = ""
|
||||||
for delta_ans in self._generate_streamly(_hist):
|
async for delta_ans in self._generate_streamly(_hist):
|
||||||
if not need2cite or cited:
|
if not need2cite or cited:
|
||||||
yield delta_ans, 0
|
yield delta_ans, 0
|
||||||
entire_txt += delta_ans
|
entire_txt += delta_ans
|
||||||
@ -276,7 +318,7 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
txt = ""
|
txt = ""
|
||||||
for delta_ans in self._gen_citations(entire_txt):
|
async for delta_ans in self._gen_citations_async(entire_txt):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
return
|
return
|
||||||
yield delta_ans, 0
|
yield delta_ans, 0
|
||||||
@ -291,14 +333,14 @@ class Agent(LLM, ToolBase):
|
|||||||
hist.append({"role": "user", "content": content})
|
hist.append({"role": "user", "content": content})
|
||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||||
for _ in range(self._param.max_rounds + 1):
|
for _ in range(self._param.max_rounds + 1):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
return
|
return
|
||||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||||
token_count += tk
|
token_count += tk or 0
|
||||||
hist.append({"role": "assistant", "content": response})
|
hist.append({"role": "assistant", "content": response})
|
||||||
try:
|
try:
|
||||||
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||||
@ -307,21 +349,22 @@ class Agent(LLM, ToolBase):
|
|||||||
for f in functions:
|
for f in functions:
|
||||||
if not isinstance(f, dict):
|
if not isinstance(f, dict):
|
||||||
raise TypeError(f"An object type should be returned, but `{f}`")
|
raise TypeError(f"An object type should be returned, but `{f}`")
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
|
||||||
thr = []
|
tool_tasks = []
|
||||||
for func in functions:
|
for func in functions:
|
||||||
name = func["name"]
|
name = func["name"]
|
||||||
args = func["arguments"]
|
args = func["arguments"]
|
||||||
if name == COMPLETE_TASK:
|
if name == COMPLETE_TASK:
|
||||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||||
for txt, tkcnt in complete():
|
async for txt, tkcnt in complete():
|
||||||
yield txt, tkcnt
|
yield txt, tkcnt
|
||||||
return
|
return
|
||||||
|
|
||||||
thr.append(executor.submit(use_tool, name, args))
|
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||||
st = timer()
|
st = timer()
|
||||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||||
append_user_content(hist, reflection)
|
append_user_content(hist, reflection)
|
||||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||||
|
|
||||||
@ -347,21 +390,17 @@ Respond immediately with your final comprehensive answer.
|
|||||||
return
|
return
|
||||||
append_user_content(hist, final_instruction)
|
append_user_content(hist, final_instruction)
|
||||||
|
|
||||||
for txt, tkcnt in complete():
|
async for txt, tkcnt in complete():
|
||||||
yield txt, tkcnt
|
yield txt, tkcnt
|
||||||
|
|
||||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
async def _gen_citations_async(self, text):
|
||||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
retrievals = self._canvas.get_reference()
|
||||||
mems = self._canvas.get_memory()
|
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||||
try:
|
async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
{"role": "user", "content": text}
|
||||||
mems = [mems[r] for r in rank]
|
]):
|
||||||
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
|
yield delta_ans
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
|
|
||||||
return "Error occurred."
|
|
||||||
|
|
||||||
def reset(self, only_output=False):
|
def reset(self, only_output=False):
|
||||||
"""
|
"""
|
||||||
@ -378,4 +417,3 @@ Respond immediately with your final comprehensive answer.
|
|||||||
for k in self._param.inputs.keys():
|
for k in self._param.inputs.keys():
|
||||||
self._param.inputs[k]["value"] = None
|
self._param.inputs[k]["value"] = None
|
||||||
self._param.debug_inputs = {}
|
self._param.debug_inputs = {}
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
@ -23,11 +24,9 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, List, Union
|
from typing import Any, List, Union
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import trio
|
|
||||||
from agent import settings
|
from agent import settings
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
|
|
||||||
|
|
||||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||||
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
||||||
@ -253,96 +252,65 @@ class ComponentParamBase(ABC):
|
|||||||
self._validate_param(attr, validation_json)
|
self._validate_param(attr, validation_json)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_string(param, descr):
|
def check_string(param, description):
|
||||||
if type(param).__name__ not in ["str"]:
|
if type(param).__name__ not in ["str"]:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be string type".format(param))
|
||||||
descr + " {} not supported, should be string type".format(param)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_empty(param, descr):
|
def check_empty(param, description):
|
||||||
if not param:
|
if not param:
|
||||||
raise ValueError(
|
raise ValueError(description + " does not support empty value.")
|
||||||
descr + " does not support empty value."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_positive_integer(param, descr):
|
def check_positive_integer(param, description):
|
||||||
if type(param).__name__ not in ["int", "long"] or param <= 0:
|
if type(param).__name__ not in ["int", "long"] or param <= 0:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be positive integer".format(param))
|
||||||
descr + " {} not supported, should be positive integer".format(param)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_positive_number(param, descr):
|
def check_positive_number(param, description):
|
||||||
if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
|
if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be positive numeric".format(param))
|
||||||
descr + " {} not supported, should be positive numeric".format(param)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_nonnegative_number(param, descr):
|
def check_nonnegative_number(param, description):
|
||||||
if type(param).__name__ not in ["float", "int", "long"] or param < 0:
|
if type(param).__name__ not in ["float", "int", "long"] or param < 0:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be non-negative numeric".format(param))
|
||||||
descr
|
|
||||||
+ " {} not supported, should be non-negative numeric".format(param)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_decimal_float(param, descr):
|
def check_decimal_float(param, description):
|
||||||
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
|
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be a float number in range [0, 1]".format(param))
|
||||||
descr
|
|
||||||
+ " {} not supported, should be a float number in range [0, 1]".format(
|
|
||||||
param
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_boolean(param, descr):
|
def check_boolean(param, description):
|
||||||
if type(param).__name__ != "bool":
|
if type(param).__name__ != "bool":
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be bool type".format(param))
|
||||||
descr + " {} not supported, should be bool type".format(param)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_open_unit_interval(param, descr):
|
def check_open_unit_interval(param, description):
|
||||||
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
|
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
|
||||||
raise ValueError(
|
raise ValueError(description + " should be a numeric number between 0 and 1 exclusively")
|
||||||
descr + " should be a numeric number between 0 and 1 exclusively"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_valid_value(param, descr, valid_values):
|
def check_valid_value(param, description, valid_values):
|
||||||
if param not in valid_values:
|
if param not in valid_values:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} is not supported, it should be in {}".format(param, valid_values))
|
||||||
descr
|
|
||||||
+ " {} is not supported, it should be in {}".format(param, valid_values)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_defined_type(param, descr, types):
|
def check_defined_type(param, description, types):
|
||||||
if type(param).__name__ not in types:
|
if type(param).__name__ not in types:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be one of {}".format(param, types))
|
||||||
descr + " {} not supported, should be one of {}".format(param, types)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_and_change_lower(param, valid_list, descr=""):
|
def check_and_change_lower(param, valid_list, description=""):
|
||||||
if type(param).__name__ != "str":
|
if type(param).__name__ != "str":
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list))
|
||||||
descr
|
|
||||||
+ " {} not supported, should be one of {}".format(param, valid_list)
|
|
||||||
)
|
|
||||||
|
|
||||||
lower_param = param.lower()
|
lower_param = param.lower()
|
||||||
if lower_param in valid_list:
|
if lower_param in valid_list:
|
||||||
return lower_param
|
return lower_param
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list))
|
||||||
descr
|
|
||||||
+ " {} not supported, should be one of {}".format(param, valid_list)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _greater_equal_than(value, limit):
|
def _greater_equal_than(value, limit):
|
||||||
@ -374,16 +342,16 @@ class ComponentParamBase(ABC):
|
|||||||
def _not_in(value, wrong_value_list):
|
def _not_in(value, wrong_value_list):
|
||||||
return value not in wrong_value_list
|
return value not in wrong_value_list
|
||||||
|
|
||||||
def _warn_deprecated_param(self, param_name, descr):
|
def _warn_deprecated_param(self, param_name, description):
|
||||||
if self._deprecated_params_set.get(param_name):
|
if self._deprecated_params_set.get(param_name):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"{descr} {param_name} is deprecated and ignored in this version."
|
f"{description} {param_name} is deprecated and ignored in this version."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _warn_to_deprecate_param(self, param_name, descr, new_param):
|
def _warn_to_deprecate_param(self, param_name, description, new_param):
|
||||||
if self._deprecated_params_set.get(param_name):
|
if self._deprecated_params_set.get(param_name):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"{descr} {param_name} will be deprecated in future release; "
|
f"{description} {param_name} will be deprecated in future release; "
|
||||||
f"please use {new_param} instead."
|
f"please use {new_param} instead."
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
@ -392,8 +360,8 @@ class ComponentParamBase(ABC):
|
|||||||
|
|
||||||
class ComponentBase(ABC):
|
class ComponentBase(ABC):
|
||||||
component_name: str
|
component_name: str
|
||||||
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""
|
"""
|
||||||
@ -445,6 +413,34 @@ class ComponentBase(ABC):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
|
async def invoke_async(self, **kwargs) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Async wrapper for component invocation.
|
||||||
|
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
|
||||||
|
Handles timing and error recording consistently with `invoke`.
|
||||||
|
"""
|
||||||
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
try:
|
||||||
|
if self.check_if_canceled("Component processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
fn_async = getattr(self, "_invoke_async", None)
|
||||||
|
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||||
|
await fn_async(**kwargs)
|
||||||
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
|
await self._invoke(**kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.to_thread(self._invoke, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if self.get_exception_default_value():
|
||||||
|
self.set_exception_default_value()
|
||||||
|
else:
|
||||||
|
self.set_output("_ERROR", str(e))
|
||||||
|
logging.exception(e)
|
||||||
|
self._param.debug_inputs = {}
|
||||||
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
|
return self.output()
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -504,7 +500,7 @@ class ComponentBase(ABC):
|
|||||||
res[exp] = {
|
res[exp] = {
|
||||||
"name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}") if cpn_id else exp,
|
"name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}") if cpn_id else exp,
|
||||||
"value": self._canvas.get_variable_value(exp),
|
"value": self._canvas.get_variable_value(exp),
|
||||||
"_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
|
"_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
|
||||||
"_cpn_id": cpn_id
|
"_cpn_id": cpn_id
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
@ -555,6 +551,7 @@ class ComponentBase(ABC):
|
|||||||
for n, v in kv.items():
|
for n, v in kv.items():
|
||||||
def repl(_match, val=v):
|
def repl(_match, val=v):
|
||||||
return str(val) if val is not None else ""
|
return str(val) if val is not None else ""
|
||||||
|
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r"\{%s\}" % re.escape(n),
|
r"\{%s\}" % re.escape(n),
|
||||||
repl,
|
repl,
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from agent.component.fillup import UserFillUpParam, UserFillUp
|
from agent.component.fillup import UserFillUpParam, UserFillUp
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
class BeginParam(UserFillUpParam):
|
class BeginParam(UserFillUpParam):
|
||||||
@ -27,7 +28,7 @@ class BeginParam(UserFillUpParam):
|
|||||||
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
|
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
|
||||||
|
|
||||||
def get_input_form(self) -> dict[str, dict]:
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
return getattr(self, "inputs")
|
return getattr(self, "inputs")
|
||||||
@ -48,7 +49,7 @@ class Begin(UserFillUp):
|
|||||||
if v.get("optional") and v.get("value", None) is None:
|
if v.get("optional") and v.get("value", None) is None:
|
||||||
v = None
|
v = None
|
||||||
else:
|
else:
|
||||||
v = self._canvas.get_files([v["value"]])
|
v = FileService.get_files([v["value"]])
|
||||||
else:
|
else:
|
||||||
v = v.get("value")
|
v = v.get("value")
|
||||||
self.set_output(k, v)
|
self.set_output(k, v)
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -97,7 +98,7 @@ class Categorize(LLM, ABC):
|
|||||||
component_name = "Categorize"
|
component_name = "Categorize"
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("Categorize processing"):
|
if self.check_if_canceled("Categorize processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ class Categorize(LLM, ABC):
|
|||||||
if self.check_if_canceled("Categorize processing"):
|
if self.check_if_canceled("Categorize processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
ans = await chat_mdl.async_chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||||
if ERROR_PREFIX in ans:
|
if ERROR_PREFIX in ans:
|
||||||
raise Exception(ans)
|
raise Exception(ans)
|
||||||
@ -144,5 +145,9 @@ class Categorize(LLM, ABC):
|
|||||||
self.set_output("category_name", max_category)
|
self.set_output("category_name", max_category)
|
||||||
self.set_output("_next", cpn_ids)
|
self.set_output("_next", cpn_ids)
|
||||||
|
|
||||||
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||||
|
|||||||
1570
agent/component/docs_generator.py
Normal file
1570
agent/component/docs_generator.py
Normal file
File diff suppressed because it is too large
Load Diff
401
agent/component/excel_processor.py
Normal file
401
agent/component/excel_processor.py
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
ExcelProcessor Component
|
||||||
|
|
||||||
|
A component for reading, processing, and generating Excel files in RAGFlow agents.
|
||||||
|
Supports multiple Excel file inputs, data transformation, and Excel output generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from abc import ABC
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.utils.api_utils import timeout
|
||||||
|
from common import settings
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelProcessorParam(ComponentParamBase):
|
||||||
|
"""
|
||||||
|
Define the ExcelProcessor component parameters.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Input configuration
|
||||||
|
self.input_files = [] # Variable references to uploaded files
|
||||||
|
self.operation = "read" # read, merge, transform, output
|
||||||
|
|
||||||
|
# Processing options
|
||||||
|
self.sheet_selection = "all" # all, first, or comma-separated sheet names
|
||||||
|
self.merge_strategy = "concat" # concat, join
|
||||||
|
self.join_on = "" # Column name for join operations
|
||||||
|
|
||||||
|
# Transform options (for LLM-guided transformations)
|
||||||
|
self.transform_instructions = ""
|
||||||
|
self.transform_data = "" # Variable reference to transformation data
|
||||||
|
|
||||||
|
# Output options
|
||||||
|
self.output_format = "xlsx" # xlsx, csv
|
||||||
|
self.output_filename = "output"
|
||||||
|
|
||||||
|
# Component outputs
|
||||||
|
self.outputs = {
|
||||||
|
"data": {
|
||||||
|
"type": "object",
|
||||||
|
"value": {}
|
||||||
|
},
|
||||||
|
"summary": {
|
||||||
|
"type": "str",
|
||||||
|
"value": ""
|
||||||
|
},
|
||||||
|
"markdown": {
|
||||||
|
"type": "str",
|
||||||
|
"value": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
self.check_valid_value(
|
||||||
|
self.operation,
|
||||||
|
"[ExcelProcessor] Operation",
|
||||||
|
["read", "merge", "transform", "output"]
|
||||||
|
)
|
||||||
|
self.check_valid_value(
|
||||||
|
self.output_format,
|
||||||
|
"[ExcelProcessor] Output format",
|
||||||
|
["xlsx", "csv"]
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelProcessor(ComponentBase, ABC):
|
||||||
|
"""
|
||||||
|
Excel processing component for RAGFlow agents.
|
||||||
|
|
||||||
|
Operations:
|
||||||
|
- read: Parse Excel files into structured data
|
||||||
|
- merge: Combine multiple Excel files
|
||||||
|
- transform: Apply data transformations based on instructions
|
||||||
|
- output: Generate Excel file output
|
||||||
|
"""
|
||||||
|
component_name = "ExcelProcessor"
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
"""Define input form for the component."""
|
||||||
|
res = {}
|
||||||
|
for ref in (self._param.input_files or []):
|
||||||
|
for k, o in self.get_input_elements_from_text(ref).items():
|
||||||
|
res[k] = {"name": o.get("name", ""), "type": "file"}
|
||||||
|
if self._param.transform_data:
|
||||||
|
for k, o in self.get_input_elements_from_text(self._param.transform_data).items():
|
||||||
|
res[k] = {"name": o.get("name", ""), "type": "object"}
|
||||||
|
return res
|
||||||
|
|
||||||
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
if self.check_if_canceled("ExcelProcessor processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
operation = self._param.operation.lower()
|
||||||
|
|
||||||
|
if operation == "read":
|
||||||
|
self._read_excels()
|
||||||
|
elif operation == "merge":
|
||||||
|
self._merge_excels()
|
||||||
|
elif operation == "transform":
|
||||||
|
self._transform_data()
|
||||||
|
elif operation == "output":
|
||||||
|
self._output_excel()
|
||||||
|
else:
|
||||||
|
self.set_output("summary", f"Unknown operation: {operation}")
|
||||||
|
|
||||||
|
def _get_file_content(self, file_ref: str) -> tuple[bytes, str]:
|
||||||
|
"""
|
||||||
|
Get file content from a variable reference.
|
||||||
|
Returns (content_bytes, filename).
|
||||||
|
"""
|
||||||
|
value = self._canvas.get_variable_value(file_ref)
|
||||||
|
if value is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Handle different value formats
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# File reference from Begin/UserFillUp component
|
||||||
|
file_id = value.get("id") or value.get("file_id")
|
||||||
|
created_by = value.get("created_by") or self._canvas.get_tenant_id()
|
||||||
|
filename = value.get("name") or value.get("filename", "unknown.xlsx")
|
||||||
|
if file_id:
|
||||||
|
content = FileService.get_blob(created_by, file_id)
|
||||||
|
return content, filename
|
||||||
|
elif isinstance(value, list) and len(value) > 0:
|
||||||
|
# List of file references - return first
|
||||||
|
return self._get_file_content_from_list(value[0])
|
||||||
|
elif isinstance(value, str):
|
||||||
|
# Could be base64 encoded or a path
|
||||||
|
if value.startswith("data:"):
|
||||||
|
import base64
|
||||||
|
# Extract base64 content
|
||||||
|
_, encoded = value.split(",", 1)
|
||||||
|
return base64.b64decode(encoded), "uploaded.xlsx"
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _get_file_content_from_list(self, item) -> tuple[bytes, str]:
|
||||||
|
"""Extract file content from a list item."""
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return self._get_file_content(item)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _parse_excel_to_dataframes(self, content: bytes, filename: str) -> dict[str, pd.DataFrame]:
|
||||||
|
"""Parse Excel content into a dictionary of DataFrames (one per sheet)."""
|
||||||
|
try:
|
||||||
|
excel_file = BytesIO(content)
|
||||||
|
|
||||||
|
if filename.lower().endswith(".csv"):
|
||||||
|
df = pd.read_csv(excel_file)
|
||||||
|
return {"Sheet1": df}
|
||||||
|
else:
|
||||||
|
# Read all sheets
|
||||||
|
xlsx = pd.ExcelFile(excel_file, engine='openpyxl')
|
||||||
|
sheet_selection = self._param.sheet_selection
|
||||||
|
|
||||||
|
if sheet_selection == "all":
|
||||||
|
sheets_to_read = xlsx.sheet_names
|
||||||
|
elif sheet_selection == "first":
|
||||||
|
sheets_to_read = [xlsx.sheet_names[0]] if xlsx.sheet_names else []
|
||||||
|
else:
|
||||||
|
# Comma-separated sheet names
|
||||||
|
requested = [s.strip() for s in sheet_selection.split(",")]
|
||||||
|
sheets_to_read = [s for s in requested if s in xlsx.sheet_names]
|
||||||
|
|
||||||
|
dfs = {}
|
||||||
|
for sheet in sheets_to_read:
|
||||||
|
dfs[sheet] = pd.read_excel(xlsx, sheet_name=sheet)
|
||||||
|
return dfs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error parsing Excel file {filename}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _read_excels(self):
|
||||||
|
"""Read and parse Excel files into structured data."""
|
||||||
|
all_data = {}
|
||||||
|
summaries = []
|
||||||
|
markdown_parts = []
|
||||||
|
|
||||||
|
for file_ref in (self._param.input_files or []):
|
||||||
|
if self.check_if_canceled("ExcelProcessor reading"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get variable value
|
||||||
|
value = self._canvas.get_variable_value(file_ref)
|
||||||
|
self.set_input_value(file_ref, str(value)[:200] if value else "")
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle file content
|
||||||
|
content, filename = self._get_file_content(file_ref)
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse Excel
|
||||||
|
dfs = self._parse_excel_to_dataframes(content, filename)
|
||||||
|
|
||||||
|
for sheet_name, df in dfs.items():
|
||||||
|
key = f"{filename}_{sheet_name}" if len(dfs) > 1 else filename
|
||||||
|
all_data[key] = df.to_dict(orient="records")
|
||||||
|
|
||||||
|
# Build summary
|
||||||
|
summaries.append(f"**{key}**: {len(df)} rows, {len(df.columns)} columns ({', '.join(df.columns.tolist()[:5])}{'...' if len(df.columns) > 5 else ''})")
|
||||||
|
|
||||||
|
# Build markdown table
|
||||||
|
markdown_parts.append(f"### {key}\n\n{df.head(10).to_markdown(index=False)}\n")
|
||||||
|
|
||||||
|
# Set outputs
|
||||||
|
self.set_output("data", all_data)
|
||||||
|
self.set_output("summary", "\n".join(summaries) if summaries else "No Excel files found")
|
||||||
|
self.set_output("markdown", "\n\n".join(markdown_parts) if markdown_parts else "No data")
|
||||||
|
|
||||||
|
def _merge_excels(self):
|
||||||
|
"""Merge multiple Excel files/sheets into one."""
|
||||||
|
all_dfs = []
|
||||||
|
|
||||||
|
for file_ref in (self._param.input_files or []):
|
||||||
|
if self.check_if_canceled("ExcelProcessor merging"):
|
||||||
|
return
|
||||||
|
|
||||||
|
value = self._canvas.get_variable_value(file_ref)
|
||||||
|
self.set_input_value(file_ref, str(value)[:200] if value else "")
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
content, filename = self._get_file_content(file_ref)
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dfs = self._parse_excel_to_dataframes(content, filename)
|
||||||
|
all_dfs.extend(dfs.values())
|
||||||
|
|
||||||
|
if not all_dfs:
|
||||||
|
self.set_output("data", {})
|
||||||
|
self.set_output("summary", "No data to merge")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Merge strategy
|
||||||
|
if self._param.merge_strategy == "concat":
|
||||||
|
merged_df = pd.concat(all_dfs, ignore_index=True)
|
||||||
|
elif self._param.merge_strategy == "join" and self._param.join_on:
|
||||||
|
# Join on specified column
|
||||||
|
merged_df = all_dfs[0]
|
||||||
|
for df in all_dfs[1:]:
|
||||||
|
merged_df = merged_df.merge(df, on=self._param.join_on, how="outer")
|
||||||
|
else:
|
||||||
|
merged_df = pd.concat(all_dfs, ignore_index=True)
|
||||||
|
|
||||||
|
self.set_output("data", {"merged": merged_df.to_dict(orient="records")})
|
||||||
|
self.set_output("summary", f"Merged {len(all_dfs)} sources into {len(merged_df)} rows, {len(merged_df.columns)} columns")
|
||||||
|
self.set_output("markdown", merged_df.head(20).to_markdown(index=False))
|
||||||
|
|
||||||
|
def _transform_data(self):
|
||||||
|
"""Apply transformations to data based on instructions or input data."""
|
||||||
|
# Get the data to transform
|
||||||
|
transform_ref = self._param.transform_data
|
||||||
|
if not transform_ref:
|
||||||
|
self.set_output("summary", "No transform data reference provided")
|
||||||
|
return
|
||||||
|
|
||||||
|
data = self._canvas.get_variable_value(transform_ref)
|
||||||
|
self.set_input_value(transform_ref, str(data)[:300] if data else "")
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
self.set_output("summary", "Transform data is empty")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# Could be {"sheet": [rows]} format
|
||||||
|
if all(isinstance(v, list) for v in data.values()):
|
||||||
|
# Multiple sheets
|
||||||
|
all_markdown = []
|
||||||
|
for sheet_name, rows in data.items():
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
all_markdown.append(f"### {sheet_name}\n\n{df.to_markdown(index=False)}")
|
||||||
|
self.set_output("data", data)
|
||||||
|
self.set_output("markdown", "\n\n".join(all_markdown))
|
||||||
|
else:
|
||||||
|
df = pd.DataFrame([data])
|
||||||
|
self.set_output("data", df.to_dict(orient="records"))
|
||||||
|
self.set_output("markdown", df.to_markdown(index=False))
|
||||||
|
elif isinstance(data, list):
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
self.set_output("data", df.to_dict(orient="records"))
|
||||||
|
self.set_output("markdown", df.to_markdown(index=False))
|
||||||
|
else:
|
||||||
|
self.set_output("data", {"raw": str(data)})
|
||||||
|
self.set_output("markdown", str(data))
|
||||||
|
|
||||||
|
self.set_output("summary", "Transformed data ready for processing")
|
||||||
|
|
||||||
|
def _output_excel(self):
|
||||||
|
"""Generate Excel file output from data."""
|
||||||
|
# Get data from transform_data reference
|
||||||
|
transform_ref = self._param.transform_data
|
||||||
|
if not transform_ref:
|
||||||
|
self.set_output("summary", "No data reference for output")
|
||||||
|
return
|
||||||
|
|
||||||
|
data = self._canvas.get_variable_value(transform_ref)
|
||||||
|
self.set_input_value(transform_ref, str(data)[:300] if data else "")
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
self.set_output("summary", "No data to output")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare DataFrames
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if all(isinstance(v, list) for v in data.values()):
|
||||||
|
# Multi-sheet format
|
||||||
|
dfs = {k: pd.DataFrame(v) for k, v in data.items()}
|
||||||
|
else:
|
||||||
|
dfs = {"Sheet1": pd.DataFrame([data])}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
dfs = {"Sheet1": pd.DataFrame(data)}
|
||||||
|
else:
|
||||||
|
self.set_output("summary", "Invalid data format for Excel output")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate output
|
||||||
|
doc_id = get_uuid()
|
||||||
|
|
||||||
|
if self._param.output_format == "csv":
|
||||||
|
# For CSV, only output first sheet
|
||||||
|
first_df = list(dfs.values())[0]
|
||||||
|
binary_content = first_df.to_csv(index=False).encode("utf-8")
|
||||||
|
filename = f"{self._param.output_filename}.csv"
|
||||||
|
else:
|
||||||
|
# Excel output
|
||||||
|
excel_io = BytesIO()
|
||||||
|
with pd.ExcelWriter(excel_io, engine='openpyxl') as writer:
|
||||||
|
for sheet_name, df in dfs.items():
|
||||||
|
# Sanitize sheet name (max 31 chars, no special chars)
|
||||||
|
safe_name = sheet_name[:31].replace("/", "_").replace("\\", "_")
|
||||||
|
df.to_excel(writer, sheet_name=safe_name, index=False)
|
||||||
|
excel_io.seek(0)
|
||||||
|
binary_content = excel_io.read()
|
||||||
|
filename = f"{self._param.output_filename}.xlsx"
|
||||||
|
|
||||||
|
# Store file
|
||||||
|
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||||
|
|
||||||
|
# Set attachment output
|
||||||
|
self.set_output("attachment", {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"format": self._param.output_format,
|
||||||
|
"file_name": filename
|
||||||
|
})
|
||||||
|
|
||||||
|
total_rows = sum(len(df) for df in dfs.values())
|
||||||
|
self.set_output("summary", f"Generated {filename} with {len(dfs)} sheet(s), {total_rows} total rows")
|
||||||
|
self.set_output("data", {k: v.to_dict(orient="records") for k, v in dfs.items()})
|
||||||
|
|
||||||
|
logging.info(f"ExcelProcessor: Generated {filename} as {doc_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"ExcelProcessor output error: {e}")
|
||||||
|
self.set_output("summary", f"Error generating output: {str(e)}")
|
||||||
|
|
||||||
|
def thoughts(self) -> str:
|
||||||
|
"""Return component thoughts for UI display."""
|
||||||
|
op = self._param.operation
|
||||||
|
if op == "read":
|
||||||
|
return "Reading Excel files..."
|
||||||
|
elif op == "merge":
|
||||||
|
return "Merging Excel data..."
|
||||||
|
elif op == "transform":
|
||||||
|
return "Transforming data..."
|
||||||
|
elif op == "output":
|
||||||
|
return "Generating Excel output..."
|
||||||
|
return "Processing Excel..."
|
||||||
@ -13,23 +13,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from abc import ABC
|
||||||
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
|
||||||
|
|
||||||
class WebhookParam(ComponentParamBase):
|
class ExitLoopParam(ComponentParamBase, ABC):
|
||||||
|
def check(self):
|
||||||
"""
|
return True
|
||||||
Define the Begin component parameters.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_input_form(self) -> dict[str, dict]:
|
|
||||||
return getattr(self, "inputs")
|
|
||||||
|
|
||||||
|
|
||||||
class Webhook(ComponentBase):
|
class ExitLoop(ComponentBase, ABC):
|
||||||
component_name = "Webhook"
|
component_name = "ExitLoop"
|
||||||
|
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
@ -18,6 +18,7 @@ import re
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
class UserFillUpParam(ComponentParamBase):
|
class UserFillUpParam(ComponentParamBase):
|
||||||
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
|
|||||||
for k, v in kwargs.get("inputs", {}).items():
|
for k, v in kwargs.get("inputs", {}).items():
|
||||||
if self.check_if_canceled("UserFillUp processing"):
|
if self.check_if_canceled("UserFillUp processing"):
|
||||||
return
|
return
|
||||||
|
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
|
||||||
|
if v.get("optional") and v.get("value", None) is None:
|
||||||
|
v = None
|
||||||
|
else:
|
||||||
|
v = FileService.get_files([v["value"]])
|
||||||
|
else:
|
||||||
|
v = v.get("value")
|
||||||
self.set_output(k, v)
|
self.set_output(k, v)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
|
|||||||
@ -13,12 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Generator
|
from typing import Any, AsyncGenerator
|
||||||
import json_repair
|
import json_repair
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
@ -166,15 +167,17 @@ class LLM(ComponentBase):
|
|||||||
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
||||||
return pts, sys_prompt
|
return pts, sys_prompt
|
||||||
|
|
||||||
def _generate(self, msg:list[dict], **kwargs) -> str:
|
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||||
if not self.imgs:
|
if not self.imgs:
|
||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||||
|
|
||||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||||
|
async def delta_wrapper(txt_iter):
|
||||||
ans = ""
|
ans = ""
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
endswith_think = False
|
endswith_think = False
|
||||||
|
|
||||||
def delta(txt):
|
def delta(txt):
|
||||||
nonlocal ans, last_idx, endswith_think
|
nonlocal ans, last_idx, endswith_think
|
||||||
delta_ans = txt[last_idx:]
|
delta_ans = txt[last_idx:]
|
||||||
@ -198,15 +201,68 @@ class LLM(ComponentBase):
|
|||||||
last_idx -= len("</think>")
|
last_idx -= len("</think>")
|
||||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||||
|
|
||||||
|
async for t in txt_iter:
|
||||||
|
yield delta(t)
|
||||||
|
|
||||||
if not self.imgs:
|
if not self.imgs:
|
||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs):
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||||
yield delta(txt)
|
yield t
|
||||||
|
return
|
||||||
|
|
||||||
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||||
|
yield t
|
||||||
|
|
||||||
|
async def _stream_output_async(self, prompt, msg):
|
||||||
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
|
answer = ""
|
||||||
|
last_idx = 0
|
||||||
|
endswith_think = False
|
||||||
|
|
||||||
|
def delta(txt):
|
||||||
|
nonlocal answer, last_idx, endswith_think
|
||||||
|
delta_ans = txt[last_idx:]
|
||||||
|
answer = txt
|
||||||
|
|
||||||
|
if delta_ans.find("<think>") == 0:
|
||||||
|
last_idx += len("<think>")
|
||||||
|
return "<think>"
|
||||||
|
elif delta_ans.find("<think>") > 0:
|
||||||
|
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||||
|
last_idx += delta_ans.find("<think>")
|
||||||
|
return delta_ans
|
||||||
|
elif delta_ans.endswith("</think>"):
|
||||||
|
endswith_think = True
|
||||||
|
elif endswith_think:
|
||||||
|
endswith_think = False
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
last_idx = len(answer)
|
||||||
|
if answer.endswith("</think>"):
|
||||||
|
last_idx -= len("</think>")
|
||||||
|
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||||
|
|
||||||
|
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||||
|
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
|
||||||
|
if self.check_if_canceled("LLM streaming"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(ans, int):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ans.find("**ERROR**") >= 0:
|
||||||
|
if self.get_exception_default_value():
|
||||||
|
self.set_output("content", self.get_exception_default_value())
|
||||||
|
yield self.get_exception_default_value()
|
||||||
else:
|
else:
|
||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
self.set_output("_ERROR", ans)
|
||||||
yield delta(txt)
|
return
|
||||||
|
|
||||||
|
yield delta(ans)
|
||||||
|
|
||||||
|
self.set_output("content", answer)
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("LLM processing"):
|
if self.check_if_canceled("LLM processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -219,20 +275,23 @@ class LLM(ComponentBase):
|
|||||||
error: str = ""
|
error: str = ""
|
||||||
output_structure = None
|
output_structure = None
|
||||||
try:
|
try:
|
||||||
output_structure = self._param.outputs['structured']
|
output_structure = self._param.outputs["structured"]
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
|
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
|
||||||
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
|
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||||
prompt += structured_output_prompt(schema)
|
prompt_with_schema = prompt + structured_output_prompt(schema)
|
||||||
for _ in range(self._param.max_retries + 1):
|
for _ in range(self._param.max_retries + 1):
|
||||||
if self.check_if_canceled("LLM processing"):
|
if self.check_if_canceled("LLM processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg_fit = message_fit_in(
|
||||||
|
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
|
||||||
|
int(self.chat_mdl.max_length * 0.97),
|
||||||
|
)
|
||||||
error = ""
|
error = ""
|
||||||
ans = self._generate(msg)
|
ans = await self._generate_async(msg_fit)
|
||||||
msg.pop(0)
|
msg_fit.pop(0)
|
||||||
if ans.find("**ERROR**") >= 0:
|
if ans.find("**ERROR**") >= 0:
|
||||||
logging.error(f"LLM response error: {ans}")
|
logging.error(f"LLM response error: {ans}")
|
||||||
error = ans
|
error = ans
|
||||||
@ -241,7 +300,7 @@ class LLM(ComponentBase):
|
|||||||
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||||
error = "The answer can't not be parsed as JSON"
|
error = "The answer can't not be parsed as JSON"
|
||||||
if error:
|
if error:
|
||||||
self.set_output("_ERROR", error)
|
self.set_output("_ERROR", error)
|
||||||
@ -249,18 +308,23 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
ex = self.exception_handler()
|
ex = self.exception_handler()
|
||||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
|
||||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
ex and ex["goto"]
|
||||||
|
):
|
||||||
|
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
error = ""
|
||||||
for _ in range(self._param.max_retries + 1):
|
for _ in range(self._param.max_retries + 1):
|
||||||
if self.check_if_canceled("LLM processing"):
|
if self.check_if_canceled("LLM processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg_fit = message_fit_in(
|
||||||
|
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
|
||||||
|
)
|
||||||
error = ""
|
error = ""
|
||||||
ans = self._generate(msg)
|
ans = await self._generate_async(msg_fit)
|
||||||
msg.pop(0)
|
msg_fit.pop(0)
|
||||||
if ans.find("**ERROR**") >= 0:
|
if ans.find("**ERROR**") >= 0:
|
||||||
logging.error(f"LLM response error: {ans}")
|
logging.error(f"LLM response error: {ans}")
|
||||||
error = ans
|
error = ans
|
||||||
@ -274,26 +338,12 @@ class LLM(ComponentBase):
|
|||||||
else:
|
else:
|
||||||
self.set_output("_ERROR", error)
|
self.set_output("_ERROR", error)
|
||||||
|
|
||||||
def _stream_output(self, prompt, msg):
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
def _invoke(self, **kwargs):
|
||||||
answer = ""
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
for ans in self._generate_streamly(msg):
|
|
||||||
if self.check_if_canceled("LLM streaming"):
|
|
||||||
return
|
|
||||||
|
|
||||||
if ans.find("**ERROR**") >= 0:
|
async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||||
if self.get_exception_default_value():
|
summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||||
self.set_output("content", self.get_exception_default_value())
|
|
||||||
yield self.get_exception_default_value()
|
|
||||||
else:
|
|
||||||
self.set_output("_ERROR", ans)
|
|
||||||
return
|
|
||||||
yield ans
|
|
||||||
answer += ans
|
|
||||||
self.set_output("content", answer)
|
|
||||||
|
|
||||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
|
||||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
|
||||||
logging.info(f"[MEMORY]: {summ}")
|
logging.info(f"[MEMORY]: {summ}")
|
||||||
self._canvas.add_memory(user, assist, summ)
|
self._canvas.add_memory(user, assist, summ)
|
||||||
|
|
||||||
|
|||||||
80
agent/component/loop.py
Normal file
80
agent/component/loop.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from abc import ABC
|
||||||
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
|
||||||
|
|
||||||
|
class LoopParam(ComponentParamBase):
|
||||||
|
"""
|
||||||
|
Define the Loop component parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loop_variables = []
|
||||||
|
self.loop_termination_condition=[]
|
||||||
|
self.maximum_loop_count = 0
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {
|
||||||
|
"items": {
|
||||||
|
"type": "json",
|
||||||
|
"name": "Items"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class Loop(ComponentBase, ABC):
|
||||||
|
component_name = "Loop"
|
||||||
|
|
||||||
|
def get_start(self):
|
||||||
|
for cid in self._canvas.components.keys():
|
||||||
|
if self._canvas.get_component(cid)["obj"].component_name.lower() != "loopitem":
|
||||||
|
continue
|
||||||
|
if self._canvas.get_component(cid)["parent_id"] == self._id:
|
||||||
|
return cid
|
||||||
|
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
if self.check_if_canceled("Loop processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
for item in self._param.loop_variables:
|
||||||
|
if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]):
|
||||||
|
assert "Loop Variable is not complete."
|
||||||
|
if item["input_mode"]=="variable":
|
||||||
|
self.set_output(item["variable"],self._canvas.get_variable_value(item["value"]))
|
||||||
|
elif item["input_mode"]=="constant":
|
||||||
|
self.set_output(item["variable"],item["value"])
|
||||||
|
else:
|
||||||
|
if item["type"] == "number":
|
||||||
|
self.set_output(item["variable"], 0)
|
||||||
|
elif item["type"] == "string":
|
||||||
|
self.set_output(item["variable"], "")
|
||||||
|
elif item["type"] == "boolean":
|
||||||
|
self.set_output(item["variable"], False)
|
||||||
|
elif item["type"].startswith("object"):
|
||||||
|
self.set_output(item["variable"], {})
|
||||||
|
elif item["type"].startswith("array"):
|
||||||
|
self.set_output(item["variable"], [])
|
||||||
|
else:
|
||||||
|
self.set_output(item["variable"], "")
|
||||||
|
|
||||||
|
|
||||||
|
def thoughts(self) -> str:
|
||||||
|
return "Loop from canvas."
|
||||||
163
agent/component/loopitem.py
Normal file
163
agent/component/loopitem.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from abc import ABC
|
||||||
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
|
||||||
|
|
||||||
|
class LoopItemParam(ComponentParamBase):
|
||||||
|
"""
|
||||||
|
Define the LoopItem component parameters.
|
||||||
|
"""
|
||||||
|
def check(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
class LoopItem(ComponentBase, ABC):
|
||||||
|
component_name = "LoopItem"
|
||||||
|
|
||||||
|
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||||
|
super().__init__(canvas, id, param)
|
||||||
|
self._idx = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
if self.check_if_canceled("LoopItem processing"):
|
||||||
|
return
|
||||||
|
parent = self.get_parent()
|
||||||
|
maximum_loop_count = parent._param.maximum_loop_count
|
||||||
|
if self._idx >= maximum_loop_count:
|
||||||
|
self._idx = -1
|
||||||
|
return
|
||||||
|
if self._idx > 0:
|
||||||
|
if self.check_if_canceled("LoopItem processing"):
|
||||||
|
return
|
||||||
|
self._idx += 1
|
||||||
|
|
||||||
|
def evaluate_condition(self,var, operator, value):
|
||||||
|
if isinstance(var, str):
|
||||||
|
if operator == "contains":
|
||||||
|
return value in var
|
||||||
|
elif operator == "not contains":
|
||||||
|
return value not in var
|
||||||
|
elif operator == "start with":
|
||||||
|
return var.startswith(value)
|
||||||
|
elif operator == "end with":
|
||||||
|
return var.endswith(value)
|
||||||
|
elif operator == "is":
|
||||||
|
return var == value
|
||||||
|
elif operator == "is not":
|
||||||
|
return var != value
|
||||||
|
elif operator == "empty":
|
||||||
|
return var == ""
|
||||||
|
elif operator == "not empty":
|
||||||
|
return var != ""
|
||||||
|
|
||||||
|
elif isinstance(var, (int, float)):
|
||||||
|
if operator == "=":
|
||||||
|
return var == value
|
||||||
|
elif operator == "≠":
|
||||||
|
return var != value
|
||||||
|
elif operator == ">":
|
||||||
|
return var > value
|
||||||
|
elif operator == "<":
|
||||||
|
return var < value
|
||||||
|
elif operator == "≥":
|
||||||
|
return var >= value
|
||||||
|
elif operator == "≤":
|
||||||
|
return var <= value
|
||||||
|
elif operator == "empty":
|
||||||
|
return var is None
|
||||||
|
elif operator == "not empty":
|
||||||
|
return var is not None
|
||||||
|
|
||||||
|
elif isinstance(var, bool):
|
||||||
|
if operator == "is":
|
||||||
|
return var is value
|
||||||
|
elif operator == "is not":
|
||||||
|
return var is not value
|
||||||
|
elif operator == "empty":
|
||||||
|
return var is None
|
||||||
|
elif operator == "not empty":
|
||||||
|
return var is not None
|
||||||
|
|
||||||
|
elif isinstance(var, dict):
|
||||||
|
if operator == "empty":
|
||||||
|
return len(var) == 0
|
||||||
|
elif operator == "not empty":
|
||||||
|
return len(var) > 0
|
||||||
|
|
||||||
|
elif isinstance(var, list):
|
||||||
|
if operator == "contains":
|
||||||
|
return value in var
|
||||||
|
elif operator == "not contains":
|
||||||
|
return value not in var
|
||||||
|
|
||||||
|
elif operator == "is":
|
||||||
|
return var == value
|
||||||
|
elif operator == "is not":
|
||||||
|
return var != value
|
||||||
|
|
||||||
|
elif operator == "empty":
|
||||||
|
return len(var) == 0
|
||||||
|
elif operator == "not empty":
|
||||||
|
return len(var) > 0
|
||||||
|
|
||||||
|
raise Exception(f"Invalid operator: {operator}")
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
if self._idx == -1:
|
||||||
|
return True
|
||||||
|
parent = self.get_parent()
|
||||||
|
logical_operator = parent._param.logical_operator if hasattr(parent._param, "logical_operator") else "and"
|
||||||
|
conditions = []
|
||||||
|
for item in parent._param.loop_termination_condition:
|
||||||
|
if not item.get("variable") or not item.get("operator"):
|
||||||
|
raise ValueError("Loop condition is incomplete.")
|
||||||
|
var = self._canvas.get_variable_value(item["variable"])
|
||||||
|
operator = item["operator"]
|
||||||
|
input_mode = item.get("input_mode", "constant")
|
||||||
|
|
||||||
|
if input_mode == "variable":
|
||||||
|
value = self._canvas.get_variable_value(item.get("value", ""))
|
||||||
|
elif input_mode == "constant":
|
||||||
|
value = item.get("value", "")
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input mode.")
|
||||||
|
conditions.append(self.evaluate_condition(var, operator, value))
|
||||||
|
should_end = (
|
||||||
|
all(conditions) if logical_operator == "and"
|
||||||
|
else any(conditions) if logical_operator == "or"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if should_end is None:
|
||||||
|
raise ValueError("Invalid logical operator,should be 'and' or 'or'.")
|
||||||
|
|
||||||
|
if should_end:
|
||||||
|
self._idx = -1
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if self._idx == -1:
|
||||||
|
self._idx = 0
|
||||||
|
else:
|
||||||
|
self._idx += 1
|
||||||
|
if self._idx >= len(self._items):
|
||||||
|
self._idx = -1
|
||||||
|
return False
|
||||||
|
|
||||||
|
def thoughts(self) -> str:
|
||||||
|
return "Next turn..."
|
||||||
@ -13,6 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -39,6 +43,7 @@ class MessageParam(ComponentParamBase):
|
|||||||
self.content = []
|
self.content = []
|
||||||
self.stream = True
|
self.stream = True
|
||||||
self.output_format = None # default output format
|
self.output_format = None # default output format
|
||||||
|
self.auto_play = False
|
||||||
self.outputs = {
|
self.outputs = {
|
||||||
"content": {
|
"content": {
|
||||||
"type": "str"
|
"type": "str"
|
||||||
@ -66,7 +71,11 @@ class Message(ComponentBase):
|
|||||||
v = ""
|
v = ""
|
||||||
ans = ""
|
ans = ""
|
||||||
if isinstance(v, partial):
|
if isinstance(v, partial):
|
||||||
for t in v():
|
iter_obj = v()
|
||||||
|
if inspect.isasyncgen(iter_obj):
|
||||||
|
ans = asyncio.run(self._consume_async_gen(iter_obj))
|
||||||
|
else:
|
||||||
|
for t in iter_obj:
|
||||||
ans += t
|
ans += t
|
||||||
elif isinstance(v, list) and delimiter:
|
elif isinstance(v, list) and delimiter:
|
||||||
ans = delimiter.join([str(vv) for vv in v])
|
ans = delimiter.join([str(vv) for vv in v])
|
||||||
@ -89,7 +98,13 @@ class Message(ComponentBase):
|
|||||||
_kwargs[_n] = v
|
_kwargs[_n] = v
|
||||||
return script, _kwargs
|
return script, _kwargs
|
||||||
|
|
||||||
def _stream(self, rand_cnt:str):
|
async def _consume_async_gen(self, agen):
|
||||||
|
buf = ""
|
||||||
|
async for t in agen:
|
||||||
|
buf += t
|
||||||
|
return buf
|
||||||
|
|
||||||
|
async def _stream(self, rand_cnt:str):
|
||||||
s = 0
|
s = 0
|
||||||
all_content = ""
|
all_content = ""
|
||||||
cache = {}
|
cache = {}
|
||||||
@ -111,7 +126,17 @@ class Message(ComponentBase):
|
|||||||
v = ""
|
v = ""
|
||||||
if isinstance(v, partial):
|
if isinstance(v, partial):
|
||||||
cnt = ""
|
cnt = ""
|
||||||
for t in v():
|
iter_obj = v()
|
||||||
|
if inspect.isasyncgen(iter_obj):
|
||||||
|
async for t in iter_obj:
|
||||||
|
if self.check_if_canceled("Message streaming"):
|
||||||
|
return
|
||||||
|
|
||||||
|
all_content += t
|
||||||
|
cnt += t
|
||||||
|
yield t
|
||||||
|
else:
|
||||||
|
for t in iter_obj:
|
||||||
if self.check_if_canceled("Message streaming"):
|
if self.check_if_canceled("Message streaming"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -120,6 +145,8 @@ class Message(ComponentBase):
|
|||||||
yield t
|
yield t
|
||||||
self.set_input_value(exp, cnt)
|
self.set_input_value(exp, cnt)
|
||||||
continue
|
continue
|
||||||
|
elif inspect.isawaitable(v):
|
||||||
|
v = await v
|
||||||
elif not isinstance(v, str):
|
elif not isinstance(v, str):
|
||||||
try:
|
try:
|
||||||
v = json.dumps(v, ensure_ascii=False)
|
v = json.dumps(v, ensure_ascii=False)
|
||||||
@ -175,6 +202,48 @@ class Message(ComponentBase):
|
|||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def _parse_markdown_table_lines(self, table_lines: list):
|
||||||
|
"""
|
||||||
|
Parse a list of Markdown table lines into a pandas DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_lines: List of strings, each representing a row in the Markdown table
|
||||||
|
(excluding separator lines like |---|---|)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas DataFrame with the table data, or None if parsing fails
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
if not table_lines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
headers = None
|
||||||
|
|
||||||
|
for line in table_lines:
|
||||||
|
# Split by | and clean up
|
||||||
|
cells = [cell.strip() for cell in line.split('|')]
|
||||||
|
# Remove empty first and last elements from split (caused by leading/trailing |)
|
||||||
|
cells = [c for c in cells if c]
|
||||||
|
|
||||||
|
if headers is None:
|
||||||
|
headers = cells
|
||||||
|
else:
|
||||||
|
rows.append(cells)
|
||||||
|
|
||||||
|
if headers and rows:
|
||||||
|
# Ensure all rows have same number of columns as headers
|
||||||
|
normalized_rows = []
|
||||||
|
for row in rows:
|
||||||
|
while len(row) < len(headers):
|
||||||
|
row.append('')
|
||||||
|
normalized_rows.append(row[:len(headers)])
|
||||||
|
|
||||||
|
return pd.DataFrame(normalized_rows, columns=headers)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _convert_content(self, content):
|
def _convert_content(self, content):
|
||||||
if not self._param.output_format:
|
if not self._param.output_format:
|
||||||
return
|
return
|
||||||
@ -182,7 +251,7 @@ class Message(ComponentBase):
|
|||||||
import pypandoc
|
import pypandoc
|
||||||
doc_id = get_uuid()
|
doc_id = get_uuid()
|
||||||
|
|
||||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx", "xlsx"}:
|
||||||
self._param.output_format = "markdown"
|
self._param.output_format = "markdown"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -202,6 +271,119 @@ class Message(ComponentBase):
|
|||||||
|
|
||||||
binary_content = converted.encode("utf-8")
|
binary_content = converted.encode("utf-8")
|
||||||
|
|
||||||
|
elif self._param.output_format == "xlsx":
|
||||||
|
import pandas as pd
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
# Debug: log the content being parsed
|
||||||
|
logging.info(f"XLSX Parser: Content length={len(content) if content else 0}, first 500 chars: {content[:500] if content else 'None'}")
|
||||||
|
|
||||||
|
# Try to parse ALL Markdown tables from the content
|
||||||
|
# Each table will be written to a separate sheet
|
||||||
|
tables = [] # List of (sheet_name, dataframe)
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
lines = content.strip().split('\n')
|
||||||
|
logging.info(f"XLSX Parser: Total lines={len(lines)}, lines starting with '|': {sum(1 for line in lines if line.strip().startswith('|'))}")
|
||||||
|
current_table_lines = []
|
||||||
|
current_table_title = None
|
||||||
|
pending_title = None
|
||||||
|
in_table = False
|
||||||
|
table_count = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
stripped = line.strip()
|
||||||
|
|
||||||
|
# Check for potential table title (lines before a table)
|
||||||
|
# Look for patterns like "Table 1:", "## Table", or markdown headers
|
||||||
|
if not in_table and stripped and not stripped.startswith('|'):
|
||||||
|
# Check if this could be a table title
|
||||||
|
lower_stripped = stripped.lower()
|
||||||
|
if (lower_stripped.startswith('table') or
|
||||||
|
stripped.startswith('#') or
|
||||||
|
':' in stripped):
|
||||||
|
pending_title = stripped.lstrip('#').strip()
|
||||||
|
|
||||||
|
if stripped.startswith('|') and '|' in stripped[1:]:
|
||||||
|
# Check if this is a separator line (|---|---|)
|
||||||
|
cleaned = stripped.replace(' ', '').replace('|', '').replace('-', '').replace(':', '')
|
||||||
|
if cleaned == '':
|
||||||
|
continue # Skip separator line
|
||||||
|
|
||||||
|
if not in_table:
|
||||||
|
# Starting a new table
|
||||||
|
in_table = True
|
||||||
|
current_table_lines = []
|
||||||
|
current_table_title = pending_title
|
||||||
|
pending_title = None
|
||||||
|
|
||||||
|
current_table_lines.append(stripped)
|
||||||
|
|
||||||
|
elif in_table and not stripped.startswith('|'):
|
||||||
|
# End of current table - save it
|
||||||
|
if current_table_lines:
|
||||||
|
df = self._parse_markdown_table_lines(current_table_lines)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
table_count += 1
|
||||||
|
# Generate sheet name
|
||||||
|
if current_table_title:
|
||||||
|
# Clean and truncate title for sheet name
|
||||||
|
sheet_name = current_table_title[:31]
|
||||||
|
sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '')
|
||||||
|
else:
|
||||||
|
sheet_name = f"Table_{table_count}"
|
||||||
|
tables.append((sheet_name, df))
|
||||||
|
|
||||||
|
# Reset for next table
|
||||||
|
in_table = False
|
||||||
|
current_table_lines = []
|
||||||
|
current_table_title = None
|
||||||
|
|
||||||
|
# Check if this line could be a title for the next table
|
||||||
|
if stripped:
|
||||||
|
lower_stripped = stripped.lower()
|
||||||
|
if (lower_stripped.startswith('table') or
|
||||||
|
stripped.startswith('#') or
|
||||||
|
':' in stripped):
|
||||||
|
pending_title = stripped.lstrip('#').strip()
|
||||||
|
|
||||||
|
# Don't forget the last table if content ends with a table
|
||||||
|
if in_table and current_table_lines:
|
||||||
|
df = self._parse_markdown_table_lines(current_table_lines)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
table_count += 1
|
||||||
|
if current_table_title:
|
||||||
|
sheet_name = current_table_title[:31]
|
||||||
|
sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '')
|
||||||
|
else:
|
||||||
|
sheet_name = f"Table_{table_count}"
|
||||||
|
tables.append((sheet_name, df))
|
||||||
|
|
||||||
|
# Fallback: if no tables found, create single sheet with content
|
||||||
|
if not tables:
|
||||||
|
df = pd.DataFrame({"Content": [content if content else ""]})
|
||||||
|
tables = [("Data", df)]
|
||||||
|
|
||||||
|
# Write all tables to Excel, each in a separate sheet
|
||||||
|
excel_io = BytesIO()
|
||||||
|
with pd.ExcelWriter(excel_io, engine='openpyxl') as writer:
|
||||||
|
used_names = set()
|
||||||
|
for sheet_name, df in tables:
|
||||||
|
# Ensure unique sheet names
|
||||||
|
original_name = sheet_name
|
||||||
|
counter = 1
|
||||||
|
while sheet_name in used_names:
|
||||||
|
suffix = f"_{counter}"
|
||||||
|
sheet_name = original_name[:31-len(suffix)] + suffix
|
||||||
|
counter += 1
|
||||||
|
used_names.add(sheet_name)
|
||||||
|
df.to_excel(writer, sheet_name=sheet_name, index=False)
|
||||||
|
|
||||||
|
excel_io.seek(0)
|
||||||
|
binary_content = excel_io.read()
|
||||||
|
|
||||||
|
logging.info(f"Generated Excel with {len(tables)} sheet(s): {[t[0] for t in tables]}")
|
||||||
|
|
||||||
else: # pdf, docx
|
else: # pdf, docx
|
||||||
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
|
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
|
||||||
tmp_name = tmp.name
|
tmp_name = tmp.name
|
||||||
|
|||||||
@ -193,7 +193,7 @@
|
|||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0.4,
|
||||||
"prompts": [
|
"prompts": [
|
||||||
{
|
{
|
||||||
"content": "Text Content:\n{Splitter:KindDingosJam@chunks}\n",
|
"content": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
|
||||||
"role": "user"
|
"role": "user"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -226,7 +226,7 @@
|
|||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0.4,
|
||||||
"prompts": [
|
"prompts": [
|
||||||
{
|
{
|
||||||
"content": "Text Content:\n\n{Splitter:KindDingosJam@chunks}\n",
|
"content": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
|
||||||
"role": "user"
|
"role": "user"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -259,7 +259,7 @@
|
|||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0.4,
|
||||||
"prompts": [
|
"prompts": [
|
||||||
{
|
{
|
||||||
"content": "Content: \n\n{Splitter:KindDingosJam@chunks}",
|
"content": "Content: \n\n{Splitter:CuteBusesBet@chunks}",
|
||||||
"role": "user"
|
"role": "user"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -485,7 +485,7 @@
|
|||||||
"outputs": {},
|
"outputs": {},
|
||||||
"presencePenaltyEnabled": false,
|
"presencePenaltyEnabled": false,
|
||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0.4,
|
||||||
"prompts": "Text Content:\n{Splitter:KindDingosJam@chunks}\n",
|
"prompts": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
|
||||||
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nExtract the most important keywords/phrases of a given piece of text content.\n\nRequirements\n- Summarize the text content, and give the top 5 important keywords/phrases.\n- The keywords MUST be in the same language as the given piece of text content.\n- The keywords are delimited by ENGLISH COMMA.\n- Output keywords ONLY.",
|
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nExtract the most important keywords/phrases of a given piece of text content.\n\nRequirements\n- Summarize the text content, and give the top 5 important keywords/phrases.\n- The keywords MUST be in the same language as the given piece of text content.\n- The keywords are delimited by ENGLISH COMMA.\n- Output keywords ONLY.",
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"temperatureEnabled": false,
|
"temperatureEnabled": false,
|
||||||
@ -522,7 +522,7 @@
|
|||||||
"outputs": {},
|
"outputs": {},
|
||||||
"presencePenaltyEnabled": false,
|
"presencePenaltyEnabled": false,
|
||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0.4,
|
||||||
"prompts": "Text Content:\n\n{Splitter:KindDingosJam@chunks}\n",
|
"prompts": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
|
||||||
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nPropose 3 questions about a given piece of text content.\n\nRequirements\n- Understand and summarize the text content, and propose the top 3 important questions.\n- The questions SHOULD NOT have overlapping meanings.\n- The questions SHOULD cover the main content of the text as much as possible.\n- The questions MUST be in the same language as the given piece of text content.\n- One question per line.\n- Output questions ONLY.",
|
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nPropose 3 questions about a given piece of text content.\n\nRequirements\n- Understand and summarize the text content, and propose the top 3 important questions.\n- The questions SHOULD NOT have overlapping meanings.\n- The questions SHOULD cover the main content of the text as much as possible.\n- The questions MUST be in the same language as the given piece of text content.\n- One question per line.\n- Output questions ONLY.",
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"temperatureEnabled": false,
|
"temperatureEnabled": false,
|
||||||
@ -559,7 +559,7 @@
|
|||||||
"outputs": {},
|
"outputs": {},
|
||||||
"presencePenaltyEnabled": false,
|
"presencePenaltyEnabled": false,
|
||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0.4,
|
||||||
"prompts": "Content: \n\n{Splitter:KindDingosJam@chunks}",
|
"prompts": "Content: \n\n{Splitter:BlueResultsWink@chunks}",
|
||||||
"sys_prompt": "Extract important structured information from the given content. Output ONLY a valid JSON string with no additional text. If no important structured information is found, output an empty JSON object: {}.\n\nImportant structured information may include: names, dates, locations, events, key facts, numerical data, or other extractable entities.",
|
"sys_prompt": "Extract important structured information from the given content. Output ONLY a valid JSON string with no additional text. If no important structured information is found, output an empty JSON object: {}.\n\nImportant structured information may include: names, dates, locations, events, key facts, numerical data, or other extractable entities.",
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"temperatureEnabled": false,
|
"temperatureEnabled": false,
|
||||||
|
|||||||
@ -578,7 +578,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"form": {
|
"form": {
|
||||||
"text": "Searches for relevant database creation statements.\n\nIt should label with a knowledgebase to which the schema is dumped in. You could use \" General \" as parsing method, \" 2 \" as chunk size and \" ; \" as delimiter."
|
"text": "Searches for relevant database creation statements.\n\nIt should label with a dataset to which the schema is dumped in. You could use \" General \" as parsing method, \" 2 \" as chunk size and \" ; \" as delimiter."
|
||||||
},
|
},
|
||||||
"label": "Note",
|
"label": "Note",
|
||||||
"name": "Note Schema"
|
"name": "Note Schema"
|
||||||
|
|||||||
@ -75,7 +75,7 @@
|
|||||||
},
|
},
|
||||||
"history": [],
|
"history": [],
|
||||||
"path": [],
|
"path": [],
|
||||||
"retrival": {"chunks": [], "doc_aggs": []},
|
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||||
"globals": {
|
"globals": {
|
||||||
"sys.query": "",
|
"sys.query": "",
|
||||||
"sys.user_id": "",
|
"sys.user_id": "",
|
||||||
|
|||||||
@ -82,7 +82,7 @@
|
|||||||
},
|
},
|
||||||
"history": [],
|
"history": [],
|
||||||
"path": [],
|
"path": [],
|
||||||
"retrival": {"chunks": [], "doc_aggs": []},
|
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||||
"globals": {
|
"globals": {
|
||||||
"sys.query": "",
|
"sys.query": "",
|
||||||
"sys.user_id": "",
|
"sys.user_id": "",
|
||||||
|
|||||||
@ -31,7 +31,7 @@
|
|||||||
"component_name": "LLM",
|
"component_name": "LLM",
|
||||||
"params": {
|
"params": {
|
||||||
"llm_id": "deepseek-chat",
|
"llm_id": "deepseek-chat",
|
||||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
|
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n Above is the knowledge base.",
|
||||||
"temperature": 0.2
|
"temperature": 0.2
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -51,7 +51,7 @@
|
|||||||
},
|
},
|
||||||
"history": [],
|
"history": [],
|
||||||
"path": [],
|
"path": [],
|
||||||
"retrival": {"chunks": [], "doc_aggs": []},
|
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||||
"globals": {
|
"globals": {
|
||||||
"sys.query": "",
|
"sys.query": "",
|
||||||
"sys.user_id": "",
|
"sys.user_id": "",
|
||||||
|
|||||||
@ -65,7 +65,7 @@
|
|||||||
"component_name": "Agent",
|
"component_name": "Agent",
|
||||||
"params": {
|
"params": {
|
||||||
"llm_id": "deepseek-chat",
|
"llm_id": "deepseek-chat",
|
||||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
|
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
|
||||||
"temperature": 0.2
|
"temperature": 0.2
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -85,7 +85,7 @@
|
|||||||
},
|
},
|
||||||
"history": [],
|
"history": [],
|
||||||
"path": [],
|
"path": [],
|
||||||
"retrival": {"chunks": [], "doc_aggs": []},
|
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||||
"globals": {
|
"globals": {
|
||||||
"sys.query": "",
|
"sys.query": "",
|
||||||
"sys.user_id": "",
|
"sys.user_id": "",
|
||||||
|
|||||||
@ -25,7 +25,7 @@
|
|||||||
"component_name": "LLM",
|
"component_name": "LLM",
|
||||||
"params": {
|
"params": {
|
||||||
"llm_id": "deepseek-chat",
|
"llm_id": "deepseek-chat",
|
||||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {tavily:0@formalized_content}\n The above is the knowledge base.",
|
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {tavily:0@formalized_content}\n Above is the knowledge base.",
|
||||||
"temperature": 0.2
|
"temperature": 0.2
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -45,7 +45,7 @@
|
|||||||
},
|
},
|
||||||
"history": [],
|
"history": [],
|
||||||
"path": [],
|
"path": [],
|
||||||
"retrival": {"chunks": [], "doc_aggs": []},
|
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||||
"globals": {
|
"globals": {
|
||||||
"sys.query": "",
|
"sys.query": "",
|
||||||
"sys.user_id": "",
|
"sys.user_id": "",
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TypedDict, List, Any
|
from typing import TypedDict, List, Any
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
@ -48,12 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
|
|||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
|
||||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||||
|
return asyncio.run(self.tool_call_async(name, arguments))
|
||||||
|
|
||||||
|
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||||
st = timer()
|
st = timer()
|
||||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
tool_obj = self.tools_map[name]
|
||||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
if isinstance(tool_obj, MCPToolCallSession):
|
||||||
|
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||||
else:
|
else:
|
||||||
resp = self.tools_map[name].invoke(**arguments)
|
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||||
|
resp = await tool_obj.invoke_async(**arguments)
|
||||||
|
else:
|
||||||
|
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||||
|
|
||||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||||
return resp
|
return resp
|
||||||
@ -139,6 +147,33 @@ class ToolBase(ComponentBase):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def invoke_async(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Async wrapper for tool invocation.
|
||||||
|
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||||
|
Mirrors the exception handling of `invoke`.
|
||||||
|
"""
|
||||||
|
if self.check_if_canceled("Tool processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
try:
|
||||||
|
fn_async = getattr(self, "_invoke_async", None)
|
||||||
|
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||||
|
res = await fn_async(**kwargs)
|
||||||
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
|
res = await self._invoke(**kwargs)
|
||||||
|
else:
|
||||||
|
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||||
|
logging.exception(e)
|
||||||
|
res = str(e)
|
||||||
|
self._param.debug_inputs = []
|
||||||
|
|
||||||
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
|
return res
|
||||||
|
|
||||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
aggs = []
|
aggs = []
|
||||||
|
|||||||
@ -69,7 +69,7 @@ class CodeExecParam(ToolParamBase):
|
|||||||
self.meta: ToolMeta = {
|
self.meta: ToolMeta = {
|
||||||
"name": "execute_code",
|
"name": "execute_code",
|
||||||
"description": """
|
"description": """
|
||||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
|
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string.
|
||||||
Here's a code example for Python(`main` function MUST be included):
|
Here's a code example for Python(`main` function MUST be included):
|
||||||
def main() -> dict:
|
def main() -> dict:
|
||||||
\"\"\"
|
\"\"\"
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -21,13 +22,13 @@ from abc import ABC
|
|||||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.dialog_service import meta_filter
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
from rag.prompts.generator import cross_languages, kb_prompt
|
||||||
|
|
||||||
|
|
||||||
class RetrievalParam(ToolParamBase):
|
class RetrievalParam(ToolParamBase):
|
||||||
@ -81,7 +82,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
component_name = "Retrieval"
|
component_name = "Retrieval"
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -130,17 +131,10 @@ class Retrieval(ToolBase, ABC):
|
|||||||
doc_ids=[]
|
doc_ids=[]
|
||||||
if self._param.meta_data_filter!={}:
|
if self._param.meta_data_filter!={}:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if self._param.meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
def _resolve_manual_filter(flt: dict) -> dict:
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, query)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif self._param.meta_data_filter.get("method") == "manual":
|
|
||||||
filters = self._param.meta_data_filter["manual"]
|
|
||||||
for flt in filters:
|
|
||||||
pat = re.compile(self.variable_ref_patt)
|
pat = re.compile(self.variable_ref_patt)
|
||||||
s = flt["value"]
|
s = flt.get("value", "")
|
||||||
out_parts = []
|
out_parts = []
|
||||||
last = 0
|
last = 0
|
||||||
|
|
||||||
@ -165,12 +159,23 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
out_parts.append(s[last:])
|
out_parts.append(s[last:])
|
||||||
flt["value"] = "".join(out_parts)
|
flt["value"] = "".join(out_parts)
|
||||||
doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and")))
|
return flt
|
||||||
if filters and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
chat_mdl = None
|
||||||
|
if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||||
|
|
||||||
|
doc_ids = await apply_meta_data_filter(
|
||||||
|
self._param.meta_data_filter,
|
||||||
|
metas,
|
||||||
|
query,
|
||||||
|
chat_mdl,
|
||||||
|
doc_ids,
|
||||||
|
_resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None,
|
||||||
|
)
|
||||||
|
|
||||||
if self._param.cross_languages:
|
if self._param.cross_languages:
|
||||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||||
|
|
||||||
if kbs:
|
if kbs:
|
||||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||||
@ -198,6 +203,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
return
|
return
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
|
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(query,
|
ck = settings.kg_retriever.retrieval(query,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
@ -242,6 +248,10 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
return form_cnt
|
return form_cnt
|
||||||
|
|
||||||
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
|
|||||||
@ -75,7 +75,7 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
if not kwargs.get("stock_code"):
|
if not kwargs.get("stock_code"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -84,33 +84,33 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
yohoo_res = []
|
yahoo_res = []
|
||||||
try:
|
try:
|
||||||
msft = yf.Ticker(kwargs["stock_code"])
|
msft = yf.Ticker(kwargs["stock_code"])
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
if self._param.info:
|
if self._param.info:
|
||||||
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
yahoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
||||||
if self._param.history:
|
if self._param.history:
|
||||||
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
yahoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
||||||
if self._param.financials:
|
if self._param.financials:
|
||||||
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
yahoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
||||||
if self._param.balance_sheet:
|
if self._param.balance_sheet:
|
||||||
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
yahoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
||||||
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
yahoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
||||||
if self._param.cash_flow_statement:
|
if self._param.cash_flow_statement:
|
||||||
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
yahoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
||||||
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
yahoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
||||||
if self._param.news:
|
if self._param.news:
|
||||||
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
yahoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
||||||
self.set_output("report", "\n\n".join(yohoo_res))
|
self.set_output("report", "\n\n".join(yahoo_res))
|
||||||
return self.output("report")
|
return self.output("report")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
last_e = e
|
last_e = e
|
||||||
logging.exception(f"YahooFinance error: {e}")
|
logging.exception(f"YahooFinance error: {e}")
|
||||||
|
|||||||
@ -51,7 +51,7 @@ class DeepResearcher:
|
|||||||
"""Remove Result Tags"""
|
"""Remove Result Tags"""
|
||||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
||||||
|
|
||||||
def _generate_reasoning(self, msg_history):
|
async def _generate_reasoning(self, msg_history):
|
||||||
"""Generate reasoning steps"""
|
"""Generate reasoning steps"""
|
||||||
query_think = ""
|
query_think = ""
|
||||||
if msg_history[-1]["role"] != "user":
|
if msg_history[-1]["role"] != "user":
|
||||||
@ -59,13 +59,14 @@ class DeepResearcher:
|
|||||||
else:
|
else:
|
||||||
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
||||||
|
|
||||||
for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
if not ans:
|
if not ans:
|
||||||
continue
|
continue
|
||||||
query_think = ans
|
query_think = ans
|
||||||
yield query_think
|
yield query_think
|
||||||
return query_think
|
query_think = ""
|
||||||
|
yield query_think
|
||||||
|
|
||||||
def _extract_search_queries(self, query_think, question, step_index):
|
def _extract_search_queries(self, query_think, question, step_index):
|
||||||
"""Extract search queries from thinking"""
|
"""Extract search queries from thinking"""
|
||||||
@ -143,10 +144,10 @@ class DeepResearcher:
|
|||||||
if d["doc_id"] not in dids:
|
if d["doc_id"] not in dids:
|
||||||
chunk_info["doc_aggs"].append(d)
|
chunk_info["doc_aggs"].append(d)
|
||||||
|
|
||||||
def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
||||||
"""Extract and summarize relevant information"""
|
"""Extract and summarize relevant information"""
|
||||||
summary_think = ""
|
summary_think = ""
|
||||||
for ans in self.chat_mdl.chat_streamly(
|
async for ans in self.chat_mdl.async_chat_streamly(
|
||||||
RELEVANT_EXTRACTION_PROMPT.format(
|
RELEVANT_EXTRACTION_PROMPT.format(
|
||||||
prev_reasoning=truncated_prev_reasoning,
|
prev_reasoning=truncated_prev_reasoning,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
@ -160,10 +161,11 @@ class DeepResearcher:
|
|||||||
continue
|
continue
|
||||||
summary_think = ans
|
summary_think = ans
|
||||||
yield summary_think
|
yield summary_think
|
||||||
|
summary_think = ""
|
||||||
|
|
||||||
return summary_think
|
yield summary_think
|
||||||
|
|
||||||
def thinking(self, chunk_info: dict, question: str):
|
async def thinking(self, chunk_info: dict, question: str):
|
||||||
executed_search_queries = []
|
executed_search_queries = []
|
||||||
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
||||||
all_reasoning_steps = []
|
all_reasoning_steps = []
|
||||||
@ -180,7 +182,7 @@ class DeepResearcher:
|
|||||||
|
|
||||||
# Step 1: Generate reasoning
|
# Step 1: Generate reasoning
|
||||||
query_think = ""
|
query_think = ""
|
||||||
for ans in self._generate_reasoning(msg_history):
|
async for ans in self._generate_reasoning(msg_history):
|
||||||
query_think = ans
|
query_think = ans
|
||||||
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||||
|
|
||||||
@ -223,7 +225,7 @@ class DeepResearcher:
|
|||||||
# Step 6: Extract relevant information
|
# Step 6: Extract relevant information
|
||||||
think += "\n\n"
|
think += "\n\n"
|
||||||
summary_think = ""
|
summary_think = ""
|
||||||
for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||||
summary_think = ans
|
summary_think = ans
|
||||||
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||||
|
|
||||||
|
|||||||
@ -14,5 +14,5 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from beartype.claw import beartype_this_package
|
# from beartype.claw import beartype_this_package
|
||||||
beartype_this_package()
|
# beartype_this_package()
|
||||||
|
|||||||
@ -13,13 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import logging
|
|
||||||
from importlib.util import module_from_spec, spec_from_file_location
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from quart import Blueprint, Quart, request, g, current_app, session
|
from quart import Blueprint, Quart, request, g, current_app, session
|
||||||
from werkzeug.wrappers.request import Request
|
|
||||||
from flasgger import Swagger
|
from flasgger import Swagger
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
from quart_cors import cors
|
from quart_cors import cors
|
||||||
@ -29,7 +28,6 @@ from api.db.services import UserService
|
|||||||
from api.utils.json_encode import CustomJSONEncoder
|
from api.utils.json_encode import CustomJSONEncoder
|
||||||
from api.utils import commands
|
from api.utils import commands
|
||||||
|
|
||||||
from flask_mail import Mail
|
|
||||||
from quart_auth import Unauthorized
|
from quart_auth import Unauthorized
|
||||||
from common import settings
|
from common import settings
|
||||||
from api.utils.api_utils import server_error_response
|
from api.utils.api_utils import server_error_response
|
||||||
@ -40,11 +38,9 @@ settings.init_settings()
|
|||||||
|
|
||||||
__all__ = ["app"]
|
__all__ = ["app"]
|
||||||
|
|
||||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
|
||||||
|
|
||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
app = cors(app, allow_origin="*")
|
app = cors(app, allow_origin="*")
|
||||||
smtp_mail_server = Mail()
|
|
||||||
|
|
||||||
# Add this at the beginning of your file to configure Swagger UI
|
# Add this at the beginning of your file to configure Swagger UI
|
||||||
swagger_config = {
|
swagger_config = {
|
||||||
@ -82,6 +78,11 @@ app.url_map.strict_slashes = False
|
|||||||
app.json_encoder = CustomJSONEncoder
|
app.json_encoder = CustomJSONEncoder
|
||||||
app.errorhandler(Exception)(server_error_response)
|
app.errorhandler(Exception)(server_error_response)
|
||||||
|
|
||||||
|
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
|
||||||
|
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
|
||||||
|
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
|
||||||
|
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
|
||||||
|
|
||||||
## convince for dev and debug
|
## convince for dev and debug
|
||||||
# app.config["LOGIN_DISABLED"] = True
|
# app.config["LOGIN_DISABLED"] = True
|
||||||
app.config["SESSION_PERMANENT"] = False
|
app.config["SESSION_PERMANENT"] = False
|
||||||
|
|||||||
@ -18,8 +18,7 @@ from quart import request
|
|||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||||
generate_confirmation_token
|
|
||||||
from common.time_utils import current_timestamp, datetime_format
|
from common.time_utils import current_timestamp, datetime_format
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
|
||||||
@ -27,7 +26,7 @@ from api.apps import login_required, current_user
|
|||||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def new_token():
|
async def new_token():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
if not tenants:
|
if not tenants:
|
||||||
@ -73,7 +72,7 @@ def token_list():
|
|||||||
@validate_request("tokens", "tenant_id")
|
@validate_request("tokens", "tenant_id")
|
||||||
@login_required
|
@login_required
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
for token in req["tokens"]:
|
for token in req["tokens"]:
|
||||||
APITokenService.filter_delete(
|
APITokenService.filter_delete(
|
||||||
@ -116,4 +115,3 @@ def stats():
|
|||||||
return get_json_result(data=res)
|
return get_json_result(data=res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import requests
|
from common.http_client import async_request, sync_request
|
||||||
from .oauth import OAuthClient, UserInfo
|
from .oauth import OAuthClient, UserInfo
|
||||||
|
|
||||||
|
|
||||||
@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
|
|||||||
|
|
||||||
def fetch_user_info(self, access_token, **kwargs):
|
def fetch_user_info(self, access_token, **kwargs):
|
||||||
"""
|
"""
|
||||||
Fetch GitHub user info.
|
Fetch GitHub user info (synchronous).
|
||||||
"""
|
"""
|
||||||
user_info = {}
|
user_info = {}
|
||||||
try:
|
try:
|
||||||
headers = {"Authorization": f"Bearer {access_token}"}
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
# user info
|
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
user_info.update(response.json())
|
user_info.update(response.json())
|
||||||
# email info
|
email_response = sync_request(
|
||||||
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout)
|
"GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
|
||||||
response.raise_for_status()
|
)
|
||||||
email_info = response.json()
|
email_response.raise_for_status()
|
||||||
user_info["email"] = next(
|
email_info = email_response.json()
|
||||||
(email for email in email_info if email["primary"]), None
|
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||||
)["email"]
|
|
||||||
return self.normalize_user_info(user_info)
|
return self.normalize_user_info(user_info)
|
||||||
except requests.exceptions.RequestException as e:
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||||
|
|
||||||
|
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||||
|
"""Async variant of fetch_user_info using httpx."""
|
||||||
|
user_info = {}
|
||||||
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
try:
|
||||||
|
response = await async_request(
|
||||||
|
"GET",
|
||||||
|
self.userinfo_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.http_request_timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
user_info.update(response.json())
|
||||||
|
|
||||||
|
email_response = await async_request(
|
||||||
|
"GET",
|
||||||
|
self.userinfo_url + "/emails",
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.http_request_timeout,
|
||||||
|
)
|
||||||
|
email_response.raise_for_status()
|
||||||
|
email_info = email_response.json()
|
||||||
|
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||||
|
return self.normalize_user_info(user_info)
|
||||||
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,8 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import requests
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from common.http_client import async_request, sync_request
|
||||||
|
|
||||||
|
|
||||||
class UserInfo:
|
class UserInfo:
|
||||||
@ -74,15 +74,40 @@ class OAuthClient:
|
|||||||
"redirect_uri": self.redirect_uri,
|
"redirect_uri": self.redirect_uri,
|
||||||
"grant_type": "authorization_code"
|
"grant_type": "authorization_code"
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = sync_request(
|
||||||
|
"POST",
|
||||||
self.token_url,
|
self.token_url,
|
||||||
data=payload,
|
data=payload,
|
||||||
headers={"Accept": "application/json"},
|
headers={"Accept": "application/json"},
|
||||||
timeout=self.http_request_timeout
|
timeout=self.http_request_timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
except requests.exceptions.RequestException as e:
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||||
|
|
||||||
|
async def async_exchange_code_for_token(self, code):
|
||||||
|
"""
|
||||||
|
Async variant of exchange_code_for_token using httpx.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = await async_request(
|
||||||
|
"POST",
|
||||||
|
self.token_url,
|
||||||
|
data=payload,
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
timeout=self.http_request_timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||||
|
|
||||||
|
|
||||||
@ -92,11 +117,27 @@ class OAuthClient:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
headers = {"Authorization": f"Bearer {access_token}"}
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
user_info = response.json()
|
user_info = response.json()
|
||||||
return self.normalize_user_info(user_info)
|
return self.normalize_user_info(user_info)
|
||||||
except requests.exceptions.RequestException as e:
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to fetch user info: {e}")
|
||||||
|
|
||||||
|
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||||
|
"""Async variant of fetch_user_info using httpx."""
|
||||||
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
try:
|
||||||
|
response = await async_request(
|
||||||
|
"GET",
|
||||||
|
self.userinfo_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.http_request_timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
user_info = response.json()
|
||||||
|
return self.normalize_user_info(user_info)
|
||||||
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to fetch user info: {e}")
|
raise ValueError(f"Failed to fetch user info: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import requests
|
from common.http_client import sync_request
|
||||||
from .oauth import OAuthClient
|
from .oauth import OAuthClient
|
||||||
|
|
||||||
|
|
||||||
@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
metadata_url = f"{issuer}/.well-known/openid-configuration"
|
metadata_url = f"{issuer}/.well-known/openid-configuration"
|
||||||
response = requests.get(metadata_url, timeout=7)
|
response = sync_request("GET", metadata_url, timeout=7)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
except requests.exceptions.RequestException as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
|
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
|
||||||
|
|
||||||
|
|
||||||
@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
|
|||||||
user_info.update(super().fetch_user_info(access_token).to_dict())
|
user_info.update(super().fetch_user_info(access_token).to_dict())
|
||||||
return self.normalize_user_info(user_info)
|
return self.normalize_user_info(user_info)
|
||||||
|
|
||||||
|
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
|
||||||
|
user_info = {}
|
||||||
|
if id_token:
|
||||||
|
user_info = self.parse_id_token(id_token)
|
||||||
|
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
|
||||||
|
return self.normalize_user_info(user_info)
|
||||||
|
|
||||||
|
|
||||||
def normalize_user_info(self, user_info):
|
def normalize_user_info(self, user_info):
|
||||||
return super().normalize_user_info(user_info)
|
return super().normalize_user_info(user_info)
|
||||||
|
|||||||
@ -13,15 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import trio
|
|
||||||
from quart import request, Response, make_response
|
from quart import request, Response, make_response
|
||||||
from agent.component import LLM
|
from agent.component import LLM
|
||||||
from api.db import CanvasCategory, FileType
|
from api.db import CanvasCategory
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
@ -32,13 +31,12 @@ from api.db.services.user_canvas_version import UserCanvasVersionService
|
|||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||||
request_json
|
get_request_json
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||||
from api.db.db_models import APIToken, Task
|
from api.db.db_models import APIToken, Task
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
|
||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
@ -56,7 +54,7 @@ def templates():
|
|||||||
@validate_request("canvas_ids")
|
@validate_request("canvas_ids")
|
||||||
@login_required
|
@login_required
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
for i in req["canvas_ids"]:
|
for i in req["canvas_ids"]:
|
||||||
if not UserCanvasService.accessible(i, current_user.id):
|
if not UserCanvasService.accessible(i, current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
@ -70,7 +68,7 @@ async def rm():
|
|||||||
@validate_request("dsl", "title")
|
@validate_request("dsl", "title")
|
||||||
@login_required
|
@login_required
|
||||||
async def save():
|
async def save():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not isinstance(req["dsl"], str):
|
if not isinstance(req["dsl"], str):
|
||||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||||
req["dsl"] = json.loads(req["dsl"])
|
req["dsl"] = json.loads(req["dsl"])
|
||||||
@ -129,17 +127,17 @@ def getsse(canvas_id):
|
|||||||
@validate_request("id")
|
@validate_request("id")
|
||||||
@login_required
|
@login_required
|
||||||
async def run():
|
async def run():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
query = req.get("query", "")
|
query = req.get("query", "")
|
||||||
files = req.get("files", [])
|
files = req.get("files", [])
|
||||||
inputs = req.get("inputs", {})
|
inputs = req.get("inputs", {})
|
||||||
user_id = req.get("user_id", current_user.id)
|
user_id = req.get("user_id", current_user.id)
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
@ -149,7 +147,7 @@ async def run():
|
|||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
task_id = get_uuid()
|
task_id = get_uuid()
|
||||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||||
if not ok:
|
if not ok:
|
||||||
return get_data_error_result(message=error_message)
|
return get_data_error_result(message=error_message)
|
||||||
return get_json_result(data={"message_id": task_id})
|
return get_json_result(data={"message_id": task_id})
|
||||||
@ -186,7 +184,7 @@ async def run():
|
|||||||
@validate_request("id", "dsl", "component_id")
|
@validate_request("id", "dsl", "component_id")
|
||||||
@login_required
|
@login_required
|
||||||
async def rerun():
|
async def rerun():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||||
if not doc:
|
if not doc:
|
||||||
return get_data_error_result(message="Document not found.")
|
return get_data_error_result(message="Document not found.")
|
||||||
@ -224,7 +222,7 @@ def cancel(task_id):
|
|||||||
@validate_request("id")
|
@validate_request("id")
|
||||||
@login_required
|
@login_required
|
||||||
async def reset():
|
async def reset():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
@ -250,71 +248,10 @@ async def upload(canvas_id):
|
|||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
user_id = cvs["user_id"]
|
user_id = cvs["user_id"]
|
||||||
def structured(filename, filetype, blob, content_type):
|
|
||||||
nonlocal user_id
|
|
||||||
if filetype == FileType.PDF.value:
|
|
||||||
blob = read_potential_broken_pdf(blob)
|
|
||||||
|
|
||||||
location = get_uuid()
|
|
||||||
FileService.put_blob(user_id, location, blob)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": location,
|
|
||||||
"name": filename,
|
|
||||||
"size": sys.getsizeof(blob),
|
|
||||||
"extension": filename.split(".")[-1].lower(),
|
|
||||||
"mime_type": content_type,
|
|
||||||
"created_by": user_id,
|
|
||||||
"created_at": time.time(),
|
|
||||||
"preview_url": None
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.args.get("url"):
|
|
||||||
from crawl4ai import (
|
|
||||||
AsyncWebCrawler,
|
|
||||||
BrowserConfig,
|
|
||||||
CrawlerRunConfig,
|
|
||||||
DefaultMarkdownGenerator,
|
|
||||||
PruningContentFilter,
|
|
||||||
CrawlResult
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
url = request.args.get("url")
|
|
||||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
|
||||||
async def adownload():
|
|
||||||
browser_config = BrowserConfig(
|
|
||||||
headless=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
|
||||||
crawler_config = CrawlerRunConfig(
|
|
||||||
markdown_generator=DefaultMarkdownGenerator(
|
|
||||||
content_filter=PruningContentFilter()
|
|
||||||
),
|
|
||||||
pdf=True,
|
|
||||||
screenshot=False
|
|
||||||
)
|
|
||||||
result: CrawlResult = await crawler.arun(
|
|
||||||
url=url,
|
|
||||||
config=crawler_config
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
page = trio.run(adownload())
|
|
||||||
if page.pdf:
|
|
||||||
if filename.split(".")[-1].lower() != "pdf":
|
|
||||||
filename += ".pdf"
|
|
||||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
|
||||||
|
|
||||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
files = await request.files
|
files = await request.files
|
||||||
file = files['file']
|
file = files['file'] if files and files.get("file") else None
|
||||||
try:
|
try:
|
||||||
DocumentService.check_doc_health(user_id, file.filename)
|
return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url")))
|
||||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -343,7 +280,7 @@ def input_form():
|
|||||||
@validate_request("id", "component_id", "params")
|
@validate_request("id", "component_id", "params")
|
||||||
@login_required
|
@login_required
|
||||||
async def debug():
|
async def debug():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
@ -363,7 +300,12 @@ async def debug():
|
|||||||
for k in outputs.keys():
|
for k in outputs.keys():
|
||||||
if isinstance(outputs[k], partial):
|
if isinstance(outputs[k], partial):
|
||||||
txt = ""
|
txt = ""
|
||||||
for c in outputs[k]():
|
iter_obj = outputs[k]()
|
||||||
|
if inspect.isasyncgen(iter_obj):
|
||||||
|
async for c in iter_obj:
|
||||||
|
txt += c
|
||||||
|
else:
|
||||||
|
for c in iter_obj:
|
||||||
txt += c
|
txt += c
|
||||||
outputs[k] = txt
|
outputs[k] = txt
|
||||||
return get_json_result(data=outputs)
|
return get_json_result(data=outputs)
|
||||||
@ -375,7 +317,7 @@ async def debug():
|
|||||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||||
@login_required
|
@login_required
|
||||||
async def test_db_connect():
|
async def test_db_connect():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
if req["db_type"] in ["mysql", "mariadb"]:
|
if req["db_type"] in ["mysql", "mariadb"]:
|
||||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||||
@ -406,7 +348,15 @@ async def test_db_connect():
|
|||||||
f"UID={req['username']};"
|
f"UID={req['username']};"
|
||||||
f"PWD={req['password']};"
|
f"PWD={req['password']};"
|
||||||
)
|
)
|
||||||
logging.info(conn_str)
|
redacted_conn_str = (
|
||||||
|
f"DATABASE={req['database']};"
|
||||||
|
f"HOSTNAME={req['host']};"
|
||||||
|
f"PORT={req['port']};"
|
||||||
|
f"PROTOCOL=TCPIP;"
|
||||||
|
f"UID={req['username']};"
|
||||||
|
f"PWD=****;"
|
||||||
|
)
|
||||||
|
logging.info(redacted_conn_str)
|
||||||
conn = ibm_db.connect(conn_str, "", "")
|
conn = ibm_db.connect(conn_str, "", "")
|
||||||
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||||
ibm_db.fetch_assoc(stmt)
|
ibm_db.fetch_assoc(stmt)
|
||||||
@ -520,7 +470,7 @@ def list_canvas():
|
|||||||
@validate_request("id", "title", "permission")
|
@validate_request("id", "title", "permission")
|
||||||
@login_required
|
@login_required
|
||||||
async def setting():
|
async def setting():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
req["user_id"] = current_user.id
|
req["user_id"] = current_user.id
|
||||||
|
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
|
|||||||
@ -13,25 +13,26 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import base64
|
||||||
import xxhash
|
import xxhash
|
||||||
from quart import request
|
from quart import request
|
||||||
|
|
||||||
from api.db.services.dialog_service import meta_filter
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||||
request_json
|
get_request_json
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||||
from common import settings
|
from common import settings
|
||||||
@ -42,7 +43,7 @@ from api.apps import login_required, current_user
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id")
|
@validate_request("doc_id")
|
||||||
async def list_chunk():
|
async def list_chunk():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
doc_id = req["doc_id"]
|
doc_id = req["doc_id"]
|
||||||
page = int(req.get("page", 1))
|
page = int(req.get("page", 1))
|
||||||
size = int(req.get("size", 30))
|
size = int(req.get("size", 30))
|
||||||
@ -123,7 +124,7 @@ def get():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||||
async def set():
|
async def set():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
d = {
|
d = {
|
||||||
"id": req["chunk_id"],
|
"id": req["chunk_id"],
|
||||||
"content_with_weight": req["content_with_weight"]}
|
"content_with_weight": req["content_with_weight"]}
|
||||||
@ -147,6 +148,7 @@ async def set():
|
|||||||
d["available_int"] = req["available_int"]
|
d["available_int"] = req["available_int"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
def _set_sync():
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
@ -158,20 +160,31 @@ async def set():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
|
_d = d
|
||||||
if doc.parser_id == ParserType.QA:
|
if doc.parser_id == ParserType.QA:
|
||||||
arr = [
|
arr = [
|
||||||
t for t in re.split(
|
t for t in re.split(
|
||||||
r"[\n\t]",
|
r"[\n\t]",
|
||||||
req["content_with_weight"]) if len(t) > 1]
|
req["content_with_weight"]) if len(t) > 1]
|
||||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||||
d = beAdoc(d, q, a, not any(
|
_d = beAdoc(d, q, a, not any(
|
||||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
|
# update image
|
||||||
|
image_id = req.get("img_id")
|
||||||
|
bkt, name = image_id.split("-")
|
||||||
|
image_base64 = req.get("image_base64", None)
|
||||||
|
if image_base64:
|
||||||
|
image_binary = base64.b64decode(image_base64)
|
||||||
|
settings.STORAGE_IMPL.put(bkt, name, image_binary)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_set_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -180,8 +193,9 @@ async def set():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||||
async def switch():
|
async def switch():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
|
def _switch_sync():
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
@ -192,6 +206,8 @@ async def switch():
|
|||||||
doc.kb_id):
|
doc.kb_id):
|
||||||
return get_data_error_result(message="Index updating failure")
|
return get_data_error_result(message="Index updating failure")
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_switch_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -200,8 +216,9 @@ async def switch():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("chunk_ids", "doc_id")
|
@validate_request("chunk_ids", "doc_id")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
|
def _rm_sync():
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
@ -216,6 +233,8 @@ async def rm():
|
|||||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rm_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -224,7 +243,7 @@ async def rm():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id", "content_with_weight")
|
@validate_request("doc_id", "content_with_weight")
|
||||||
async def create():
|
async def create():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||||
"content_with_weight": req["content_with_weight"]}
|
"content_with_weight": req["content_with_weight"]}
|
||||||
@ -245,6 +264,7 @@ async def create():
|
|||||||
d["tag_feas"] = req["tag_feas"]
|
d["tag_feas"] = req["tag_feas"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
def _create_sync():
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
@ -274,6 +294,8 @@ async def create():
|
|||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
doc.id, doc.kb_id, c, 1, 0)
|
||||||
return get_json_result(data={"chunk_id": chunck_id})
|
return get_json_result(data={"chunk_id": chunck_id})
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_create_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -282,7 +304,7 @@ async def create():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("kb_id", "question")
|
@validate_request("kb_id", "question")
|
||||||
async def retrieval_test():
|
async def retrieval_test():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
page = int(req.get("page", 1))
|
page = int(req.get("page", 1))
|
||||||
size = int(req.get("size", 30))
|
size = int(req.get("size", 30))
|
||||||
question = req["question"]
|
question = req["question"]
|
||||||
@ -297,25 +319,29 @@ async def retrieval_test():
|
|||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
async def _retrieval():
|
||||||
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
|
|
||||||
|
meta_data_filter = {}
|
||||||
|
chat_mdl = None
|
||||||
if req.get("search_id", ""):
|
if req.get("search_id", ""):
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
if meta_data_filter.get("method") == "auto":
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
else:
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
meta_data_filter = req.get("meta_data_filter") or {}
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
if not doc_ids:
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
try:
|
if meta_data_filter:
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
|
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
|
||||||
|
|
||||||
|
tenants = UserTenantService.query(user_id=user_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(
|
||||||
@ -324,15 +350,16 @@ async def retrieval_test():
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of dataset authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
|
_question = question
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -342,31 +369,35 @@ async def retrieval_test():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
top,
|
top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl,
|
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||||
highlight=req.get("highlight", False),
|
highlight=req.get("highlight", False),
|
||||||
rank_feature=labels
|
rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question,
|
ck = settings.kg_retriever.retrieval(_question,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
|
||||||
|
|
||||||
for c in ranks["chunks"]:
|
for c in ranks["chunks"]:
|
||||||
c.pop("vector", None)
|
c.pop("vector", None)
|
||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await _retrieval()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||||
|
|||||||
@ -26,19 +26,20 @@ from google_auth_oauthlib.flow import Flow
|
|||||||
|
|
||||||
from api.db import InputType
|
from api.db import InputType
|
||||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
|
||||||
from common.constants import RetCode, TaskStatus
|
from common.constants import RetCode, TaskStatus
|
||||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, BOX_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
from common.data_source.google_util.constant import WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
from box_sdk_gen import BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def set_connector():
|
async def set_connector():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
if req.get("id"):
|
if req.get("id"):
|
||||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||||
ConnectorService.update_by_id(req["id"], conn)
|
ConnectorService.update_by_id(req["id"], conn)
|
||||||
@ -90,7 +91,7 @@ def list_logs(connector_id):
|
|||||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def resume(connector_id):
|
async def resume(connector_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
if req.get("resume"):
|
if req.get("resume"):
|
||||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||||
else:
|
else:
|
||||||
@ -102,7 +103,7 @@ async def resume(connector_id):
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("kb_id")
|
@validate_request("kb_id")
|
||||||
async def rebuild(connector_id):
|
async def rebuild(connector_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
||||||
if err:
|
if err:
|
||||||
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||||
@ -117,17 +118,27 @@ def rm_connector(connector_id):
|
|||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state"
|
|
||||||
GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
|
|
||||||
WEB_FLOW_TTL_SECS = 15 * 60
|
WEB_FLOW_TTL_SECS = 15 * 60
|
||||||
|
|
||||||
|
|
||||||
def _web_state_cache_key(flow_id: str) -> str:
|
def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||||
return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"
|
"""Return Redis key for web OAuth state.
|
||||||
|
|
||||||
|
The default prefix keeps backward compatibility for Google Drive.
|
||||||
|
When source_type == "gmail", a different prefix is used so that
|
||||||
|
Drive/Gmail flows don't clash in Redis.
|
||||||
|
"""
|
||||||
|
prefix = f"{source_type}_web_flow_state"
|
||||||
|
return f"{prefix}:{flow_id}"
|
||||||
|
|
||||||
|
|
||||||
def _web_result_cache_key(flow_id: str) -> str:
|
def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||||
return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"
|
"""Return Redis key for web OAuth result.
|
||||||
|
|
||||||
|
Mirrors _web_state_cache_key logic for result storage.
|
||||||
|
"""
|
||||||
|
prefix = f"{source_type}_web_flow_result"
|
||||||
|
return f"{prefix}:{flow_id}"
|
||||||
|
|
||||||
|
|
||||||
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
||||||
@ -146,19 +157,24 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return {"web": web_section}
|
return {"web": web_section}
|
||||||
|
|
||||||
|
|
||||||
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
|
||||||
status = "success" if success else "error"
|
status = "success" if success else "error"
|
||||||
auto_close = "window.close();" if success else ""
|
auto_close = "window.close();" if success else ""
|
||||||
escaped_message = escape(message)
|
escaped_message = escape(message)
|
||||||
|
# Drive: ragflow-google-drive-oauth
|
||||||
|
# Gmail: ragflow-gmail-oauth
|
||||||
|
payload_type = f"ragflow-{source}-oauth"
|
||||||
payload_json = json.dumps(
|
payload_json = json.dumps(
|
||||||
{
|
{
|
||||||
"type": "ragflow-google-drive-oauth",
|
"type": payload_type,
|
||||||
"status": status,
|
"status": status,
|
||||||
"flowId": flow_id or "",
|
"flowId": flow_id or "",
|
||||||
"message": message,
|
"message": message,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
# TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type
|
||||||
|
html = WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||||
|
title=f"Google {source.capitalize()} Authorization",
|
||||||
heading="Authorization complete" if success else "Authorization failed",
|
heading="Authorization complete" if success else "Authorization failed",
|
||||||
message=escaped_message,
|
message=escaped_message,
|
||||||
payload_json=payload_json,
|
payload_json=payload_json,
|
||||||
@ -169,20 +185,33 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("credentials")
|
@validate_request("credentials")
|
||||||
async def start_google_drive_web_oauth():
|
async def start_google_web_oauth():
|
||||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
source = request.args.get("type", "google-drive")
|
||||||
|
if source not in ("google-drive", "gmail"):
|
||||||
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||||
|
|
||||||
|
if source == "gmail":
|
||||||
|
redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||||
|
scopes = GOOGLE_SCOPES[DocumentSource.GMAIL]
|
||||||
|
else:
|
||||||
|
redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||||
|
scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]
|
||||||
|
|
||||||
|
if not redirect_uri:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
code=RetCode.SERVER_ERROR,
|
code=RetCode.SERVER_ERROR,
|
||||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
message="Google OAuth redirect URI is not configured on the server.",
|
||||||
)
|
)
|
||||||
|
|
||||||
req = await request.json or {}
|
req = await get_request_json()
|
||||||
raw_credentials = req.get("credentials", "")
|
raw_credentials = req.get("credentials", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = _load_credentials(raw_credentials)
|
credentials = _load_credentials(raw_credentials)
|
||||||
|
print(credentials)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||||
|
|
||||||
@ -199,8 +228,8 @@ async def start_google_drive_web_oauth():
|
|||||||
|
|
||||||
flow_id = str(uuid.uuid4())
|
flow_id = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
flow = Flow.from_client_config(client_config, scopes=scopes)
|
||||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
flow.redirect_uri = redirect_uri
|
||||||
authorization_url, _ = flow.authorization_url(
|
authorization_url, _ = flow.authorization_url(
|
||||||
access_type="offline",
|
access_type="offline",
|
||||||
include_granted_scopes="true",
|
include_granted_scopes="true",
|
||||||
@ -219,7 +248,7 @@ async def start_google_drive_web_oauth():
|
|||||||
"client_config": client_config,
|
"client_config": client_config,
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
}
|
}
|
||||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)
|
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
|
||||||
|
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data={
|
data={
|
||||||
@ -230,60 +259,115 @@ async def start_google_drive_web_oauth():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||||
async def google_drive_web_oauth_callback():
|
async def google_gmail_web_oauth_callback():
|
||||||
state_id = request.args.get("state")
|
state_id = request.args.get("state")
|
||||||
error = request.args.get("error")
|
error = request.args.get("error")
|
||||||
|
source = "gmail"
|
||||||
|
|
||||||
error_description = request.args.get("error_description") or error
|
error_description = request.args.get("error_description") or error
|
||||||
|
|
||||||
if not state_id:
|
if not state_id:
|
||||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||||
|
|
||||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
|
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||||
if not state_cache:
|
if not state_cache:
|
||||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||||
|
|
||||||
state_obj = json.loads(state_cache)
|
state_obj = json.loads(state_cache)
|
||||||
client_config = state_obj.get("client_config")
|
client_config = state_obj.get("client_config")
|
||||||
if not client_config:
|
if not client_config:
|
||||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||||
|
|
||||||
code = request.args.get("code")
|
code = request.args.get("code")
|
||||||
if not code:
|
if not code:
|
||||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL])
|
||||||
|
flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||||
flow.fetch_token(code=code)
|
flow.fetch_token(code=code)
|
||||||
except Exception as exc: # pragma: no cover - defensive
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||||
|
|
||||||
creds_json = flow.credentials.to_json()
|
creds_json = flow.credentials.to_json()
|
||||||
result_payload = {
|
result_payload = {
|
||||||
"user_id": state_obj.get("user_id"),
|
"user_id": state_obj.get("user_id"),
|
||||||
"credentials": creds_json,
|
"credentials": creds_json,
|
||||||
}
|
}
|
||||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
|
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
|
|
||||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||||
|
async def google_drive_web_oauth_callback():
|
||||||
|
state_id = request.args.get("state")
|
||||||
|
error = request.args.get("error")
|
||||||
|
source = "google-drive"
|
||||||
|
|
||||||
|
error_description = request.args.get("error_description") or error
|
||||||
|
|
||||||
|
if not state_id:
|
||||||
|
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||||
|
|
||||||
|
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||||
|
if not state_cache:
|
||||||
|
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||||
|
|
||||||
|
state_obj = json.loads(state_cache)
|
||||||
|
client_config = state_obj.get("client_config")
|
||||||
|
if not client_config:
|
||||||
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
|
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
|
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||||
|
|
||||||
|
code = request.args.get("code")
|
||||||
|
if not code:
|
||||||
|
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||||
|
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||||
|
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||||
|
flow.fetch_token(code=code)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||||
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
|
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||||
|
|
||||||
|
creds_json = flow.credentials.to_json()
|
||||||
|
result_payload = {
|
||||||
|
"user_id": state_obj.get("user_id"),
|
||||||
|
"credentials": creds_json,
|
||||||
|
}
|
||||||
|
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||||
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||||
|
|
||||||
|
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||||
|
|
||||||
|
@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("flow_id")
|
@validate_request("flow_id")
|
||||||
async def poll_google_drive_web_result():
|
async def poll_google_web_result():
|
||||||
req = await request.json or {}
|
req = await request.json or {}
|
||||||
|
source = request.args.get("type")
|
||||||
|
if source not in ("google-drive", "gmail"):
|
||||||
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||||
flow_id = req.get("flow_id")
|
flow_id = req.get("flow_id")
|
||||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source))
|
||||||
if not cache_raw:
|
if not cache_raw:
|
||||||
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||||
|
|
||||||
@ -291,5 +375,109 @@ async def poll_google_drive_web_result():
|
|||||||
if result.get("user_id") != current_user.id:
|
if result.get("user_id") != current_user.id:
|
||||||
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||||
|
|
||||||
REDIS_CONN.delete(_web_result_cache_key(flow_id))
|
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
|
||||||
return get_json_result(data={"credentials": result.get("credentials")})
|
return get_json_result(data={"credentials": result.get("credentials")})
|
||||||
|
|
||||||
|
@manager.route("/box/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def start_box_web_oauth():
|
||||||
|
req = await get_request_json()
|
||||||
|
|
||||||
|
client_id = req.get("client_id")
|
||||||
|
client_secret = req.get("client_secret")
|
||||||
|
redirect_uri = req.get("redirect_uri", BOX_WEB_OAUTH_REDIRECT_URI)
|
||||||
|
|
||||||
|
if not client_id or not client_secret:
|
||||||
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box client_id and client_secret are required.")
|
||||||
|
|
||||||
|
flow_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
box_auth = BoxOAuth(
|
||||||
|
OAuthConfig(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_url = box_auth.get_authorize_url(
|
||||||
|
options=GetAuthorizeUrlOptions(
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
state=flow_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_payload = {
|
||||||
|
"user_id": current_user.id,
|
||||||
|
"auth_url": auth_url,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
}
|
||||||
|
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, "box"), cache_payload, WEB_FLOW_TTL_SECS)
|
||||||
|
return get_json_result(
|
||||||
|
data = {
|
||||||
|
"flow_id": flow_id,
|
||||||
|
"authorization_url": auth_url,
|
||||||
|
"expires_in": WEB_FLOW_TTL_SECS,}
|
||||||
|
)
|
||||||
|
|
||||||
|
@manager.route("/box/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||||
|
async def box_web_oauth_callback():
|
||||||
|
flow_id = request.args.get("state")
|
||||||
|
if not flow_id:
|
||||||
|
return await _render_web_oauth_popup("", False, "Missing OAuth parameters.", "box")
|
||||||
|
|
||||||
|
code = request.args.get("code")
|
||||||
|
if not code:
|
||||||
|
return await _render_web_oauth_popup(flow_id, False, "Missing authorization code from Box.", "box")
|
||||||
|
|
||||||
|
cache_payload = json.loads(REDIS_CONN.get(_web_state_cache_key(flow_id, "box")))
|
||||||
|
if not cache_payload:
|
||||||
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box OAuth session expired or invalid.")
|
||||||
|
|
||||||
|
error = request.args.get("error")
|
||||||
|
error_description = request.args.get("error_description") or error
|
||||||
|
if error:
|
||||||
|
REDIS_CONN.delete(_web_state_cache_key(flow_id, "box"))
|
||||||
|
return await _render_web_oauth_popup(flow_id, False, error_description or "Authorization failed.", "box")
|
||||||
|
|
||||||
|
auth = BoxOAuth(
|
||||||
|
OAuthConfig(
|
||||||
|
client_id=cache_payload.get("client_id"),
|
||||||
|
client_secret=cache_payload.get("client_secret"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
auth.get_tokens_authorization_code_grant(code)
|
||||||
|
token = auth.retrieve_token()
|
||||||
|
result_payload = {
|
||||||
|
"user_id": cache_payload.get("user_id"),
|
||||||
|
"client_id": cache_payload.get("client_id"),
|
||||||
|
"client_secret": cache_payload.get("client_secret"),
|
||||||
|
"access_token": token.access_token,
|
||||||
|
"refresh_token": token.refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
REDIS_CONN.set_obj(_web_result_cache_key(flow_id, "box"), result_payload, WEB_FLOW_TTL_SECS)
|
||||||
|
REDIS_CONN.delete(_web_state_cache_key(flow_id, "box"))
|
||||||
|
|
||||||
|
return await _render_web_oauth_popup(flow_id, True, "Authorization completed successfully.", "box")
|
||||||
|
|
||||||
|
@manager.route("/box/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("flow_id")
|
||||||
|
async def poll_box_web_result():
|
||||||
|
req = await get_request_json()
|
||||||
|
flow_id = req.get("flow_id")
|
||||||
|
|
||||||
|
cache_blob = REDIS_CONN.get(_web_result_cache_key(flow_id, "box"))
|
||||||
|
if not cache_blob:
|
||||||
|
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||||
|
|
||||||
|
cache_raw = json.loads(cache_blob)
|
||||||
|
if cache_raw.get("user_id") != current_user.id:
|
||||||
|
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||||
|
|
||||||
|
REDIS_CONN.delete(_web_result_cache_key(flow_id, "box"))
|
||||||
|
|
||||||
|
return get_json_result(data={"credentials": cache_raw})
|
||||||
@ -14,19 +14,21 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import tempfile
|
||||||
from quart import Response, request
|
from quart import Response, request
|
||||||
from api.apps import current_user, login_required
|
from api.apps import current_user, login_required
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||||
from rag.prompts.template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.generator import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
from common.constants import RetCode, LLMType
|
from common.constants import RetCode, LLMType
|
||||||
@ -35,7 +37,7 @@ from common.constants import RetCode, LLMType
|
|||||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def set_conversation():
|
async def set_conversation():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
conv_id = req.get("conversation_id")
|
conv_id = req.get("conversation_id")
|
||||||
is_new = req.get("is_new")
|
is_new = req.get("is_new")
|
||||||
name = req.get("name", "New conversation")
|
name = req.get("name", "New conversation")
|
||||||
@ -78,7 +80,7 @@ async def set_conversation():
|
|||||||
|
|
||||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def get():
|
async def get():
|
||||||
conv_id = request.args["conversation_id"]
|
conv_id = request.args["conversation_id"]
|
||||||
try:
|
try:
|
||||||
e, conv = ConversationService.get_by_id(conv_id)
|
e, conv = ConversationService.get_by_id(conv_id)
|
||||||
@ -129,7 +131,7 @@ def getsse(dialog_id):
|
|||||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
conv_ids = req["conversation_ids"]
|
conv_ids = req["conversation_ids"]
|
||||||
try:
|
try:
|
||||||
for cid in conv_ids:
|
for cid in conv_ids:
|
||||||
@ -150,7 +152,7 @@ async def rm():
|
|||||||
|
|
||||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def list_conversation():
|
async def list_conversation():
|
||||||
dialog_id = request.args["dialog_id"]
|
dialog_id = request.args["dialog_id"]
|
||||||
try:
|
try:
|
||||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||||
@ -167,7 +169,7 @@ def list_conversation():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "messages")
|
@validate_request("conversation_id", "messages")
|
||||||
async def completion():
|
async def completion():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
msg = []
|
msg = []
|
||||||
for m in req["messages"]:
|
for m in req["messages"]:
|
||||||
if m["role"] == "system":
|
if m["role"] == "system":
|
||||||
@ -216,10 +218,10 @@ async def completion():
|
|||||||
dia.llm_setting = chat_model_config
|
dia.llm_setting = chat_model_config
|
||||||
|
|
||||||
is_embedded = bool(chat_model_id)
|
is_embedded = bool(chat_model_id)
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **req):
|
async for ans in async_chat(dia, msg, True, **req):
|
||||||
ans = structure_answer(conv, ans, message_id, conv.id)
|
ans = structure_answer(conv, ans, message_id, conv.id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
@ -239,7 +241,7 @@ async def completion():
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, **req):
|
async for ans in async_chat(dia, msg, **req):
|
||||||
answer = structure_answer(conv, ans, message_id, conv.id)
|
answer = structure_answer(conv, ans, message_id, conv.id)
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
@ -248,11 +250,69 @@ async def completion():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def sequence2txt():
|
||||||
|
req = await request.form
|
||||||
|
stream_mode = req.get("stream", "false").lower() == "true"
|
||||||
|
files = await request.files
|
||||||
|
if "file" not in files:
|
||||||
|
return get_data_error_result(message="Missing 'file' in multipart form-data")
|
||||||
|
|
||||||
|
uploaded = files["file"]
|
||||||
|
|
||||||
|
ALLOWED_EXTS = {
|
||||||
|
".wav", ".mp3", ".m4a", ".aac",
|
||||||
|
".flac", ".ogg", ".webm",
|
||||||
|
".opus", ".wma"
|
||||||
|
}
|
||||||
|
|
||||||
|
filename = uploaded.filename or ""
|
||||||
|
suffix = os.path.splitext(filename)[-1].lower()
|
||||||
|
if suffix not in ALLOWED_EXTS:
|
||||||
|
return get_data_error_result(message=
|
||||||
|
f"Unsupported audio format: {suffix}. "
|
||||||
|
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||||
|
)
|
||||||
|
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||||
|
os.close(fd)
|
||||||
|
await uploaded.save(temp_audio_path)
|
||||||
|
|
||||||
|
tenants = TenantService.get_info_by(current_user.id)
|
||||||
|
if not tenants:
|
||||||
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
|
asr_id = tenants[0]["asr_id"]
|
||||||
|
if not asr_id:
|
||||||
|
return get_data_error_result(message="No default ASR model is set")
|
||||||
|
|
||||||
|
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||||
|
if not stream_mode:
|
||||||
|
text = asr_mdl.transcription(temp_audio_path)
|
||||||
|
try:
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||||
|
return get_json_result(data={"text": text})
|
||||||
|
async def event_stream():
|
||||||
|
try:
|
||||||
|
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||||
|
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
err = {"event": "error", "text": str(e)}
|
||||||
|
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||||
|
|
||||||
|
return Response(event_stream(), content_type="text/event-stream")
|
||||||
|
|
||||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def tts():
|
async def tts():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
text = req["text"]
|
text = req["text"]
|
||||||
|
|
||||||
tenants = TenantService.get_info_by(current_user.id)
|
tenants = TenantService.get_info_by(current_user.id)
|
||||||
@ -285,7 +345,7 @@ async def tts():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "message_id")
|
@validate_request("conversation_id", "message_id")
|
||||||
async def delete_msg():
|
async def delete_msg():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
@ -308,7 +368,7 @@ async def delete_msg():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "message_id")
|
@validate_request("conversation_id", "message_id")
|
||||||
async def thumbup():
|
async def thumbup():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
@ -335,7 +395,7 @@ async def thumbup():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("question", "kb_ids")
|
@validate_request("question", "kb_ids")
|
||||||
async def ask_about():
|
async def ask_about():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
uid = current_user.id
|
uid = current_user.id
|
||||||
|
|
||||||
search_id = req.get("search_id", "")
|
search_id = req.get("search_id", "")
|
||||||
@ -346,10 +406,10 @@ async def ask_about():
|
|||||||
if search_app:
|
if search_app:
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
@ -367,7 +427,7 @@ async def ask_about():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("question", "kb_ids")
|
@validate_request("question", "kb_ids")
|
||||||
async def mindmap():
|
async def mindmap():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
search_id = req.get("search_id", "")
|
search_id = req.get("search_id", "")
|
||||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||||
@ -375,7 +435,7 @@ async def mindmap():
|
|||||||
kb_ids.extend(req["kb_ids"])
|
kb_ids.extend(req["kb_ids"])
|
||||||
kb_ids = list(set(kb_ids))
|
kb_ids = list(set(kb_ids))
|
||||||
|
|
||||||
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
mind_map = await gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||||
if "error" in mind_map:
|
if "error" in mind_map:
|
||||||
return server_error_response(Exception(mind_map["error"]))
|
return server_error_response(Exception(mind_map["error"]))
|
||||||
return get_json_result(data=mind_map)
|
return get_json_result(data=mind_map)
|
||||||
@ -385,7 +445,7 @@ async def mindmap():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("question")
|
@validate_request("question")
|
||||||
async def related_questions():
|
async def related_questions():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
|
|
||||||
search_id = req.get("search_id", "")
|
search_id = req.get("search_id", "")
|
||||||
search_config = {}
|
search_config = {}
|
||||||
@ -402,7 +462,7 @@ async def related_questions():
|
|||||||
if "parameter" in gen_conf:
|
if "parameter" in gen_conf:
|
||||||
del gen_conf["parameter"]
|
del gen_conf["parameter"]
|
||||||
prompt = load_prompt("related_question")
|
prompt = load_prompt("related_question")
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
|||||||
@ -21,10 +21,9 @@ from common.constants import StatusEnum
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from api.utils.api_utils import get_json_result
|
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +31,7 @@ from api.apps import login_required, current_user
|
|||||||
@validate_request("prompt_config")
|
@validate_request("prompt_config")
|
||||||
@login_required
|
@login_required
|
||||||
async def set_dialog():
|
async def set_dialog():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
dialog_id = req.get("dialog_id", "")
|
dialog_id = req.get("dialog_id", "")
|
||||||
is_create = not dialog_id
|
is_create = not dialog_id
|
||||||
name = req.get("name", "New Dialog")
|
name = req.get("name", "New Dialog")
|
||||||
@ -66,7 +65,7 @@ async def set_dialog():
|
|||||||
|
|
||||||
if not is_create:
|
if not is_create:
|
||||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
|
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
|
||||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no knowledge base / Tavily used here.")
|
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config["parameters"]:
|
||||||
if p["optional"]:
|
if p["optional"]:
|
||||||
@ -181,7 +180,7 @@ async def list_dialogs_next():
|
|||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
|
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
owner_ids = req.get("owner_ids", [])
|
owner_ids = req.get("owner_ids", [])
|
||||||
try:
|
try:
|
||||||
if not owner_ids:
|
if not owner_ids:
|
||||||
@ -209,7 +208,7 @@ async def list_dialogs_next():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("dialog_ids")
|
@validate_request("dialog_ids")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
dialog_list=[]
|
dialog_list=[]
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -26,6 +27,7 @@ from api.db import VALID_FILE_TYPES, FileType
|
|||||||
from api.db.db_models import Task
|
from api.db.db_models import Task
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||||
|
from common.metadata_utils import meta_filter, convert_conditions
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
@ -36,7 +38,7 @@ from api.utils.api_utils import (
|
|||||||
get_data_error_result,
|
get_data_error_result,
|
||||||
get_json_result,
|
get_json_result,
|
||||||
server_error_response,
|
server_error_response,
|
||||||
validate_request, request_json,
|
validate_request, get_request_json,
|
||||||
)
|
)
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
@ -68,11 +70,11 @@ async def upload():
|
|||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this dataset!")
|
||||||
if not check_kb_team_permission(kb, current_user.id):
|
if not check_kb_team_permission(kb, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||||
if err:
|
if err:
|
||||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
@ -97,7 +99,7 @@ async def web_crawl():
|
|||||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this dataset!")
|
||||||
if check_kb_team_permission(kb, current_user.id):
|
if check_kb_team_permission(kb, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
@ -153,7 +155,7 @@ async def web_crawl():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("name", "kb_id")
|
@validate_request("name", "kb_id")
|
||||||
async def create():
|
async def create():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
kb_id = req["kb_id"]
|
kb_id = req["kb_id"]
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||||
@ -167,10 +169,10 @@ async def create():
|
|||||||
try:
|
try:
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Can't find this knowledgebase!")
|
return get_data_error_result(message="Can't find this dataset!")
|
||||||
|
|
||||||
if DocumentService.query(name=req["name"], kb_id=kb_id):
|
if DocumentService.query(name=req["name"], kb_id=kb_id):
|
||||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
return get_data_error_result(message="Duplicated document name in the same dataset.")
|
||||||
|
|
||||||
kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
|
kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
|
||||||
if not kb_root_folder:
|
if not kb_root_folder:
|
||||||
@ -217,7 +219,7 @@ async def list_docs():
|
|||||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
keywords = request.args.get("keywords", "")
|
keywords = request.args.get("keywords", "")
|
||||||
|
|
||||||
page_number = int(request.args.get("page", 0))
|
page_number = int(request.args.get("page", 0))
|
||||||
@ -230,7 +232,7 @@ async def list_docs():
|
|||||||
create_time_from = int(request.args.get("create_time_from", 0))
|
create_time_from = int(request.args.get("create_time_from", 0))
|
||||||
create_time_to = int(request.args.get("create_time_to", 0))
|
create_time_to = int(request.args.get("create_time_to", 0))
|
||||||
|
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
run_status = req.get("run_status", [])
|
run_status = req.get("run_status", [])
|
||||||
if run_status:
|
if run_status:
|
||||||
@ -245,9 +247,19 @@ async def list_docs():
|
|||||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||||
|
|
||||||
suffix = req.get("suffix", [])
|
suffix = req.get("suffix", [])
|
||||||
|
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_data_error_result(message="metadata_condition must be an object.")
|
||||||
|
|
||||||
|
doc_ids_filter = None
|
||||||
|
if metadata_condition:
|
||||||
|
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
||||||
|
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||||
|
if metadata_condition.get("conditions") and not doc_ids_filter:
|
||||||
|
return get_json_result(data={"total": 0, "docs": []})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix)
|
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids_filter)
|
||||||
|
|
||||||
if create_time_from or create_time_to:
|
if create_time_from or create_time_to:
|
||||||
filtered_docs = []
|
filtered_docs = []
|
||||||
@ -271,7 +283,7 @@ async def list_docs():
|
|||||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def get_filter():
|
async def get_filter():
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
kb_id = req.get("kb_id")
|
kb_id = req.get("kb_id")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
@ -281,7 +293,7 @@ async def get_filter():
|
|||||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
keywords = req.get("keywords", "")
|
keywords = req.get("keywords", "")
|
||||||
|
|
||||||
@ -309,7 +321,7 @@ async def get_filter():
|
|||||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def doc_infos():
|
async def doc_infos():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
doc_ids = req["doc_ids"]
|
doc_ids = req["doc_ids"]
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
if not DocumentService.accessible(doc_id, current_user.id):
|
if not DocumentService.accessible(doc_id, current_user.id):
|
||||||
@ -318,6 +330,87 @@ async def doc_infos():
|
|||||||
return get_json_result(data=list(docs.dicts()))
|
return get_json_result(data=list(docs.dicts()))
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/metadata/summary", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def metadata_summary():
|
||||||
|
req = await get_request_json()
|
||||||
|
kb_id = req.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
for tenant in tenants:
|
||||||
|
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = DocumentService.get_metadata_summary(kb_id)
|
||||||
|
return get_json_result(data={"summary": summary})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def metadata_update():
|
||||||
|
req = await get_request_json()
|
||||||
|
kb_id = req.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
for tenant in tenants:
|
||||||
|
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
selector = req.get("selector", {}) or {}
|
||||||
|
updates = req.get("updates", []) or []
|
||||||
|
deletes = req.get("deletes", []) or []
|
||||||
|
|
||||||
|
if not isinstance(selector, dict):
|
||||||
|
return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
if not isinstance(updates, list) or not isinstance(deletes, list):
|
||||||
|
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
metadata_condition = selector.get("metadata_condition", {}) or {}
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
document_ids = selector.get("document_ids", []) or []
|
||||||
|
if document_ids and not isinstance(document_ids, list):
|
||||||
|
return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
for upd in updates:
|
||||||
|
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
|
||||||
|
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
for d in deletes:
|
||||||
|
if not isinstance(d, dict) or not d.get("key"):
|
||||||
|
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id])
|
||||||
|
target_doc_ids = set(kb_doc_ids)
|
||||||
|
if document_ids:
|
||||||
|
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
||||||
|
if invalid_ids:
|
||||||
|
return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
target_doc_ids = set(document_ids)
|
||||||
|
|
||||||
|
if metadata_condition:
|
||||||
|
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
||||||
|
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||||
|
target_doc_ids = target_doc_ids & filtered_ids
|
||||||
|
if metadata_condition.get("conditions") and not target_doc_ids:
|
||||||
|
return get_json_result(data={"updated": 0, "matched_docs": 0})
|
||||||
|
|
||||||
|
target_doc_ids = list(target_doc_ids)
|
||||||
|
updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes)
|
||||||
|
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
|
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
|
||||||
# @login_required
|
# @login_required
|
||||||
def thumbnails():
|
def thumbnails():
|
||||||
@ -341,7 +434,7 @@ def thumbnails():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_ids", "status")
|
@validate_request("doc_ids", "status")
|
||||||
async def change_status():
|
async def change_status():
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
doc_ids = req.get("doc_ids", [])
|
doc_ids = req.get("doc_ids", [])
|
||||||
status = str(req.get("status", ""))
|
status = str(req.get("status", ""))
|
||||||
|
|
||||||
@ -361,7 +454,7 @@ async def change_status():
|
|||||||
continue
|
continue
|
||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
result[doc_id] = {"error": "Can't find this knowledgebase!"}
|
result[doc_id] = {"error": "Can't find this dataset!"}
|
||||||
continue
|
continue
|
||||||
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
|
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
|
||||||
result[doc_id] = {"error": "Database error (Document update)!"}
|
result[doc_id] = {"error": "Database error (Document update)!"}
|
||||||
@ -381,7 +474,7 @@ async def change_status():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id")
|
@validate_request("doc_id")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
doc_ids = req["doc_id"]
|
doc_ids = req["doc_id"]
|
||||||
if isinstance(doc_ids, str):
|
if isinstance(doc_ids, str):
|
||||||
doc_ids = [doc_ids]
|
doc_ids = [doc_ids]
|
||||||
@ -390,7 +483,7 @@ async def rm():
|
|||||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||||
@ -402,11 +495,13 @@ async def rm():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_ids", "run")
|
@validate_request("doc_ids", "run")
|
||||||
async def run():
|
async def run():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
try:
|
||||||
|
def _run_sync():
|
||||||
for doc_id in req["doc_ids"]:
|
for doc_id in req["doc_ids"]:
|
||||||
if not DocumentService.accessible(doc_id, current_user.id):
|
if not DocumentService.accessible(doc_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
try:
|
|
||||||
kb_table_num_map = {}
|
kb_table_num_map = {}
|
||||||
for id in req["doc_ids"]:
|
for id in req["doc_ids"]:
|
||||||
info = {"run": str(req["run"]), "progress": 0}
|
info = {"run": str(req["run"]), "progress": 0}
|
||||||
@ -437,10 +532,12 @@ async def run():
|
|||||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
doc = doc.to_dict()
|
doc_dict = doc.to_dict()
|
||||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_run_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -449,10 +546,12 @@ async def run():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id", "name")
|
@validate_request("doc_id", "name")
|
||||||
async def rename():
|
async def rename():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
try:
|
||||||
|
def _rename_sync():
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
try:
|
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
@ -463,7 +562,7 @@ async def rename():
|
|||||||
|
|
||||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||||
if d.name == req["name"]:
|
if d.name == req["name"]:
|
||||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
return get_data_error_result(message="Duplicated document name in the same dataset.")
|
||||||
|
|
||||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||||
return get_data_error_result(message="Database error (Document rename)!")
|
return get_data_error_result(message="Database error (Document rename)!")
|
||||||
@ -487,8 +586,10 @@ async def rename():
|
|||||||
search.index_name(tenant_id),
|
search.index_name(tenant_id),
|
||||||
doc.kb_id,
|
doc.kb_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rename_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -502,7 +603,8 @@ async def get(doc_id):
|
|||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||||
response = await make_response(settings.STORAGE_IMPL.get(b, n))
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||||
|
response = await make_response(data)
|
||||||
|
|
||||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||||
ext = ext.group(1) if ext else None
|
ext = ext.group(1) if ext else None
|
||||||
@ -523,8 +625,7 @@ async def get(doc_id):
|
|||||||
async def download_attachment(attachment_id):
|
async def download_attachment(attachment_id):
|
||||||
try:
|
try:
|
||||||
ext = request.args.get("ext", "markdown")
|
ext = request.args.get("ext", "markdown")
|
||||||
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||||
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
|
|
||||||
response = await make_response(data)
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
@ -539,7 +640,7 @@ async def download_attachment(attachment_id):
|
|||||||
@validate_request("doc_id")
|
@validate_request("doc_id")
|
||||||
async def change_parser():
|
async def change_parser():
|
||||||
|
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
@ -596,7 +697,8 @@ async def get_image(image_id):
|
|||||||
if len(arr) != 2:
|
if len(arr) != 2:
|
||||||
return get_data_error_result(message="Image not found.")
|
return get_data_error_result(message="Image not found.")
|
||||||
bkt, nm = image_id.split("-")
|
bkt, nm = image_id.split("-")
|
||||||
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||||
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", "image/JPEG")
|
response.headers.set("Content-Type", "image/JPEG")
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -607,7 +709,7 @@ async def get_image(image_id):
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id")
|
@validate_request("conversation_id")
|
||||||
async def upload_and_parse():
|
async def upload_and_parse():
|
||||||
files = await request.file
|
files = await request.files
|
||||||
if "file" not in files:
|
if "file" not in files:
|
||||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
@ -624,7 +726,8 @@ async def upload_and_parse():
|
|||||||
@manager.route("/parse", methods=["POST"]) # noqa: F821
|
@manager.route("/parse", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def parse():
|
async def parse():
|
||||||
url = await request.json.get("url") if await request.json else ""
|
req = await get_request_json()
|
||||||
|
url = req.get("url", "")
|
||||||
if url:
|
if url:
|
||||||
if not is_valid_url(url):
|
if not is_valid_url(url):
|
||||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||||
@ -679,7 +782,7 @@ async def parse():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id", "meta")
|
@validate_request("doc_id", "meta")
|
||||||
async def set_meta():
|
async def set_meta():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
try:
|
try:
|
||||||
@ -687,7 +790,10 @@ async def set_meta():
|
|||||||
if not isinstance(meta, dict):
|
if not isinstance(meta, dict):
|
||||||
return get_json_result(data=False, message="Only dictionary type supported.", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="Only dictionary type supported.", code=RetCode.ARGUMENT_ERROR)
|
||||||
for k, v in meta.items():
|
for k, v in meta.items():
|
||||||
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float):
|
if isinstance(v, list):
|
||||||
|
if not all(isinstance(i, (str, int, float)) for i in v):
|
||||||
|
return get_json_result(data=False, message=f"The type is not supported in list: {v}", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
elif not isinstance(v, (str, int, float)):
|
||||||
return get_json_result(data=False, message=f"The type is not supported: {v}", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message=f"The type is not supported: {v}", code=RetCode.ARGUMENT_ERROR)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_json_result(data=False, message=f"Json syntax error: {e}", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message=f"Json syntax error: {e}", code=RetCode.ARGUMENT_ERROR)
|
||||||
@ -705,3 +811,13 @@ async def set_meta():
|
|||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/upload_info", methods=["POST"]) # noqa: F821
|
||||||
|
async def upload_info():
|
||||||
|
files = await request.files
|
||||||
|
file = files['file'] if files and files.get("file") else None
|
||||||
|
try:
|
||||||
|
return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url")))
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|||||||
479
api/apps/evaluation_app.py
Normal file
479
api/apps/evaluation_app.py
Normal file
@ -0,0 +1,479 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
RAG Evaluation API Endpoints
|
||||||
|
|
||||||
|
Provides REST API for RAG evaluation functionality including:
|
||||||
|
- Dataset management
|
||||||
|
- Test case management
|
||||||
|
- Evaluation execution
|
||||||
|
- Results retrieval
|
||||||
|
- Configuration recommendations
|
||||||
|
"""
|
||||||
|
|
||||||
|
from quart import request
|
||||||
|
from api.apps import login_required, current_user
|
||||||
|
from api.db.services.evaluation_service import EvaluationService
|
||||||
|
from api.utils.api_utils import (
|
||||||
|
get_data_error_result,
|
||||||
|
get_json_result,
|
||||||
|
get_request_json,
|
||||||
|
server_error_response,
|
||||||
|
validate_request
|
||||||
|
)
|
||||||
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Dataset Management ====================
|
||||||
|
|
||||||
|
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("name", "kb_ids")
|
||||||
|
async def create_dataset():
|
||||||
|
"""
|
||||||
|
Create a new evaluation dataset.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"name": "Dataset name",
|
||||||
|
"description": "Optional description",
|
||||||
|
"kb_ids": ["kb_id1", "kb_id2"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
name = req.get("name", "").strip()
|
||||||
|
description = req.get("description", "")
|
||||||
|
kb_ids = req.get("kb_ids", [])
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
return get_data_error_result(message="Dataset name cannot be empty")
|
||||||
|
|
||||||
|
if not kb_ids or not isinstance(kb_ids, list):
|
||||||
|
return get_data_error_result(message="kb_ids must be a non-empty list")
|
||||||
|
|
||||||
|
success, result = EvaluationService.create_dataset(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
tenant_id=current_user.id,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message=result)
|
||||||
|
|
||||||
|
return get_json_result(data={"dataset_id": result})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def list_datasets():
|
||||||
|
"""
|
||||||
|
List evaluation datasets for current tenant.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
- page: Page number (default: 1)
|
||||||
|
- page_size: Items per page (default: 20)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
page = int(request.args.get("page", 1))
|
||||||
|
page_size = int(request.args.get("page_size", 20))
|
||||||
|
|
||||||
|
result = EvaluationService.list_datasets(
|
||||||
|
tenant_id=current_user.id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_dataset(dataset_id):
|
||||||
|
"""Get dataset details by ID"""
|
||||||
|
try:
|
||||||
|
dataset = EvaluationService.get_dataset(dataset_id)
|
||||||
|
if not dataset:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Dataset not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=dataset)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def update_dataset(dataset_id):
|
||||||
|
"""
|
||||||
|
Update dataset.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"name": "New name",
|
||||||
|
"description": "New description",
|
||||||
|
"kb_ids": ["kb_id1", "kb_id2"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
|
||||||
|
# Remove fields that shouldn't be updated
|
||||||
|
req.pop("id", None)
|
||||||
|
req.pop("tenant_id", None)
|
||||||
|
req.pop("created_by", None)
|
||||||
|
req.pop("create_time", None)
|
||||||
|
|
||||||
|
success = EvaluationService.update_dataset(dataset_id, **req)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message="Failed to update dataset")
|
||||||
|
|
||||||
|
return get_json_result(data={"dataset_id": dataset_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def delete_dataset(dataset_id):
|
||||||
|
"""Delete dataset (soft delete)"""
|
||||||
|
try:
|
||||||
|
success = EvaluationService.delete_dataset(dataset_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message="Failed to delete dataset")
|
||||||
|
|
||||||
|
return get_json_result(data={"dataset_id": dataset_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Test Case Management ====================
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("question")
|
||||||
|
async def add_test_case(dataset_id):
|
||||||
|
"""
|
||||||
|
Add a test case to a dataset.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"question": "Test question",
|
||||||
|
"reference_answer": "Optional ground truth answer",
|
||||||
|
"relevant_doc_ids": ["doc_id1", "doc_id2"],
|
||||||
|
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
|
||||||
|
"metadata": {"key": "value"}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
question = req.get("question", "").strip()
|
||||||
|
|
||||||
|
if not question:
|
||||||
|
return get_data_error_result(message="Question cannot be empty")
|
||||||
|
|
||||||
|
success, result = EvaluationService.add_test_case(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
question=question,
|
||||||
|
reference_answer=req.get("reference_answer"),
|
||||||
|
relevant_doc_ids=req.get("relevant_doc_ids"),
|
||||||
|
relevant_chunk_ids=req.get("relevant_chunk_ids"),
|
||||||
|
metadata=req.get("metadata")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message=result)
|
||||||
|
|
||||||
|
return get_json_result(data={"case_id": result})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("cases")
|
||||||
|
async def import_test_cases(dataset_id):
|
||||||
|
"""
|
||||||
|
Bulk import test cases.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"cases": [
|
||||||
|
{
|
||||||
|
"question": "Question 1",
|
||||||
|
"reference_answer": "Answer 1",
|
||||||
|
...
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "Question 2",
|
||||||
|
...
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
cases = req.get("cases", [])
|
||||||
|
|
||||||
|
if not cases or not isinstance(cases, list):
|
||||||
|
return get_data_error_result(message="cases must be a non-empty list")
|
||||||
|
|
||||||
|
success_count, failure_count = EvaluationService.import_test_cases(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
cases=cases
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"success_count": success_count,
|
||||||
|
"failure_count": failure_count,
|
||||||
|
"total": len(cases)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_test_cases(dataset_id):
|
||||||
|
"""Get all test cases for a dataset"""
|
||||||
|
try:
|
||||||
|
cases = EvaluationService.get_test_cases(dataset_id)
|
||||||
|
return get_json_result(data={"cases": cases, "total": len(cases)})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def delete_test_case(case_id):
|
||||||
|
"""Delete a test case"""
|
||||||
|
try:
|
||||||
|
success = EvaluationService.delete_test_case(case_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message="Failed to delete test case")
|
||||||
|
|
||||||
|
return get_json_result(data={"case_id": case_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Evaluation Execution ====================
|
||||||
|
|
||||||
|
@manager.route('/run/start', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("dataset_id", "dialog_id")
|
||||||
|
async def start_evaluation():
|
||||||
|
"""
|
||||||
|
Start an evaluation run.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"dataset_id": "dataset_id",
|
||||||
|
"dialog_id": "dialog_id",
|
||||||
|
"name": "Optional run name"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
dataset_id = req.get("dataset_id")
|
||||||
|
dialog_id = req.get("dialog_id")
|
||||||
|
name = req.get("name")
|
||||||
|
|
||||||
|
success, result = EvaluationService.start_evaluation(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
dialog_id=dialog_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message=result)
|
||||||
|
|
||||||
|
return get_json_result(data={"run_id": result})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_evaluation_run(run_id):
|
||||||
|
"""Get evaluation run details"""
|
||||||
|
try:
|
||||||
|
result = EvaluationService.get_run_results(run_id)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Evaluation run not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_run_results(run_id):
|
||||||
|
"""Get detailed results for an evaluation run"""
|
||||||
|
try:
|
||||||
|
result = EvaluationService.get_run_results(run_id)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Evaluation run not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/list', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def list_evaluation_runs():
|
||||||
|
"""
|
||||||
|
List evaluation runs.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
- dataset_id: Filter by dataset (optional)
|
||||||
|
- dialog_id: Filter by dialog (optional)
|
||||||
|
- page: Page number (default: 1)
|
||||||
|
- page_size: Items per page (default: 20)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# TODO: Implement list_runs in EvaluationService
|
||||||
|
return get_json_result(data={"runs": [], "total": 0})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def delete_evaluation_run(run_id):
|
||||||
|
"""Delete an evaluation run"""
|
||||||
|
try:
|
||||||
|
# TODO: Implement delete_run in EvaluationService
|
||||||
|
return get_json_result(data={"run_id": run_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Analysis & Recommendations ====================
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_recommendations(run_id):
|
||||||
|
"""Get configuration recommendations based on evaluation results"""
|
||||||
|
try:
|
||||||
|
recommendations = EvaluationService.get_recommendations(run_id)
|
||||||
|
return get_json_result(data={"recommendations": recommendations})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/compare', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("run_ids")
|
||||||
|
async def compare_runs():
|
||||||
|
"""
|
||||||
|
Compare multiple evaluation runs.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"run_ids": ["run_id1", "run_id2", "run_id3"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
run_ids = req.get("run_ids", [])
|
||||||
|
|
||||||
|
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="run_ids must be a list with at least 2 run IDs"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement compare_runs in EvaluationService
|
||||||
|
return get_json_result(data={"comparison": {}})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def export_results(run_id):
|
||||||
|
"""Export evaluation results as JSON/CSV"""
|
||||||
|
try:
|
||||||
|
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
|
||||||
|
|
||||||
|
result = EvaluationService.get_run_results(run_id)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Evaluation run not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement CSV export
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Real-time Evaluation ====================
|
||||||
|
|
||||||
|
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("question", "dialog_id")
|
||||||
|
async def evaluate_single():
|
||||||
|
"""
|
||||||
|
Evaluate a single question-answer pair in real-time.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"question": "Test question",
|
||||||
|
"dialog_id": "dialog_id",
|
||||||
|
"reference_answer": "Optional ground truth",
|
||||||
|
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# req = await get_request_json() # TODO: Use for single evaluation implementation
|
||||||
|
|
||||||
|
# TODO: Implement single evaluation
|
||||||
|
# This would execute the RAG pipeline and return metrics immediately
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"answer": "",
|
||||||
|
"metrics": {},
|
||||||
|
"retrieved_chunks": []
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
@ -19,22 +19,20 @@ from pathlib import Path
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
|
||||||
from quart import request
|
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.utils.api_utils import get_json_result
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/convert', methods=['POST']) # noqa: F821
|
@manager.route('/convert', methods=['POST']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("file_ids", "kb_ids")
|
@validate_request("file_ids", "kb_ids")
|
||||||
async def convert():
|
async def convert():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
kb_ids = req["kb_ids"]
|
kb_ids = req["kb_ids"]
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
file2documents = []
|
file2documents = []
|
||||||
@ -70,7 +68,7 @@ async def convert():
|
|||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Can't find this knowledgebase!")
|
message="Can't find this dataset!")
|
||||||
e, file = FileService.get_by_id(id)
|
e, file = FileService.get_by_id(id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
@ -79,7 +77,8 @@ async def convert():
|
|||||||
doc = DocumentService.insert({
|
doc = DocumentService.insert({
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"kb_id": kb.id,
|
"kb_id": kb.id,
|
||||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
"parser_id": kb.parser_id,
|
||||||
|
"pipeline_id": kb.pipeline_id,
|
||||||
"parser_config": kb.parser_config,
|
"parser_config": kb.parser_config,
|
||||||
"created_by": current_user.id,
|
"created_by": current_user.id,
|
||||||
"type": file.type,
|
"type": file.type,
|
||||||
@ -104,7 +103,7 @@ async def convert():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("file_ids")
|
@validate_request("file_ids")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
if not file_ids:
|
if not file_ids:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
@ -29,7 +30,7 @@ from common.constants import RetCode, FileSource
|
|||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result, get_request_json
|
||||||
from api.utils.file_utils import filename_type
|
from api.utils.file_utils import filename_type
|
||||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||||
from common import settings
|
from common import settings
|
||||||
@ -61,9 +62,10 @@ async def upload():
|
|||||||
e, pf_folder = FileService.get_by_id(pf_id)
|
e, pf_folder = FileService.get_by_id(pf_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result( message="Can't find this folder!")
|
return get_data_error_result( message="Can't find this folder!")
|
||||||
for file_obj in file_objs:
|
|
||||||
|
async def _handle_single_file(file_obj):
|
||||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||||
|
|
||||||
# split file name path
|
# split file name path
|
||||||
@ -75,35 +77,36 @@ async def upload():
|
|||||||
file_len = len(file_obj_names)
|
file_len = len(file_obj_names)
|
||||||
|
|
||||||
# get folder
|
# get folder
|
||||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||||
len_id_list = len(file_id_list)
|
len_id_list = len(file_id_list)
|
||||||
|
|
||||||
# create folder
|
# create folder
|
||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
else:
|
else:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
|
|
||||||
# file type
|
# file type
|
||||||
filetype = filename_type(file_obj_names[file_len - 1])
|
filetype = filename_type(file_obj_names[file_len - 1])
|
||||||
location = file_obj_names[file_len - 1]
|
location = file_obj_names[file_len - 1]
|
||||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||||
location += "_"
|
location += "_"
|
||||||
blob = file_obj.read()
|
blob = await asyncio.to_thread(file_obj.read)
|
||||||
filename = duplicate_name(
|
filename = await asyncio.to_thread(
|
||||||
|
duplicate_name,
|
||||||
FileService.query,
|
FileService.query,
|
||||||
name=file_obj_names[file_len - 1],
|
name=file_obj_names[file_len - 1],
|
||||||
parent_id=last_folder.id)
|
parent_id=last_folder.id)
|
||||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||||
file = {
|
file_data = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"parent_id": last_folder.id,
|
"parent_id": last_folder.id,
|
||||||
"tenant_id": current_user.id,
|
"tenant_id": current_user.id,
|
||||||
@ -113,8 +116,13 @@ async def upload():
|
|||||||
"location": location,
|
"location": location,
|
||||||
"size": len(blob),
|
"size": len(blob),
|
||||||
}
|
}
|
||||||
file = FileService.insert(file)
|
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||||
file_res.append(file.to_json())
|
return inserted.to_json()
|
||||||
|
|
||||||
|
for file_obj in file_objs:
|
||||||
|
res = await _handle_single_file(file_obj)
|
||||||
|
file_res.append(res)
|
||||||
|
|
||||||
return get_json_result(data=file_res)
|
return get_json_result(data=file_res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -124,7 +132,7 @@ async def upload():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("name")
|
@validate_request("name")
|
||||||
async def create():
|
async def create():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
pf_id = req.get("parent_id")
|
pf_id = req.get("parent_id")
|
||||||
input_file_type = req.get("type")
|
input_file_type = req.get("type")
|
||||||
if not pf_id:
|
if not pf_id:
|
||||||
@ -239,9 +247,10 @@ def get_all_parent_folders():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("file_ids")
|
@validate_request("file_ids")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
|
|
||||||
|
try:
|
||||||
def _delete_single_file(file):
|
def _delete_single_file(file):
|
||||||
try:
|
try:
|
||||||
if file.location:
|
if file.location:
|
||||||
@ -271,7 +280,7 @@ async def rm():
|
|||||||
|
|
||||||
FileService.delete(folder)
|
FileService.delete(folder)
|
||||||
|
|
||||||
try:
|
def _rm_sync():
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e or not file:
|
if not e or not file:
|
||||||
@ -292,6 +301,8 @@ async def rm():
|
|||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rm_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -300,7 +311,7 @@ async def rm():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("file_id", "name")
|
@validate_request("file_id", "name")
|
||||||
async def rename():
|
async def rename():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(req["file_id"])
|
e, file = FileService.get_by_id(req["file_id"])
|
||||||
if not e:
|
if not e:
|
||||||
@ -346,10 +357,10 @@ async def get(file_id):
|
|||||||
if not check_file_team_permission(file, current_user.id):
|
if not check_file_team_permission(file, current_user.id):
|
||||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||||
if not blob:
|
if not blob:
|
||||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||||
blob = settings.STORAGE_IMPL.get(b, n)
|
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||||
|
|
||||||
response = await make_response(blob)
|
response = await make_response(blob)
|
||||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||||
@ -369,7 +380,7 @@ async def get(file_id):
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("src_file_ids", "dest_file_id")
|
@validate_request("src_file_ids", "dest_file_id")
|
||||||
async def move():
|
async def move():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
file_ids = req["src_file_ids"]
|
file_ids = req["src_file_ids"]
|
||||||
dest_parent_id = req["dest_file_id"]
|
dest_parent_id = req["dest_file_id"]
|
||||||
@ -444,10 +455,12 @@ async def move():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _move_sync():
|
||||||
for file in files:
|
for file in files:
|
||||||
_move_entry_recursive(file, dest_folder)
|
_move_entry_recursive(file, dest_folder)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_move_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from quart import request
|
from quart import request
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -30,7 +31,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
|
|||||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||||
request_json
|
get_request_json
|
||||||
from api.db import VALID_FILE_TYPES
|
from api.db import VALID_FILE_TYPES
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
@ -48,7 +49,7 @@ from api.apps import login_required, current_user
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("name")
|
@validate_request("name")
|
||||||
async def create():
|
async def create():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
e, res = KnowledgebaseService.create_with_name(
|
e, res = KnowledgebaseService.create_with_name(
|
||||||
name = req.pop("name", None),
|
name = req.pop("name", None),
|
||||||
tenant_id = current_user.id,
|
tenant_id = current_user.id,
|
||||||
@ -72,7 +73,7 @@ async def create():
|
|||||||
@validate_request("kb_id", "name", "description", "parser_id")
|
@validate_request("kb_id", "name", "description", "parser_id")
|
||||||
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||||
async def update():
|
async def update():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not isinstance(req["name"], str):
|
if not isinstance(req["name"], str):
|
||||||
return get_data_error_result(message="Dataset name must be string.")
|
return get_data_error_result(message="Dataset name must be string.")
|
||||||
if req["name"].strip() == "":
|
if req["name"].strip() == "":
|
||||||
@ -92,19 +93,19 @@ async def update():
|
|||||||
if not KnowledgebaseService.query(
|
if not KnowledgebaseService.query(
|
||||||
created_by=current_user.id, id=req["kb_id"]):
|
created_by=current_user.id, id=req["kb_id"]):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of dataset authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Can't find this knowledgebase!")
|
message="Can't find this dataset!")
|
||||||
|
|
||||||
if req["name"].lower() != kb.name.lower() \
|
if req["name"].lower() != kb.name.lower() \
|
||||||
and len(
|
and len(
|
||||||
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Duplicated knowledgebase name.")
|
message="Duplicated dataset name.")
|
||||||
|
|
||||||
del req["kb_id"]
|
del req["kb_id"]
|
||||||
connectors = []
|
connectors = []
|
||||||
@ -116,12 +117,22 @@ async def update():
|
|||||||
|
|
||||||
if kb.pagerank != req.get("pagerank", 0):
|
if kb.pagerank != req.get("pagerank", 0):
|
||||||
if req.get("pagerank", 0) > 0:
|
if req.get("pagerank", 0) > 0:
|
||||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
await asyncio.to_thread(
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update,
|
||||||
|
{"kb_id": kb.id},
|
||||||
|
{PAGERANK_FLD: req["pagerank"]},
|
||||||
|
search.index_name(kb.tenant_id),
|
||||||
|
kb.id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
await asyncio.to_thread(
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update,
|
||||||
|
{"exists": PAGERANK_FLD},
|
||||||
|
{"remove": PAGERANK_FLD},
|
||||||
|
search.index_name(kb.tenant_id),
|
||||||
|
kb.id,
|
||||||
|
)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||||
if not e:
|
if not e:
|
||||||
@ -151,12 +162,12 @@ def detail():
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of dataset authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
kb = KnowledgebaseService.get_detail(kb_id)
|
kb = KnowledgebaseService.get_detail(kb_id)
|
||||||
if not kb:
|
if not kb:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Can't find this knowledgebase!")
|
message="Can't find this dataset!")
|
||||||
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
||||||
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
|
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
|
||||||
|
|
||||||
@ -182,7 +193,7 @@ async def list_kbs():
|
|||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
|
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
owner_ids = req.get("owner_ids", [])
|
owner_ids = req.get("owner_ids", [])
|
||||||
try:
|
try:
|
||||||
if not owner_ids:
|
if not owner_ids:
|
||||||
@ -209,7 +220,7 @@ async def list_kbs():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("kb_id")
|
@validate_request("kb_id")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -221,9 +232,10 @@ async def rm():
|
|||||||
created_by=current_user.id, id=req["kb_id"])
|
created_by=current_user.id, id=req["kb_id"])
|
||||||
if not kbs:
|
if not kbs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of dataset authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
def _rm_sync():
|
||||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
@ -243,6 +255,8 @@ async def rm():
|
|||||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rm_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -286,7 +300,7 @@ def list_tags_from_kbs():
|
|||||||
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def rm_tags(kb_id):
|
async def rm_tags(kb_id):
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -306,7 +320,7 @@ async def rm_tags(kb_id):
|
|||||||
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def rename_tags(kb_id):
|
async def rename_tags(kb_id):
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -428,7 +442,7 @@ async def list_pipeline_logs():
|
|||||||
if create_date_to > create_date_from:
|
if create_date_to > create_date_from:
|
||||||
return get_data_error_result(message="Create data filter is abnormal.")
|
return get_data_error_result(message="Create data filter is abnormal.")
|
||||||
|
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
operation_status = req.get("operation_status", [])
|
operation_status = req.get("operation_status", [])
|
||||||
if operation_status:
|
if operation_status:
|
||||||
@ -470,7 +484,7 @@ async def list_pipeline_dataset_logs():
|
|||||||
if create_date_to > create_date_from:
|
if create_date_to > create_date_from:
|
||||||
return get_data_error_result(message="Create data filter is abnormal.")
|
return get_data_error_result(message="Create data filter is abnormal.")
|
||||||
|
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
operation_status = req.get("operation_status", [])
|
operation_status = req.get("operation_status", [])
|
||||||
if operation_status:
|
if operation_status:
|
||||||
@ -492,7 +506,7 @@ async def delete_pipeline_logs():
|
|||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
log_ids = req.get("log_ids", [])
|
log_ids = req.get("log_ids", [])
|
||||||
|
|
||||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||||
@ -517,7 +531,7 @@ def pipeline_log_detail():
|
|||||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def run_graphrag():
|
async def run_graphrag():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
kb_id = req.get("kb_id", "")
|
kb_id = req.get("kb_id", "")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
@ -586,7 +600,7 @@ def trace_graphrag():
|
|||||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def run_raptor():
|
async def run_raptor():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
kb_id = req.get("kb_id", "")
|
kb_id = req.get("kb_id", "")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
@ -655,7 +669,7 @@ def trace_raptor():
|
|||||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def run_mindmap():
|
async def run_mindmap():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
kb_id = req.get("kb_id", "")
|
kb_id = req.get("kb_id", "")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
@ -861,7 +875,7 @@ async def check_embedding():
|
|||||||
def _clean(s: str) -> str:
|
def _clean(s: str) -> str:
|
||||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||||
return s if s else "None"
|
return s if s else "None"
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
kb_id = req.get("kb_id", "")
|
kb_id = req.get("kb_id", "")
|
||||||
embd_id = req.get("embd_id", "")
|
embd_id = req.get("embd_id", "")
|
||||||
n = int(req.get("check_num", 5))
|
n = int(req.get("check_num", 5))
|
||||||
@ -922,5 +936,3 @@ async def check_embedding():
|
|||||||
if summary["avg_cos_sim"] > 0.9:
|
if summary["avg_cos_sim"] > 0.9:
|
||||||
return get_json_result(data={"summary": summary, "results": results})
|
return get_json_result(data={"summary": summary, "results": results})
|
||||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,28 +15,28 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
from quart import request
|
|
||||||
from api.apps import current_user, login_required
|
from api.apps import current_user, login_required
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
|
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("secret_key", "public_key", "host")
|
@validate_request("secret_key", "public_key", "host")
|
||||||
async def set_api_key():
|
async def set_api_key():
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
secret_key = req.get("secret_key", "")
|
secret_key = req.get("secret_key", "")
|
||||||
public_key = req.get("public_key", "")
|
public_key = req.get("public_key", "")
|
||||||
host = req.get("host", "")
|
host = req.get("host", "")
|
||||||
if not all([secret_key, public_key, host]):
|
if not all([secret_key, public_key, host]):
|
||||||
return get_error_data_result(message="Missing required fields")
|
return get_error_data_result(message="Missing required fields")
|
||||||
|
|
||||||
|
current_user_id = current_user.id
|
||||||
langfuse_keys = dict(
|
langfuse_keys = dict(
|
||||||
tenant_id=current_user.id,
|
tenant_id=current_user_id,
|
||||||
secret_key=secret_key,
|
secret_key=secret_key,
|
||||||
public_key=public_key,
|
public_key=public_key,
|
||||||
host=host,
|
host=host,
|
||||||
@ -46,23 +46,24 @@ async def set_api_key():
|
|||||||
if not langfuse.auth_check():
|
if not langfuse.auth_check():
|
||||||
return get_error_data_result(message="Invalid Langfuse keys")
|
return get_error_data_result(message="Invalid Langfuse keys")
|
||||||
|
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
try:
|
try:
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
TenantLangfuseService.save(**langfuse_keys)
|
TenantLangfuseService.save(**langfuse_keys)
|
||||||
else:
|
else:
|
||||||
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
|
TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
|
||||||
return get_json_result(data=langfuse_keys)
|
return get_json_result(data=langfuse_keys)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def get_api_key():
|
def get_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ def get_api_key():
|
|||||||
except langfuse.api.core.api_error.ApiError as api_err:
|
except langfuse.api.core.api_error.ApiError as api_err:
|
||||||
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
||||||
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
||||||
@ -85,7 +86,8 @@ def get_api_key():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def delete_api_key():
|
def delete_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
@ -94,4 +96,4 @@ def delete_api_key():
|
|||||||
TenantLangfuseService.delete_model(langfuse_entry)
|
TenantLangfuseService.delete_model(langfuse_entry)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -21,12 +21,11 @@ from quart import request
|
|||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||||
from common.constants import StatusEnum, LLMType
|
from common.constants import StatusEnum, LLMType
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
|
||||||
from rag.utils.base64_image import test_image
|
from rag.utils.base64_image import test_image
|
||||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||||
@ -44,7 +43,13 @@ def factories():
|
|||||||
mdl_types[m.fid] = set([])
|
mdl_types[m.fid] = set([])
|
||||||
mdl_types[m.fid].add(m.model_type)
|
mdl_types[m.fid].add(m.model_type)
|
||||||
for f in fac:
|
for f in fac:
|
||||||
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
|
f["model_types"] = list(
|
||||||
|
mdl_types.get(
|
||||||
|
f["name"],
|
||||||
|
[LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS, LLMType.OCR],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return get_json_result(data=fac)
|
return get_json_result(data=fac)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -54,7 +59,7 @@ def factories():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("llm_factory", "api_key")
|
@validate_request("llm_factory", "api_key")
|
||||||
async def set_api_key():
|
async def set_api_key():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
# test if api key works
|
# test if api key works
|
||||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||||
factory = req["llm_factory"]
|
factory = req["llm_factory"]
|
||||||
@ -75,7 +80,7 @@ async def set_api_key():
|
|||||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||||
if m.find("**ERROR**") >= 0:
|
if m.find("**ERROR**") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
chat_passed = True
|
chat_passed = True
|
||||||
@ -124,7 +129,7 @@ async def set_api_key():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("llm_factory")
|
@validate_request("llm_factory")
|
||||||
async def add_llm():
|
async def add_llm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
factory = req["llm_factory"]
|
factory = req["llm_factory"]
|
||||||
api_key = req.get("api_key", "x")
|
api_key = req.get("api_key", "x")
|
||||||
llm_name = req.get("llm_name")
|
llm_name = req.get("llm_name")
|
||||||
@ -152,7 +157,7 @@ async def add_llm():
|
|||||||
elif factory == "Bedrock":
|
elif factory == "Bedrock":
|
||||||
# For Bedrock, due to its special authentication method
|
# For Bedrock, due to its special authentication method
|
||||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
|
||||||
|
|
||||||
elif factory == "LocalAI":
|
elif factory == "LocalAI":
|
||||||
llm_name += "___LocalAI"
|
llm_name += "___LocalAI"
|
||||||
@ -187,6 +192,9 @@ async def add_llm():
|
|||||||
elif factory == "OpenRouter":
|
elif factory == "OpenRouter":
|
||||||
api_key = apikey_json(["api_key", "provider_order"])
|
api_key = apikey_json(["api_key", "provider_order"])
|
||||||
|
|
||||||
|
elif factory == "MinerU":
|
||||||
|
api_key = apikey_json(["api_key", "provider_order"])
|
||||||
|
|
||||||
llm = {
|
llm = {
|
||||||
"tenant_id": current_user.id,
|
"tenant_id": current_user.id,
|
||||||
"llm_factory": factory,
|
"llm_factory": factory,
|
||||||
@ -218,7 +226,7 @@ async def add_llm():
|
|||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -252,6 +260,15 @@ async def add_llm():
|
|||||||
pass
|
pass
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
|
elif llm["model_type"] == LLMType.OCR.value:
|
||||||
|
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||||
|
try:
|
||||||
|
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
|
||||||
|
ok, reason = mdl.check_available()
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError(reason or "Model not available")
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
else:
|
else:
|
||||||
# TODO: check other type of models
|
# TODO: check other type of models
|
||||||
pass
|
pass
|
||||||
@ -269,7 +286,7 @@ async def add_llm():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("llm_factory", "llm_name")
|
@validate_request("llm_factory", "llm_name")
|
||||||
async def delete_llm():
|
async def delete_llm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
@ -278,7 +295,7 @@ async def delete_llm():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("llm_factory", "llm_name")
|
@validate_request("llm_factory", "llm_name")
|
||||||
async def enable_llm():
|
async def enable_llm():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
TenantLLMService.filter_update(
|
TenantLLMService.filter_update(
|
||||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
||||||
)
|
)
|
||||||
@ -289,7 +306,7 @@ async def enable_llm():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("llm_factory")
|
@validate_request("llm_factory")
|
||||||
async def delete_factory():
|
async def delete_factory():
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
@ -298,6 +315,7 @@ async def delete_factory():
|
|||||||
@login_required
|
@login_required
|
||||||
def my_llms():
|
def my_llms():
|
||||||
try:
|
try:
|
||||||
|
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
||||||
include_details = request.args.get("include_details", "false").lower() == "true"
|
include_details = request.args.get("include_details", "false").lower() == "true"
|
||||||
|
|
||||||
if include_details:
|
if include_details:
|
||||||
@ -345,6 +363,7 @@ def list_app():
|
|||||||
weighted = []
|
weighted = []
|
||||||
model_type = request.args.get("model_type")
|
model_type = request.args.get("model_type")
|
||||||
try:
|
try:
|
||||||
|
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
||||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
||||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||||
|
|||||||
@ -22,8 +22,7 @@ from api.db.services.user_service import TenantService
|
|||||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||||
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||||
get_mcp_tools
|
|
||||||
from api.utils.web_utils import get_float, safe_json_parse
|
from api.utils.web_utils import get_float, safe_json_parse
|
||||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
@ -40,7 +39,7 @@ async def list_mcp() -> Response:
|
|||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
|
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
mcp_ids = req.get("mcp_ids", [])
|
mcp_ids = req.get("mcp_ids", [])
|
||||||
try:
|
try:
|
||||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||||
@ -73,7 +72,7 @@ def detail() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("name", "url", "server_type")
|
@validate_request("name", "url", "server_type")
|
||||||
async def create() -> Response:
|
async def create() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
server_type = req.get("server_type", "")
|
server_type = req.get("server_type", "")
|
||||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||||
@ -128,7 +127,7 @@ async def create() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("mcp_id")
|
@validate_request("mcp_id")
|
||||||
async def update() -> Response:
|
async def update() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
mcp_id = req.get("mcp_id", "")
|
mcp_id = req.get("mcp_id", "")
|
||||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||||
@ -184,7 +183,7 @@ async def update() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("mcp_ids")
|
@validate_request("mcp_ids")
|
||||||
async def rm() -> Response:
|
async def rm() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
mcp_ids = req.get("mcp_ids", [])
|
mcp_ids = req.get("mcp_ids", [])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -202,7 +201,7 @@ async def rm() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("mcpServers")
|
@validate_request("mcpServers")
|
||||||
async def import_multiple() -> Response:
|
async def import_multiple() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
servers = req.get("mcpServers", {})
|
servers = req.get("mcpServers", {})
|
||||||
if not servers:
|
if not servers:
|
||||||
return get_data_error_result(message="No MCP servers provided.")
|
return get_data_error_result(message="No MCP servers provided.")
|
||||||
@ -269,7 +268,7 @@ async def import_multiple() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("mcp_ids")
|
@validate_request("mcp_ids")
|
||||||
async def export_multiple() -> Response:
|
async def export_multiple() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
mcp_ids = req.get("mcp_ids", [])
|
mcp_ids = req.get("mcp_ids", [])
|
||||||
|
|
||||||
if not mcp_ids:
|
if not mcp_ids:
|
||||||
@ -301,7 +300,7 @@ async def export_multiple() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("mcp_ids")
|
@validate_request("mcp_ids")
|
||||||
async def list_tools() -> Response:
|
async def list_tools() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
mcp_ids = req.get("mcp_ids", [])
|
mcp_ids = req.get("mcp_ids", [])
|
||||||
if not mcp_ids:
|
if not mcp_ids:
|
||||||
return get_data_error_result(message="No MCP server IDs provided.")
|
return get_data_error_result(message="No MCP server IDs provided.")
|
||||||
@ -348,7 +347,7 @@ async def list_tools() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("mcp_id", "tool_name", "arguments")
|
@validate_request("mcp_id", "tool_name", "arguments")
|
||||||
async def test_tool() -> Response:
|
async def test_tool() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
mcp_id = req.get("mcp_id", "")
|
mcp_id = req.get("mcp_id", "")
|
||||||
if not mcp_id:
|
if not mcp_id:
|
||||||
return get_data_error_result(message="No MCP server ID provided.")
|
return get_data_error_result(message="No MCP server ID provided.")
|
||||||
@ -381,7 +380,7 @@ async def test_tool() -> Response:
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("mcp_id", "tools")
|
@validate_request("mcp_id", "tools")
|
||||||
async def cache_tool() -> Response:
|
async def cache_tool() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
mcp_id = req.get("mcp_id", "")
|
mcp_id = req.get("mcp_id", "")
|
||||||
if not mcp_id:
|
if not mcp_id:
|
||||||
return get_data_error_result(message="No MCP server ID provided.")
|
return get_data_error_result(message="No MCP server ID provided.")
|
||||||
@ -404,7 +403,7 @@ async def cache_tool() -> Response:
|
|||||||
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
||||||
@validate_request("url", "server_type")
|
@validate_request("url", "server_type")
|
||||||
async def test_mcp() -> Response:
|
async def test_mcp() -> Response:
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
url = req.get("url", "")
|
url = req.get("url", "")
|
||||||
if not url:
|
if not url:
|
||||||
|
|||||||
185
api/apps/memories_app.py
Normal file
185
api/apps/memories_app.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from quart import request
|
||||||
|
from api.apps import login_required, current_user
|
||||||
|
from api.db import TenantPermission
|
||||||
|
from api.db.services.memory_service import MemoryService
|
||||||
|
from api.db.services.user_service import UserTenantService
|
||||||
|
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||||
|
not_allowed_parameters
|
||||||
|
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||||
|
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||||
|
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||||
|
async def create_memory():
|
||||||
|
req = await get_request_json()
|
||||||
|
# check name length
|
||||||
|
name = req["name"]
|
||||||
|
memory_name = name.strip()
|
||||||
|
if len(memory_name) == 0:
|
||||||
|
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||||
|
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||||
|
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||||
|
# check memory_type valid
|
||||||
|
memory_type = set(req["memory_type"])
|
||||||
|
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||||
|
if invalid_type:
|
||||||
|
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||||
|
memory_type = list(memory_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
res, memory = MemoryService.create_memory(
|
||||||
|
tenant_id=current_user.id,
|
||||||
|
name=memory_name,
|
||||||
|
memory_type=memory_type,
|
||||||
|
embd_id=req["embd_id"],
|
||||||
|
llm_id=req["llm_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||||
|
|
||||||
|
else:
|
||||||
|
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
||||||
|
async def update_memory(memory_id):
|
||||||
|
req = await get_request_json()
|
||||||
|
update_dict = {}
|
||||||
|
# check name length
|
||||||
|
if "name" in req:
|
||||||
|
name = req["name"]
|
||||||
|
memory_name = name.strip()
|
||||||
|
if len(memory_name) == 0:
|
||||||
|
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||||
|
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||||
|
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||||
|
update_dict["name"] = memory_name
|
||||||
|
# check permissions valid
|
||||||
|
if req.get("permissions"):
|
||||||
|
if req["permissions"] not in [e.value for e in TenantPermission]:
|
||||||
|
return get_error_argument_result(f"Unknown permission '{req['permissions']}'.")
|
||||||
|
update_dict["permissions"] = req["permissions"]
|
||||||
|
if req.get("llm_id"):
|
||||||
|
update_dict["llm_id"] = req["llm_id"]
|
||||||
|
# check memory_size valid
|
||||||
|
if req.get("memory_size"):
|
||||||
|
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
||||||
|
return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.")
|
||||||
|
update_dict["memory_size"] = req["memory_size"]
|
||||||
|
# check forgetting_policy valid
|
||||||
|
if req.get("forgetting_policy"):
|
||||||
|
if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]:
|
||||||
|
return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.")
|
||||||
|
update_dict["forgetting_policy"] = req["forgetting_policy"]
|
||||||
|
# check temperature valid
|
||||||
|
if "temperature" in req:
|
||||||
|
temperature = float(req["temperature"])
|
||||||
|
if not 0 <= temperature <= 1:
|
||||||
|
return get_error_argument_result("Temperature should be in range [0, 1].")
|
||||||
|
update_dict["temperature"] = temperature
|
||||||
|
# allow update to empty fields
|
||||||
|
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
||||||
|
if field in req:
|
||||||
|
update_dict[field] = req[field]
|
||||||
|
current_memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not current_memory:
|
||||||
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||||
|
|
||||||
|
memory_dict = current_memory.to_dict()
|
||||||
|
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
||||||
|
to_update = {}
|
||||||
|
for k, v in update_dict.items():
|
||||||
|
if isinstance(v, list) and set(memory_dict[k]) != set(v):
|
||||||
|
to_update[k] = v
|
||||||
|
elif memory_dict[k] != v:
|
||||||
|
to_update[k] = v
|
||||||
|
|
||||||
|
if not to_update:
|
||||||
|
return get_json_result(message=True, data=memory_dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
MemoryService.update_memory(memory_id, to_update)
|
||||||
|
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def delete_memory(memory_id):
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return get_json_result(message=True, code=RetCode.NOT_FOUND)
|
||||||
|
try:
|
||||||
|
MemoryService.delete_memory(memory_id)
|
||||||
|
return get_json_result(message=True)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def list_memory():
|
||||||
|
args = request.args
|
||||||
|
try:
|
||||||
|
tenant_ids = args.getlist("tenant_id")
|
||||||
|
memory_types = args.getlist("memory_type")
|
||||||
|
storage_type = args.get("storage_type")
|
||||||
|
keywords = args.get("keywords", "")
|
||||||
|
page = int(args.get("page", 1))
|
||||||
|
page_size = int(args.get("page_size", 50))
|
||||||
|
# make filter dict
|
||||||
|
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
|
||||||
|
if not tenant_ids:
|
||||||
|
# restrict to current user's tenants
|
||||||
|
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||||
|
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||||
|
else:
|
||||||
|
filter_dict["tenant_id"] = tenant_ids
|
||||||
|
|
||||||
|
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||||
|
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||||
|
return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_memory_config(memory_id):
|
||||||
|
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||||
|
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||||
@ -14,20 +14,29 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import CanvasCategory
|
from api.db import CanvasCategory
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||||
from api.utils.api_utils import get_result
|
from api.utils.api_utils import get_result
|
||||||
from quart import request, Response
|
from quart import request, Response
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||||
@ -53,7 +62,7 @@ def list_agents(tenant_id):
|
|||||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def create_agent(tenant_id: str):
|
async def create_agent(tenant_id: str):
|
||||||
req: dict[str, Any] = cast(dict[str, Any], await request.json)
|
req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
|
||||||
req["user_id"] = tenant_id
|
req["user_id"] = tenant_id
|
||||||
|
|
||||||
if req.get("dsl") is not None:
|
if req.get("dsl") is not None:
|
||||||
@ -90,7 +99,7 @@ async def create_agent(tenant_id: str):
|
|||||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def update_agent(tenant_id: str, agent_id: str):
|
async def update_agent(tenant_id: str, agent_id: str):
|
||||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
|
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
|
||||||
req["user_id"] = tenant_id
|
req["user_id"] = tenant_id
|
||||||
|
|
||||||
if req.get("dsl") is not None:
|
if req.get("dsl") is not None:
|
||||||
@ -132,48 +141,776 @@ def delete_agent(tenant_id: str, agent_id: str):
|
|||||||
UserCanvasService.delete_by_id(agent_id)
|
UserCanvasService.delete_by_id(agent_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||||
|
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
|
||||||
|
async def webhook(agent_id: str):
|
||||||
|
is_test = request.path.startswith("/api/v1/webhook_test")
|
||||||
|
start_ts = time.time()
|
||||||
|
|
||||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
# 1. Fetch canvas by agent_id
|
||||||
@token_required
|
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
async def webhook(tenant_id: str, agent_id: str):
|
if not exists:
|
||||||
req = await request.json
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
|
||||||
return get_json_result(
|
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
|
||||||
code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
|
|
||||||
if not isinstance(cvs.dsl, str):
|
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
# 2. Check canvas category
|
||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 3. Load DSL from canvas
|
||||||
|
dsl = getattr(cvs, "dsl", None)
|
||||||
|
if not isinstance(dsl, dict):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 4. Check webhook configuration in DSL
|
||||||
|
components = dsl.get("components", {})
|
||||||
|
for k, _ in components.items():
|
||||||
|
cpn_obj = components[k]["obj"]
|
||||||
|
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
|
||||||
|
webhook_cfg = cpn_obj["params"]
|
||||||
|
|
||||||
|
if not webhook_cfg:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 5. Validate request method against webhook_cfg.methods
|
||||||
|
allowed_methods = webhook_cfg.get("methods", [])
|
||||||
|
request_method = request.method.upper()
|
||||||
|
if allowed_methods and request_method not in allowed_methods:
|
||||||
|
return get_data_error_result(
|
||||||
|
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||||
|
),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 6. Validate webhook security
|
||||||
|
async def validate_webhook_security(security_cfg: dict):
|
||||||
|
"""Validate webhook security rules based on security configuration."""
|
||||||
|
|
||||||
|
if not security_cfg:
|
||||||
|
return # No security config → allowed by default
|
||||||
|
|
||||||
|
# 1. Validate max body size
|
||||||
|
await _validate_max_body_size(security_cfg)
|
||||||
|
|
||||||
|
# 2. Validate IP whitelist
|
||||||
|
_validate_ip_whitelist(security_cfg)
|
||||||
|
|
||||||
|
# # 3. Validate rate limiting
|
||||||
|
_validate_rate_limit(security_cfg)
|
||||||
|
|
||||||
|
# 4. Validate authentication
|
||||||
|
auth_type = security_cfg.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "none":
|
||||||
|
return
|
||||||
|
|
||||||
|
if auth_type == "token":
|
||||||
|
_validate_token_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "basic":
|
||||||
|
_validate_basic_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "jwt":
|
||||||
|
_validate_jwt_auth(security_cfg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||||
|
|
||||||
|
async def _validate_max_body_size(security_cfg):
|
||||||
|
"""Check request size does not exceed max_body_size."""
|
||||||
|
max_size = security_cfg.get("max_body_size")
|
||||||
|
if not max_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert "10MB" → bytes
|
||||||
|
units = {"kb": 1024, "mb": 1024**2}
|
||||||
|
size_str = max_size.lower()
|
||||||
|
|
||||||
|
for suffix, factor in units.items():
|
||||||
|
if size_str.endswith(suffix):
|
||||||
|
limit = int(size_str.replace(suffix, "")) * factor
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid max_body_size format")
|
||||||
|
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||||
|
if limit > MAX_LIMIT:
|
||||||
|
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
|
||||||
|
|
||||||
|
content_length = request.content_length or 0
|
||||||
|
if content_length > limit:
|
||||||
|
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||||
|
|
||||||
|
def _validate_ip_whitelist(security_cfg):
|
||||||
|
"""Allow only IPs listed in ip_whitelist."""
|
||||||
|
whitelist = security_cfg.get("ip_whitelist", [])
|
||||||
|
if not whitelist:
|
||||||
|
return
|
||||||
|
|
||||||
|
client_ip = request.remote_addr
|
||||||
|
|
||||||
|
|
||||||
|
for rule in whitelist:
|
||||||
|
if "/" in rule:
|
||||||
|
# CIDR notation
|
||||||
|
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Single IP
|
||||||
|
if client_ip == rule:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||||
|
|
||||||
|
def _validate_rate_limit(security_cfg):
|
||||||
|
"""Simple in-memory rate limiting."""
|
||||||
|
rl = security_cfg.get("rate_limit")
|
||||||
|
if not rl:
|
||||||
|
return
|
||||||
|
|
||||||
|
limit = int(rl.get("limit", 60))
|
||||||
|
if limit <= 0:
|
||||||
|
raise Exception("rate_limit.limit must be > 0")
|
||||||
|
per = rl.get("per", "minute")
|
||||||
|
|
||||||
|
window = {
|
||||||
|
"second": 1,
|
||||||
|
"minute": 60,
|
||||||
|
"hour": 3600,
|
||||||
|
"day": 86400,
|
||||||
|
}.get(per)
|
||||||
|
|
||||||
|
if not window:
|
||||||
|
raise Exception(f"Invalid rate_limit.per: {per}")
|
||||||
|
|
||||||
|
capacity = limit
|
||||||
|
rate = limit / window
|
||||||
|
cost = 1
|
||||||
|
|
||||||
|
key = f"rl:tb:{agent_id}"
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
res = REDIS_CONN.lua_token_bucket(
|
||||||
|
keys=[key],
|
||||||
|
args=[capacity, rate, now, cost],
|
||||||
|
client=REDIS_CONN.REDIS,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = int(res[0])
|
||||||
|
if allowed != 1:
|
||||||
|
raise Exception("Too many requests (rate limit exceeded)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_json_result(
|
raise Exception(f"Rate limit error: {e}")
|
||||||
data=False, message=str(e),
|
|
||||||
code=RetCode.EXCEPTION_ERROR)
|
|
||||||
|
|
||||||
async def sse():
|
def _validate_token_auth(security_cfg):
|
||||||
nonlocal canvas
|
"""Validate header-based token authentication."""
|
||||||
|
token_cfg = security_cfg.get("token",{})
|
||||||
|
header = token_cfg.get("token_header")
|
||||||
|
token_value = token_cfg.get("token_value")
|
||||||
|
|
||||||
|
provided = request.headers.get(header)
|
||||||
|
if provided != token_value:
|
||||||
|
raise Exception("Invalid token authentication")
|
||||||
|
|
||||||
|
def _validate_basic_auth(security_cfg):
|
||||||
|
"""Validate HTTP Basic Auth credentials."""
|
||||||
|
auth_cfg = security_cfg.get("basic_auth", {})
|
||||||
|
username = auth_cfg.get("username")
|
||||||
|
password = auth_cfg.get("password")
|
||||||
|
|
||||||
|
auth = request.authorization
|
||||||
|
if not auth or auth.username != username or auth.password != password:
|
||||||
|
raise Exception("Invalid Basic Auth credentials")
|
||||||
|
|
||||||
|
def _validate_jwt_auth(security_cfg):
|
||||||
|
"""Validate JWT token in Authorization header."""
|
||||||
|
jwt_cfg = security_cfg.get("jwt", {})
|
||||||
|
secret = jwt_cfg.get("secret")
|
||||||
|
if not secret:
|
||||||
|
raise Exception("JWT secret not configured")
|
||||||
|
required_claims = jwt_cfg.get("required_claims", [])
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
raise Exception("Missing Bearer token")
|
||||||
|
|
||||||
|
token = auth_header[len("Bearer "):].strip()
|
||||||
|
if not token:
|
||||||
|
raise Exception("Empty Bearer token")
|
||||||
|
|
||||||
|
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
|
||||||
|
|
||||||
|
decode_kwargs = {
|
||||||
|
"key": secret,
|
||||||
|
"algorithms": [alg],
|
||||||
|
}
|
||||||
|
options = {}
|
||||||
|
if jwt_cfg.get("audience"):
|
||||||
|
decode_kwargs["audience"] = jwt_cfg["audience"]
|
||||||
|
options["verify_aud"] = True
|
||||||
|
else:
|
||||||
|
options["verify_aud"] = False
|
||||||
|
|
||||||
|
if jwt_cfg.get("issuer"):
|
||||||
|
decode_kwargs["issuer"] = jwt_cfg["issuer"]
|
||||||
|
options["verify_iss"] = True
|
||||||
|
else:
|
||||||
|
options["verify_iss"] = False
|
||||||
try:
|
try:
|
||||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
decoded = jwt.decode(
|
||||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
token,
|
||||||
|
options=options,
|
||||||
|
**decode_kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Invalid JWT: {str(e)}")
|
||||||
|
|
||||||
|
raw_required_claims = jwt_cfg.get("required_claims", [])
|
||||||
|
if isinstance(raw_required_claims, str):
|
||||||
|
required_claims = [raw_required_claims]
|
||||||
|
elif isinstance(raw_required_claims, (list, tuple, set)):
|
||||||
|
required_claims = list(raw_required_claims)
|
||||||
|
else:
|
||||||
|
required_claims = []
|
||||||
|
|
||||||
|
required_claims = [
|
||||||
|
c for c in required_claims
|
||||||
|
if isinstance(c, str) and c.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||||
|
for claim in required_claims:
|
||||||
|
if claim in RESERVED_CLAIMS:
|
||||||
|
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
|
||||||
|
|
||||||
|
for claim in required_claims:
|
||||||
|
if claim not in decoded:
|
||||||
|
raise Exception(f"Missing JWT claim: {claim}")
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
try:
|
||||||
|
security_config=webhook_cfg.get("security", {})
|
||||||
|
await validate_webhook_security(security_config)
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||||
|
if not isinstance(cvs.dsl, str):
|
||||||
|
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
try:
|
||||||
|
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||||
|
except Exception as e:
|
||||||
|
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
|
||||||
|
resp.status_code = RetCode.BAD_REQUEST
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# 7. Parse request body
|
||||||
|
async def parse_webhook_request(content_type):
|
||||||
|
"""Parse request based on content-type and return structured data."""
|
||||||
|
|
||||||
|
# 1. Query
|
||||||
|
query_data = {k: v for k, v in request.args.items()}
|
||||||
|
|
||||||
|
# 2. Headers
|
||||||
|
header_data = {k: v for k, v in request.headers.items()}
|
||||||
|
|
||||||
|
# 3. Body
|
||||||
|
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
if ctype and ctype != content_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
body_data: dict = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ctype == "application/json":
|
||||||
|
body_data = await request.get_json() or {}
|
||||||
|
|
||||||
|
elif ctype == "multipart/form-data":
|
||||||
|
nonlocal canvas
|
||||||
|
form = await request.form
|
||||||
|
files = await request.files
|
||||||
|
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
for key, value in form.items():
|
||||||
|
body_data[key] = value
|
||||||
|
|
||||||
|
if len(files) > 10:
|
||||||
|
raise Exception("Too many uploaded files")
|
||||||
|
for key, file in files.items():
|
||||||
|
desc = FileService.upload_info(
|
||||||
|
cvs.user_id, # user
|
||||||
|
file, # FileStorage
|
||||||
|
None # url (None for webhook)
|
||||||
|
)
|
||||||
|
file_parsed= await canvas.get_files_async([desc])
|
||||||
|
body_data[key] = file_parsed
|
||||||
|
|
||||||
|
elif ctype == "application/x-www-form-urlencoded":
|
||||||
|
form = await request.form
|
||||||
|
body_data = dict(form)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# text/plain / octet-stream / empty / unknown
|
||||||
|
raw = await request.get_data()
|
||||||
|
if raw:
|
||||||
|
try:
|
||||||
|
body_data = json.loads(raw.decode("utf-8"))
|
||||||
|
except Exception:
|
||||||
|
body_data = {}
|
||||||
|
else:
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query": query_data,
|
||||||
|
"headers": header_data,
|
||||||
|
"body": body_data,
|
||||||
|
"content_type": ctype,
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract_by_schema(data, schema, name="section"):
|
||||||
|
"""
|
||||||
|
Extract only fields defined in schema.
|
||||||
|
Required fields must exist.
|
||||||
|
Optional fields default to type-based default values.
|
||||||
|
Type validation included.
|
||||||
|
"""
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
required = schema.get("required", [])
|
||||||
|
|
||||||
|
extracted = {}
|
||||||
|
|
||||||
|
for field, field_schema in props.items():
|
||||||
|
field_type = field_schema.get("type")
|
||||||
|
|
||||||
|
# 1. Required field missing
|
||||||
|
if field in required and field not in data:
|
||||||
|
raise Exception(f"{name} missing required field: {field}")
|
||||||
|
|
||||||
|
# 2. Optional → default value
|
||||||
|
if field not in data:
|
||||||
|
extracted[field] = default_for_type(field_type)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_value = data[field]
|
||||||
|
|
||||||
|
# 3. Auto convert value
|
||||||
|
try:
|
||||||
|
value = auto_cast_value(raw_value, field_type)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||||
|
|
||||||
|
# 4. Type validation
|
||||||
|
if not validate_type(value, field_type):
|
||||||
|
raise Exception(
|
||||||
|
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted[field] = value
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
|
def default_for_type(t):
|
||||||
|
"""Return default value for the given schema type."""
|
||||||
|
if t == "file":
|
||||||
|
return []
|
||||||
|
if t == "object":
|
||||||
|
return {}
|
||||||
|
if t == "boolean":
|
||||||
|
return False
|
||||||
|
if t == "number":
|
||||||
|
return 0
|
||||||
|
if t == "string":
|
||||||
|
return ""
|
||||||
|
if t and t.startswith("array"):
|
||||||
|
return []
|
||||||
|
if t == "null":
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def auto_cast_value(value, expected_type):
|
||||||
|
"""Convert string values into schema type when possible."""
|
||||||
|
|
||||||
|
# Non-string values already good
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
v = value.strip()
|
||||||
|
|
||||||
|
# Boolean
|
||||||
|
if expected_type == "boolean":
|
||||||
|
if v.lower() in ["true", "1"]:
|
||||||
|
return True
|
||||||
|
if v.lower() in ["false", "0"]:
|
||||||
|
return False
|
||||||
|
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||||
|
|
||||||
|
# Number
|
||||||
|
if expected_type == "number":
|
||||||
|
# integer
|
||||||
|
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||||
|
return int(v)
|
||||||
|
|
||||||
|
# float
|
||||||
|
try:
|
||||||
|
return float(v)
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to number")
|
||||||
|
|
||||||
|
# Object
|
||||||
|
if expected_type == "object":
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an object")
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to object")
|
||||||
|
|
||||||
|
# Array <T>
|
||||||
|
if expected_type.startswith("array"):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an array")
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to array")
|
||||||
|
|
||||||
|
# String (accept original)
|
||||||
|
if expected_type == "string":
|
||||||
|
return value
|
||||||
|
|
||||||
|
# File
|
||||||
|
if expected_type == "file":
|
||||||
|
return value
|
||||||
|
# Default: do nothing
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def validate_type(value, t):
|
||||||
|
"""Validate value type against schema type t."""
|
||||||
|
if t == "file":
|
||||||
|
return isinstance(value, list)
|
||||||
|
|
||||||
|
if t == "string":
|
||||||
|
return isinstance(value, str)
|
||||||
|
|
||||||
|
if t == "number":
|
||||||
|
return isinstance(value, (int, float))
|
||||||
|
|
||||||
|
if t == "boolean":
|
||||||
|
return isinstance(value, bool)
|
||||||
|
|
||||||
|
if t == "object":
|
||||||
|
return isinstance(value, dict)
|
||||||
|
|
||||||
|
# array<string> / array<number> / array<object>
|
||||||
|
if t.startswith("array"):
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "<" in t and ">" in t:
|
||||||
|
inner = t[t.find("<") + 1 : t.find(">")]
|
||||||
|
|
||||||
|
# Check each element type
|
||||||
|
for item in value:
|
||||||
|
if not validate_type(item, inner):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||||
|
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||||
|
|
||||||
|
# Extract strictly by schema
|
||||||
|
try:
|
||||||
|
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||||
|
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||||
|
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
clean_request = {
|
||||||
|
"query": query_clean,
|
||||||
|
"headers": header_clean,
|
||||||
|
"body": body_clean,
|
||||||
|
"input": parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||||
|
response_cfg = webhook_cfg.get("response", {})
|
||||||
|
|
||||||
|
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
|
||||||
|
key = f"webhook-trace-{agent_id}-logs"
|
||||||
|
|
||||||
|
raw = REDIS_CONN.get(key)
|
||||||
|
obj = json.loads(raw) if raw else {"webhooks": {}}
|
||||||
|
|
||||||
|
ws = obj["webhooks"].setdefault(
|
||||||
|
str(start_ts),
|
||||||
|
{"start_ts": start_ts, "events": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
ws["events"].append({
|
||||||
|
"ts": time.time(),
|
||||||
|
**event
|
||||||
|
})
|
||||||
|
|
||||||
|
REDIS_CONN.set_obj(key, obj, ttl)
|
||||||
|
|
||||||
|
if execution_mode == "Immediately":
|
||||||
|
status = response_cfg.get("status", 200)
|
||||||
|
try:
|
||||||
|
status = int(status)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
if not (200 <= status <= 399):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
body_tpl = response_cfg.get("body_template", "")
|
||||||
|
|
||||||
|
def parse_body(body: str):
|
||||||
|
if not body:
|
||||||
|
return None, "application/json"
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(body)
|
||||||
|
return parsed, "application/json"
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return body, "text/plain"
|
||||||
|
|
||||||
|
|
||||||
|
body, content_type = parse_body(body_tpl)
|
||||||
|
resp = Response(
|
||||||
|
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||||
|
status=status,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def background_run():
|
||||||
|
try:
|
||||||
|
async for ans in canvas.run(
|
||||||
|
query="",
|
||||||
|
user_id=cvs.user_id,
|
||||||
|
webhook_payload=clean_request
|
||||||
|
):
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(agent_id, start_ts, ans)
|
||||||
|
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
cvs.dsl = json.loads(str(canvas))
|
cvs.dsl = json.loads(str(canvas))
|
||||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
|
||||||
|
|
||||||
resp = Response(sse(), mimetype="text/event-stream")
|
except Exception as e:
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
logging.exception("Webhook background run failed")
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
if is_test:
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
try:
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to append webhook trace")
|
||||||
|
|
||||||
|
asyncio.create_task(background_run())
|
||||||
return resp
|
return resp
|
||||||
|
else:
|
||||||
|
async def sse():
|
||||||
|
nonlocal canvas
|
||||||
|
contents: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for ans in canvas.run(
|
||||||
|
query="",
|
||||||
|
user_id=cvs.user_id,
|
||||||
|
webhook_payload=clean_request,
|
||||||
|
):
|
||||||
|
if ans["event"] == "message":
|
||||||
|
content = ans["data"]["content"]
|
||||||
|
if ans["data"].get("start_to_think", False):
|
||||||
|
content = "<think>"
|
||||||
|
elif ans["data"].get("end_to_think", False):
|
||||||
|
content = "</think>"
|
||||||
|
if content:
|
||||||
|
contents.append(content)
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
ans
|
||||||
|
)
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
final_content = "".join(contents)
|
||||||
|
yield json.dumps(final_content, ensure_ascii=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield json.dumps({"code": 500, "message": str(e)}, ensure_ascii=False)
|
||||||
|
|
||||||
|
resp = Response(sse(), mimetype="application/json")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
|
||||||
|
async def webhook_trace(agent_id: str):
|
||||||
|
def encode_webhook_id(start_ts: str) -> str:
|
||||||
|
WEBHOOK_ID_SECRET = "webhook_id_secret"
|
||||||
|
sig = hmac.new(
|
||||||
|
WEBHOOK_ID_SECRET.encode("utf-8"),
|
||||||
|
start_ts.encode("utf-8"),
|
||||||
|
hashlib.sha256,
|
||||||
|
).digest()
|
||||||
|
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
|
||||||
|
|
||||||
|
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
|
||||||
|
for ts in webhooks.keys():
|
||||||
|
if encode_webhook_id(ts) == enc_id:
|
||||||
|
return ts
|
||||||
|
return None
|
||||||
|
since_ts = request.args.get("since_ts", type=float)
|
||||||
|
webhook_id = request.args.get("webhook_id")
|
||||||
|
|
||||||
|
key = f"webhook-trace-{agent_id}-logs"
|
||||||
|
raw = REDIS_CONN.get(key)
|
||||||
|
|
||||||
|
if since_ts is None:
|
||||||
|
now = time.time()
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": now,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not raw:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
obj = json.loads(raw)
|
||||||
|
webhooks = obj.get("webhooks", {})
|
||||||
|
|
||||||
|
if webhook_id is None:
|
||||||
|
candidates = [
|
||||||
|
float(k) for k in webhooks.keys() if float(k) > since_ts
|
||||||
|
]
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
start_ts = min(candidates)
|
||||||
|
real_id = str(start_ts)
|
||||||
|
webhook_id = encode_webhook_id(real_id)
|
||||||
|
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": start_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
real_id = decode_webhook_id(webhook_id, webhooks)
|
||||||
|
|
||||||
|
if not real_id:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = webhooks.get(str(real_id))
|
||||||
|
events = ws.get("events", [])
|
||||||
|
new_events = [e for e in events if e.get("ts", 0) > since_ts]
|
||||||
|
|
||||||
|
next_ts = since_ts
|
||||||
|
for e in new_events:
|
||||||
|
next_ts = max(next_ts, e["ts"])
|
||||||
|
|
||||||
|
finished = any(e.get("event") == "finished" for e in new_events)
|
||||||
|
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": new_events,
|
||||||
|
"next_since_ts": next_ts,
|
||||||
|
"finished": finished,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@ -21,13 +21,13 @@ from api.db.services.tenant_llm_service import TenantLLMService
|
|||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.constants import RetCode, StatusEnum
|
from common.constants import RetCode, StatusEnum
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
|
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def create(tenant_id):
|
async def create(tenant_id):
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||||
for kb_id in ids:
|
for kb_id in ids:
|
||||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||||
@ -92,7 +92,7 @@ async def create(tenant_id):
|
|||||||
req["tenant_id"] = tenant_id
|
req["tenant_id"] = tenant_id
|
||||||
# prompt more parameter
|
# prompt more parameter
|
||||||
default_prompt = {
|
default_prompt = {
|
||||||
"system": """You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.
|
"system": """You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" Answers need to consider chat history.
|
||||||
Here is the knowledge base:
|
Here is the knowledge base:
|
||||||
{knowledge}
|
{knowledge}
|
||||||
The above is the knowledge base.""",
|
The above is the knowledge base.""",
|
||||||
@ -146,7 +146,7 @@ async def create(tenant_id):
|
|||||||
async def update(tenant_id, chat_id):
|
async def update(tenant_id, chat_id):
|
||||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||||
return get_error_data_result(message="You do not own the chat")
|
return get_error_data_result(message="You do not own the chat")
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
ids = req.get("dataset_ids", [])
|
ids = req.get("dataset_ids", [])
|
||||||
if "show_quotation" in req:
|
if "show_quotation" in req:
|
||||||
req["do_refer"] = req.pop("show_quotation")
|
req["do_refer"] = req.pop("show_quotation")
|
||||||
@ -174,7 +174,9 @@ async def update(tenant_id, chat_id):
|
|||||||
req["llm_id"] = llm.pop("model_name")
|
req["llm_id"] = llm.pop("model_name")
|
||||||
if req.get("llm_id") is not None:
|
if req.get("llm_id") is not None:
|
||||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
|
model_type = llm.pop("model_type")
|
||||||
|
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
|
||||||
|
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
|
||||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||||
req["llm_setting"] = req.pop("llm")
|
req["llm_setting"] = req.pop("llm")
|
||||||
e, tenant = TenantService.get_by_id(tenant_id)
|
e, tenant = TenantService.get_by_id(tenant_id)
|
||||||
@ -229,7 +231,7 @@ async def update(tenant_id, chat_id):
|
|||||||
async def delete_chats(tenant_id):
|
async def delete_chats(tenant_id):
|
||||||
errors = []
|
errors = []
|
||||||
success_count = 0
|
success_count = 0
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not req:
|
if not req:
|
||||||
ids = None
|
ids = None
|
||||||
else:
|
else:
|
||||||
@ -250,7 +252,6 @@ async def delete_chats(tenant_id):
|
|||||||
continue
|
continue
|
||||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||||
success_count += DialogService.update_by_id(id, temp_dict)
|
success_count += DialogService.update_by_id(id, temp_dict)
|
||||||
print(success_count, "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$", flush=True)
|
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
|
|||||||
@ -15,14 +15,14 @@
|
|||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from quart import request, jsonify
|
from quart import jsonify
|
||||||
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
from common.metadata_utils import meta_filter, convert_conditions
|
||||||
|
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
|
||||||
from common.constants import RetCode, LLMType
|
from common.constants import RetCode, LLMType
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ async def retrieval(tenant_id):
|
|||||||
404:
|
404:
|
||||||
description: Knowledge base or document not found
|
description: Knowledge base or document not found
|
||||||
"""
|
"""
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
question = req["query"]
|
question = req["query"]
|
||||||
kb_id = req["knowledge_id"]
|
kb_id = req["knowledge_id"]
|
||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
@ -33,10 +34,10 @@ from api.db.services.file_service import FileService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.task_service import TaskService, queue_tasks
|
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
from common.metadata_utils import meta_filter, convert_conditions
|
||||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||||
request_json
|
get_request_json
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
@ -231,12 +232,12 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
||||||
return get_error_data_result(message="You don't own the dataset.")
|
return get_error_data_result(message="You don't own the dataset.")
|
||||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result(message="Can't find this knowledgebase!")
|
return get_error_data_result(message="Can't find this dataset!")
|
||||||
doc = DocumentService.query(kb_id=dataset_id, id=document_id)
|
doc = DocumentService.query(kb_id=dataset_id, id=document_id)
|
||||||
if not doc:
|
if not doc:
|
||||||
return get_error_data_result(message="The dataset doesn't own the document.")
|
return get_error_data_result(message="The dataset doesn't own the document.")
|
||||||
@ -321,9 +322,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
try:
|
try:
|
||||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||||
return get_error_data_result(message="Database error (Document update)!")
|
return get_error_data_result(message="Database error (Document update)!")
|
||||||
|
|
||||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||||
return get_result(data=True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -350,12 +349,10 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
}
|
}
|
||||||
renamed_doc = {}
|
renamed_doc = {}
|
||||||
for key, value in doc.to_dict().items():
|
for key, value in doc.to_dict().items():
|
||||||
if key == "run":
|
|
||||||
renamed_doc["run"] = run_mapping.get(str(value))
|
|
||||||
new_key = key_mapping.get(key, key)
|
new_key = key_mapping.get(key, key)
|
||||||
renamed_doc[new_key] = value
|
renamed_doc[new_key] = value
|
||||||
if key == "run":
|
if key == "run":
|
||||||
renamed_doc["run"] = run_mapping.get(value)
|
renamed_doc["run"] = run_mapping.get(str(value))
|
||||||
|
|
||||||
return get_result(data=renamed_doc)
|
return get_result(data=renamed_doc)
|
||||||
|
|
||||||
@ -555,13 +552,29 @@ def list_docs(dataset_id, tenant_id):
|
|||||||
run_status = q.getlist("run")
|
run_status = q.getlist("run")
|
||||||
create_time_from = int(q.get("create_time_from", 0))
|
create_time_from = int(q.get("create_time_from", 0))
|
||||||
create_time_to = int(q.get("create_time_to", 0))
|
create_time_to = int(q.get("create_time_to", 0))
|
||||||
|
metadata_condition_raw = q.get("metadata_condition")
|
||||||
|
metadata_condition = {}
|
||||||
|
if metadata_condition_raw:
|
||||||
|
try:
|
||||||
|
metadata_condition = json.loads(metadata_condition_raw)
|
||||||
|
except Exception:
|
||||||
|
return get_error_data_result(message="metadata_condition must be valid JSON.")
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_error_data_result(message="metadata_condition must be an object.")
|
||||||
|
|
||||||
# map run status (accept text or numeric) - align with API parameter
|
# map run status (text or numeric) - align with API parameter
|
||||||
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
||||||
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
||||||
|
|
||||||
|
doc_ids_filter = None
|
||||||
|
if metadata_condition:
|
||||||
|
metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
|
||||||
|
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||||
|
if metadata_condition.get("conditions") and not doc_ids_filter:
|
||||||
|
return get_result(data={"total": 0, "docs": []})
|
||||||
|
|
||||||
docs, total = DocumentService.get_list(
|
docs, total = DocumentService.get_list(
|
||||||
dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted
|
dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted, doc_ids_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
# time range filter (0 means no bound)
|
# time range filter (0 means no bound)
|
||||||
@ -590,6 +603,70 @@ def list_docs(dataset_id, tenant_id):
|
|||||||
|
|
||||||
return get_result(data={"total": total, "docs": output_docs})
|
return get_result(data={"total": total, "docs": output_docs})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
def metadata_summary(dataset_id, tenant_id):
|
||||||
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = DocumentService.get_metadata_summary(dataset_id)
|
||||||
|
return get_result(data={"summary": summary})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/datasets/<dataset_id>/metadata/update", methods=["POST"]) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
async def metadata_batch_update(dataset_id, tenant_id):
|
||||||
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||||
|
|
||||||
|
req = await get_request_json()
|
||||||
|
selector = req.get("selector", {}) or {}
|
||||||
|
updates = req.get("updates", []) or []
|
||||||
|
deletes = req.get("deletes", []) or []
|
||||||
|
|
||||||
|
if not isinstance(selector, dict):
|
||||||
|
return get_error_data_result(message="selector must be an object.")
|
||||||
|
if not isinstance(updates, list) or not isinstance(deletes, list):
|
||||||
|
return get_error_data_result(message="updates and deletes must be lists.")
|
||||||
|
|
||||||
|
metadata_condition = selector.get("metadata_condition", {}) or {}
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_error_data_result(message="metadata_condition must be an object.")
|
||||||
|
|
||||||
|
document_ids = selector.get("document_ids", []) or []
|
||||||
|
if document_ids and not isinstance(document_ids, list):
|
||||||
|
return get_error_data_result(message="document_ids must be a list.")
|
||||||
|
|
||||||
|
for upd in updates:
|
||||||
|
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
|
||||||
|
return get_error_data_result(message="Each update requires key and value.")
|
||||||
|
for d in deletes:
|
||||||
|
if not isinstance(d, dict) or not d.get("key"):
|
||||||
|
return get_error_data_result(message="Each delete requires key.")
|
||||||
|
|
||||||
|
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
|
||||||
|
target_doc_ids = set(kb_doc_ids)
|
||||||
|
if document_ids:
|
||||||
|
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
||||||
|
if invalid_ids:
|
||||||
|
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
|
||||||
|
target_doc_ids = set(document_ids)
|
||||||
|
|
||||||
|
if metadata_condition:
|
||||||
|
metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
|
||||||
|
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||||
|
target_doc_ids = target_doc_ids & filtered_ids
|
||||||
|
if metadata_condition.get("conditions") and not target_doc_ids:
|
||||||
|
return get_result(data={"updated": 0, "matched_docs": 0})
|
||||||
|
|
||||||
|
target_doc_ids = list(target_doc_ids)
|
||||||
|
updated = DocumentService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
|
||||||
|
return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
||||||
|
|
||||||
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
|
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def delete(tenant_id, dataset_id):
|
async def delete(tenant_id, dataset_id):
|
||||||
@ -631,7 +708,7 @@ async def delete(tenant_id, dataset_id):
|
|||||||
"""
|
"""
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not req:
|
if not req:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
else:
|
else:
|
||||||
@ -741,7 +818,7 @@ async def parse(tenant_id, dataset_id):
|
|||||||
"""
|
"""
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not req.get("document_ids"):
|
if not req.get("document_ids"):
|
||||||
return get_error_data_result("`document_ids` is required")
|
return get_error_data_result("`document_ids` is required")
|
||||||
doc_list = req.get("document_ids")
|
doc_list = req.get("document_ids")
|
||||||
@ -824,7 +901,7 @@ async def stop_parsing(tenant_id, dataset_id):
|
|||||||
"""
|
"""
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
if not req.get("document_ids"):
|
if not req.get("document_ids"):
|
||||||
return get_error_data_result("`document_ids` is required")
|
return get_error_data_result("`document_ids` is required")
|
||||||
@ -839,6 +916,8 @@ async def stop_parsing(tenant_id, dataset_id):
|
|||||||
return get_error_data_result(message=f"You don't own the document {id}.")
|
return get_error_data_result(message=f"You don't own the document {id}.")
|
||||||
if int(doc[0].progress) == 1 or doc[0].progress == 0:
|
if int(doc[0].progress) == 1 or doc[0].progress == 0:
|
||||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||||
|
# Send cancellation signal via Redis to stop background task
|
||||||
|
cancel_all_task_of(id)
|
||||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||||
@ -892,7 +971,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
type: string
|
type: string
|
||||||
required: false
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
description: Chunk Id.
|
description: Chunk id.
|
||||||
- in: header
|
- in: header
|
||||||
name: Authorization
|
name: Authorization
|
||||||
type: string
|
type: string
|
||||||
@ -1096,7 +1175,7 @@ async def add_chunk(tenant_id, dataset_id, document_id):
|
|||||||
if not doc:
|
if not doc:
|
||||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||||
doc = doc[0]
|
doc = doc[0]
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not str(req.get("content", "")).strip():
|
if not str(req.get("content", "")).strip():
|
||||||
return get_error_data_result(message="`content` is required")
|
return get_error_data_result(message="`content` is required")
|
||||||
if "important_keywords" in req:
|
if "important_keywords" in req:
|
||||||
@ -1202,7 +1281,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
|
|||||||
docs = DocumentService.get_by_ids([document_id])
|
docs = DocumentService.get_by_ids([document_id])
|
||||||
if not docs:
|
if not docs:
|
||||||
raise LookupError(f"Can't find the document with ID {document_id}!")
|
raise LookupError(f"Can't find the document with ID {document_id}!")
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
condition = {"doc_id": document_id}
|
condition = {"doc_id": document_id}
|
||||||
if "chunk_ids" in req:
|
if "chunk_ids" in req:
|
||||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||||
@ -1288,7 +1367,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
if not doc:
|
if not doc:
|
||||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||||
doc = doc[0]
|
doc = doc[0]
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if "content" in req and req["content"] is not None:
|
if "content" in req and req["content"] is not None:
|
||||||
content = req["content"]
|
content = req["content"]
|
||||||
else:
|
else:
|
||||||
@ -1411,7 +1490,7 @@ async def retrieval_test(tenant_id):
|
|||||||
format: float
|
format: float
|
||||||
description: Similarity score.
|
description: Similarity score.
|
||||||
"""
|
"""
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
if not req.get("dataset_ids"):
|
if not req.get("dataset_ids"):
|
||||||
return get_error_data_result("`dataset_ids` is required.")
|
return get_error_data_result("`dataset_ids` is required.")
|
||||||
kb_ids = req["dataset_ids"]
|
kb_ids = req["dataset_ids"]
|
||||||
@ -1446,6 +1525,9 @@ async def retrieval_test(tenant_id):
|
|||||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||||
|
# If metadata_condition has conditions but no docs match, return empty result
|
||||||
|
if not doc_ids and metadata_condition.get("conditions"):
|
||||||
|
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
|
||||||
if metadata_condition and not doc_ids:
|
if metadata_condition and not doc_ids:
|
||||||
doc_ids = ["-999"]
|
doc_ids = ["-999"]
|
||||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||||
@ -1467,11 +1549,11 @@ async def retrieval_test(tenant_id):
|
|||||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||||
|
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
question = await cross_languages(kb.tenant_id, None, question, langs)
|
||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += await keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
|
|||||||
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from quart import request, make_response
|
from quart import request, make_response
|
||||||
@ -23,14 +23,15 @@ from pathlib import Path
|
|||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils.api_utils import server_error_response, token_required
|
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils.api_utils import get_json_result
|
|
||||||
from api.utils.file_utils import filename_type
|
from api.utils.file_utils import filename_type
|
||||||
|
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||||
from common import settings
|
from common import settings
|
||||||
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||||
@ -40,7 +41,7 @@ async def upload(tenant_id):
|
|||||||
Upload a file to the system.
|
Upload a file to the system.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -86,19 +87,19 @@ async def upload(tenant_id):
|
|||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
|
|
||||||
if 'file' not in files:
|
if 'file' not in files:
|
||||||
return get_json_result(data=False, message='No file part!', code=400)
|
return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST)
|
||||||
file_objs = files.getlist('file')
|
file_objs = files.getlist('file')
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
if file_obj.filename == '':
|
if file_obj.filename == '':
|
||||||
return get_json_result(data=False, message='No selected file!', code=400)
|
return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST)
|
||||||
|
|
||||||
file_res = []
|
file_res = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, pf_folder = FileService.get_by_id(pf_id)
|
e, pf_folder = FileService.get_by_id(pf_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
# Handle file path
|
# Handle file path
|
||||||
@ -114,13 +115,13 @@ async def upload(tenant_id):
|
|||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
else:
|
else:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
|
|
||||||
@ -156,7 +157,7 @@ async def create(tenant_id):
|
|||||||
Create a new file or folder.
|
Create a new file or folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -193,16 +194,16 @@ async def create(tenant_id):
|
|||||||
type:
|
type:
|
||||||
type: string
|
type: string
|
||||||
"""
|
"""
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
pf_id = await request.json.get("parent_id")
|
pf_id = req.get("parent_id")
|
||||||
input_file_type = await request.json.get("type")
|
input_file_type = req.get("type")
|
||||||
if not pf_id:
|
if not pf_id:
|
||||||
root_folder = FileService.get_root_folder(tenant_id)
|
root_folder = FileService.get_root_folder(tenant_id)
|
||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not FileService.is_parent_folder_exist(pf_id):
|
if not FileService.is_parent_folder_exist(pf_id):
|
||||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400)
|
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
||||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
||||||
|
|
||||||
@ -229,12 +230,12 @@ async def create(tenant_id):
|
|||||||
|
|
||||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def list_files(tenant_id):
|
async def list_files(tenant_id):
|
||||||
"""
|
"""
|
||||||
List files under a specific folder.
|
List files under a specific folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -306,13 +307,13 @@ def list_files(tenant_id):
|
|||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(pf_id)
|
e, file = FileService.get_by_id(pf_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="Folder not found!", code=404)
|
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
||||||
|
|
||||||
parent_folder = FileService.get_parent_folder(pf_id)
|
parent_folder = FileService.get_parent_folder(pf_id)
|
||||||
if not parent_folder:
|
if not parent_folder:
|
||||||
return get_json_result(message="File not found!", code=404)
|
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -321,12 +322,12 @@ def list_files(tenant_id):
|
|||||||
|
|
||||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def get_root_folder(tenant_id):
|
async def get_root_folder(tenant_id):
|
||||||
"""
|
"""
|
||||||
Get user's root folder.
|
Get user's root folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
responses:
|
responses:
|
||||||
@ -357,12 +358,12 @@ def get_root_folder(tenant_id):
|
|||||||
|
|
||||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def get_parent_folder():
|
async def get_parent_folder():
|
||||||
"""
|
"""
|
||||||
Get parent folder info of a file.
|
Get parent folder info of a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -392,7 +393,7 @@ def get_parent_folder():
|
|||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="Folder not found!", code=404)
|
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
parent_folder = FileService.get_parent_folder(file_id)
|
parent_folder = FileService.get_parent_folder(file_id)
|
||||||
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
||||||
@ -402,12 +403,12 @@ def get_parent_folder():
|
|||||||
|
|
||||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def get_all_parent_folders(tenant_id):
|
async def get_all_parent_folders(tenant_id):
|
||||||
"""
|
"""
|
||||||
Get all parent folders of a file.
|
Get all parent folders of a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -439,7 +440,7 @@ def get_all_parent_folders(tenant_id):
|
|||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="Folder not found!", code=404)
|
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
parent_folders = FileService.get_all_parent_folders(file_id)
|
parent_folders = FileService.get_all_parent_folders(file_id)
|
||||||
parent_folders_res = [folder.to_json() for folder in parent_folders]
|
parent_folders_res = [folder.to_json() for folder in parent_folders]
|
||||||
@ -455,7 +456,7 @@ async def rm(tenant_id):
|
|||||||
Delete one or multiple files/folders.
|
Delete one or multiple files/folders.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -481,40 +482,40 @@ async def rm(tenant_id):
|
|||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
"""
|
"""
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
try:
|
try:
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="File or Folder not found!", code=404)
|
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_json_result(message="Tenant not found!", code=404)
|
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
if file.type == FileType.FOLDER.value:
|
if file.type == FileType.FOLDER.value:
|
||||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||||
for inner_file_id in file_id_list:
|
for inner_file_id in file_id_list:
|
||||||
e, file = FileService.get_by_id(inner_file_id)
|
e, file = FileService.get_by_id(inner_file_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="File not found!", code=404)
|
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||||
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
||||||
else:
|
else:
|
||||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||||
if not FileService.delete(file):
|
if not FileService.delete(file):
|
||||||
return get_json_result(message="Database error (File removal)!", code=500)
|
return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
informs = File2DocumentService.get_by_file_id(file_id)
|
informs = File2DocumentService.get_by_file_id(file_id)
|
||||||
for inform in informs:
|
for inform in informs:
|
||||||
doc_id = inform.document_id
|
doc_id = inform.document_id
|
||||||
e, doc = DocumentService.get_by_id(doc_id)
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="Document not found!", code=404)
|
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_json_result(message="Tenant not found!", code=404)
|
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||||
if not DocumentService.remove_document(doc, tenant_id):
|
if not DocumentService.remove_document(doc, tenant_id):
|
||||||
return get_json_result(message="Database error (Document removal)!", code=500)
|
return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR)
|
||||||
File2DocumentService.delete_by_file_id(file_id)
|
File2DocumentService.delete_by_file_id(file_id)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
@ -529,7 +530,7 @@ async def rename(tenant_id):
|
|||||||
Rename a file.
|
Rename a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -556,27 +557,27 @@ async def rename(tenant_id):
|
|||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
"""
|
"""
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(req["file_id"])
|
e, file = FileService.get_by_id(req["file_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="File not found!", code=404)
|
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||||
file.name.lower()).suffix:
|
file.name.lower()).suffix:
|
||||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
|
||||||
|
|
||||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||||
if existing_file.name == req["name"]:
|
if existing_file.name == req["name"]:
|
||||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
||||||
|
|
||||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||||
return get_json_result(message="Database error (File rename)!", code=500)
|
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
||||||
if informs:
|
if informs:
|
||||||
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
|
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
|
||||||
return get_json_result(message="Database error (Document rename)!", code=500)
|
return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -590,7 +591,7 @@ async def get(tenant_id, file_id):
|
|||||||
Download a file.
|
Download a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
produces:
|
produces:
|
||||||
@ -606,13 +607,13 @@ async def get(tenant_id, file_id):
|
|||||||
description: File stream
|
description: File stream
|
||||||
schema:
|
schema:
|
||||||
type: file
|
type: file
|
||||||
404:
|
RetCode.NOT_FOUND:
|
||||||
description: File not found
|
description: File not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="Document not found!", code=404)
|
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||||
if not blob:
|
if not blob:
|
||||||
@ -630,6 +631,19 @@ async def get(tenant_id, file_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
async def download_attachment(tenant_id,attachment_id):
|
||||||
|
try:
|
||||||
|
ext = request.args.get("ext", "markdown")
|
||||||
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||||
|
response = await make_response(data)
|
||||||
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
@ -638,7 +652,7 @@ async def move(tenant_id):
|
|||||||
Move one or multiple files to another folder.
|
Move one or multiple files to another folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -667,7 +681,7 @@ async def move(tenant_id):
|
|||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
"""
|
"""
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
file_ids = req["src_file_ids"]
|
file_ids = req["src_file_ids"]
|
||||||
parent_id = req["dest_file_id"]
|
parent_id = req["dest_file_id"]
|
||||||
@ -677,13 +691,13 @@ async def move(tenant_id):
|
|||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
file = files_dict[file_id]
|
file = files_dict[file_id]
|
||||||
if not file:
|
if not file:
|
||||||
return get_json_result(message="File or Folder not found!", code=404)
|
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_json_result(message="Tenant not found!", code=404)
|
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
fe, _ = FileService.get_by_id(parent_id)
|
fe, _ = FileService.get_by_id(parent_id)
|
||||||
if not fe:
|
if not fe:
|
||||||
return get_json_result(message="Parent Folder not found!", code=404)
|
return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
FileService.move_file(file_ids, parent_id)
|
FileService.move_file(file_ids, parent_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
@ -694,7 +708,7 @@ async def move(tenant_id):
|
|||||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def convert(tenant_id):
|
async def convert(tenant_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
kb_ids = req["kb_ids"]
|
kb_ids = req["kb_ids"]
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
file2documents = []
|
file2documents = []
|
||||||
@ -705,7 +719,7 @@ async def convert(tenant_id):
|
|||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
file = files_set[file_id]
|
file = files_set[file_id]
|
||||||
if not file:
|
if not file:
|
||||||
return get_json_result(message="File not found!", code=404)
|
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||||
file_ids_list = [file_id]
|
file_ids_list = [file_id]
|
||||||
if file.type == FileType.FOLDER.value:
|
if file.type == FileType.FOLDER.value:
|
||||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||||
@ -716,13 +730,13 @@ async def convert(tenant_id):
|
|||||||
doc_id = inform.document_id
|
doc_id = inform.document_id
|
||||||
e, doc = DocumentService.get_by_id(doc_id)
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="Document not found!", code=404)
|
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_json_result(message="Tenant not found!", code=404)
|
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||||
if not DocumentService.remove_document(doc, tenant_id):
|
if not DocumentService.remove_document(doc, tenant_id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
message="Database error (Document removal)!", code=404)
|
message="Database error (Document removal)!", code=RetCode.NOT_FOUND)
|
||||||
File2DocumentService.delete_by_file_id(id)
|
File2DocumentService.delete_by_file_id(id)
|
||||||
|
|
||||||
# insert
|
# insert
|
||||||
@ -730,11 +744,11 @@ async def convert(tenant_id):
|
|||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
message="Can't find this knowledgebase!", code=404)
|
message="Can't find this dataset!", code=RetCode.NOT_FOUND)
|
||||||
e, file = FileService.get_by_id(id)
|
e, file = FileService.get_by_id(id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
message="Can't find this file!", code=404)
|
message="Can't find this file!", code=RetCode.NOT_FOUND)
|
||||||
|
|
||||||
doc = DocumentService.insert({
|
doc = DocumentService.insert({
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import copy
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -25,27 +26,30 @@ from api.db.db_models import APIToken
|
|||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||||
from api.db.services.canvas_service import completion as agent_completion
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
|
from api.db.services.conversation_service import async_completion as rag_completion
|
||||||
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||||
get_result, server_error_response, token_required, validate_request
|
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts.template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format
|
||||||
from common.constants import RetCode, LLMType, StatusEnum
|
from common.constants import RetCode, LLMType, StatusEnum
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def create(tenant_id, chat_id):
|
async def create(tenant_id, chat_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
req["dialog_id"] = chat_id
|
req["dialog_id"] = chat_id
|
||||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||||
if not dia:
|
if not dia:
|
||||||
@ -73,7 +77,7 @@ async def create(tenant_id, chat_id):
|
|||||||
|
|
||||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def create_agent_session(tenant_id, agent_id):
|
async def create_agent_session(tenant_id, agent_id):
|
||||||
user_id = request.args.get("user_id", tenant_id)
|
user_id = request.args.get("user_id", tenant_id)
|
||||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
if not e:
|
if not e:
|
||||||
@ -98,7 +102,7 @@ def create_agent_session(tenant_id, agent_id):
|
|||||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def update(tenant_id, chat_id, session_id):
|
async def update(tenant_id, chat_id, session_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
req["dialog_id"] = chat_id
|
req["dialog_id"] = chat_id
|
||||||
conv_id = session_id
|
conv_id = session_id
|
||||||
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||||
@ -120,16 +124,38 @@ async def update(tenant_id, chat_id, session_id):
|
|||||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def chat_completion(tenant_id, chat_id):
|
async def chat_completion(tenant_id, chat_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
if not req:
|
if not req:
|
||||||
req = {"question": ""}
|
req = {"question": ""}
|
||||||
if not req.get("session_id"):
|
if not req.get("session_id"):
|
||||||
req["question"] = ""
|
req["question"] = ""
|
||||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
|
||||||
|
if not dia:
|
||||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||||
|
dia = dia[0]
|
||||||
if req.get("session_id"):
|
if req.get("session_id"):
|
||||||
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
|
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
|
||||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
||||||
|
|
||||||
|
metadata_condition = req.get("metadata_condition") or {}
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_error_data_result(message="metadata_condition must be an object.")
|
||||||
|
|
||||||
|
if metadata_condition and req.get("question"):
|
||||||
|
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||||
|
filtered_doc_ids = meta_filter(
|
||||||
|
metas,
|
||||||
|
convert_conditions(metadata_condition),
|
||||||
|
metadata_condition.get("logic", "and"),
|
||||||
|
)
|
||||||
|
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||||
|
filtered_doc_ids = ["-999"]
|
||||||
|
|
||||||
|
if filtered_doc_ids:
|
||||||
|
req["doc_ids"] = ",".join(filtered_doc_ids)
|
||||||
|
else:
|
||||||
|
req.pop("doc_ids", None)
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
@ -140,7 +166,7 @@ async def chat_completion(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in rag_completion(tenant_id, chat_id, **req):
|
async for ans in rag_completion(tenant_id, chat_id, **req):
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
@ -192,7 +218,19 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||||
],
|
],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
extra_body={"reference": reference}
|
extra_body={
|
||||||
|
"reference": reference,
|
||||||
|
"metadata_condition": {
|
||||||
|
"logic": "and",
|
||||||
|
"conditions": [
|
||||||
|
{
|
||||||
|
"name": "author",
|
||||||
|
"comparison_operator": "is",
|
||||||
|
"value": "bob"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -206,9 +244,13 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
if reference:
|
if reference:
|
||||||
print(completion.choices[0].message.reference)
|
print(completion.choices[0].message.reference)
|
||||||
"""
|
"""
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
need_reference = bool(req.get("reference", False))
|
extra_body = req.get("extra_body") or {}
|
||||||
|
if extra_body and not isinstance(extra_body, dict):
|
||||||
|
return get_error_data_result("extra_body must be an object.")
|
||||||
|
|
||||||
|
need_reference = bool(extra_body.get("reference", False))
|
||||||
|
|
||||||
messages = req.get("messages", [])
|
messages = req.get("messages", [])
|
||||||
# To prevent empty [] input
|
# To prevent empty [] input
|
||||||
@ -226,6 +268,22 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||||
dia = dia[0]
|
dia = dia[0]
|
||||||
|
|
||||||
|
metadata_condition = extra_body.get("metadata_condition") or {}
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_error_data_result(message="metadata_condition must be an object.")
|
||||||
|
|
||||||
|
doc_ids_str = None
|
||||||
|
if metadata_condition:
|
||||||
|
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||||
|
filtered_doc_ids = meta_filter(
|
||||||
|
metas,
|
||||||
|
convert_conditions(metadata_condition),
|
||||||
|
metadata_condition.get("logic", "and"),
|
||||||
|
)
|
||||||
|
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||||
|
filtered_doc_ids = ["-999"]
|
||||||
|
doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
|
||||||
|
|
||||||
# Filter system and non-sense assistant messages
|
# Filter system and non-sense assistant messages
|
||||||
msg = []
|
msg = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
@ -244,7 +302,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
# The value for the usage field on all chunks except for the last one will be null.
|
# The value for the usage field on all chunks except for the last one will be null.
|
||||||
# The usage field on the last chunk contains token usage statistics for the entire request.
|
# The usage field on the last chunk contains token usage statistics for the entire request.
|
||||||
# The choices field on the last chunk will always be an empty array [].
|
# The choices field on the last chunk will always be an empty array [].
|
||||||
def streamed_response_generator(chat_id, dia, msg):
|
async def streamed_response_generator(chat_id, dia, msg):
|
||||||
token_used = 0
|
token_used = 0
|
||||||
answer_cache = ""
|
answer_cache = ""
|
||||||
reasoning_cache = ""
|
reasoning_cache = ""
|
||||||
@ -273,7 +331,10 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||||
|
if doc_ids_str:
|
||||||
|
chat_kwargs["doc_ids"] = doc_ids_str
|
||||||
|
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||||
last_ans = ans
|
last_ans = ans
|
||||||
answer = ans["answer"]
|
answer = ans["answer"]
|
||||||
|
|
||||||
@ -325,8 +386,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
response["choices"][0]["delta"]["content"] = None
|
response["choices"][0]["delta"]["content"] = None
|
||||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||||
response["choices"][0]["finish_reason"] = "stop"
|
response["choices"][0]["finish_reason"] = "stop"
|
||||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
|
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||||
"total_tokens": len(prompt) + token_used}
|
|
||||||
if need_reference:
|
if need_reference:
|
||||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||||
@ -341,7 +401,10 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||||
|
if doc_ids_str:
|
||||||
|
chat_kwargs["doc_ids"] = doc_ids_str
|
||||||
|
async for ans in async_chat(dia, msg, False, **chat_kwargs):
|
||||||
# focus answer content only
|
# focus answer content only
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
@ -384,7 +447,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
@validate_request("model", "messages") # noqa: F821
|
@validate_request("model", "messages") # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||||
messages = req.get("messages", [])
|
messages = req.get("messages", [])
|
||||||
if not messages:
|
if not messages:
|
||||||
@ -442,11 +505,13 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def agent_completions(tenant_id, agent_id):
|
async def agent_completions(tenant_id, agent_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
|
return_trace = bool(req.get("return_trace", False))
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
|
|
||||||
async def generate():
|
async def generate():
|
||||||
|
trace_items = []
|
||||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
if isinstance(answer, str):
|
if isinstance(answer, str):
|
||||||
try:
|
try:
|
||||||
@ -454,7 +519,21 @@ async def agent_completions(tenant_id, agent_id):
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ans.get("event") not in ["message", "message_end"]:
|
event = ans.get("event")
|
||||||
|
if event == "node_finished":
|
||||||
|
if return_trace:
|
||||||
|
data = ans.get("data", {})
|
||||||
|
trace_items.append(
|
||||||
|
{
|
||||||
|
"component_id": data.get("component_id"),
|
||||||
|
"trace": [copy.deepcopy(data)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ans.setdefault("data", {})["trace"] = trace_items
|
||||||
|
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||||
|
yield answer
|
||||||
|
|
||||||
|
if event not in ["message", "message_end"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield answer
|
yield answer
|
||||||
@ -471,6 +550,7 @@ async def agent_completions(tenant_id, agent_id):
|
|||||||
full_content = ""
|
full_content = ""
|
||||||
reference = {}
|
reference = {}
|
||||||
final_ans = ""
|
final_ans = ""
|
||||||
|
trace_items = []
|
||||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
try:
|
try:
|
||||||
ans = json.loads(answer[5:])
|
ans = json.loads(answer[5:])
|
||||||
@ -481,17 +561,28 @@ async def agent_completions(tenant_id, agent_id):
|
|||||||
if ans.get("data", {}).get("reference", None):
|
if ans.get("data", {}).get("reference", None):
|
||||||
reference.update(ans["data"]["reference"])
|
reference.update(ans["data"]["reference"])
|
||||||
|
|
||||||
|
if return_trace and ans.get("event") == "node_finished":
|
||||||
|
data = ans.get("data", {})
|
||||||
|
trace_items.append(
|
||||||
|
{
|
||||||
|
"component_id": data.get("component_id"),
|
||||||
|
"trace": [copy.deepcopy(data)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
final_ans = ans
|
final_ans = ans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_result(data=f"**ERROR**: {str(e)}")
|
return get_result(data=f"**ERROR**: {str(e)}")
|
||||||
final_ans["data"]["content"] = full_content
|
final_ans["data"]["content"] = full_content
|
||||||
final_ans["data"]["reference"] = reference
|
final_ans["data"]["reference"] = reference
|
||||||
|
if return_trace and final_ans:
|
||||||
|
final_ans["data"]["trace"] = trace_items
|
||||||
return get_result(data=final_ans)
|
return get_result(data=final_ans)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def list_session(tenant_id, chat_id):
|
async def list_session(tenant_id, chat_id):
|
||||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||||
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
||||||
id = request.args.get("id")
|
id = request.args.get("id")
|
||||||
@ -545,7 +636,7 @@ def list_session(tenant_id, chat_id):
|
|||||||
|
|
||||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def list_agent_session(tenant_id, agent_id):
|
async def list_agent_session(tenant_id, agent_id):
|
||||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||||
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
||||||
id = request.args.get("id")
|
id = request.args.get("id")
|
||||||
@ -614,7 +705,7 @@ async def delete(tenant_id, chat_id):
|
|||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
success_count = 0
|
success_count = 0
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
convs = ConversationService.query(dialog_id=chat_id)
|
convs = ConversationService.query(dialog_id=chat_id)
|
||||||
if not req:
|
if not req:
|
||||||
ids = None
|
ids = None
|
||||||
@ -662,7 +753,7 @@ async def delete(tenant_id, chat_id):
|
|||||||
async def delete_agent_session(tenant_id, agent_id):
|
async def delete_agent_session(tenant_id, agent_id):
|
||||||
errors = []
|
errors = []
|
||||||
success_count = 0
|
success_count = 0
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||||
if not cvs:
|
if not cvs:
|
||||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||||
@ -715,7 +806,7 @@ async def delete_agent_session(tenant_id, agent_id):
|
|||||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def ask_about(tenant_id):
|
async def ask_about(tenant_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
if not req.get("question"):
|
if not req.get("question"):
|
||||||
return get_error_data_result("`question` is required.")
|
return get_error_data_result("`question` is required.")
|
||||||
if not req.get("dataset_ids"):
|
if not req.get("dataset_ids"):
|
||||||
@ -732,10 +823,10 @@ async def ask_about(tenant_id):
|
|||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
uid = tenant_id
|
uid = tenant_id
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
@ -754,7 +845,7 @@ async def ask_about(tenant_id):
|
|||||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def related_questions(tenant_id):
|
async def related_questions(tenant_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
if not req.get("question"):
|
if not req.get("question"):
|
||||||
return get_error_data_result("`question` is required.")
|
return get_error_data_result("`question` is required.")
|
||||||
question = req["question"]
|
question = req["question"]
|
||||||
@ -787,7 +878,7 @@ Reason:
|
|||||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -805,7 +896,7 @@ Related search terms:
|
|||||||
|
|
||||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||||
async def chatbot_completions(dialog_id):
|
async def chatbot_completions(dialog_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
|
|
||||||
token = request.headers.get("Authorization").split()
|
token = request.headers.get("Authorization").split()
|
||||||
if len(token) != 2:
|
if len(token) != 2:
|
||||||
@ -826,12 +917,12 @@ async def chatbot_completions(dialog_id):
|
|||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
for answer in iframe_completion(dialog_id, **req):
|
async for answer in iframe_completion(dialog_id, **req):
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
||||||
def chatbots_inputs(dialog_id):
|
async def chatbots_inputs(dialog_id):
|
||||||
token = request.headers.get("Authorization").split()
|
token = request.headers.get("Authorization").split()
|
||||||
if len(token) != 2:
|
if len(token) != 2:
|
||||||
return get_error_data_result(message='Authorization is not valid!"')
|
return get_error_data_result(message='Authorization is not valid!"')
|
||||||
@ -855,7 +946,7 @@ def chatbots_inputs(dialog_id):
|
|||||||
|
|
||||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||||
async def agent_bot_completions(agent_id):
|
async def agent_bot_completions(agent_id):
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
|
|
||||||
token = request.headers.get("Authorization").split()
|
token = request.headers.get("Authorization").split()
|
||||||
if len(token) != 2:
|
if len(token) != 2:
|
||||||
@ -878,7 +969,7 @@ async def agent_bot_completions(agent_id):
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
||||||
def begin_inputs(agent_id):
|
async def begin_inputs(agent_id):
|
||||||
token = request.headers.get("Authorization").split()
|
token = request.headers.get("Authorization").split()
|
||||||
if len(token) != 2:
|
if len(token) != 2:
|
||||||
return get_error_data_result(message='Authorization is not valid!"')
|
return get_error_data_result(message='Authorization is not valid!"')
|
||||||
@ -908,7 +999,7 @@ async def ask_about_embedded():
|
|||||||
if not objs:
|
if not objs:
|
||||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||||
|
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
uid = objs[0].tenant_id
|
uid = objs[0].tenant_id
|
||||||
|
|
||||||
search_id = req.get("search_id", "")
|
search_id = req.get("search_id", "")
|
||||||
@ -917,10 +1008,10 @@ async def ask_about_embedded():
|
|||||||
if search_app := SearchService.get_detail(search_id):
|
if search_app := SearchService.get_detail(search_id):
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
@ -947,7 +1038,7 @@ async def retrieval_test_embedded():
|
|||||||
if not objs:
|
if not objs:
|
||||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||||
|
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
page = int(req.get("page", 1))
|
page = int(req.get("page", 1))
|
||||||
size = int(req.get("size", 30))
|
size = int(req.get("size", 30))
|
||||||
question = req["question"]
|
question = req["question"]
|
||||||
@ -963,28 +1054,31 @@ async def retrieval_test_embedded():
|
|||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
|
||||||
|
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
|
|
||||||
|
async def _retrieval():
|
||||||
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
|
tenant_ids = []
|
||||||
|
_question = question
|
||||||
|
|
||||||
|
meta_data_filter = {}
|
||||||
|
chat_mdl = None
|
||||||
if req.get("search_id", ""):
|
if req.get("search_id", ""):
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
else:
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
meta_data_filter = req.get("meta_data_filter") or {}
|
||||||
if not doc_ids:
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
doc_ids = None
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
if meta_data_filter:
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = ["-999"]
|
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids)
|
||||||
|
|
||||||
try:
|
|
||||||
tenants = UserTenantService.query(user_id=tenant_id)
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
@ -992,7 +1086,7 @@ async def retrieval_test_embedded():
|
|||||||
tenant_ids.append(tenant.tenant_id)
|
tenant_ids.append(tenant.tenant_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.",
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.",
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
@ -1000,7 +1094,7 @@ async def retrieval_test_embedded():
|
|||||||
return get_error_data_result(message="Knowledgebase not found!")
|
return get_error_data_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -1010,15 +1104,15 @@ async def retrieval_test_embedded():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
@ -1028,6 +1122,9 @@ async def retrieval_test_embedded():
|
|||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await _retrieval()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||||
@ -1046,7 +1143,7 @@ async def related_questions_embedded():
|
|||||||
if not objs:
|
if not objs:
|
||||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||||
|
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
@ -1064,7 +1161,7 @@ async def related_questions_embedded():
|
|||||||
|
|
||||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||||
prompt = load_prompt("related_question")
|
prompt = load_prompt("related_question")
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -1081,7 +1178,7 @@ Related search terms:
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
|
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
|
||||||
def detail_share_embedded():
|
async def detail_share_embedded():
|
||||||
token = request.headers.get("Authorization").split()
|
token = request.headers.get("Authorization").split()
|
||||||
if len(token) != 2:
|
if len(token) != 2:
|
||||||
return get_error_data_result(message='Authorization is not valid!"')
|
return get_error_data_result(message='Authorization is not valid!"')
|
||||||
@ -1123,12 +1220,12 @@ async def mindmap():
|
|||||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||||
|
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
|
|
||||||
search_id = req.get("search_id", "")
|
search_id = req.get("search_id", "")
|
||||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||||
|
|
||||||
mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
||||||
if "error" in mind_map:
|
if "error" in mind_map:
|
||||||
return server_error_response(Exception(mind_map["error"]))
|
return server_error_response(Exception(mind_map["error"]))
|
||||||
return get_json_result(data=mind_map)
|
return get_json_result(data=mind_map)
|
||||||
|
|||||||
@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
|
|||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.constants import RetCode, StatusEnum
|
from common.constants import RetCode, StatusEnum
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("name")
|
@validate_request("name")
|
||||||
async def create():
|
async def create():
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
search_name = req["name"]
|
search_name = req["name"]
|
||||||
description = req.get("description", "")
|
description = req.get("description", "")
|
||||||
if not isinstance(search_name, str):
|
if not isinstance(search_name, str):
|
||||||
@ -66,7 +66,7 @@ async def create():
|
|||||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||||
async def update():
|
async def update():
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
if not isinstance(req["name"], str):
|
if not isinstance(req["name"], str):
|
||||||
return get_data_error_result(message="Search name must be string.")
|
return get_data_error_result(message="Search name must be string.")
|
||||||
if req["name"].strip() == "":
|
if req["name"].strip() == "":
|
||||||
@ -150,7 +150,7 @@ async def list_search_app():
|
|||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
|
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
owner_ids = req.get("owner_ids", [])
|
owner_ids = req.get("owner_ids", [])
|
||||||
try:
|
try:
|
||||||
if not owner_ids:
|
if not owner_ids:
|
||||||
@ -174,7 +174,7 @@ async def list_search_app():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request("search_id")
|
@validate_request("search_id")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
search_id = req["search_id"]
|
search_id = req["search_id"]
|
||||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|||||||
@ -13,8 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import logging
|
||||||
from quart import request
|
import asyncio
|
||||||
from api.db import UserTenantRole
|
from api.db import UserTenantRole
|
||||||
from api.db.db_models import UserTenant
|
from api.db.db_models import UserTenant
|
||||||
from api.db.services.user_service import UserTenantService, UserService
|
from api.db.services.user_service import UserTenantService, UserService
|
||||||
@ -22,10 +22,10 @@ from api.db.services.user_service import UserTenantService, UserService
|
|||||||
from common.constants import RetCode, StatusEnum
|
from common.constants import RetCode, StatusEnum
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.time_utils import delta_seconds
|
from common.time_utils import delta_seconds
|
||||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||||
from api.utils.web_utils import send_invite_email
|
from api.utils.web_utils import send_invite_email
|
||||||
from common import settings
|
from common import settings
|
||||||
from api.apps import smtp_mail_server, login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||||
@ -56,7 +56,7 @@ async def create(tenant_id):
|
|||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR)
|
code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
invite_user_email = req["email"]
|
invite_user_email = req["email"]
|
||||||
invite_users = UserService.query(email=invite_user_email)
|
invite_users = UserService.query(email=invite_user_email)
|
||||||
if not invite_users:
|
if not invite_users:
|
||||||
@ -81,20 +81,24 @@ async def create(tenant_id):
|
|||||||
role=UserTenantRole.INVITE,
|
role=UserTenantRole.INVITE,
|
||||||
status=StatusEnum.VALID.value)
|
status=StatusEnum.VALID.value)
|
||||||
|
|
||||||
if smtp_mail_server and settings.SMTP_CONF:
|
try:
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
user_name = ""
|
user_name = ""
|
||||||
_, user = UserService.get_by_id(current_user.id)
|
_, user = UserService.get_by_id(current_user.id)
|
||||||
if user:
|
if user:
|
||||||
user_name = user.nickname
|
user_name = user.nickname
|
||||||
|
|
||||||
Thread(
|
asyncio.create_task(
|
||||||
target=send_invite_email,
|
send_invite_email(
|
||||||
args=(invite_user_email, settings.MAIL_FRONTEND_URL, tenant_id, user_name or current_user.email),
|
to_email=invite_user_email,
|
||||||
daemon=True
|
invite_url=settings.MAIL_FRONTEND_URL,
|
||||||
).start()
|
tenant_id=tenant_id,
|
||||||
|
inviter=user_name or current_user.email
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"Failed to send invite email to {invite_user_email}: {e}")
|
||||||
|
return get_json_result(data=False, message="Failed to send invite email.", code=RetCode.SERVER_ERROR)
|
||||||
usr = invite_users[0].to_dict()
|
usr = invite_users[0].to_dict()
|
||||||
usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]}
|
usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]}
|
||||||
|
|
||||||
|
|||||||
@ -21,8 +21,9 @@ import re
|
|||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import base64
|
||||||
|
|
||||||
from quart import redirect, request, session, make_response
|
from quart import make_response, redirect, request, session
|
||||||
from werkzeug.security import check_password_hash, generate_password_hash
|
from werkzeug.security import check_password_hash, generate_password_hash
|
||||||
|
|
||||||
from api.apps.auth import get_auth_client
|
from api.apps.auth import get_auth_client
|
||||||
@ -39,12 +40,13 @@ from common.connection_utils import construct_response
|
|||||||
from api.utils.api_utils import (
|
from api.utils.api_utils import (
|
||||||
get_data_error_result,
|
get_data_error_result,
|
||||||
get_json_result,
|
get_json_result,
|
||||||
|
get_request_json,
|
||||||
server_error_response,
|
server_error_response,
|
||||||
validate_request,
|
validate_request,
|
||||||
)
|
)
|
||||||
from api.utils.crypt import decrypt
|
from api.utils.crypt import decrypt
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from api.apps import smtp_mail_server, login_required, current_user, login_user, logout_user
|
from api.apps import login_required, current_user, login_user, logout_user
|
||||||
from api.utils.web_utils import (
|
from api.utils.web_utils import (
|
||||||
send_email_html,
|
send_email_html,
|
||||||
OTP_LENGTH,
|
OTP_LENGTH,
|
||||||
@ -57,6 +59,7 @@ from api.utils.web_utils import (
|
|||||||
captcha_key,
|
captcha_key,
|
||||||
)
|
)
|
||||||
from common import settings
|
from common import settings
|
||||||
|
from common.http_client import async_request
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||||
@ -90,11 +93,14 @@ async def login():
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
json_body = await request.json
|
json_body = await get_request_json()
|
||||||
if not json_body:
|
if not json_body:
|
||||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||||
|
|
||||||
email = json_body.get("email", "")
|
email = json_body.get("email", "")
|
||||||
|
if email == "admin@ragflow.io":
|
||||||
|
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Default admin account cannot be used to login normal services!")
|
||||||
|
|
||||||
users = UserService.query(email=email)
|
users = UserService.query(email=email)
|
||||||
if not users:
|
if not users:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
@ -121,8 +127,8 @@ async def login():
|
|||||||
response_data = user.to_json()
|
response_data = user.to_json()
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
login_user(user)
|
login_user(user)
|
||||||
user.update_time = (current_timestamp(),)
|
user.update_time = current_timestamp()
|
||||||
user.update_date = (datetime_format(datetime.now()),)
|
user.update_date = datetime_format(datetime.now())
|
||||||
user.save()
|
user.save()
|
||||||
msg = "Welcome back!"
|
msg = "Welcome back!"
|
||||||
|
|
||||||
@ -136,7 +142,7 @@ async def login():
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
||||||
def get_login_channels():
|
async def get_login_channels():
|
||||||
"""
|
"""
|
||||||
Get all supported authentication channels.
|
Get all supported authentication channels.
|
||||||
"""
|
"""
|
||||||
@ -157,7 +163,7 @@ def get_login_channels():
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
||||||
def oauth_login(channel):
|
async def oauth_login(channel):
|
||||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||||
if not channel_config:
|
if not channel_config:
|
||||||
raise ValueError(f"Invalid channel name: {channel}")
|
raise ValueError(f"Invalid channel name: {channel}")
|
||||||
@ -170,7 +176,7 @@ def oauth_login(channel):
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
||||||
def oauth_callback(channel):
|
async def oauth_callback(channel):
|
||||||
"""
|
"""
|
||||||
Handle the OAuth/OIDC callback for various channels dynamically.
|
Handle the OAuth/OIDC callback for various channels dynamically.
|
||||||
"""
|
"""
|
||||||
@ -192,6 +198,9 @@ def oauth_callback(channel):
|
|||||||
return redirect("/?error=missing_code")
|
return redirect("/?error=missing_code")
|
||||||
|
|
||||||
# Exchange authorization code for access token
|
# Exchange authorization code for access token
|
||||||
|
if hasattr(auth_cli, "async_exchange_code_for_token"):
|
||||||
|
token_info = await auth_cli.async_exchange_code_for_token(code)
|
||||||
|
else:
|
||||||
token_info = auth_cli.exchange_code_for_token(code)
|
token_info = auth_cli.exchange_code_for_token(code)
|
||||||
access_token = token_info.get("access_token")
|
access_token = token_info.get("access_token")
|
||||||
if not access_token:
|
if not access_token:
|
||||||
@ -200,6 +209,9 @@ def oauth_callback(channel):
|
|||||||
id_token = token_info.get("id_token")
|
id_token = token_info.get("id_token")
|
||||||
|
|
||||||
# Fetch user info
|
# Fetch user info
|
||||||
|
if hasattr(auth_cli, "async_fetch_user_info"):
|
||||||
|
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
|
||||||
|
else:
|
||||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||||
if not user_info.email:
|
if not user_info.email:
|
||||||
return redirect("/?error=email_missing")
|
return redirect("/?error=email_missing")
|
||||||
@ -259,7 +271,7 @@ def oauth_callback(channel):
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
|
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
|
||||||
def github_callback():
|
async def github_callback():
|
||||||
"""
|
"""
|
||||||
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
||||||
|
|
||||||
@ -279,9 +291,8 @@ def github_callback():
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
import requests
|
res = await async_request(
|
||||||
|
"POST",
|
||||||
res = requests.post(
|
|
||||||
settings.GITHUB_OAUTH.get("url"),
|
settings.GITHUB_OAUTH.get("url"),
|
||||||
data={
|
data={
|
||||||
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||||
@ -299,7 +310,7 @@ def github_callback():
|
|||||||
|
|
||||||
session["access_token"] = res["access_token"]
|
session["access_token"] = res["access_token"]
|
||||||
session["access_token_from"] = "github"
|
session["access_token_from"] = "github"
|
||||||
user_info = user_info_from_github(session["access_token"])
|
user_info = await user_info_from_github(session["access_token"])
|
||||||
email_address = user_info["email"]
|
email_address = user_info["email"]
|
||||||
users = UserService.query(email=email_address)
|
users = UserService.query(email=email_address)
|
||||||
user_id = get_uuid()
|
user_id = get_uuid()
|
||||||
@ -348,7 +359,7 @@ def github_callback():
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
|
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
|
||||||
def feishu_callback():
|
async def feishu_callback():
|
||||||
"""
|
"""
|
||||||
Feishu OAuth callback endpoint.
|
Feishu OAuth callback endpoint.
|
||||||
---
|
---
|
||||||
@ -366,9 +377,8 @@ def feishu_callback():
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
import requests
|
app_access_token_res = await async_request(
|
||||||
|
"POST",
|
||||||
app_access_token_res = requests.post(
|
|
||||||
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
@ -382,7 +392,8 @@ def feishu_callback():
|
|||||||
if app_access_token_res["code"] != 0:
|
if app_access_token_res["code"] != 0:
|
||||||
return redirect("/?error=%s" % app_access_token_res)
|
return redirect("/?error=%s" % app_access_token_res)
|
||||||
|
|
||||||
res = requests.post(
|
res = await async_request(
|
||||||
|
"POST",
|
||||||
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
@ -403,7 +414,7 @@ def feishu_callback():
|
|||||||
return redirect("/?error=contact:user.email:readonly not in scope")
|
return redirect("/?error=contact:user.email:readonly not in scope")
|
||||||
session["access_token"] = res["data"]["access_token"]
|
session["access_token"] = res["data"]["access_token"]
|
||||||
session["access_token_from"] = "feishu"
|
session["access_token_from"] = "feishu"
|
||||||
user_info = user_info_from_feishu(session["access_token"])
|
user_info = await user_info_from_feishu(session["access_token"])
|
||||||
email_address = user_info["email"]
|
email_address = user_info["email"]
|
||||||
users = UserService.query(email=email_address)
|
users = UserService.query(email=email_address)
|
||||||
user_id = get_uuid()
|
user_id = get_uuid()
|
||||||
@ -451,36 +462,34 @@ def feishu_callback():
|
|||||||
return redirect("/?auth=%s" % user.get_id())
|
return redirect("/?auth=%s" % user.get_id())
|
||||||
|
|
||||||
|
|
||||||
def user_info_from_feishu(access_token):
|
async def user_info_from_feishu(access_token):
|
||||||
import requests
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
"Authorization": f"Bearer {access_token}",
|
"Authorization": f"Bearer {access_token}",
|
||||||
}
|
}
|
||||||
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||||
user_info = res.json()["data"]
|
user_info = res.json()["data"]
|
||||||
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
|
|
||||||
def user_info_from_github(access_token):
|
async def user_info_from_github(access_token):
|
||||||
import requests
|
|
||||||
|
|
||||||
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||||
user_info = res.json()
|
user_info = res.json()
|
||||||
email_info = requests.get(
|
email_info_response = await async_request(
|
||||||
|
"GET",
|
||||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
).json()
|
)
|
||||||
|
email_info = email_info_response.json()
|
||||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/logout", methods=["GET"]) # noqa: F821
|
@manager.route("/logout", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def log_out():
|
async def log_out():
|
||||||
"""
|
"""
|
||||||
User logout endpoint.
|
User logout endpoint.
|
||||||
---
|
---
|
||||||
@ -531,7 +540,7 @@ async def setting_user():
|
|||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
update_dict = {}
|
update_dict = {}
|
||||||
request_data = await request.json
|
request_data = await get_request_json()
|
||||||
if request_data.get("password"):
|
if request_data.get("password"):
|
||||||
new_password = request_data.get("new_password")
|
new_password = request_data.get("new_password")
|
||||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||||
@ -570,7 +579,7 @@ async def setting_user():
|
|||||||
|
|
||||||
@manager.route("/info", methods=["GET"]) # noqa: F821
|
@manager.route("/info", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def user_profile():
|
async def user_profile():
|
||||||
"""
|
"""
|
||||||
Get user profile information.
|
Get user profile information.
|
||||||
---
|
---
|
||||||
@ -698,7 +707,7 @@ async def user_add():
|
|||||||
code=RetCode.OPERATING_ERROR,
|
code=RetCode.OPERATING_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
email_address = req["email"]
|
email_address = req["email"]
|
||||||
|
|
||||||
# Validate the email address
|
# Validate the email address
|
||||||
@ -755,7 +764,7 @@ async def user_add():
|
|||||||
|
|
||||||
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
|
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def tenant_info():
|
async def tenant_info():
|
||||||
"""
|
"""
|
||||||
Get tenant information.
|
Get tenant information.
|
||||||
---
|
---
|
||||||
@ -831,7 +840,7 @@ async def set_tenant_info():
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
req = await request.json
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
tid = req.pop("tenant_id")
|
tid = req.pop("tenant_id")
|
||||||
TenantService.update_by_id(tid, req)
|
TenantService.update_by_id(tid, req)
|
||||||
@ -875,7 +884,7 @@ async def forget_send_otp():
|
|||||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||||
"""
|
"""
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
email = req.get("email") or ""
|
email = req.get("email") or ""
|
||||||
captcha = (req.get("captcha") or "").strip()
|
captcha = (req.get("captcha") or "").strip()
|
||||||
|
|
||||||
@ -918,47 +927,45 @@ async def forget_send_otp():
|
|||||||
|
|
||||||
ttl_min = OTP_TTL_SECONDS // 60
|
ttl_min = OTP_TTL_SECONDS // 60
|
||||||
|
|
||||||
if not smtp_mail_server:
|
|
||||||
logging.warning("SMTP mail server not initialized; skip sending email.")
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
send_email_html(
|
await send_email_html(
|
||||||
subject="Your Password Reset Code",
|
subject="Your Password Reset Code",
|
||||||
to_email=email,
|
to_email=email,
|
||||||
template_key="reset_code",
|
template_key="reset_code",
|
||||||
code=otp,
|
code=otp,
|
||||||
ttl_min=ttl_min,
|
ttl_min=ttl_min,
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
|
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
|
||||||
|
|
||||||
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
|
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/forget", methods=["POST"]) # noqa: F821
|
def _verified_key(email: str) -> str:
|
||||||
async def forget():
|
return f"otp:verified:{email}"
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/forget/verify-otp", methods=["POST"]) # noqa: F821
|
||||||
|
async def forget_verify_otp():
|
||||||
"""
|
"""
|
||||||
POST: Verify email + OTP and reset password, then log the user in.
|
Verify email + OTP only. On success:
|
||||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
- consume the OTP and attempt counters
|
||||||
|
- set a short-lived verified flag in Redis for the email
|
||||||
|
Request JSON: { email, otp }
|
||||||
"""
|
"""
|
||||||
req = await request.get_json()
|
req = await get_request_json()
|
||||||
email = req.get("email") or ""
|
email = req.get("email") or ""
|
||||||
otp = (req.get("otp") or "").strip()
|
otp = (req.get("otp") or "").strip()
|
||||||
new_pwd = req.get("new_password")
|
|
||||||
new_pwd2 = req.get("confirm_new_password")
|
|
||||||
|
|
||||||
if not all([email, otp, new_pwd, new_pwd2]):
|
if not all([email, otp]):
|
||||||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
|
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and otp are required")
|
||||||
|
|
||||||
# For reset, passwords are provided as-is (no decrypt needed)
|
|
||||||
if new_pwd != new_pwd2:
|
|
||||||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
|
||||||
|
|
||||||
users = UserService.query(email=email)
|
users = UserService.query(email=email)
|
||||||
if not users:
|
if not users:
|
||||||
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||||||
|
|
||||||
user = users[0]
|
|
||||||
# Verify OTP from Redis
|
# Verify OTP from Redis
|
||||||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||||
if REDIS_CONN.get(k_lock):
|
if REDIS_CONN.get(k_lock):
|
||||||
@ -974,7 +981,6 @@ async def forget():
|
|||||||
except Exception:
|
except Exception:
|
||||||
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
|
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
|
||||||
|
|
||||||
# Case-insensitive verification: OTP generated uppercase
|
|
||||||
calc = hash_code(otp.upper(), salt)
|
calc = hash_code(otp.upper(), salt)
|
||||||
if calc != stored_hash:
|
if calc != stored_hash:
|
||||||
# bump attempts
|
# bump attempts
|
||||||
@ -987,23 +993,70 @@ async def forget():
|
|||||||
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
|
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
|
||||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="expired otp")
|
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="expired otp")
|
||||||
|
|
||||||
# Success: consume OTP and reset password
|
# Success: consume OTP and attempts; mark verified
|
||||||
REDIS_CONN.delete(k_code)
|
REDIS_CONN.delete(k_code)
|
||||||
REDIS_CONN.delete(k_attempts)
|
REDIS_CONN.delete(k_attempts)
|
||||||
REDIS_CONN.delete(k_last)
|
REDIS_CONN.delete(k_last)
|
||||||
REDIS_CONN.delete(k_lock)
|
REDIS_CONN.delete(k_lock)
|
||||||
|
|
||||||
|
# set verified flag with limited TTL, reuse OTP_TTL_SECONDS or smaller window
|
||||||
try:
|
try:
|
||||||
UserService.update_user_password(user.id, new_pwd)
|
REDIS_CONN.set(_verified_key(email), "1", OTP_TTL_SECONDS)
|
||||||
|
except Exception:
|
||||||
|
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to set verification state")
|
||||||
|
|
||||||
|
return get_json_result(data=True, code=RetCode.SUCCESS, message="otp verified")
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/forget/reset-password", methods=["POST"]) # noqa: F821
|
||||||
|
async def forget_reset_password():
|
||||||
|
"""
|
||||||
|
Reset password after successful OTP verification.
|
||||||
|
Requires: { email, new_password, confirm_new_password }
|
||||||
|
Steps:
|
||||||
|
- check verified flag in Redis
|
||||||
|
- update user password
|
||||||
|
- auto login
|
||||||
|
- clear verified flag
|
||||||
|
"""
|
||||||
|
|
||||||
|
req = await get_request_json()
|
||||||
|
email = req.get("email") or ""
|
||||||
|
new_pwd = req.get("new_password")
|
||||||
|
new_pwd2 = req.get("confirm_new_password")
|
||||||
|
|
||||||
|
new_pwd_base64 = decrypt(new_pwd)
|
||||||
|
new_pwd_string = base64.b64decode(new_pwd_base64).decode('utf-8')
|
||||||
|
new_pwd2_string = base64.b64decode(decrypt(new_pwd2)).decode('utf-8')
|
||||||
|
|
||||||
|
REDIS_CONN.get(_verified_key(email))
|
||||||
|
if not REDIS_CONN.get(_verified_key(email)):
|
||||||
|
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="email not verified")
|
||||||
|
|
||||||
|
if not all([email, new_pwd, new_pwd2]):
|
||||||
|
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and passwords are required")
|
||||||
|
|
||||||
|
if new_pwd_string != new_pwd2_string:
|
||||||
|
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
||||||
|
|
||||||
|
users = UserService.query_user_by_email(email=email)
|
||||||
|
if not users:
|
||||||
|
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||||||
|
|
||||||
|
user = users[0]
|
||||||
|
try:
|
||||||
|
UserService.update_user_password(user.id, new_pwd_base64)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="failed to reset password")
|
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="failed to reset password")
|
||||||
|
|
||||||
# Auto login (reuse login flow)
|
# clear verified flag
|
||||||
user.access_token = get_uuid()
|
try:
|
||||||
login_user(user)
|
REDIS_CONN.delete(_verified_key(email))
|
||||||
user.update_time = (current_timestamp(),)
|
except Exception:
|
||||||
user.update_date = (datetime_format(datetime.now()),)
|
pass
|
||||||
user.save()
|
|
||||||
msg = "Password reset successful. Logged in."
|
msg = "Password reset successful. Logged in."
|
||||||
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,3 +24,5 @@ REQUEST_MAX_WAIT_SEC = 300
|
|||||||
|
|
||||||
DATASET_NAME_LIMIT = 128
|
DATASET_NAME_LIMIT = 128
|
||||||
FILE_NAME_LEN_LIMIT = 255
|
FILE_NAME_LEN_LIMIT = 255
|
||||||
|
MEMORY_NAME_LIMIT = 128
|
||||||
|
MEMORY_SIZE_LIMIT = 10*1024*1024 # Byte
|
||||||
|
|||||||
@ -749,7 +749,7 @@ class Knowledgebase(DataBaseModel):
|
|||||||
|
|
||||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||||
pagerank = IntegerField(default=0, index=False)
|
pagerank = IntegerField(default=0, index=False)
|
||||||
|
|
||||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||||
@ -774,7 +774,7 @@ class Document(DataBaseModel):
|
|||||||
kb_id = CharField(max_length=256, null=False, index=True)
|
kb_id = CharField(max_length=256, null=False, index=True)
|
||||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
|
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
|
||||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||||
created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
|
created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
|
||||||
@ -1113,6 +1113,91 @@ class SyncLogs(DataBaseModel):
|
|||||||
db_table = "sync_logs"
|
db_table = "sync_logs"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationDataset(DataBaseModel):
|
||||||
|
"""Ground truth dataset for RAG evaluation"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
|
||||||
|
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
|
||||||
|
description = TextField(null=True, help_text="dataset description")
|
||||||
|
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
|
||||||
|
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
|
||||||
|
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||||
|
update_time = BigIntegerField(null=False, help_text="last update timestamp")
|
||||||
|
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_datasets"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationCase(DataBaseModel):
|
||||||
|
"""Individual test case in an evaluation dataset"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||||
|
question = TextField(null=False, help_text="test question")
|
||||||
|
reference_answer = TextField(null=True, help_text="optional ground truth answer")
|
||||||
|
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
|
||||||
|
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
|
||||||
|
metadata = JSONField(null=True, help_text="additional context/tags")
|
||||||
|
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_cases"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationRun(DataBaseModel):
|
||||||
|
"""A single evaluation run"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||||
|
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
|
||||||
|
name = CharField(max_length=255, null=False, help_text="run name")
|
||||||
|
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
|
||||||
|
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
|
||||||
|
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
|
||||||
|
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
|
||||||
|
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||||
|
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_runs"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationResult(DataBaseModel):
|
||||||
|
"""Result for a single test case in an evaluation run"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
|
||||||
|
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
|
||||||
|
generated_answer = TextField(null=False, help_text="generated answer")
|
||||||
|
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
|
||||||
|
metrics = JSONField(null=False, help_text="all computed metrics")
|
||||||
|
execution_time = FloatField(null=False, help_text="response time in seconds")
|
||||||
|
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
|
||||||
|
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_results"
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
name = CharField(max_length=128, null=False, index=False, help_text="Memory name")
|
||||||
|
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||||
|
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
memory_type = IntegerField(null=False, default=1, index=True, help_text="Bit flags (LSB->MSB): 1=raw, 2=semantic, 4=episodic, 8=procedural. E.g., 5 enables raw + episodic.")
|
||||||
|
storage_type = CharField(max_length=32, default='table', null=False, index=True, help_text="table|graph")
|
||||||
|
embd_id = CharField(max_length=128, null=False, index=False, help_text="embedding model ID")
|
||||||
|
llm_id = CharField(max_length=128, null=False, index=False, help_text="chat model ID")
|
||||||
|
permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me")
|
||||||
|
description = TextField(null=True, help_text="description")
|
||||||
|
memory_size = IntegerField(default=5242880, null=False, index=False)
|
||||||
|
forgetting_policy = CharField(max_length=32, null=False, default="fifo", index=False, help_text="lru|fifo")
|
||||||
|
temperature = FloatField(default=0.5, index=False)
|
||||||
|
system_prompt = TextField(null=True, help_text="system prompt", index=False)
|
||||||
|
user_prompt = TextField(null=True, help_text="user prompt", index=False)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "memory"
|
||||||
|
|
||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
logging.disable(logging.ERROR)
|
logging.disable(logging.ERROR)
|
||||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||||
@ -1293,4 +1378,43 @@ def migrate_db():
|
|||||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# RAG Evaluation tables
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -73,11 +74,10 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
|||||||
UserTenantService.insert(**usr_tenant)
|
UserTenantService.insert(**usr_tenant)
|
||||||
TenantLLMService.insert_many(tenant_llm)
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Super user initialized. email: {email}, password: {password}. Changing the password after login is strongly recommended.")
|
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
||||||
|
|
||||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||||
msg = chat_mdl.chat(system="", history=[
|
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
|
||||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
|
||||||
if msg.find("ERROR: ") == 0:
|
if msg.find("ERROR: ") == 0:
|
||||||
logging.error(
|
logging.error(
|
||||||
"'{}' doesn't work. {}".format(
|
"'{}' doesn't work. {}".format(
|
||||||
|
|||||||
@ -153,7 +153,7 @@ def delete_user_data(user_id: str) -> dict:
|
|||||||
done_msg += "Start to delete owned tenant.\n"
|
done_msg += "Start to delete owned tenant.\n"
|
||||||
tenant_id = owned_tenant[0]["tenant_id"]
|
tenant_id = owned_tenant[0]["tenant_id"]
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||||
# step1.1 delete knowledgebase related file and info
|
# step1.1 delete dataset related file and info
|
||||||
if kb_ids:
|
if kb_ids:
|
||||||
# step1.1.1 delete files in storage, remove bucket
|
# step1.1.1 delete files in storage, remove bucket
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
@ -182,7 +182,7 @@ def delete_user_data(user_id: str) -> dict:
|
|||||||
search.index_name(tenant_id), kb_ids)
|
search.index_name(tenant_id), kb_ids)
|
||||||
done_msg += f"- Deleted {r} chunk records.\n"
|
done_msg += f"- Deleted {r} chunk records.\n"
|
||||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||||
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
done_msg += f"- Deleted {kb_delete_res} dataset records.\n"
|
||||||
# step1.1.4 delete agents
|
# step1.1.4 delete agents
|
||||||
agent_delete_res = delete_user_agents(usr.id)
|
agent_delete_res = delete_user_agents(usr.id)
|
||||||
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||||
@ -258,7 +258,7 @@ def delete_user_data(user_id: str) -> dict:
|
|||||||
# step2.1.5 delete document record
|
# step2.1.5 delete document record
|
||||||
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||||
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||||
# step2.1.6 update knowledge base doc&chunk&token cnt
|
# step2.1.6 update dataset doc&chunk&token cnt
|
||||||
for kb_id, doc_num in kb_doc_info.items():
|
for kb_id, doc_num in kb_doc_info.items():
|
||||||
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||||
|
|
||||||
@ -273,7 +273,7 @@ def delete_user_data(user_id: str) -> dict:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
return {"success": False, "message": "An internal error occurred during user deletion. Some operations may have completed.","details": done_msg}
|
||||||
|
|
||||||
|
|
||||||
def delete_user_agents(user_id: str) -> dict:
|
def delete_user_agents(user_id: str) -> dict:
|
||||||
|
|||||||
@ -169,10 +169,12 @@ class CommonService:
|
|||||||
"""
|
"""
|
||||||
if "id" not in kwargs:
|
if "id" not in kwargs:
|
||||||
kwargs["id"] = get_uuid()
|
kwargs["id"] = get_uuid()
|
||||||
kwargs["create_time"] = current_timestamp()
|
timestamp = current_timestamp()
|
||||||
kwargs["create_date"] = datetime_format(datetime.now())
|
cur_datetime = datetime_format(datetime.now())
|
||||||
kwargs["update_time"] = current_timestamp()
|
kwargs["create_time"] = timestamp
|
||||||
kwargs["update_date"] = datetime_format(datetime.now())
|
kwargs["create_date"] = cur_datetime
|
||||||
|
kwargs["update_time"] = timestamp
|
||||||
|
kwargs["update_date"] = cur_datetime
|
||||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
return sample_obj
|
return sample_obj
|
||||||
|
|
||||||
@ -207,10 +209,14 @@ class CommonService:
|
|||||||
data_list (list): List of dictionaries containing record data to update.
|
data_list (list): List of dictionaries containing record data to update.
|
||||||
Each dictionary must include an 'id' field.
|
Each dictionary must include an 'id' field.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
timestamp = current_timestamp()
|
||||||
|
cur_datetime = datetime_format(datetime.now())
|
||||||
|
for data in data_list:
|
||||||
|
data["update_time"] = timestamp
|
||||||
|
data["update_date"] = cur_datetime
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
data["update_time"] = current_timestamp()
|
|
||||||
data["update_date"] = datetime_format(datetime.now())
|
|
||||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from common.constants import StatusEnum
|
|||||||
from api.db.db_models import Conversation, DB
|
from api.db.db_models import Conversation, DB
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.dialog_service import DialogService, chat
|
from api.db.services.dialog_service import DialogService, async_chat
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
|||||||
conv.reference[-1] = reference
|
conv.reference[-1] = reference
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||||
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
|
||||||
assert name, "`name` can not be empty."
|
assert name, "`name` can not be empty."
|
||||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||||
assert dia, "You do not own the chat."
|
assert dia, "You do not own the chat."
|
||||||
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
yield answer
|
yield answer
|
||||||
|
|
||||||
|
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
||||||
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
|
||||||
e, dia = DialogService.get_by_id(dialog_id)
|
e, dia = DialogService.get_by_id(dialog_id)
|
||||||
assert e, "Dialog not found"
|
assert e, "Dialog not found"
|
||||||
if not session_id:
|
if not session_id:
|
||||||
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
|
|||||||
@ -21,10 +21,10 @@ from copy import deepcopy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import trio
|
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
from agentic_reasoning import DeepResearcher
|
from agentic_reasoning import DeepResearcher
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
from common.constants import LLMType, ParserType, StatusEnum
|
from common.constants import LLMType, ParserType, StatusEnum
|
||||||
from api.db.db_models import DB, Dialog
|
from api.db.db_models import DB, Dialog
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
@ -32,6 +32,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from common.time_utils import current_timestamp, datetime_format
|
from common.time_utils import current_timestamp, datetime_format
|
||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
@ -39,7 +40,7 @@ from rag.app.resume import forbidden_select_fields4resume
|
|||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||||
from common.token_utils import num_tokens_from_string
|
from common.token_utils import num_tokens_from_string
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
@ -177,7 +178,11 @@ class DialogService(CommonService):
|
|||||||
offset += limit
|
offset += limit
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def chat_solo(dialog, messages, stream=True):
|
|
||||||
|
async def async_chat_solo(dialog, messages, stream=True):
|
||||||
|
attachments = ""
|
||||||
|
if "files" in messages[-1]:
|
||||||
|
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||||
else:
|
else:
|
||||||
@ -188,10 +193,13 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
if prompt_config.get("tts"):
|
if prompt_config.get("tts"):
|
||||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||||
|
if attachments and msg:
|
||||||
|
msg[-1]["content"] += attachments
|
||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
delta_ans = ""
|
delta_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
answer = ""
|
||||||
|
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans):]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
@ -202,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||||
@ -270,84 +278,12 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|||||||
return answer, idx
|
return answer, idx
|
||||||
|
|
||||||
|
|
||||||
def convert_conditions(metadata_condition):
|
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||||
if metadata_condition is None:
|
|
||||||
metadata_condition = {}
|
|
||||||
op_mapping = {
|
|
||||||
"is": "=",
|
|
||||||
"not is": "≠"
|
|
||||||
}
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
|
||||||
"key": cond["name"],
|
|
||||||
"value": cond["value"]
|
|
||||||
}
|
|
||||||
for cond in metadata_condition.get("conditions", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
|
||||||
doc_ids = set([])
|
|
||||||
|
|
||||||
def filter_out(v2docs, operator, value):
|
|
||||||
ids = []
|
|
||||||
for input, docids in v2docs.items():
|
|
||||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
|
||||||
try:
|
|
||||||
input = float(input)
|
|
||||||
value = float(value)
|
|
||||||
except Exception:
|
|
||||||
input = str(input)
|
|
||||||
value = str(value)
|
|
||||||
|
|
||||||
for conds in [
|
|
||||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
|
||||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
|
||||||
(operator == "in", str(input).lower() in str(value).lower()),
|
|
||||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
|
||||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
|
||||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
|
||||||
(operator == "empty", not input),
|
|
||||||
(operator == "not empty", input),
|
|
||||||
(operator == "=", input == value),
|
|
||||||
(operator == "≠", input != value),
|
|
||||||
(operator == ">", input > value),
|
|
||||||
(operator == "<", input < value),
|
|
||||||
(operator == "≥", input >= value),
|
|
||||||
(operator == "≤", input <= value),
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
if all(conds):
|
|
||||||
ids.extend(docids)
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return ids
|
|
||||||
|
|
||||||
for k, v2docs in metas.items():
|
|
||||||
for f in filters:
|
|
||||||
if k != f["key"]:
|
|
||||||
continue
|
|
||||||
ids = filter_out(v2docs, f["op"], f["value"])
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = set(ids)
|
|
||||||
else:
|
|
||||||
if logic == "and":
|
|
||||||
doc_ids = doc_ids & set(ids)
|
|
||||||
else:
|
|
||||||
doc_ids = doc_ids | set(ids)
|
|
||||||
if not doc_ids:
|
|
||||||
return []
|
|
||||||
return list(doc_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
for ans in chat_solo(dialog, messages, stream):
|
async for ans in async_chat_solo(dialog, messages, stream):
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
chat_start_ts = timer()
|
chat_start_ts = timer()
|
||||||
|
|
||||||
@ -380,18 +316,21 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
retriever = settings.retriever
|
retriever = settings.retriever
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||||
|
attachments_= ""
|
||||||
if "doc_ids" in messages[-1]:
|
if "doc_ids" in messages[-1]:
|
||||||
attachments = messages[-1]["doc_ids"]
|
attachments = messages[-1]["doc_ids"]
|
||||||
|
if "files" in messages[-1]:
|
||||||
|
attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||||
|
|
||||||
prompt_config = dialog.prompt_config
|
prompt_config = dialog.prompt_config
|
||||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||||
# try to use sql if field mapping is good to go
|
# try to use sql if field mapping is good to go
|
||||||
if field_map:
|
if field_map:
|
||||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||||
if ans:
|
if ans:
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config["parameters"]:
|
||||||
if p["key"] == "knowledge":
|
if p["key"] == "knowledge":
|
||||||
@ -402,28 +341,25 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||||
|
|
||||||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||||
else:
|
else:
|
||||||
questions = questions[-1:]
|
questions = questions[-1:]
|
||||||
|
|
||||||
if prompt_config.get("cross_languages"):
|
if prompt_config.get("cross_languages"):
|
||||||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||||
|
|
||||||
if dialog.meta_data_filter:
|
if dialog.meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||||
if dialog.meta_data_filter.get("method") == "auto":
|
attachments = await apply_meta_data_filter(
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, questions[-1])
|
dialog.meta_data_filter,
|
||||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
metas,
|
||||||
if not attachments:
|
questions[-1],
|
||||||
attachments = None
|
chat_mdl,
|
||||||
elif dialog.meta_data_filter.get("method") == "manual":
|
attachments,
|
||||||
conds = dialog.meta_data_filter["manual"]
|
)
|
||||||
attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and")))
|
|
||||||
if conds and not attachments:
|
|
||||||
attachments = ["-999"]
|
|
||||||
|
|
||||||
if prompt_config.get("keyword", False):
|
if prompt_config.get("keyword", False):
|
||||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
|
||||||
|
|
||||||
refine_question_ts = timer()
|
refine_question_ts = timer()
|
||||||
|
|
||||||
@ -451,7 +387,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||||
if isinstance(think, str):
|
if isinstance(think, str):
|
||||||
thought = think
|
thought = think
|
||||||
knowledges = [t for t in think.split("\n") if t]
|
knowledges = [t for t in think.split("\n") if t]
|
||||||
@ -478,6 +414,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
|
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
||||||
if prompt_config.get("tavily_api_key"):
|
if prompt_config.get("tavily_api_key"):
|
||||||
tav = Tavily(prompt_config["tavily_api_key"])
|
tav = Tavily(prompt_config["tavily_api_key"])
|
||||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||||
@ -498,12 +435,13 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
"audio_binary": tts(tts_mdl, empty_res)}
|
"audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
return
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
gen_conf = dialog.llm_setting
|
gen_conf = dialog.llm_setting
|
||||||
|
|
||||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
|
||||||
prompt4citation = ""
|
prompt4citation = ""
|
||||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||||
prompt4citation = citation_prompt()
|
prompt4citation = citation_prompt()
|
||||||
@ -602,7 +540,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
@ -616,17 +554,17 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought + answer)
|
yield decorate_answer(thought + answer)
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
res = decorate_answer(answer)
|
res = decorate_answer(answer)
|
||||||
res["audio_binary"] = tts(tts_mdl, answer)
|
res["audio_binary"] = tts(tts_mdl, answer)
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
return None
|
return
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||||
sys_prompt = """
|
sys_prompt = """
|
||||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||||
Ensure that:
|
Ensure that:
|
||||||
@ -644,9 +582,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||||
tried_times = 0
|
tried_times = 0
|
||||||
|
|
||||||
def get_table():
|
async def get_table():
|
||||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||||
@ -672,6 +610,10 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
if kb_ids:
|
if kb_ids:
|
||||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||||
if "where" not in sql.lower():
|
if "where" not in sql.lower():
|
||||||
|
o = sql.lower().split("order by")
|
||||||
|
if len(o) > 1:
|
||||||
|
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||||
|
else:
|
||||||
sql += f" WHERE {kb_filter}"
|
sql += f" WHERE {kb_filter}"
|
||||||
else:
|
else:
|
||||||
sql += f" AND {kb_filter}"
|
sql += f" AND {kb_filter}"
|
||||||
@ -680,10 +622,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
tried_times += 1
|
tried_times += 1
|
||||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||||
|
|
||||||
tbl, sql = get_table()
|
try:
|
||||||
if tbl is None:
|
tbl, sql = await get_table()
|
||||||
return None
|
except Exception as e:
|
||||||
if tbl.get("error") and tried_times <= 2:
|
|
||||||
user_prompt = """
|
user_prompt = """
|
||||||
Table name: {};
|
Table name: {};
|
||||||
Table of database fields are as follows:
|
Table of database fields are as follows:
|
||||||
@ -697,16 +638,14 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
The SQL error you provided last time is as follows:
|
The SQL error you provided last time is as follows:
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Error issued by database as follows:
|
|
||||||
{}
|
|
||||||
|
|
||||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
|
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||||
tbl, sql = get_table()
|
try:
|
||||||
logging.debug("TRY it again: {}".format(sql))
|
tbl, sql = await get_table()
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
logging.debug("GET table: {}".format(tbl))
|
if len(tbl["rows"]) == 0:
|
||||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||||
@ -750,17 +689,51 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
"prompt": sys_prompt,
|
"prompt": sys_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
def tts(tts_mdl, text):
|
def tts(tts_mdl, text):
|
||||||
if not tts_mdl or not text:
|
if not tts_mdl or not text:
|
||||||
return None
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
bin = b""
|
bin = b""
|
||||||
|
try:
|
||||||
for chunk in tts_mdl.tts(text):
|
for chunk in tts_mdl.tts(text):
|
||||||
bin += chunk
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||||
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|
||||||
doc_ids = search_config.get("doc_ids", [])
|
doc_ids = search_config.get("doc_ids", [])
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
kb_ids = search_config.get("kb_ids", kb_ids)
|
kb_ids = search_config.get("kb_ids", kb_ids)
|
||||||
@ -783,15 +756,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
|
|
||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
@ -834,13 +799,13 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
return {"answer": answer, "reference": refs}
|
return {"answer": answer, "reference": refs}
|
||||||
|
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
yield {"answer": answer, "reference": {}}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
|
|
||||||
|
|
||||||
def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
doc_ids = search_config.get("doc_ids", [])
|
doc_ids = search_config.get("doc_ids", [])
|
||||||
rerank_id = search_config.get("rerank_id", "")
|
rerank_id = search_config.get("rerank_id", "")
|
||||||
@ -858,15 +823,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
|
|
||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
@ -884,5 +841,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
rank_feature=label_question(question, kbs),
|
rank_feature=label_question(question, kbs),
|
||||||
)
|
)
|
||||||
mindmap = MindMapExtractor(chat_mdl)
|
mindmap = MindMapExtractor(chat_mdl)
|
||||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
mind_map = await mindmap([c["content_with_weight"] for c in ranks["chunks"]])
|
||||||
return mind_map.output
|
return mind_map.output
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
@ -22,7 +23,6 @@ from copy import deepcopy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import trio
|
|
||||||
import xxhash
|
import xxhash
|
||||||
from peewee import fn, Case, JOIN
|
from peewee import fn, Case, JOIN
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_list(cls, kb_id, page_number, items_per_page,
|
def get_list(cls, kb_id, page_number, items_per_page,
|
||||||
orderby, desc, keywords, id, name, suffix=None, run = None):
|
orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None):
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_cls_model_fields()
|
||||||
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||||
.join(File, on = (File.id == File2Document.file_id))\
|
.join(File, on = (File.id == File2Document.file_id))\
|
||||||
@ -96,6 +96,8 @@ class DocumentService(CommonService):
|
|||||||
docs = docs.where(
|
docs = docs.where(
|
||||||
fn.LOWER(cls.model.name).contains(keywords.lower())
|
fn.LOWER(cls.model.name).contains(keywords.lower())
|
||||||
)
|
)
|
||||||
|
if doc_ids:
|
||||||
|
docs = docs.where(cls.model.id.in_(doc_ids))
|
||||||
if suffix:
|
if suffix:
|
||||||
docs = docs.where(cls.model.suffix.in_(suffix))
|
docs = docs.where(cls.model.suffix.in_(suffix))
|
||||||
if run:
|
if run:
|
||||||
@ -123,7 +125,7 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||||
orderby, desc, keywords, run_status, types, suffix):
|
orderby, desc, keywords, run_status, types, suffix, doc_ids=None):
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_cls_model_fields()
|
||||||
if keywords:
|
if keywords:
|
||||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||||
@ -143,6 +145,8 @@ class DocumentService(CommonService):
|
|||||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
.where(cls.model.kb_id == kb_id)
|
.where(cls.model.kb_id == kb_id)
|
||||||
|
|
||||||
|
if doc_ids:
|
||||||
|
docs = docs.where(cls.model.id.in_(doc_ids))
|
||||||
if run_status:
|
if run_status:
|
||||||
docs = docs.where(cls.model.run.in_(run_status))
|
docs = docs.where(cls.model.run.in_(run_status))
|
||||||
if types:
|
if types:
|
||||||
@ -644,6 +648,13 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_meta_by_kbs(cls, kb_ids):
|
def get_meta_by_kbs(cls, kb_ids):
|
||||||
|
"""
|
||||||
|
Legacy metadata aggregator (backward-compatible).
|
||||||
|
- Does NOT expand list values and a list is kept as one string key.
|
||||||
|
Example: {"tags": ["foo","bar"]} -> meta["tags"]["['foo', 'bar']"] = [doc_id]
|
||||||
|
- Expects meta_fields is a dict.
|
||||||
|
Use when existing callers rely on the old list-as-string semantics.
|
||||||
|
"""
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.meta_fields,
|
cls.model.meta_fields,
|
||||||
@ -660,6 +671,171 @@ class DocumentService(CommonService):
|
|||||||
meta[k][v].append(doc_id)
|
meta[k][v].append(doc_id)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_flatted_meta_by_kbs(cls, kb_ids):
|
||||||
|
"""
|
||||||
|
- Parses stringified JSON meta_fields when possible and skips non-dict or unparsable values.
|
||||||
|
- Expands list values into individual entries.
|
||||||
|
Example: {"tags": ["foo","bar"], "author": "alice"} ->
|
||||||
|
meta["tags"]["foo"] = [doc_id], meta["tags"]["bar"] = [doc_id], meta["author"]["alice"] = [doc_id]
|
||||||
|
Prefer for metadata_condition filtering and scenarios that must respect list semantics.
|
||||||
|
"""
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.meta_fields,
|
||||||
|
]
|
||||||
|
meta = {}
|
||||||
|
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
|
||||||
|
doc_id = r.id
|
||||||
|
meta_fields = r.meta_fields or {}
|
||||||
|
if isinstance(meta_fields, str):
|
||||||
|
try:
|
||||||
|
meta_fields = json.loads(meta_fields)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if not isinstance(meta_fields, dict):
|
||||||
|
continue
|
||||||
|
for k, v in meta_fields.items():
|
||||||
|
if k not in meta:
|
||||||
|
meta[k] = {}
|
||||||
|
values = v if isinstance(v, list) else [v]
|
||||||
|
for vv in values:
|
||||||
|
if vv is None:
|
||||||
|
continue
|
||||||
|
sv = str(vv)
|
||||||
|
if sv not in meta[k]:
|
||||||
|
meta[k][sv] = []
|
||||||
|
meta[k][sv].append(doc_id)
|
||||||
|
return meta
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_metadata_summary(cls, kb_id):
|
||||||
|
fields = [cls.model.id, cls.model.meta_fields]
|
||||||
|
summary = {}
|
||||||
|
for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id):
|
||||||
|
meta_fields = r.meta_fields or {}
|
||||||
|
if isinstance(meta_fields, str):
|
||||||
|
try:
|
||||||
|
meta_fields = json.loads(meta_fields)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if not isinstance(meta_fields, dict):
|
||||||
|
continue
|
||||||
|
for k, v in meta_fields.items():
|
||||||
|
values = v if isinstance(v, list) else [v]
|
||||||
|
for vv in values:
|
||||||
|
if not vv:
|
||||||
|
continue
|
||||||
|
sv = str(vv)
|
||||||
|
if k not in summary:
|
||||||
|
summary[k] = {}
|
||||||
|
summary[k][sv] = summary[k].get(sv, 0) + 1
|
||||||
|
return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None):
|
||||||
|
updates = updates or []
|
||||||
|
deletes = deletes or []
|
||||||
|
if not doc_ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _normalize_meta(meta):
|
||||||
|
if isinstance(meta, str):
|
||||||
|
try:
|
||||||
|
meta = json.loads(meta)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
return {}
|
||||||
|
return deepcopy(meta)
|
||||||
|
|
||||||
|
def _str_equal(a, b):
|
||||||
|
return str(a) == str(b)
|
||||||
|
|
||||||
|
def _apply_updates(meta):
|
||||||
|
changed = False
|
||||||
|
for upd in updates:
|
||||||
|
key = upd.get("key")
|
||||||
|
if not key or key not in meta:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_value = upd.get("value")
|
||||||
|
match_provided = "match" in upd
|
||||||
|
if isinstance(meta[key], list):
|
||||||
|
if not match_provided:
|
||||||
|
meta[key] = new_value
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
match_value = upd.get("match")
|
||||||
|
replaced = False
|
||||||
|
new_list = []
|
||||||
|
for item in meta[key]:
|
||||||
|
if _str_equal(item, match_value):
|
||||||
|
new_list.append(new_value)
|
||||||
|
replaced = True
|
||||||
|
else:
|
||||||
|
new_list.append(item)
|
||||||
|
if replaced:
|
||||||
|
meta[key] = new_list
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
if not match_provided:
|
||||||
|
meta[key] = new_value
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
match_value = upd.get("match")
|
||||||
|
if _str_equal(meta[key], match_value):
|
||||||
|
meta[key] = new_value
|
||||||
|
changed = True
|
||||||
|
return changed
|
||||||
|
|
||||||
|
def _apply_deletes(meta):
|
||||||
|
changed = False
|
||||||
|
for d in deletes:
|
||||||
|
key = d.get("key")
|
||||||
|
if not key or key not in meta:
|
||||||
|
continue
|
||||||
|
value = d.get("value", None)
|
||||||
|
if isinstance(meta[key], list):
|
||||||
|
if value is None:
|
||||||
|
del meta[key]
|
||||||
|
changed = True
|
||||||
|
continue
|
||||||
|
new_list = [item for item in meta[key] if not _str_equal(item, value)]
|
||||||
|
if len(new_list) != len(meta[key]):
|
||||||
|
if new_list:
|
||||||
|
meta[key] = new_list
|
||||||
|
else:
|
||||||
|
del meta[key]
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
if value is None or _str_equal(meta[key], value):
|
||||||
|
del meta[key]
|
||||||
|
changed = True
|
||||||
|
return changed
|
||||||
|
|
||||||
|
updated_docs = 0
|
||||||
|
with DB.atomic():
|
||||||
|
rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
|
||||||
|
(cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id)
|
||||||
|
)
|
||||||
|
for r in rows:
|
||||||
|
meta = _normalize_meta(r.meta_fields or {})
|
||||||
|
original_meta = deepcopy(meta)
|
||||||
|
changed = _apply_updates(meta)
|
||||||
|
changed = _apply_deletes(meta) or changed
|
||||||
|
if changed and meta != original_meta:
|
||||||
|
cls.model.update(
|
||||||
|
meta_fields=meta,
|
||||||
|
update_time=current_timestamp(),
|
||||||
|
update_date=get_format_time()
|
||||||
|
).where(cls.model.id == r.id).execute()
|
||||||
|
updated_docs += 1
|
||||||
|
return updated_docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_progress(cls):
|
def update_progress(cls):
|
||||||
@ -719,10 +895,14 @@ class DocumentService(CommonService):
|
|||||||
# only for special task and parsed docs and unfinished
|
# only for special task and parsed docs and unfinished
|
||||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||||
msg = "\n".join(sorted(msg))
|
msg = "\n".join(sorted(msg))
|
||||||
|
begin_at = d.get("process_begin_at")
|
||||||
|
if not begin_at:
|
||||||
|
begin_at = datetime.now()
|
||||||
|
# fallback
|
||||||
|
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"process_duration": datetime.timestamp(
|
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
|
||||||
datetime.now()) -
|
|
||||||
d["process_begin_at"].timestamp(),
|
|
||||||
"run": status}
|
"run": status}
|
||||||
if prg != 0 and not freeze_progress:
|
if prg != 0 and not freeze_progress:
|
||||||
info["progress"] = prg
|
info["progress"] = prg
|
||||||
@ -902,12 +1082,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
|
|
||||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||||
if not dia.kb_ids:
|
if not dia.kb_ids:
|
||||||
raise LookupError("No knowledge base associated with this conversation. "
|
raise LookupError("No dataset associated with this conversation. "
|
||||||
"Please add a knowledge base before uploading documents")
|
"Please add a dataset before uploading documents")
|
||||||
kb_id = dia.kb_ids[0]
|
kb_id = dia.kb_ids[0]
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this dataset!")
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||||
|
|
||||||
@ -923,7 +1103,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
ParserType.AUDIO.value: audio,
|
ParserType.AUDIO.value: audio,
|
||||||
ParserType.EMAIL.value: email
|
ParserType.EMAIL.value: email
|
||||||
}
|
}
|
||||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
|
||||||
exe = ThreadPoolExecutor(max_workers=12)
|
exe = ThreadPoolExecutor(max_workers=12)
|
||||||
threads = []
|
threads = []
|
||||||
doc_nm = {}
|
doc_nm = {}
|
||||||
@ -995,7 +1175,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
mindmap = MindMapExtractor(llm_bdl)
|
mindmap = MindMapExtractor(llm_bdl)
|
||||||
try:
|
try:
|
||||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
|
||||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||||
if len(mind_map) < 32:
|
if len(mind_map) < 32:
|
||||||
raise Exception("Few content: " + mind_map)
|
raise Exception("Few content: " + mind_map)
|
||||||
|
|||||||
637
api/db/services/evaluation_service.py
Normal file
637
api/db/services/evaluation_service.py
Normal file
@ -0,0 +1,637 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
RAG Evaluation Service
|
||||||
|
|
||||||
|
Provides functionality for evaluating RAG system performance including:
|
||||||
|
- Dataset management
|
||||||
|
- Test case management
|
||||||
|
- Evaluation execution
|
||||||
|
- Metrics computation
|
||||||
|
- Configuration recommendations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.dialog_service import DialogService
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
from common.time_utils import current_timestamp
|
||||||
|
from common.constants import StatusEnum
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationService(CommonService):
|
||||||
|
"""Service for managing RAG evaluations"""
|
||||||
|
|
||||||
|
model = EvaluationDataset
|
||||||
|
|
||||||
|
# ==================== Dataset Management ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||||
|
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Create a new evaluation dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Dataset name
|
||||||
|
description: Dataset description
|
||||||
|
kb_ids: List of knowledge base IDs to evaluate against
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
user_id: User ID who creates the dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, dataset_id or error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
dataset_id = get_uuid()
|
||||||
|
dataset = {
|
||||||
|
"id": dataset_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"kb_ids": kb_ids,
|
||||||
|
"created_by": user_id,
|
||||||
|
"create_time": current_timestamp(),
|
||||||
|
"update_time": current_timestamp(),
|
||||||
|
"status": StatusEnum.VALID.value
|
||||||
|
}
|
||||||
|
|
||||||
|
if not EvaluationDataset.create(**dataset):
|
||||||
|
return False, "Failed to create dataset"
|
||||||
|
|
||||||
|
return True, dataset_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error creating evaluation dataset: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get dataset by ID"""
|
||||||
|
try:
|
||||||
|
dataset = EvaluationDataset.get_by_id(dataset_id)
|
||||||
|
if dataset:
|
||||||
|
return dataset.to_dict()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||||
|
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
||||||
|
"""List datasets for a tenant"""
|
||||||
|
try:
|
||||||
|
query = EvaluationDataset.select().where(
|
||||||
|
(EvaluationDataset.tenant_id == tenant_id) &
|
||||||
|
(EvaluationDataset.status == StatusEnum.VALID.value)
|
||||||
|
).order_by(EvaluationDataset.create_time.desc())
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
datasets = query.paginate(page, page_size)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"datasets": [d.to_dict() for d in datasets]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error listing datasets: {e}")
|
||||||
|
return {"total": 0, "datasets": []}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
||||||
|
"""Update dataset"""
|
||||||
|
try:
|
||||||
|
kwargs["update_time"] = current_timestamp()
|
||||||
|
return EvaluationDataset.update(**kwargs).where(
|
||||||
|
EvaluationDataset.id == dataset_id
|
||||||
|
).execute() > 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_dataset(cls, dataset_id: str) -> bool:
|
||||||
|
"""Soft delete dataset"""
|
||||||
|
try:
|
||||||
|
return EvaluationDataset.update(
|
||||||
|
status=StatusEnum.INVALID.value,
|
||||||
|
update_time=current_timestamp()
|
||||||
|
).where(EvaluationDataset.id == dataset_id).execute() > 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ==================== Test Case Management ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_test_case(cls, dataset_id: str, question: str,
|
||||||
|
reference_answer: Optional[str] = None,
|
||||||
|
relevant_doc_ids: Optional[List[str]] = None,
|
||||||
|
relevant_chunk_ids: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Add a test case to a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset ID
|
||||||
|
question: Test question
|
||||||
|
reference_answer: Optional ground truth answer
|
||||||
|
relevant_doc_ids: Optional list of relevant document IDs
|
||||||
|
relevant_chunk_ids: Optional list of relevant chunk IDs
|
||||||
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, case_id or error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
case_id = get_uuid()
|
||||||
|
case = {
|
||||||
|
"id": case_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"question": question,
|
||||||
|
"reference_answer": reference_answer,
|
||||||
|
"relevant_doc_ids": relevant_doc_ids,
|
||||||
|
"relevant_chunk_ids": relevant_chunk_ids,
|
||||||
|
"metadata": metadata,
|
||||||
|
"create_time": current_timestamp()
|
||||||
|
}
|
||||||
|
|
||||||
|
if not EvaluationCase.create(**case):
|
||||||
|
return False, "Failed to create test case"
|
||||||
|
|
||||||
|
return True, case_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error adding test case: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all test cases for a dataset"""
|
||||||
|
try:
|
||||||
|
cases = EvaluationCase.select().where(
|
||||||
|
EvaluationCase.dataset_id == dataset_id
|
||||||
|
).order_by(EvaluationCase.create_time)
|
||||||
|
|
||||||
|
return [c.to_dict() for c in cases]
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_test_case(cls, case_id: str) -> bool:
|
||||||
|
"""Delete a test case"""
|
||||||
|
try:
|
||||||
|
return EvaluationCase.delete().where(
|
||||||
|
EvaluationCase.id == case_id
|
||||||
|
).execute() > 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error deleting test case {case_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Bulk import test cases from a list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset ID
|
||||||
|
cases: List of test case dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success_count, failure_count)
|
||||||
|
"""
|
||||||
|
success_count = 0
|
||||||
|
failure_count = 0
|
||||||
|
|
||||||
|
for case_data in cases:
|
||||||
|
success, _ = cls.add_test_case(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
question=case_data.get("question", ""),
|
||||||
|
reference_answer=case_data.get("reference_answer"),
|
||||||
|
relevant_doc_ids=case_data.get("relevant_doc_ids"),
|
||||||
|
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||||
|
metadata=case_data.get("metadata")
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
failure_count += 1
|
||||||
|
|
||||||
|
return success_count, failure_count
|
||||||
|
|
||||||
|
# ==================== Evaluation Execution ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||||
|
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Start an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset ID
|
||||||
|
dialog_id: Dialog configuration to evaluate
|
||||||
|
user_id: User ID who starts the run
|
||||||
|
name: Optional run name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, run_id or error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get dialog configuration
|
||||||
|
success, dialog = DialogService.get_by_id(dialog_id)
|
||||||
|
if not success:
|
||||||
|
return False, "Dialog not found"
|
||||||
|
|
||||||
|
# Create evaluation run
|
||||||
|
run_id = get_uuid()
|
||||||
|
if not name:
|
||||||
|
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
run = {
|
||||||
|
"id": run_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"dialog_id": dialog_id,
|
||||||
|
"name": name,
|
||||||
|
"config_snapshot": dialog.to_dict(),
|
||||||
|
"metrics_summary": None,
|
||||||
|
"status": "RUNNING",
|
||||||
|
"created_by": user_id,
|
||||||
|
"create_time": current_timestamp(),
|
||||||
|
"complete_time": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if not EvaluationRun.create(**run):
|
||||||
|
return False, "Failed to create evaluation run"
|
||||||
|
|
||||||
|
# Execute evaluation asynchronously (in production, use task queue)
|
||||||
|
# For now, we'll execute synchronously
|
||||||
|
cls._execute_evaluation(run_id, dataset_id, dialog)
|
||||||
|
|
||||||
|
return True, run_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error starting evaluation: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
||||||
|
"""
|
||||||
|
Execute evaluation for all test cases.
|
||||||
|
|
||||||
|
This method runs the RAG pipeline for each test case and computes metrics.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get all test cases
|
||||||
|
test_cases = cls.get_test_cases(dataset_id)
|
||||||
|
|
||||||
|
if not test_cases:
|
||||||
|
EvaluationRun.update(
|
||||||
|
status="FAILED",
|
||||||
|
complete_time=current_timestamp()
|
||||||
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute each test case
|
||||||
|
results = []
|
||||||
|
for case in test_cases:
|
||||||
|
result = cls._evaluate_single_case(run_id, case, dialog)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Compute summary metrics
|
||||||
|
metrics_summary = cls._compute_summary_metrics(results)
|
||||||
|
|
||||||
|
# Update run status
|
||||||
|
EvaluationRun.update(
|
||||||
|
status="COMPLETED",
|
||||||
|
metrics_summary=metrics_summary,
|
||||||
|
complete_time=current_timestamp()
|
||||||
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error executing evaluation {run_id}: {e}")
|
||||||
|
EvaluationRun.update(
|
||||||
|
status="FAILED",
|
||||||
|
complete_time=current_timestamp()
|
||||||
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||||
|
dialog: Any) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Evaluate a single test case.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Evaluation run ID
|
||||||
|
case: Test case dictionary
|
||||||
|
dialog: Dialog configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result dictionary or None if failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prepare messages
|
||||||
|
messages = [{"role": "user", "content": case["question"]}]
|
||||||
|
|
||||||
|
# Execute RAG pipeline
|
||||||
|
start_time = timer()
|
||||||
|
answer = ""
|
||||||
|
retrieved_chunks = []
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_from_async_gen(async_gen):
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
async def consume():
|
||||||
|
try:
|
||||||
|
async for item in async_gen:
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
|
loop.run_until_complete(consume())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
|
from api.db.services.dialog_service import async_chat
|
||||||
|
|
||||||
|
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
|
||||||
|
|
||||||
|
for ans in chat(dialog, messages, stream=False):
|
||||||
|
if isinstance(ans, dict):
|
||||||
|
answer = ans.get("answer", "")
|
||||||
|
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
||||||
|
break
|
||||||
|
|
||||||
|
execution_time = timer() - start_time
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = cls._compute_metrics(
|
||||||
|
question=case["question"],
|
||||||
|
generated_answer=answer,
|
||||||
|
reference_answer=case.get("reference_answer"),
|
||||||
|
retrieved_chunks=retrieved_chunks,
|
||||||
|
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
||||||
|
dialog=dialog
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save result
|
||||||
|
result_id = get_uuid()
|
||||||
|
result = {
|
||||||
|
"id": result_id,
|
||||||
|
"run_id": run_id,
|
||||||
|
"case_id": case["id"],
|
||||||
|
"generated_answer": answer,
|
||||||
|
"retrieved_chunks": retrieved_chunks,
|
||||||
|
"metrics": metrics,
|
||||||
|
"execution_time": execution_time,
|
||||||
|
"token_usage": None, # TODO: Track token usage
|
||||||
|
"create_time": current_timestamp()
|
||||||
|
}
|
||||||
|
|
||||||
|
EvaluationResult.create(**result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_metrics(cls, question: str, generated_answer: str,
|
||||||
|
reference_answer: Optional[str],
|
||||||
|
retrieved_chunks: List[Dict[str, Any]],
|
||||||
|
relevant_chunk_ids: Optional[List[str]],
|
||||||
|
dialog: Any) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute evaluation metrics for a single test case.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of metric names to values
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
# Retrieval metrics (if ground truth chunks provided)
|
||||||
|
if relevant_chunk_ids:
|
||||||
|
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
||||||
|
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
||||||
|
|
||||||
|
# Generation metrics
|
||||||
|
if generated_answer:
|
||||||
|
# Basic metrics
|
||||||
|
metrics["answer_length"] = len(generated_answer)
|
||||||
|
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
||||||
|
|
||||||
|
# TODO: Implement advanced metrics using LLM-as-judge
|
||||||
|
# - Faithfulness (hallucination detection)
|
||||||
|
# - Answer relevance
|
||||||
|
# - Context relevance
|
||||||
|
# - Semantic similarity (if reference answer provided)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||||
|
relevant_ids: List[str]) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute retrieval metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retrieved_ids: List of retrieved chunk IDs
|
||||||
|
relevant_ids: List of relevant chunk IDs (ground truth)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of retrieval metrics
|
||||||
|
"""
|
||||||
|
if not relevant_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
retrieved_set = set(retrieved_ids)
|
||||||
|
relevant_set = set(relevant_ids)
|
||||||
|
|
||||||
|
# Precision: proportion of retrieved that are relevant
|
||||||
|
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
||||||
|
|
||||||
|
# Recall: proportion of relevant that were retrieved
|
||||||
|
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
||||||
|
|
||||||
|
# F1 score
|
||||||
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||||
|
|
||||||
|
# Hit rate: whether any relevant chunk was retrieved
|
||||||
|
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
||||||
|
|
||||||
|
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
||||||
|
mrr = 0.0
|
||||||
|
for i, chunk_id in enumerate(retrieved_ids, 1):
|
||||||
|
if chunk_id in relevant_set:
|
||||||
|
mrr = 1.0 / i
|
||||||
|
break
|
||||||
|
|
||||||
|
return {
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1_score": f1,
|
||||||
|
"hit_rate": hit_rate,
|
||||||
|
"mrr": mrr
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Compute summary metrics across all test cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of result dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary metrics dictionary
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
metric_sums = {}
|
||||||
|
metric_counts = {}
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
metrics = result.get("metrics", {})
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
metric_sums[key] = metric_sums.get(key, 0) + value
|
||||||
|
metric_counts[key] = metric_counts.get(key, 0) + 1
|
||||||
|
|
||||||
|
# Compute averages
|
||||||
|
summary = {
|
||||||
|
"total_cases": len(results),
|
||||||
|
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in metric_sums:
|
||||||
|
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
# ==================== Results & Analysis ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get results for an evaluation run"""
|
||||||
|
try:
|
||||||
|
run = EvaluationRun.get_by_id(run_id)
|
||||||
|
if not run:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
results = EvaluationResult.select().where(
|
||||||
|
EvaluationResult.run_id == run_id
|
||||||
|
).order_by(EvaluationResult.create_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"run": run.to_dict(),
|
||||||
|
"results": [r.to_dict() for r in results]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting run results {run_id}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Analyze evaluation results and provide configuration recommendations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Evaluation run ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recommendation dictionaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
run = EvaluationRun.get_by_id(run_id)
|
||||||
|
if not run or not run.metrics_summary:
|
||||||
|
return []
|
||||||
|
|
||||||
|
metrics = run.metrics_summary
|
||||||
|
recommendations = []
|
||||||
|
|
||||||
|
# Low precision: retrieving irrelevant chunks
|
||||||
|
if metrics.get("avg_precision", 1.0) < 0.7:
|
||||||
|
recommendations.append({
|
||||||
|
"issue": "Low Precision",
|
||||||
|
"severity": "high",
|
||||||
|
"description": "System is retrieving many irrelevant chunks",
|
||||||
|
"suggestions": [
|
||||||
|
"Increase similarity_threshold to filter out less relevant chunks",
|
||||||
|
"Enable reranking to improve chunk ordering",
|
||||||
|
"Reduce top_k to return fewer chunks"
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Low recall: missing relevant chunks
|
||||||
|
if metrics.get("avg_recall", 1.0) < 0.7:
|
||||||
|
recommendations.append({
|
||||||
|
"issue": "Low Recall",
|
||||||
|
"severity": "high",
|
||||||
|
"description": "System is missing relevant chunks",
|
||||||
|
"suggestions": [
|
||||||
|
"Increase top_k to retrieve more chunks",
|
||||||
|
"Lower similarity_threshold to be more inclusive",
|
||||||
|
"Enable hybrid search (keyword + semantic)",
|
||||||
|
"Check chunk size - may be too large or too small"
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Slow response time
|
||||||
|
if metrics.get("avg_execution_time", 0) > 5.0:
|
||||||
|
recommendations.append({
|
||||||
|
"issue": "Slow Response Time",
|
||||||
|
"severity": "medium",
|
||||||
|
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
|
||||||
|
"suggestions": [
|
||||||
|
"Reduce top_k to retrieve fewer chunks",
|
||||||
|
"Optimize embedding model selection",
|
||||||
|
"Consider caching frequently asked questions"
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
return recommendations
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
||||||
|
return []
|
||||||
@ -13,10 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
@ -89,13 +94,13 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_id_by_file_id(cls, file_id):
|
def get_kb_id_by_file_id(cls, file_id):
|
||||||
# Get knowledge base IDs associated with a file
|
# Get dataset IDs associated with a file
|
||||||
# Args:
|
# Args:
|
||||||
# file_id: File ID
|
# file_id: File ID
|
||||||
# Returns:
|
# Returns:
|
||||||
# List of dictionaries containing knowledge base IDs and names
|
# List of dictionaries containing dataset IDs and names
|
||||||
kbs = (
|
kbs = (
|
||||||
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
|
cls.model.select(*[Knowledgebase.id, Knowledgebase.name, File2Document.document_id])
|
||||||
.join(File2Document, on=(File2Document.file_id == file_id))
|
.join(File2Document, on=(File2Document.file_id == file_id))
|
||||||
.join(Document, on=(File2Document.document_id == Document.id))
|
.join(Document, on=(File2Document.document_id == Document.id))
|
||||||
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
||||||
@ -105,7 +110,7 @@ class FileService(CommonService):
|
|||||||
return []
|
return []
|
||||||
kbs_info_list = []
|
kbs_info_list = []
|
||||||
for kb in list(kbs.dicts()):
|
for kb in list(kbs.dicts()):
|
||||||
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
|
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"], "document_id": kb["document_id"]})
|
||||||
return kbs_info_list
|
return kbs_info_list
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -242,7 +247,7 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_folder(cls, tenant_id):
|
def get_kb_folder(cls, tenant_id):
|
||||||
# Get knowledge base folder for tenant
|
# Get dataset folder for tenant
|
||||||
# Args:
|
# Args:
|
||||||
# tenant_id: Tenant ID
|
# tenant_id: Tenant ID
|
||||||
# Returns:
|
# Returns:
|
||||||
@ -258,7 +263,7 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
|
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
|
||||||
# Create a new file from knowledge base
|
# Create a new file from dataset
|
||||||
# Args:
|
# Args:
|
||||||
# tenant_id: Tenant ID
|
# tenant_id: Tenant ID
|
||||||
# name: File name
|
# name: File name
|
||||||
@ -287,7 +292,7 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def init_knowledgebase_docs(cls, root_id, tenant_id):
|
def init_knowledgebase_docs(cls, root_id, tenant_id):
|
||||||
# Initialize knowledge base documents
|
# Initialize dataset documents
|
||||||
# Args:
|
# Args:
|
||||||
# root_id: Root folder ID
|
# root_id: Root folder ID
|
||||||
# tenant_id: Tenant ID
|
# tenant_id: Tenant ID
|
||||||
@ -520,7 +525,7 @@ class FileService(CommonService):
|
|||||||
if img_base64 and file_type == FileType.VISUAL.value:
|
if img_base64 and file_type == FileType.VISUAL.value:
|
||||||
return GptV4.image2base64(blob)
|
return GptV4.image2base64(blob)
|
||||||
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
|
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
|
||||||
return "\n".join([ck["content_with_weight"] for ck in cks])
|
return f"\n -----------------\nFile: {filename}\nContent as following: \n" + "\n".join([ck["content_with_weight"] for ck in cks])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_parser(doc_type, filename, default):
|
def get_parser(doc_type, filename, default):
|
||||||
@ -588,3 +593,80 @@ class FileService(CommonService):
|
|||||||
errors += str(e)
|
errors += str(e)
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upload_info(user_id, file, url: str|None=None):
|
||||||
|
def structured(filename, filetype, blob, content_type):
|
||||||
|
nonlocal user_id
|
||||||
|
if filetype == FileType.PDF.value:
|
||||||
|
blob = read_potential_broken_pdf(blob)
|
||||||
|
|
||||||
|
location = get_uuid()
|
||||||
|
FileService.put_blob(user_id, location, blob)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": location,
|
||||||
|
"name": filename,
|
||||||
|
"size": sys.getsizeof(blob),
|
||||||
|
"extension": filename.split(".")[-1].lower(),
|
||||||
|
"mime_type": content_type,
|
||||||
|
"created_by": user_id,
|
||||||
|
"created_at": time.time(),
|
||||||
|
"preview_url": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if url:
|
||||||
|
from crawl4ai import (
|
||||||
|
AsyncWebCrawler,
|
||||||
|
BrowserConfig,
|
||||||
|
CrawlerRunConfig,
|
||||||
|
DefaultMarkdownGenerator,
|
||||||
|
PruningContentFilter,
|
||||||
|
CrawlResult
|
||||||
|
)
|
||||||
|
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||||
|
async def adownload():
|
||||||
|
browser_config = BrowserConfig(
|
||||||
|
headless=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
|
crawler_config = CrawlerRunConfig(
|
||||||
|
markdown_generator=DefaultMarkdownGenerator(
|
||||||
|
content_filter=PruningContentFilter()
|
||||||
|
),
|
||||||
|
pdf=True,
|
||||||
|
screenshot=False
|
||||||
|
)
|
||||||
|
result: CrawlResult = await crawler.arun(
|
||||||
|
url=url,
|
||||||
|
config=crawler_config
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
page = asyncio.run(adownload())
|
||||||
|
if page.pdf:
|
||||||
|
if filename.split(".")[-1].lower() != "pdf":
|
||||||
|
filename += ".pdf"
|
||||||
|
return structured(filename, "pdf", page.pdf, page.response_headers["content-type"])
|
||||||
|
|
||||||
|
return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)
|
||||||
|
|
||||||
|
DocumentService.check_doc_health(user_id, file.filename)
|
||||||
|
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_files(files: Union[None, list[dict]]) -> list[str]:
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
def image_to_base64(file):
|
||||||
|
return "data:{};base64,{}".format(file["mime_type"],
|
||||||
|
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||||
|
exe = ThreadPoolExecutor(max_workers=5)
|
||||||
|
threads = []
|
||||||
|
for file in files:
|
||||||
|
if file["mime_type"].find("image") >=0:
|
||||||
|
threads.append(exe.submit(image_to_base64, file))
|
||||||
|
continue
|
||||||
|
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||||
|
return [th.result() for th in threads]
|
||||||
|
|
||||||
|
|||||||
@ -30,9 +30,9 @@ from api.utils.api_utils import get_parser_config, get_data_error_result
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgebaseService(CommonService):
|
class KnowledgebaseService(CommonService):
|
||||||
"""Service class for managing knowledge base operations.
|
"""Service class for managing dataset operations.
|
||||||
|
|
||||||
This class extends CommonService to provide specialized functionality for knowledge base
|
This class extends CommonService to provide specialized functionality for dataset
|
||||||
management, including document parsing status tracking, access control, and configuration
|
management, including document parsing status tracking, access control, and configuration
|
||||||
management. It handles operations such as listing, creating, updating, and deleting
|
management. It handles operations such as listing, creating, updating, and deleting
|
||||||
knowledge bases, as well as managing their associated documents and permissions.
|
knowledge bases, as well as managing their associated documents and permissions.
|
||||||
@ -41,7 +41,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
- Document parsing status verification
|
- Document parsing status verification
|
||||||
- Knowledge base access control
|
- Knowledge base access control
|
||||||
- Parser configuration management
|
- Parser configuration management
|
||||||
- Tenant-based knowledge base organization
|
- Tenant-based dataset organization
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
model: The Knowledgebase model class for database operations.
|
model: The Knowledgebase model class for database operations.
|
||||||
@ -51,18 +51,18 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def accessible4deletion(cls, kb_id, user_id):
|
def accessible4deletion(cls, kb_id, user_id):
|
||||||
"""Check if a knowledge base can be deleted by a specific user.
|
"""Check if a dataset can be deleted by a specific user.
|
||||||
|
|
||||||
This method verifies whether a user has permission to delete a knowledge base
|
This method verifies whether a user has permission to delete a dataset
|
||||||
by checking if they are the creator of that knowledge base.
|
by checking if they are the creator of that dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kb_id (str): The unique identifier of the knowledge base to check.
|
kb_id (str): The unique identifier of the dataset to check.
|
||||||
user_id (str): The unique identifier of the user attempting the deletion.
|
user_id (str): The unique identifier of the user attempting the deletion.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the user has permission to delete the knowledge base,
|
bool: True if the user has permission to delete the dataset,
|
||||||
False if the user doesn't have permission or the knowledge base doesn't exist.
|
False if the user doesn't have permission or the dataset doesn't exist.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> KnowledgebaseService.accessible4deletion("kb123", "user456")
|
>>> KnowledgebaseService.accessible4deletion("kb123", "user456")
|
||||||
@ -71,10 +71,10 @@ class KnowledgebaseService(CommonService):
|
|||||||
Note:
|
Note:
|
||||||
- This method only checks creator permissions
|
- This method only checks creator permissions
|
||||||
- A return value of False can mean either:
|
- A return value of False can mean either:
|
||||||
1. The knowledge base doesn't exist
|
1. The dataset doesn't exist
|
||||||
2. The user is not the creator of the knowledge base
|
2. The user is not the creator of the dataset
|
||||||
"""
|
"""
|
||||||
# Check if a knowledge base can be deleted by a user
|
# Check if a dataset can be deleted by a user
|
||||||
docs = cls.model.select(
|
docs = cls.model.select(
|
||||||
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
|
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
|
||||||
docs = docs.dicts()
|
docs = docs.dicts()
|
||||||
@ -85,7 +85,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def is_parsed_done(cls, kb_id):
|
def is_parsed_done(cls, kb_id):
|
||||||
# Check if all documents in the knowledge base have completed parsing
|
# Check if all documents in the dataset have completed parsing
|
||||||
#
|
#
|
||||||
# Args:
|
# Args:
|
||||||
# kb_id: Knowledge base ID
|
# kb_id: Knowledge base ID
|
||||||
@ -96,13 +96,13 @@ class KnowledgebaseService(CommonService):
|
|||||||
from common.constants import TaskStatus
|
from common.constants import TaskStatus
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
|
|
||||||
# Get knowledge base information
|
# Get dataset information
|
||||||
kbs = cls.query(id=kb_id)
|
kbs = cls.query(id=kb_id)
|
||||||
if not kbs:
|
if not kbs:
|
||||||
return False, "Knowledge base not found"
|
return False, "Knowledge base not found"
|
||||||
kb = kbs[0]
|
kb = kbs[0]
|
||||||
|
|
||||||
# Get all documents in the knowledge base
|
# Get all documents in the dataset
|
||||||
docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "", [], [])
|
docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "", [], [])
|
||||||
|
|
||||||
# Check parsing status of each document
|
# Check parsing status of each document
|
||||||
@ -119,9 +119,9 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def list_documents_by_ids(cls, kb_ids):
|
def list_documents_by_ids(cls, kb_ids):
|
||||||
# Get document IDs associated with given knowledge base IDs
|
# Get document IDs associated with given dataset IDs
|
||||||
# Args:
|
# Args:
|
||||||
# kb_ids: List of knowledge base IDs
|
# kb_ids: List of dataset IDs
|
||||||
# Returns:
|
# Returns:
|
||||||
# List of document IDs
|
# List of document IDs
|
||||||
doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
|
doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
|
||||||
@ -235,11 +235,11 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_ids(cls, tenant_id):
|
def get_kb_ids(cls, tenant_id):
|
||||||
# Get all knowledge base IDs for a tenant
|
# Get all dataset IDs for a tenant
|
||||||
# Args:
|
# Args:
|
||||||
# tenant_id: Tenant ID
|
# tenant_id: Tenant ID
|
||||||
# Returns:
|
# Returns:
|
||||||
# List of knowledge base IDs
|
# List of dataset IDs
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
]
|
]
|
||||||
@ -250,11 +250,11 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_detail(cls, kb_id):
|
def get_detail(cls, kb_id):
|
||||||
# Get detailed information about a knowledge base
|
# Get detailed information about a dataset
|
||||||
# Args:
|
# Args:
|
||||||
# kb_id: Knowledge base ID
|
# kb_id: Knowledge base ID
|
||||||
# Returns:
|
# Returns:
|
||||||
# Dictionary containing knowledge base details
|
# Dictionary containing dataset details
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.embd_id,
|
cls.model.embd_id,
|
||||||
@ -294,13 +294,13 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_parser_config(cls, id, config):
|
def update_parser_config(cls, id, config):
|
||||||
# Update parser configuration for a knowledge base
|
# Update parser configuration for a dataset
|
||||||
# Args:
|
# Args:
|
||||||
# id: Knowledge base ID
|
# id: Knowledge base ID
|
||||||
# config: New parser configuration
|
# config: New parser configuration
|
||||||
e, m = cls.get_by_id(id)
|
e, m = cls.get_by_id(id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError(f"knowledgebase({id}) not found.")
|
raise LookupError(f"dataset({id}) not found.")
|
||||||
|
|
||||||
def dfs_update(old, new):
|
def dfs_update(old, new):
|
||||||
# Deep update of nested configuration
|
# Deep update of nested configuration
|
||||||
@ -325,7 +325,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
def delete_field_map(cls, id):
|
def delete_field_map(cls, id):
|
||||||
e, m = cls.get_by_id(id)
|
e, m = cls.get_by_id(id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError(f"knowledgebase({id}) not found.")
|
raise LookupError(f"dataset({id}) not found.")
|
||||||
|
|
||||||
m.parser_config.pop("field_map", None)
|
m.parser_config.pop("field_map", None)
|
||||||
cls.update_by_id(id, {"parser_config": m.parser_config})
|
cls.update_by_id(id, {"parser_config": m.parser_config})
|
||||||
@ -335,7 +335,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
def get_field_map(cls, ids):
|
def get_field_map(cls, ids):
|
||||||
# Get field mappings for knowledge bases
|
# Get field mappings for knowledge bases
|
||||||
# Args:
|
# Args:
|
||||||
# ids: List of knowledge base IDs
|
# ids: List of dataset IDs
|
||||||
# Returns:
|
# Returns:
|
||||||
# Dictionary of field mappings
|
# Dictionary of field mappings
|
||||||
conf = {}
|
conf = {}
|
||||||
@ -347,7 +347,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_name(cls, kb_name, tenant_id):
|
def get_by_name(cls, kb_name, tenant_id):
|
||||||
# Get knowledge base by name and tenant ID
|
# Get dataset by name and tenant ID
|
||||||
# Args:
|
# Args:
|
||||||
# kb_name: Knowledge base name
|
# kb_name: Knowledge base name
|
||||||
# tenant_id: Tenant ID
|
# tenant_id: Tenant ID
|
||||||
@ -365,9 +365,9 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_all_ids(cls):
|
def get_all_ids(cls):
|
||||||
# Get all knowledge base IDs
|
# Get all dataset IDs
|
||||||
# Returns:
|
# Returns:
|
||||||
# List of all knowledge base IDs
|
# List of all dataset IDs
|
||||||
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
|
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
|
||||||
|
|
||||||
|
|
||||||
@ -471,7 +471,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def accessible(cls, kb_id, user_id):
|
def accessible(cls, kb_id, user_id):
|
||||||
# Check if a knowledge base is accessible by a user
|
# Check if a dataset is accessible by a user
|
||||||
# Args:
|
# Args:
|
||||||
# kb_id: Knowledge base ID
|
# kb_id: Knowledge base ID
|
||||||
# user_id: User ID
|
# user_id: User ID
|
||||||
@ -488,12 +488,12 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_by_id(cls, kb_id, user_id):
|
def get_kb_by_id(cls, kb_id, user_id):
|
||||||
# Get knowledge base by ID and user ID
|
# Get dataset by ID and user ID
|
||||||
# Args:
|
# Args:
|
||||||
# kb_id: Knowledge base ID
|
# kb_id: Knowledge base ID
|
||||||
# user_id: User ID
|
# user_id: User ID
|
||||||
# Returns:
|
# Returns:
|
||||||
# List containing knowledge base information
|
# List containing dataset information
|
||||||
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||||
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||||
kbs = kbs.dicts()
|
kbs = kbs.dicts()
|
||||||
@ -502,12 +502,12 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_by_name(cls, kb_name, user_id):
|
def get_kb_by_name(cls, kb_name, user_id):
|
||||||
# Get knowledge base by name and user ID
|
# Get dataset by name and user ID
|
||||||
# Args:
|
# Args:
|
||||||
# kb_name: Knowledge base name
|
# kb_name: Knowledge base name
|
||||||
# user_id: User ID
|
# user_id: User ID
|
||||||
# Returns:
|
# Returns:
|
||||||
# List containing knowledge base information
|
# List containing dataset information
|
||||||
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||||
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
|
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
|
||||||
kbs = kbs.dicts()
|
kbs = kbs.dicts()
|
||||||
|
|||||||
@ -13,16 +13,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
import re
|
import re
|
||||||
from common.token_utils import num_tokens_from_string
|
import threading
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from common.constants import LLMType
|
|
||||||
from api.db.db_models import LLM
|
from api.db.db_models import LLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||||
|
from common.constants import LLMType
|
||||||
|
from common.token_utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class LLMService(CommonService):
|
class LLMService(CommonService):
|
||||||
@ -31,6 +35,7 @@ class LLMService(CommonService):
|
|||||||
|
|
||||||
def get_init_tenant_llm(user_id):
|
def get_init_tenant_llm(user_id):
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
|
|
||||||
model_configs = {
|
model_configs = {
|
||||||
@ -104,7 +109,7 @@ class LLMBundle(LLM4Tenant):
|
|||||||
|
|
||||||
llm_name = getattr(self, "llm_name", None)
|
llm_name = getattr(self, "llm_name", None)
|
||||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||||
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
logging.error("LLMBundle.encode can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(usage_details={"total_tokens": used_tokens})
|
generation.update(usage_details={"total_tokens": used_tokens})
|
||||||
@ -119,7 +124,7 @@ class LLMBundle(LLM4Tenant):
|
|||||||
emd, used_tokens = self.mdl.encode_queries(query)
|
emd, used_tokens = self.mdl.encode_queries(query)
|
||||||
llm_name = getattr(self, "llm_name", None)
|
llm_name = getattr(self, "llm_name", None)
|
||||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||||
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
logging.error("LLMBundle.encode_queries can't update token usage for <tenant redacted>/EMBEDDING used_tokens: {}".format(used_tokens))
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(usage_details={"total_tokens": used_tokens})
|
generation.update(usage_details={"total_tokens": used_tokens})
|
||||||
@ -183,6 +188,68 @@ class LLMBundle(LLM4Tenant):
|
|||||||
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
def stream_transcription(self, audio):
|
||||||
|
mdl = self.mdl
|
||||||
|
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
|
||||||
|
if supports_stream:
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(
|
||||||
|
trace_context=self.trace_context,
|
||||||
|
name="stream_transcription",
|
||||||
|
metadata={"model": self.llm_name},
|
||||||
|
)
|
||||||
|
final_text = ""
|
||||||
|
used_tokens = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for evt in mdl.stream_transcription(audio):
|
||||||
|
if evt.get("event") == "final":
|
||||||
|
final_text = evt.get("text", "")
|
||||||
|
|
||||||
|
yield evt
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
err = {"event": "error", "text": str(e)}
|
||||||
|
yield err
|
||||||
|
final_text = final_text or ""
|
||||||
|
finally:
|
||||||
|
if final_text:
|
||||||
|
used_tokens = num_tokens_from_string(final_text)
|
||||||
|
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
|
||||||
|
|
||||||
|
if self.langfuse:
|
||||||
|
generation.update(
|
||||||
|
output={"output": final_text},
|
||||||
|
usage_details={"total_tokens": used_tokens},
|
||||||
|
)
|
||||||
|
generation.end()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(
|
||||||
|
trace_context=self.trace_context,
|
||||||
|
name="stream_transcription",
|
||||||
|
metadata={"model": self.llm_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
full_text, used_tokens = mdl.transcription(audio)
|
||||||
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
|
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
|
||||||
|
|
||||||
|
if self.langfuse:
|
||||||
|
generation.update(
|
||||||
|
output={"output": full_text},
|
||||||
|
usage_details={"total_tokens": used_tokens},
|
||||||
|
)
|
||||||
|
generation.end()
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"event": "final",
|
||||||
|
"text": full_text,
|
||||||
|
"streaming": False,
|
||||||
|
}
|
||||||
|
|
||||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
|
||||||
@ -227,46 +294,132 @@ class LLMBundle(LLM4Tenant):
|
|||||||
return kwargs
|
return kwargs
|
||||||
else:
|
else:
|
||||||
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
||||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
|
||||||
|
def _run_coroutine_sync(self, coro):
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
try:
|
||||||
|
result_queue.put((True, asyncio.run(coro)))
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put((False, e))
|
||||||
|
|
||||||
|
thread = threading.Thread(target=runner, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
success, value = result_queue.get_nowait()
|
||||||
|
if success:
|
||||||
|
return value
|
||||||
|
raise value
|
||||||
|
|
||||||
|
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
async def consume():
|
||||||
|
try:
|
||||||
|
async for item in async_gen_fn(*args, **kwargs):
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
|
loop.run_until_complete(consume())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def _bridge_sync_stream(self, gen):
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
for item in gen:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
|
except Exception as e:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
|
|
||||||
|
threading.Thread(target=worker, daemon=True).start()
|
||||||
|
return queue
|
||||||
|
|
||||||
|
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
|
||||||
|
base_fn = self.mdl.async_chat_with_tools
|
||||||
|
elif hasattr(self.mdl, "async_chat"):
|
||||||
|
base_fn = self.mdl.async_chat
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||||
if self.is_tools and self.mdl.is_tools:
|
|
||||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
|
||||||
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
txt, used_tokens = chat_partial(**use_kwargs)
|
|
||||||
txt = self._remove_reasoning_content(txt)
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
txt, used_tokens = await chat_partial(**use_kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"error": str(e)})
|
||||||
|
generation.end()
|
||||||
|
raise
|
||||||
|
|
||||||
|
txt = self._remove_reasoning_content(txt)
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
|
|
||||||
if self.langfuse:
|
if generation:
|
||||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||||
generation.end()
|
generation.end()
|
||||||
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
|
total_tokens = 0
|
||||||
|
ans = ""
|
||||||
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||||
|
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||||
|
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||||
|
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
ans = ""
|
if stream_fn:
|
||||||
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||||
total_tokens = 0
|
|
||||||
if self.is_tools and self.mdl.is_tools:
|
|
||||||
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
for txt in chat_partial(**use_kwargs):
|
try:
|
||||||
|
async for txt in chat_partial(**use_kwargs):
|
||||||
if isinstance(txt, int):
|
if isinstance(txt, int):
|
||||||
total_tokens = txt
|
total_tokens = txt
|
||||||
if self.langfuse:
|
|
||||||
generation.update(output={"output": ans})
|
|
||||||
generation.end()
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if txt.endswith("</think>"):
|
if txt.endswith("</think>"):
|
||||||
@ -277,7 +430,14 @@ class LLMBundle(LLM4Tenant):
|
|||||||
|
|
||||||
ans += txt
|
ans += txt
|
||||||
yield ans
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
if total_tokens > 0:
|
if generation:
|
||||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
generation.update(output={"error": str(e)})
|
||||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
generation.end()
|
||||||
|
raise
|
||||||
|
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
|
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||||
|
generation.end()
|
||||||
|
return
|
||||||
|
|||||||
150
api/db/services/memory_service.py
Normal file
150
api/db/services/memory_service.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from api.apps import current_user
|
||||||
|
from api.db.db_models import DB, Memory, User
|
||||||
|
from api.db.services import duplicate_name
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.utils.memory_utils import calculate_memory_type
|
||||||
|
from api.constants import MEMORY_NAME_LIMIT
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
from common.time_utils import get_format_time, current_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryService(CommonService):
|
||||||
|
# Service class for manage memory operations
|
||||||
|
model = Memory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_memory_id(cls, memory_id: str):
|
||||||
|
return cls.model.select().where(cls.model.id == memory_id).first()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_with_owner_name_by_id(cls, memory_id: str):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
User.nickname.alias("owner_name"),
|
||||||
|
cls.model.memory_type,
|
||||||
|
cls.model.storage_type,
|
||||||
|
cls.model.embd_id,
|
||||||
|
cls.model.llm_id,
|
||||||
|
cls.model.permissions,
|
||||||
|
cls.model.description,
|
||||||
|
cls.model.memory_size,
|
||||||
|
cls.model.forgetting_policy,
|
||||||
|
cls.model.temperature,
|
||||||
|
cls.model.system_prompt,
|
||||||
|
cls.model.user_prompt
|
||||||
|
]
|
||||||
|
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
||||||
|
cls.model.id == memory_id
|
||||||
|
).first()
|
||||||
|
return memory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_size: int = 50):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
User.nickname.alias("owner_name"),
|
||||||
|
cls.model.memory_type,
|
||||||
|
cls.model.storage_type,
|
||||||
|
cls.model.permissions,
|
||||||
|
cls.model.description
|
||||||
|
]
|
||||||
|
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
||||||
|
if filter_dict.get("tenant_id"):
|
||||||
|
memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"]))
|
||||||
|
if filter_dict.get("memory_type"):
|
||||||
|
memory_type_int = calculate_memory_type(filter_dict["memory_type"])
|
||||||
|
memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0)
|
||||||
|
if filter_dict.get("storage_type"):
|
||||||
|
memories = memories.where(cls.model.storage_type == filter_dict["storage_type"])
|
||||||
|
if keywords:
|
||||||
|
memories = memories.where(cls.model.name.contains(keywords))
|
||||||
|
count = memories.count()
|
||||||
|
memories = memories.order_by(cls.model.update_time.desc())
|
||||||
|
memories = memories.paginate(page, page_size)
|
||||||
|
|
||||||
|
return list(memories.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str):
|
||||||
|
# Deduplicate name within tenant
|
||||||
|
memory_name = duplicate_name(
|
||||||
|
cls.query,
|
||||||
|
name=name,
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||||
|
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
|
||||||
|
|
||||||
|
# build create dict
|
||||||
|
memory_info = {
|
||||||
|
"id": get_uuid(),
|
||||||
|
"name": memory_name,
|
||||||
|
"memory_type": calculate_memory_type(memory_type),
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"embd_id": embd_id,
|
||||||
|
"llm_id": llm_id,
|
||||||
|
"create_time": current_timestamp(),
|
||||||
|
"create_date": get_format_time(),
|
||||||
|
"update_time": current_timestamp(),
|
||||||
|
"update_date": get_format_time(),
|
||||||
|
}
|
||||||
|
obj = cls.model(**memory_info).save(force_insert=True)
|
||||||
|
|
||||||
|
if not obj:
|
||||||
|
return False, "Could not create new memory."
|
||||||
|
|
||||||
|
db_row = cls.model.select().where(cls.model.id == memory_info["id"]).first()
|
||||||
|
|
||||||
|
return obj, db_row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_memory(cls, memory_id: str, update_dict: dict):
|
||||||
|
if not update_dict:
|
||||||
|
return 0
|
||||||
|
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
||||||
|
update_dict["temperature"] = float(update_dict["temperature"])
|
||||||
|
if "name" in update_dict:
|
||||||
|
update_dict["name"] = duplicate_name(
|
||||||
|
cls.query,
|
||||||
|
name=update_dict["name"],
|
||||||
|
tenant_id=current_user.id
|
||||||
|
)
|
||||||
|
update_dict.update({
|
||||||
|
"update_time": current_timestamp(),
|
||||||
|
"update_date": get_format_time()
|
||||||
|
})
|
||||||
|
|
||||||
|
return cls.model.update(update_dict).where(cls.model.id == memory_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_memory(cls, memory_id: str):
|
||||||
|
return cls.model.delete().where(cls.model.id == memory_id).execute()
|
||||||
@ -121,7 +121,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
else:
|
else:
|
||||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
raise RuntimeError(f"Cannot find dataset {document.kb_id} for referred_document {referred_document_id}")
|
||||||
|
|
||||||
tenant_id = kb_info.tenant_id
|
tenant_id = kb_info.tenant_id
|
||||||
title = document.parser_id
|
title = document.parser_id
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class TaskService(CommonService):
|
|||||||
"""Retrieve detailed task information by task ID.
|
"""Retrieve detailed task information by task ID.
|
||||||
|
|
||||||
This method fetches comprehensive task details including associated document,
|
This method fetches comprehensive task details including associated document,
|
||||||
knowledge base, and tenant information. It also handles task retry logic and
|
dataset, and tenant information. It also handles task retry logic and
|
||||||
progress updates.
|
progress updates.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -121,6 +121,13 @@ class TaskService(CommonService):
|
|||||||
.where(cls.model.id == task_id)
|
.where(cls.model.id == task_id)
|
||||||
)
|
)
|
||||||
docs = list(docs.dicts())
|
docs = list(docs.dicts())
|
||||||
|
# Assuming docs = list(docs.dicts())
|
||||||
|
if docs:
|
||||||
|
kb_config = docs[0]['kb_parser_config'] # Dict from Knowledgebase.parser_config
|
||||||
|
mineru_method = kb_config.get('mineru_parse_method', 'auto')
|
||||||
|
mineru_formula = kb_config.get('mineru_formula_enable', True)
|
||||||
|
mineru_table = kb_config.get('mineru_table_enable', True)
|
||||||
|
print(mineru_method, mineru_formula, mineru_table)
|
||||||
if not docs:
|
if not docs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -14,15 +14,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from peewee import IntegrityError
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import LLMType
|
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType
|
||||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel
|
||||||
|
|
||||||
|
|
||||||
class LLMFactoriesService(CommonService):
|
class LLMFactoriesService(CommonService):
|
||||||
@ -104,6 +106,10 @@ class TenantLLMService(CommonService):
|
|||||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||||
elif llm_type == LLMType.TTS:
|
elif llm_type == LLMType.TTS:
|
||||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.OCR:
|
||||||
|
if not llm_name:
|
||||||
|
raise LookupError("OCR model name is required")
|
||||||
|
mdlnm = llm_name
|
||||||
else:
|
else:
|
||||||
assert False, "LLM type error"
|
assert False, "LLM type error"
|
||||||
|
|
||||||
@ -137,31 +143,31 @@ class TenantLLMService(CommonService):
|
|||||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
base_url=model_config["api_base"])
|
base_url=model_config["api_base"])
|
||||||
|
|
||||||
if llm_type == LLMType.RERANK:
|
elif llm_type == LLMType.RERANK:
|
||||||
if model_config["llm_factory"] not in RerankModel:
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
return None
|
return None
|
||||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
base_url=model_config["api_base"])
|
base_url=model_config["api_base"])
|
||||||
|
|
||||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
if model_config["llm_factory"] not in CvModel:
|
if model_config["llm_factory"] not in CvModel:
|
||||||
return None
|
return None
|
||||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
||||||
base_url=model_config["api_base"], **kwargs)
|
base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
if llm_type == LLMType.CHAT.value:
|
elif llm_type == LLMType.CHAT.value:
|
||||||
if model_config["llm_factory"] not in ChatModel:
|
if model_config["llm_factory"] not in ChatModel:
|
||||||
return None
|
return None
|
||||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
base_url=model_config["api_base"], **kwargs)
|
base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
if llm_type == LLMType.SPEECH2TEXT:
|
elif llm_type == LLMType.SPEECH2TEXT:
|
||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
return None
|
return None
|
||||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
||||||
model_name=model_config["llm_name"], lang=lang,
|
model_name=model_config["llm_name"], lang=lang,
|
||||||
base_url=model_config["api_base"])
|
base_url=model_config["api_base"])
|
||||||
if llm_type == LLMType.TTS:
|
elif llm_type == LLMType.TTS:
|
||||||
if model_config["llm_factory"] not in TTSModel:
|
if model_config["llm_factory"] not in TTSModel:
|
||||||
return None
|
return None
|
||||||
return TTSModel[model_config["llm_factory"]](
|
return TTSModel[model_config["llm_factory"]](
|
||||||
@ -169,6 +175,17 @@ class TenantLLMService(CommonService):
|
|||||||
model_config["llm_name"],
|
model_config["llm_name"],
|
||||||
base_url=model_config["api_base"],
|
base_url=model_config["api_base"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif llm_type == LLMType.OCR:
|
||||||
|
if model_config["llm_factory"] not in OcrModel:
|
||||||
|
return None
|
||||||
|
return OcrModel[model_config["llm_factory"]](
|
||||||
|
key=model_config["api_key"],
|
||||||
|
model_name=model_config["llm_name"],
|
||||||
|
base_url=model_config.get("api_base", ""),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -186,6 +203,7 @@ class TenantLLMService(CommonService):
|
|||||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||||
|
LLMType.OCR.value: llm_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
mdlnm = llm_map.get(llm_type)
|
mdlnm = llm_map.get(llm_type)
|
||||||
@ -218,6 +236,68 @@ class TenantLLMService(CommonService):
|
|||||||
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _collect_mineru_env_config(cls) -> dict | None:
|
||||||
|
cfg = MINERU_DEFAULT_CONFIG
|
||||||
|
found = False
|
||||||
|
for key in MINERU_ENV_KEYS:
|
||||||
|
val = os.environ.get(key)
|
||||||
|
if val:
|
||||||
|
found = True
|
||||||
|
cfg[key] = val
|
||||||
|
return cfg if found else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def ensure_mineru_from_env(cls, tenant_id: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Ensure a MinerU OCR model exists for the tenant if env variables are present.
|
||||||
|
Return the existing or newly created llm_name, or None if env not set.
|
||||||
|
"""
|
||||||
|
cfg = cls._collect_mineru_env_config()
|
||||||
|
if not cfg:
|
||||||
|
return None
|
||||||
|
|
||||||
|
saved_mineru_models = cls.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value)
|
||||||
|
|
||||||
|
def _parse_api_key(raw: str) -> dict:
|
||||||
|
try:
|
||||||
|
return json.loads(raw or "{}")
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
for item in saved_mineru_models:
|
||||||
|
api_cfg = _parse_api_key(item.api_key)
|
||||||
|
normalized = {k: api_cfg.get(k, MINERU_DEFAULT_CONFIG.get(k)) for k in MINERU_ENV_KEYS}
|
||||||
|
if normalized == cfg:
|
||||||
|
return item.llm_name
|
||||||
|
|
||||||
|
used_names = {item.llm_name for item in saved_mineru_models}
|
||||||
|
idx = 1
|
||||||
|
base_name = "mineru-from-env"
|
||||||
|
while True:
|
||||||
|
candidate = f"{base_name}-{idx}"
|
||||||
|
if candidate in used_names:
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
cls.save(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
llm_factory="MinerU",
|
||||||
|
llm_name=candidate,
|
||||||
|
model_type=LLMType.OCR.value,
|
||||||
|
api_key=json.dumps(cfg),
|
||||||
|
api_base="",
|
||||||
|
max_tokens=0,
|
||||||
|
)
|
||||||
|
return candidate
|
||||||
|
except IntegrityError:
|
||||||
|
logging.warning("MinerU env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
|
||||||
|
used_names.add(candidate)
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete_by_tenant_id(cls, tenant_id):
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
|
|||||||
@ -25,13 +25,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import faulthandler
|
import faulthandler
|
||||||
|
|
||||||
from api.apps import app, smtp_mail_server
|
from api.apps import app
|
||||||
from api.db.runtime_config import RuntimeConfig
|
from api.db.runtime_config import RuntimeConfig
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
@ -69,7 +68,7 @@ def signal_handler(sig, frame):
|
|||||||
logging.info("Received interrupt signal, shutting down...")
|
logging.info("Received interrupt signal, shutting down...")
|
||||||
shutdown_all_mcp_sessions()
|
shutdown_all_mcp_sessions()
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
time.sleep(1)
|
stop_event.wait(1)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -144,18 +143,6 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
threading.Timer(1.0, delayed_start_update_progress).start()
|
threading.Timer(1.0, delayed_start_update_progress).start()
|
||||||
|
|
||||||
# init smtp server
|
|
||||||
if settings.SMTP_CONF:
|
|
||||||
app.config["MAIL_SERVER"] = settings.MAIL_SERVER
|
|
||||||
app.config["MAIL_PORT"] = settings.MAIL_PORT
|
|
||||||
app.config["MAIL_USE_SSL"] = settings.MAIL_USE_SSL
|
|
||||||
app.config["MAIL_USE_TLS"] = settings.MAIL_USE_TLS
|
|
||||||
app.config["MAIL_USERNAME"] = settings.MAIL_USERNAME
|
|
||||||
app.config["MAIL_PASSWORD"] = settings.MAIL_PASSWORD
|
|
||||||
app.config["MAIL_DEFAULT_SENDER"] = settings.MAIL_DEFAULT_SENDER
|
|
||||||
smtp_mail_server.init_app(app)
|
|
||||||
|
|
||||||
|
|
||||||
# start http server
|
# start http server
|
||||||
try:
|
try:
|
||||||
logging.info("RAGFlow HTTP server start...")
|
logging.info("RAGFlow HTTP server start...")
|
||||||
@ -163,5 +150,5 @@ if __name__ == '__main__':
|
|||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
time.sleep(1)
|
stop_event.wait(1)
|
||||||
os.kill(os.getpid(), signal.SIGKILL)
|
os.kill(os.getpid(), signal.SIGKILL)
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@ -22,9 +23,9 @@ import os
|
|||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import trio
|
|
||||||
from quart import (
|
from quart import (
|
||||||
Response,
|
Response,
|
||||||
jsonify,
|
jsonify,
|
||||||
@ -45,11 +46,40 @@ from common import settings
|
|||||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
|
|
||||||
|
|
||||||
async def request_json():
|
async def _coerce_request_data() -> dict:
|
||||||
|
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||||
|
payload: Any = None
|
||||||
|
last_error: Exception | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await request.json
|
payload = await request.get_json(force=True, silent=True)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
return {}
|
last_error = e
|
||||||
|
payload = None
|
||||||
|
|
||||||
|
if payload is None:
|
||||||
|
try:
|
||||||
|
form = await request.form
|
||||||
|
payload = form.to_dict()
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
payload = None
|
||||||
|
|
||||||
|
if payload is None:
|
||||||
|
if last_error is not None:
|
||||||
|
raise last_error
|
||||||
|
raise ValueError("No JSON body or form data found in request.")
|
||||||
|
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
return payload or {}
|
||||||
|
|
||||||
|
if isinstance(payload, str):
|
||||||
|
raise AttributeError("'str' object has no attribute 'get'")
|
||||||
|
|
||||||
|
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
|
||||||
|
|
||||||
|
async def get_request_json():
|
||||||
|
return await _coerce_request_data()
|
||||||
|
|
||||||
def serialize_for_json(obj):
|
def serialize_for_json(obj):
|
||||||
"""
|
"""
|
||||||
@ -137,7 +167,7 @@ def validate_request(*args, **kwargs):
|
|||||||
def wrapper(func):
|
def wrapper(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def decorated_function(*_args, **_kwargs):
|
async def decorated_function(*_args, **_kwargs):
|
||||||
errs = process_args(await request.json or (await request.form).to_dict())
|
errs = process_args(await _coerce_request_data())
|
||||||
if errs:
|
if errs:
|
||||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
@ -152,7 +182,7 @@ def validate_request(*args, **kwargs):
|
|||||||
def not_allowed_parameters(*params):
|
def not_allowed_parameters(*params):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
input_arguments = await request.json or (await request.form).to_dict()
|
input_arguments = await _coerce_request_data()
|
||||||
for param in params:
|
for param in params:
|
||||||
if param in input_arguments:
|
if param in input_arguments:
|
||||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||||
@ -313,6 +343,10 @@ def get_parser_config(chunk_method, parser_config):
|
|||||||
chunk_method = "naive"
|
chunk_method = "naive"
|
||||||
|
|
||||||
# Define default configurations for each chunking method
|
# Define default configurations for each chunking method
|
||||||
|
base_defaults = {
|
||||||
|
"table_context_size": 0,
|
||||||
|
"image_context_size": 0,
|
||||||
|
}
|
||||||
key_mapping = {
|
key_mapping = {
|
||||||
"naive": {
|
"naive": {
|
||||||
"layout_recognize": "DeepDOC",
|
"layout_recognize": "DeepDOC",
|
||||||
@ -365,16 +399,19 @@ def get_parser_config(chunk_method, parser_config):
|
|||||||
|
|
||||||
default_config = key_mapping[chunk_method]
|
default_config = key_mapping[chunk_method]
|
||||||
|
|
||||||
# If no parser_config provided, return default
|
# If no parser_config provided, return default merged with base defaults
|
||||||
if not parser_config:
|
if not parser_config:
|
||||||
return default_config
|
if default_config is None:
|
||||||
|
return deep_merge(base_defaults, {})
|
||||||
|
return deep_merge(base_defaults, default_config)
|
||||||
|
|
||||||
# If parser_config is provided, merge with defaults to ensure required fields exist
|
# If parser_config is provided, merge with defaults to ensure required fields exist
|
||||||
if default_config is None:
|
if default_config is None:
|
||||||
return parser_config
|
return deep_merge(base_defaults, parser_config)
|
||||||
|
|
||||||
# Ensure raptor and graphrag fields have default values if not provided
|
# Ensure raptor and graphrag fields have default values if not provided
|
||||||
merged_config = deep_merge(default_config, parser_config)
|
merged_config = deep_merge(base_defaults, default_config)
|
||||||
|
merged_config = deep_merge(merged_config, parser_config)
|
||||||
|
|
||||||
return merged_config
|
return merged_config
|
||||||
|
|
||||||
@ -644,18 +681,32 @@ async def is_strong_enough(chat_model, embedding_model):
|
|||||||
async def _is_strong_enough():
|
async def _is_strong_enough():
|
||||||
nonlocal chat_model, embedding_model
|
nonlocal chat_model, embedding_model
|
||||||
if embedding_model:
|
if embedding_model:
|
||||||
with trio.fail_after(10):
|
await asyncio.wait_for(
|
||||||
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
|
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
if chat_model:
|
if chat_model:
|
||||||
with trio.fail_after(30):
|
res = await asyncio.wait_for(
|
||||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
|
chat_model.async_chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}]),
|
||||||
if res.find("**ERROR**") >= 0:
|
timeout=30
|
||||||
|
)
|
||||||
|
if "**ERROR**" in res:
|
||||||
raise Exception(res)
|
raise Exception(res)
|
||||||
|
|
||||||
# Pressure test for GraphRAG task
|
# Pressure test for GraphRAG task
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = [
|
||||||
for _ in range(count):
|
asyncio.create_task(_is_strong_enough())
|
||||||
nursery.start_soon(_is_strong_enough)
|
for _ in range(count)
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Pressure test failed: {e}")
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_llm_factories() -> list:
|
def get_allowed_llm_factories() -> list:
|
||||||
|
|||||||
@ -20,18 +20,18 @@ Reusable HTML email templates and registry.
|
|||||||
|
|
||||||
# Invitation email template
|
# Invitation email template
|
||||||
INVITE_EMAIL_TMPL = """
|
INVITE_EMAIL_TMPL = """
|
||||||
<p>Hi {{email}},</p>
|
Hi {{email}},
|
||||||
<p>{{inviter}} has invited you to join their team (ID: {{tenant_id}}).</p>
|
{{inviter}} has invited you to join their team (ID: {{tenant_id}}).
|
||||||
<p>Click the link below to complete your registration:<br>
|
Click the link below to complete your registration:
|
||||||
<a href="{{invite_url}}">{{invite_url}}</a></p>
|
{{invite_url}}
|
||||||
<p>If you did not request this, please ignore this email.</p>
|
If you did not request this, please ignore this email.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Password reset code template
|
# Password reset code template
|
||||||
RESET_CODE_EMAIL_TMPL = """
|
RESET_CODE_EMAIL_TMPL = """
|
||||||
<p>Hello,</p>
|
Hello,
|
||||||
<p>Your password reset code is: <b>{{ code }}</b></p>
|
Your password reset code is: {{ code }}
|
||||||
<p>This code will expire in {{ ttl_min }} minutes.</p>
|
This code will expire in {{ ttl_min }} minutes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Template registry
|
# Template registry
|
||||||
|
|||||||
54
api/utils/memory_utils.py
Normal file
54
api/utils/memory_utils.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from typing import List
|
||||||
|
from common.constants import MemoryType
|
||||||
|
|
||||||
|
def format_ret_data_from_memory(memory):
|
||||||
|
return {
|
||||||
|
"id": memory.id,
|
||||||
|
"name": memory.name,
|
||||||
|
"avatar": memory.avatar,
|
||||||
|
"tenant_id": memory.tenant_id,
|
||||||
|
"owner_name": memory.owner_name if hasattr(memory, "owner_name") else None,
|
||||||
|
"memory_type": get_memory_type_human(memory.memory_type),
|
||||||
|
"storage_type": memory.storage_type,
|
||||||
|
"embd_id": memory.embd_id,
|
||||||
|
"llm_id": memory.llm_id,
|
||||||
|
"permissions": memory.permissions,
|
||||||
|
"description": memory.description,
|
||||||
|
"memory_size": memory.memory_size,
|
||||||
|
"forgetting_policy": memory.forgetting_policy,
|
||||||
|
"temperature": memory.temperature,
|
||||||
|
"system_prompt": memory.system_prompt,
|
||||||
|
"user_prompt": memory.user_prompt,
|
||||||
|
"create_time": memory.create_time,
|
||||||
|
"create_date": memory.create_date,
|
||||||
|
"update_time": memory.update_time,
|
||||||
|
"update_date": memory.update_date
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_type_human(memory_type: int) -> List[str]:
|
||||||
|
return [mem_type.name.lower() for mem_type in MemoryType if memory_type & mem_type.value]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_memory_type(memory_type_name_list: List[str]) -> int:
|
||||||
|
memory_type = 0
|
||||||
|
type_value_map = {mem_type.name.lower(): mem_type.value for mem_type in MemoryType}
|
||||||
|
for mem_type in memory_type_name_list:
|
||||||
|
if mem_type in type_value_map:
|
||||||
|
memory_type |= type_value_map[mem_type]
|
||||||
|
return memory_type
|
||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
import string
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ from pydantic import (
|
|||||||
StringConstraints,
|
StringConstraints,
|
||||||
ValidationError,
|
ValidationError,
|
||||||
field_validator,
|
field_validator,
|
||||||
|
model_validator,
|
||||||
)
|
)
|
||||||
from pydantic_core import PydanticCustomError
|
from pydantic_core import PydanticCustomError
|
||||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||||
@ -329,6 +331,7 @@ class RaptorConfig(Base):
|
|||||||
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
|
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
|
||||||
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
||||||
random_seed: Annotated[int, Field(default=0, ge=0)]
|
random_seed: Annotated[int, Field(default=0, ge=0)]
|
||||||
|
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
|
||||||
|
|
||||||
|
|
||||||
class GraphragConfig(Base):
|
class GraphragConfig(Base):
|
||||||
@ -361,10 +364,9 @@ class CreateDatasetReq(Base):
|
|||||||
description: Annotated[str | None, Field(default=None, max_length=65535)]
|
description: Annotated[str | None, Field(default=None, max_length=65535)]
|
||||||
embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")]
|
embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")]
|
||||||
permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)]
|
permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)]
|
||||||
chunk_method: Annotated[
|
chunk_method: Annotated[str | None, Field(default=None, serialization_alias="parser_id")]
|
||||||
Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
|
parse_type: Annotated[int | None, Field(default=None, ge=0, le=64)]
|
||||||
Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"),
|
pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")]
|
||||||
]
|
|
||||||
parser_config: Annotated[ParserConfig | None, Field(default=None)]
|
parser_config: Annotated[ParserConfig | None, Field(default=None)]
|
||||||
|
|
||||||
@field_validator("avatar", mode="after")
|
@field_validator("avatar", mode="after")
|
||||||
@ -525,6 +527,93 @@ class CreateDatasetReq(Base):
|
|||||||
raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)})
|
raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)})
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("pipeline_id", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def validate_pipeline_id(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate pipeline_id as 32-char lowercase hex string if provided.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- None or empty string: treat as None (not set)
|
||||||
|
- Must be exactly length 32
|
||||||
|
- Must contain only hex digits (0-9a-fA-F); normalized to lowercase
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if v == "":
|
||||||
|
return None
|
||||||
|
if len(v) != 32:
|
||||||
|
raise PydanticCustomError("format_invalid", "pipeline_id must be 32 hex characters")
|
||||||
|
if any(ch not in string.hexdigits for ch in v):
|
||||||
|
raise PydanticCustomError("format_invalid", "pipeline_id must be hexadecimal")
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_parser_dependency(self) -> "CreateDatasetReq":
|
||||||
|
"""
|
||||||
|
Mixed conditional validation:
|
||||||
|
- If parser_id is omitted (field not set):
|
||||||
|
* If both parse_type and pipeline_id are omitted → default chunk_method = "naive"
|
||||||
|
* If both parse_type and pipeline_id are provided → allow ingestion pipeline mode
|
||||||
|
- If parser_id is provided (valid enum) → parse_type and pipeline_id must be None (disallow mixed usage)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PydanticCustomError with code 'dependency_error' on violation.
|
||||||
|
"""
|
||||||
|
# Omitted chunk_method (not in fields) logic
|
||||||
|
if self.chunk_method is None and "chunk_method" not in self.model_fields_set:
|
||||||
|
# All three absent → default naive
|
||||||
|
if self.parse_type is None and self.pipeline_id is None:
|
||||||
|
object.__setattr__(self, "chunk_method", "naive")
|
||||||
|
return self
|
||||||
|
# parser_id omitted: require BOTH parse_type & pipeline_id present (no partial allowed)
|
||||||
|
if self.parse_type is None or self.pipeline_id is None:
|
||||||
|
missing = []
|
||||||
|
if self.parse_type is None:
|
||||||
|
missing.append("parse_type")
|
||||||
|
if self.pipeline_id is None:
|
||||||
|
missing.append("pipeline_id")
|
||||||
|
raise PydanticCustomError(
|
||||||
|
"dependency_error",
|
||||||
|
"parser_id omitted → required fields missing: {fields}",
|
||||||
|
{"fields": ", ".join(missing)},
|
||||||
|
)
|
||||||
|
# Both provided → allow pipeline mode
|
||||||
|
return self
|
||||||
|
|
||||||
|
# parser_id provided (valid): MUST NOT have parse_type or pipeline_id
|
||||||
|
if isinstance(self.chunk_method, str):
|
||||||
|
if self.parse_type is not None or self.pipeline_id is not None:
|
||||||
|
invalid = []
|
||||||
|
if self.parse_type is not None:
|
||||||
|
invalid.append("parse_type")
|
||||||
|
if self.pipeline_id is not None:
|
||||||
|
invalid.append("pipeline_id")
|
||||||
|
raise PydanticCustomError(
|
||||||
|
"dependency_error",
|
||||||
|
"parser_id provided → disallowed fields present: {fields}",
|
||||||
|
{"fields": ", ".join(invalid)},
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("chunk_method", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def validate_chunk_method(cls, v: Any, handler) -> Any:
|
||||||
|
"""Wrap validation to unify error messages, including type errors (e.g. list)."""
|
||||||
|
allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"}
|
||||||
|
error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'"
|
||||||
|
# Omitted field: handler won't be invoked (wrap still gets value); None treated as explicit invalid
|
||||||
|
if v is None:
|
||||||
|
raise PydanticCustomError("literal_error", error_msg)
|
||||||
|
try:
|
||||||
|
# Run inner validation (type checking)
|
||||||
|
result = handler(v)
|
||||||
|
except Exception:
|
||||||
|
raise PydanticCustomError("literal_error", error_msg)
|
||||||
|
# After handler, enforce enumeration
|
||||||
|
if not isinstance(result, str) or result == "" or result not in allowed:
|
||||||
|
raise PydanticCustomError("literal_error", error_msg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class UpdateDatasetReq(CreateDatasetReq):
|
class UpdateDatasetReq(CreateDatasetReq):
|
||||||
dataset_id: Annotated[str, Field(...)]
|
dataset_id: Annotated[str, Field(...)]
|
||||||
|
|||||||
@ -20,9 +20,10 @@ import json
|
|||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
import aiosmtplib
|
||||||
from api.apps import smtp_mail_server
|
from email.mime.text import MIMEText
|
||||||
from flask_mail import Message
|
from email.header import Header
|
||||||
|
from common import settings
|
||||||
from quart import render_template_string
|
from quart import render_template_string
|
||||||
from api.utils.email_templates import EMAIL_TEMPLATES
|
from api.utils.email_templates import EMAIL_TEMPLATES
|
||||||
from selenium import webdriver
|
from selenium import webdriver
|
||||||
@ -35,11 +36,11 @@ from selenium.webdriver.support.ui import WebDriverWait
|
|||||||
from webdriver_manager.chrome import ChromeDriverManager
|
from webdriver_manager.chrome import ChromeDriverManager
|
||||||
|
|
||||||
|
|
||||||
OTP_LENGTH = 8
|
OTP_LENGTH = 4
|
||||||
OTP_TTL_SECONDS = 5 * 60
|
OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes
|
||||||
ATTEMPT_LIMIT = 5
|
ATTEMPT_LIMIT = 5 # maximum attempts
|
||||||
ATTEMPT_LOCK_SECONDS = 30 * 60
|
ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes
|
||||||
RESEND_COOLDOWN_SECONDS = 60
|
RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute
|
||||||
|
|
||||||
|
|
||||||
CONTENT_TYPE_MAP = {
|
CONTENT_TYPE_MAP = {
|
||||||
@ -185,25 +186,32 @@ def get_float(req: dict, key: str, default: float | int = 10.0) -> float:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def send_email_html(subject: str, to_email: str, template_key: str, **context):
|
async def send_email_html(to_email: str, subject: str, template_key: str, **context):
|
||||||
"""Generic HTML email sender using shared templates.
|
|
||||||
template_key must exist in EMAIL_TEMPLATES.
|
body = await render_template_string(EMAIL_TEMPLATES.get(template_key), **context)
|
||||||
"""
|
msg = MIMEText(body, "plain", "utf-8")
|
||||||
from api.apps import app
|
msg["Subject"] = Header(subject, "utf-8")
|
||||||
tmpl = EMAIL_TEMPLATES.get(template_key)
|
msg["From"] = f"{settings.MAIL_DEFAULT_SENDER[0]} <{settings.MAIL_DEFAULT_SENDER[1]}>"
|
||||||
if not tmpl:
|
msg["To"] = to_email
|
||||||
raise ValueError(f"Unknown email template: {template_key}")
|
|
||||||
with app.app_context():
|
smtp = aiosmtplib.SMTP(
|
||||||
msg = Message(subject=subject, recipients=[to_email])
|
hostname=settings.MAIL_SERVER,
|
||||||
msg.html = render_template_string(tmpl, **context)
|
port=settings.MAIL_PORT,
|
||||||
smtp_mail_server.send(msg)
|
use_tls=True,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
await smtp.connect()
|
||||||
|
await smtp.login(settings.MAIL_USERNAME, settings.MAIL_PASSWORD)
|
||||||
|
await smtp.send_message(msg)
|
||||||
|
await smtp.quit()
|
||||||
|
|
||||||
|
|
||||||
def send_invite_email(to_email, invite_url, tenant_id, inviter):
|
async def send_invite_email(to_email, invite_url, tenant_id, inviter):
|
||||||
# Reuse the generic HTML sender with 'invite' template
|
# Reuse the generic HTML sender with 'invite' template
|
||||||
send_email_html(
|
await send_email_html(
|
||||||
subject="RAGFlow Invitation",
|
|
||||||
to_email=to_email,
|
to_email=to_email,
|
||||||
|
subject="RAGFlow Invitation",
|
||||||
template_key="invite",
|
template_key="invite",
|
||||||
email=to_email,
|
email=to_email,
|
||||||
invite_url=invite_url,
|
invite_url=invite_url,
|
||||||
|
|||||||
@ -19,7 +19,6 @@ import queue
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, Coroutine, Optional, Type, Union
|
from typing import Any, Callable, Coroutine, Optional, Type, Union
|
||||||
import asyncio
|
import asyncio
|
||||||
import trio
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from quart import make_response, jsonify
|
from quart import make_response, jsonify
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
@ -70,11 +69,10 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
|
|||||||
for a in range(attempts):
|
for a in range(attempts):
|
||||||
try:
|
try:
|
||||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||||
with trio.fail_after(seconds):
|
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
|
||||||
return await func(*args, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except trio.TooSlowError:
|
except asyncio.TimeoutError:
|
||||||
if a < attempts - 1:
|
if a < attempts - 1:
|
||||||
continue
|
continue
|
||||||
if on_timeout is not None:
|
if on_timeout is not None:
|
||||||
|
|||||||
@ -49,6 +49,7 @@ class RetCode(IntEnum, CustomEnum):
|
|||||||
RUNNING = 106
|
RUNNING = 106
|
||||||
PERMISSION_ERROR = 108
|
PERMISSION_ERROR = 108
|
||||||
AUTHENTICATION_ERROR = 109
|
AUTHENTICATION_ERROR = 109
|
||||||
|
BAD_REQUEST = 400
|
||||||
UNAUTHORIZED = 401
|
UNAUTHORIZED = 401
|
||||||
SERVER_ERROR = 500
|
SERVER_ERROR = 500
|
||||||
FORBIDDEN = 403
|
FORBIDDEN = 403
|
||||||
@ -72,6 +73,7 @@ class LLMType(StrEnum):
|
|||||||
IMAGE2TEXT = 'image2text'
|
IMAGE2TEXT = 'image2text'
|
||||||
RERANK = 'rerank'
|
RERANK = 'rerank'
|
||||||
TTS = 'tts'
|
TTS = 'tts'
|
||||||
|
OCR = 'ocr'
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(StrEnum):
|
class TaskStatus(StrEnum):
|
||||||
@ -121,7 +123,7 @@ class FileSource(StrEnum):
|
|||||||
WEBDAV = "webdav"
|
WEBDAV = "webdav"
|
||||||
MOODLE = "moodle"
|
MOODLE = "moodle"
|
||||||
DROPBOX = "dropbox"
|
DROPBOX = "dropbox"
|
||||||
|
BOX = "box"
|
||||||
|
|
||||||
class PipelineTaskType(StrEnum):
|
class PipelineTaskType(StrEnum):
|
||||||
PARSE = "Parse"
|
PARSE = "Parse"
|
||||||
@ -147,6 +149,24 @@ class Storage(Enum):
|
|||||||
AWS_S3 = 4
|
AWS_S3 = 4
|
||||||
OSS = 5
|
OSS = 5
|
||||||
OPENDAL = 6
|
OPENDAL = 6
|
||||||
|
GCS = 7
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryType(Enum):
|
||||||
|
RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
|
||||||
|
SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
|
||||||
|
EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
|
||||||
|
PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStorageType(StrEnum):
|
||||||
|
TABLE = "table"
|
||||||
|
GRAPH = "graph"
|
||||||
|
|
||||||
|
|
||||||
|
class ForgettingPolicy(StrEnum):
|
||||||
|
FIFO = "fifo"
|
||||||
|
|
||||||
|
|
||||||
# environment
|
# environment
|
||||||
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
||||||
@ -197,3 +217,13 @@ PAGERANK_FLD = "pagerank_fea"
|
|||||||
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||||
TAG_FLD = "tag_feas"
|
TAG_FLD = "tag_feas"
|
||||||
|
|
||||||
|
|
||||||
|
MINERU_ENV_KEYS = ["MINERU_APISERVER", "MINERU_OUTPUT_DIR", "MINERU_BACKEND", "MINERU_SERVER_URL", "MINERU_DELETE_OUTPUT"]
|
||||||
|
MINERU_DEFAULT_CONFIG = {
|
||||||
|
"MINERU_APISERVER": "",
|
||||||
|
"MINERU_OUTPUT_DIR": "",
|
||||||
|
"MINERU_BACKEND": "pipeline",
|
||||||
|
"MINERU_SERVER_URL": "",
|
||||||
|
"MINERU_DELETE_OUTPUT": 1,
|
||||||
|
}
|
||||||
|
|||||||
374
common/crypto_utils.py
Normal file
374
common/crypto_utils.py
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
from cryptography.hazmat.primitives import padding
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCrypto:
|
||||||
|
"""Base class for cryptographic algorithms"""
|
||||||
|
|
||||||
|
# Magic header to identify encrypted data
|
||||||
|
ENCRYPTED_MAGIC = b'RAGF'
|
||||||
|
|
||||||
|
def __init__(self, key, iv=None, block_size=16, key_length=32, iv_length=16):
|
||||||
|
"""
|
||||||
|
Initialize cryptographic algorithm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Encryption key
|
||||||
|
iv: Initialization vector, automatically generated if None
|
||||||
|
block_size: Block size
|
||||||
|
key_length: Key length
|
||||||
|
iv_length: Initialization vector length
|
||||||
|
"""
|
||||||
|
self.block_size = block_size
|
||||||
|
self.key_length = key_length
|
||||||
|
self.iv_length = iv_length
|
||||||
|
|
||||||
|
# Normalize key
|
||||||
|
self.key = self._normalize_key(key)
|
||||||
|
self.iv = iv
|
||||||
|
|
||||||
|
def _normalize_key(self, key):
|
||||||
|
"""Normalize key length"""
|
||||||
|
if isinstance(key, str):
|
||||||
|
key = key.encode('utf-8')
|
||||||
|
|
||||||
|
# Use PBKDF2 for key derivation to ensure correct key length
|
||||||
|
kdf = PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
length=self.key_length,
|
||||||
|
salt=b"ragflow_crypto_salt", # Fixed salt to ensure consistent key derivation results
|
||||||
|
iterations=100000,
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
return kdf.derive(key)
|
||||||
|
|
||||||
|
def encrypt(self, data):
|
||||||
|
"""
|
||||||
|
Encrypt data (template method)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Data to encrypt (bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encrypted data (bytes), format: magic_header + iv + encrypted_data
|
||||||
|
"""
|
||||||
|
# Generate random IV
|
||||||
|
iv = os.urandom(self.iv_length) if not self.iv else self.iv
|
||||||
|
|
||||||
|
# Use PKCS7 padding
|
||||||
|
padder = padding.PKCS7(self.block_size * 8).padder()
|
||||||
|
padded_data = padder.update(data) + padder.finalize()
|
||||||
|
|
||||||
|
# Delegate to subclass for specific encryption
|
||||||
|
ciphertext = self._encrypt(padded_data, iv)
|
||||||
|
|
||||||
|
# Return Magic Header + IV + encrypted data
|
||||||
|
return self.ENCRYPTED_MAGIC + iv + ciphertext
|
||||||
|
|
||||||
|
def decrypt(self, encrypted_data):
|
||||||
|
"""
|
||||||
|
Decrypt data (template method)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_data: Encrypted data (bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted data (bytes)
|
||||||
|
"""
|
||||||
|
# Check if data is encrypted by magic header
|
||||||
|
if not encrypted_data.startswith(self.ENCRYPTED_MAGIC):
|
||||||
|
# Not encrypted, return as-is
|
||||||
|
return encrypted_data
|
||||||
|
|
||||||
|
# Remove magic header
|
||||||
|
encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC):]
|
||||||
|
|
||||||
|
# Separate IV and encrypted data
|
||||||
|
iv = encrypted_data[:self.iv_length]
|
||||||
|
ciphertext = encrypted_data[self.iv_length:]
|
||||||
|
|
||||||
|
# Delegate to subclass for specific decryption
|
||||||
|
padded_data = self._decrypt(ciphertext, iv)
|
||||||
|
|
||||||
|
# Remove padding
|
||||||
|
unpadder = padding.PKCS7(self.block_size * 8).unpadder()
|
||||||
|
data = unpadder.update(padded_data) + unpadder.finalize()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _encrypt(self, padded_data, iv):
|
||||||
|
"""
|
||||||
|
Encrypt padded data with specific algorithm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
padded_data: Padded data to encrypt
|
||||||
|
iv: Initialization vector
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encrypted data
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("_encrypt method must be implemented by subclass")
|
||||||
|
|
||||||
|
def _decrypt(self, ciphertext, iv):
|
||||||
|
"""
|
||||||
|
Decrypt ciphertext with specific algorithm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ciphertext: Ciphertext to decrypt
|
||||||
|
iv: Initialization vector
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted padded data
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("_decrypt method must be implemented by subclass")
|
||||||
|
|
||||||
|
|
||||||
|
class AESCrypto(BaseCrypto):
|
||||||
|
"""Base class for AES cryptographic algorithm"""
|
||||||
|
|
||||||
|
def __init__(self, key, iv=None, key_length=32):
|
||||||
|
"""
|
||||||
|
Initialize AES cryptographic algorithm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Encryption key
|
||||||
|
iv: Initialization vector, automatically generated if None
|
||||||
|
key_length: Key length (16 for AES-128, 32 for AES-256)
|
||||||
|
"""
|
||||||
|
super().__init__(key, iv, block_size=16, key_length=key_length, iv_length=16)
|
||||||
|
|
||||||
|
def _encrypt(self, padded_data, iv):
|
||||||
|
"""AES encryption implementation"""
|
||||||
|
# Create encryptor
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.AES(self.key),
|
||||||
|
modes.CBC(iv),
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
|
||||||
|
# Encrypt data
|
||||||
|
return encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
|
||||||
|
def _decrypt(self, ciphertext, iv):
|
||||||
|
"""AES decryption implementation"""
|
||||||
|
# Create decryptor
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.AES(self.key),
|
||||||
|
modes.CBC(iv),
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
|
||||||
|
# Decrypt data
|
||||||
|
return decryptor.update(ciphertext) + decryptor.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
class AES128CBC(AESCrypto):
|
||||||
|
"""AES-128-CBC cryptographic algorithm"""
|
||||||
|
|
||||||
|
def __init__(self, key, iv=None):
|
||||||
|
"""
|
||||||
|
Initialize AES-128-CBC cryptographic algorithm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Encryption key
|
||||||
|
iv: Initialization vector, automatically generated if None
|
||||||
|
"""
|
||||||
|
super().__init__(key, iv, key_length=16)
|
||||||
|
|
||||||
|
|
||||||
|
class AES256CBC(AESCrypto):
|
||||||
|
"""AES-256-CBC cryptographic algorithm"""
|
||||||
|
|
||||||
|
def __init__(self, key, iv=None):
|
||||||
|
"""
|
||||||
|
Initialize AES-256-CBC cryptographic algorithm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Encryption key
|
||||||
|
iv: Initialization vector, automatically generated if None
|
||||||
|
"""
|
||||||
|
super().__init__(key, iv, key_length=32)
|
||||||
|
|
||||||
|
|
||||||
|
class SM4CBC(BaseCrypto):
|
||||||
|
"""SM4-CBC cryptographic algorithm using cryptography library for better performance"""
|
||||||
|
|
||||||
|
def __init__(self, key, iv=None):
|
||||||
|
"""
|
||||||
|
Initialize SM4-CBC cryptographic algorithm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Encryption key
|
||||||
|
iv: Initialization vector, automatically generated if None
|
||||||
|
"""
|
||||||
|
super().__init__(key, iv, block_size=16, key_length=16, iv_length=16)
|
||||||
|
|
||||||
|
def _encrypt(self, padded_data, iv):
|
||||||
|
"""SM4 encryption implementation using cryptography library"""
|
||||||
|
# Create encryptor
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.SM4(self.key),
|
||||||
|
modes.CBC(iv),
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
|
||||||
|
# Encrypt data
|
||||||
|
return encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
|
||||||
|
def _decrypt(self, ciphertext, iv):
|
||||||
|
"""SM4 decryption implementation using cryptography library"""
|
||||||
|
# Create decryptor
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.SM4(self.key),
|
||||||
|
modes.CBC(iv),
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
|
||||||
|
# Decrypt data
|
||||||
|
return decryptor.update(ciphertext) + decryptor.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
class CryptoUtil:
|
||||||
|
"""Cryptographic utility class, using factory pattern to create cryptographic algorithm instances"""
|
||||||
|
|
||||||
|
# Supported cryptographic algorithms mapping
|
||||||
|
SUPPORTED_ALGORITHMS = {
|
||||||
|
"aes-128-cbc": AES128CBC,
|
||||||
|
"aes-256-cbc": AES256CBC,
|
||||||
|
"sm4-cbc": SM4CBC
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, algorithm="aes-256-cbc", key=None, iv=None):
|
||||||
|
"""
|
||||||
|
Initialize cryptographic utility
|
||||||
|
|
||||||
|
Args:
|
||||||
|
algorithm: Cryptographic algorithm, default is aes-256-cbc
|
||||||
|
key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None
|
||||||
|
iv: Initialization vector, automatically generated if None
|
||||||
|
"""
|
||||||
|
if algorithm not in self.SUPPORTED_ALGORITHMS:
|
||||||
|
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
raise ValueError("Encryption key not provided and RAGFLOW_CRYPTO_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Create cryptographic algorithm instance
|
||||||
|
self.algorithm_name = algorithm
|
||||||
|
self.crypto = self.SUPPORTED_ALGORITHMS[algorithm](key=key, iv=iv)
|
||||||
|
|
||||||
|
def encrypt(self, data):
|
||||||
|
"""
|
||||||
|
Encrypt data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Data to encrypt (bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encrypted data (bytes)
|
||||||
|
"""
|
||||||
|
# import time
|
||||||
|
# start_time = time.time()
|
||||||
|
encrypted = self.crypto.encrypt(data)
|
||||||
|
# end_time = time.time()
|
||||||
|
# logging.info(f"Encryption completed, data length: {len(data)} bytes, time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
|
return encrypted
|
||||||
|
|
||||||
|
def decrypt(self, encrypted_data):
|
||||||
|
"""
|
||||||
|
Decrypt data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_data: Encrypted data (bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted data (bytes)
|
||||||
|
"""
|
||||||
|
# import time
|
||||||
|
# start_time = time.time()
|
||||||
|
decrypted = self.crypto.decrypt(encrypted_data)
|
||||||
|
# end_time = time.time()
|
||||||
|
# logging.info(f"Decryption completed, data length: {len(encrypted_data)} bytes, time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
|
return decrypted
|
||||||
|
|
||||||
|
|
||||||
|
# Test code
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test AES encryption
|
||||||
|
crypto = CryptoUtil(algorithm="aes-256-cbc", key="test_key_123456")
|
||||||
|
test_data = b"Hello, RAGFlow! This is a test for encryption."
|
||||||
|
|
||||||
|
encrypted = crypto.encrypt(test_data)
|
||||||
|
decrypted = crypto.decrypt(encrypted)
|
||||||
|
|
||||||
|
print("AES Test:")
|
||||||
|
print(f"Original: {test_data}")
|
||||||
|
print(f"Encrypted: {encrypted}")
|
||||||
|
print(f"Decrypted: {decrypted}")
|
||||||
|
print(f"Success: {test_data == decrypted}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test SM4 encryption
|
||||||
|
try:
|
||||||
|
crypto_sm4 = CryptoUtil(algorithm="sm4-cbc", key="test_key_123456")
|
||||||
|
encrypted_sm4 = crypto_sm4.encrypt(test_data)
|
||||||
|
decrypted_sm4 = crypto_sm4.decrypt(encrypted_sm4)
|
||||||
|
|
||||||
|
print("SM4 Test:")
|
||||||
|
print(f"Original: {test_data}")
|
||||||
|
print(f"Encrypted: {encrypted_sm4}")
|
||||||
|
print(f"Decrypted: {decrypted_sm4}")
|
||||||
|
print(f"Success: {test_data == decrypted_sm4}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"SM4 Test Failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Test with specific algorithm classes directly
|
||||||
|
print("\nDirect Algorithm Class Test:")
|
||||||
|
|
||||||
|
# Test AES-128-CBC
|
||||||
|
aes128 = AES128CBC(key="test_key_123456")
|
||||||
|
encrypted_aes128 = aes128.encrypt(test_data)
|
||||||
|
decrypted_aes128 = aes128.decrypt(encrypted_aes128)
|
||||||
|
print(f"AES-128-CBC test: {'passed' if decrypted_aes128 == test_data else 'failed'}")
|
||||||
|
|
||||||
|
# Test AES-256-CBC
|
||||||
|
aes256 = AES256CBC(key="test_key_123456")
|
||||||
|
encrypted_aes256 = aes256.encrypt(test_data)
|
||||||
|
decrypted_aes256 = aes256.decrypt(encrypted_aes256)
|
||||||
|
print(f"AES-256-CBC test: {'passed' if decrypted_aes256 == test_data else 'failed'}")
|
||||||
|
|
||||||
|
# Test SM4-CBC
|
||||||
|
try:
|
||||||
|
sm4 = SM4CBC(key="test_key_123456")
|
||||||
|
encrypted_sm4 = sm4.encrypt(test_data)
|
||||||
|
decrypted_sm4 = sm4.decrypt(encrypted_sm4)
|
||||||
|
print(f"SM4-CBC test: {'passed' if decrypted_sm4 == test_data else 'failed'}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"SM4-CBC test failed: {e}")
|
||||||
@ -1,6 +1,26 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Thanks to https://github.com/onyx-dot-app/onyx
|
Thanks to https://github.com/onyx-dot-app/onyx
|
||||||
|
|
||||||
|
Content of this directory is under the "MIT Expat" license as defined below.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .blob_connector import BlobStorageConnector
|
from .blob_connector import BlobStorageConnector
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user