mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
148 Commits
209b731541
...
pipeline
| Author | SHA1 | Date | |
|---|---|---|---|
| 32dbed36e3 | |||
| 7f62ab8eb3 | |||
| e87987785c | |||
| b3b0be832a | |||
| 20b577a72c | |||
| 4d6ff672eb | |||
| fb19e24f8a | |||
| 9989e06abb | |||
| c49e81882c | |||
| 63cdce660e | |||
| 8bc8126848 | |||
| 71f69cdb75 | |||
| 664bc0b961 | |||
| f4cc4dbd30 | |||
| cce361d774 | |||
| 7a63b6386e | |||
| 4996dcb0eb | |||
| 3521eb61fe | |||
| 6b9b785b5c | |||
| 4c0a89f262 | |||
| 76b1ee2a00 | |||
| 771a38434f | |||
| 886d38620e | |||
| c7efaab30e | |||
| ff49454501 | |||
| 14273b4595 | |||
| abe7132630 | |||
| c1151519a0 | |||
| a1147ce609 | |||
| d907e79893 | |||
| 1b19d302c5 | |||
| 840b2b5809 | |||
| a6039cf563 | |||
| 8be7380b79 | |||
| afb8a84f7b | |||
| 6bf0cda16f | |||
| 5715ca6b74 | |||
| 8f465525f7 | |||
| f20dca2895 | |||
| 0c557e37ad | |||
| d0bfe8b10c | |||
| 28afc7e67d | |||
| 73c33bc8d2 | |||
| 476852e8f1 | |||
| e6cf00cb33 | |||
| d039d1e73d | |||
| d050ef568d | |||
| 028c2d83e9 | |||
| b5d6a6e8f2 | |||
| 5dfdbcce3a | |||
| 4fae40f66a | |||
| a1b947ffd6 | |||
| f9c7404bee | |||
| 5c1791d7f0 | |||
| e82617f6de | |||
| a7abc57f68 | |||
| cf1f523d03 | |||
| ccb255919a | |||
| b68c84b52e | |||
| 93cf0258c3 | |||
| b79fef1ca8 | |||
| 2b50de3186 | |||
| d8ef22db68 | |||
| 592f3b1555 | |||
| 3404469e2a | |||
| 63d7382dc9 | |||
| 179091b1a4 | |||
| d14d92a900 | |||
| 1936ad82d2 | |||
| 8a09f07186 | |||
| df8d31451b | |||
| fc95d113c3 | |||
| 7d14455fbe | |||
| bbe6ed3b90 | |||
| 127af4e45c | |||
| 41cdba19ba | |||
| 0d9c1f1c3c | |||
| e650f0d368 | |||
| 067b4fc012 | |||
| 38ff2ffc01 | |||
| a9cc992d13 | |||
| 5cf2c97908 | |||
| 81fede0041 | |||
| 07a83f93d5 | |||
| 1a904edd94 | |||
| 906969fe4e | |||
| 776ea078a6 | |||
| fcdde26a7f | |||
| 79076ffb5f | |||
| e8dcdfb9f0 | |||
| c4f43a395d | |||
| a255c78b59 | |||
| 936f27e9e5 | |||
| 2616f651c9 | |||
| e8018fde83 | |||
| f514482c0a | |||
| e9ee9269f5 | |||
| cf18231713 | |||
| f48aed6d4a | |||
| b524cf0ec8 | |||
| 994517495f | |||
| 63781bde3f | |||
| 91d6fb8061 | |||
| 45f52e85d7 | |||
| 9aa8cfb73a | |||
| 79ca25ec7e | |||
| 6ff7cfe005 | |||
| 4e16936fa4 | |||
| 677c99b090 | |||
| 8e30a75e5c | |||
| b14052e5a2 | |||
| ddaed541ff | |||
| 1ee9c0b8d9 | |||
| 9b724b3b5e | |||
| 3b1ee769eb | |||
| 41cb94324a | |||
| 982ec24fa7 | |||
| 1f7a035340 | |||
| d04ae3f943 | |||
| abd19b0f48 | |||
| aa1251af9a | |||
| 483f3aa71d | |||
| 72bb79e8dd | |||
| 927a195008 | |||
| d13dc0c24d | |||
| 37ac7576f1 | |||
| c832e0b858 | |||
| 5d015e48c1 | |||
| b58e882eaa | |||
| 1bc33009c7 | |||
| cb731dce34 | |||
| 1595cdc48f | |||
| 4179ecd469 | |||
| cb14dafaca | |||
| c2567844ea | |||
| 757c5376be | |||
| 79968c37a8 | |||
| 2e00d8d3d4 | |||
| 0b456a18a3 | |||
| dd8e660f0a | |||
| 98ee3dee74 | |||
| d4b0cd8599 | |||
| 3398dac906 | |||
| 7eb25e0de6 | |||
| bed77ee28f | |||
| 56cd576876 | |||
| 4fbad2828c | |||
| e997bf6507 |
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@ -88,7 +88,9 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: infiniflow/ragflow:${{ env.RELEASE_TAG }}
|
||||
tags: |
|
||||
infiniflow/ragflow:${{ env.RELEASE_TAG }}
|
||||
infiniflow/ragflow:latest-full
|
||||
file: Dockerfile
|
||||
platforms: linux/amd64
|
||||
|
||||
@ -98,7 +100,9 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: infiniflow/ragflow:${{ env.RELEASE_TAG }}-slim
|
||||
tags: |
|
||||
infiniflow/ragflow:${{ env.RELEASE_TAG }}-slim
|
||||
infiniflow/ragflow:latest-slim
|
||||
file: Dockerfile
|
||||
build-args: LIGHTEN=1
|
||||
platforms: linux/amd64
|
||||
|
||||
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@ -67,6 +67,7 @@ jobs:
|
||||
|
||||
- name: Start ragflow:nightly-slim
|
||||
run: |
|
||||
sudo docker compose -f docker/docker-compose.yml down --volumes --remove-orphans
|
||||
echo -e "\nRAGFLOW_IMAGE=infiniflow/ragflow:nightly-slim" >> docker/.env
|
||||
sudo docker compose -f docker/docker-compose.yml up -d
|
||||
|
||||
|
||||
17
README.md
17
README.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.4">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -71,10 +71,7 @@
|
||||
|
||||
## 💡 What is RAGFlow?
|
||||
|
||||
[RAGFlow](https://ragflow.io/) is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document
|
||||
understanding. It offers a streamlined RAG workflow for businesses of any scale, combining LLM (Large Language Models)
|
||||
to provide truthful question-answering capabilities, backed by well-founded citations from various complex formatted
|
||||
data.
|
||||
[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation (RAG) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged context engine and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
|
||||
|
||||
## 🎮 Demo
|
||||
|
||||
@ -190,7 +187,7 @@ releases! 🌟
|
||||
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
||||
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
||||
|
||||
> The command below downloads the `v0.20.4-slim` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.20.4-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.4` for the full edition `v0.20.4`.
|
||||
> The command below downloads the `v0.20.5-slim` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.20.5-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` for the full edition `v0.20.5`.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -203,8 +200,8 @@ releases! 🌟
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|--------------------------|
|
||||
| v0.20.4 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.4-slim | ≈2 | ❌ | Stable release |
|
||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
|
||||
@ -348,8 +345,10 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
sudo apt-get install libjemalloc-dev
|
||||
# centos
|
||||
sudo yum install jemalloc
|
||||
# mac
|
||||
sudo brew install jemalloc
|
||||
```
|
||||
|
||||
|
||||
6. Launch backend service:
|
||||
|
||||
```bash
|
||||
|
||||
12
README_id.md
12
README_id.md
@ -22,7 +22,7 @@
|
||||
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.4">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
||||
@ -67,7 +67,7 @@
|
||||
|
||||
## 💡 Apa Itu RAGFlow?
|
||||
|
||||
[RAGFlow](https://ragflow.io/) adalah mesin RAG (Retrieval-Augmented Generation) open-source berbasis pemahaman dokumen yang mendalam. Platform ini menyediakan alur kerja RAG yang efisien untuk bisnis dengan berbagai skala, menggabungkan LLM (Large Language Models) untuk menyediakan kemampuan tanya-jawab yang benar dan didukung oleh referensi dari data terstruktur kompleks.
|
||||
[RAGFlow](https://ragflow.io/) adalah mesin RAG (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
|
||||
|
||||
## 🎮 Demo
|
||||
|
||||
@ -181,7 +181,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
||||
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> Perintah di bawah ini mengunduh edisi v0.20.4-slim dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.20.4-slim, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. Misalnya, atur RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.4 untuk edisi lengkap v0.20.4.
|
||||
> Perintah di bawah ini mengunduh edisi v0.20.5-slim dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.20.5-slim, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. Misalnya, atur RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5 untuk edisi lengkap v0.20.5.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -194,8 +194,8 @@ $ docker compose -f docker-compose.yml up -d
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.20.4 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.4-slim | ≈2 | ❌ | Stable release |
|
||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
|
||||
@ -312,6 +312,8 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
sudo apt-get install libjemalloc-dev
|
||||
# centos
|
||||
sudo yum install jemalloc
|
||||
# mac
|
||||
sudo brew install jemalloc
|
||||
```
|
||||
|
||||
6. Jalankan aplikasi backend:
|
||||
|
||||
14
README_ja.md
14
README_ja.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.4">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -47,7 +47,7 @@
|
||||
|
||||
## 💡 RAGFlow とは?
|
||||
|
||||
[RAGFlow](https://ragflow.io/) は、深い文書理解に基づいたオープンソースの RAG (Retrieval-Augmented Generation) エンジンである。LLM(大規模言語モデル)を組み合わせることで、様々な複雑なフォーマットのデータから根拠のある引用に裏打ちされた、信頼できる質問応答機能を実現し、あらゆる規模のビジネスに適した RAG ワークフローを提供します。
|
||||
[RAGFlow](https://ragflow.io/) は、先進的なRAG(Retrieval-Augmented Generation)技術と Agent 機能を融合し、大規模言語モデル(LLM)に優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型コンテキストエンジンと事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
|
||||
|
||||
## 🎮 Demo
|
||||
|
||||
@ -160,7 +160,7 @@
|
||||
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
||||
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
||||
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.20.4-slim エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.20.4-slim とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。例えば、完全版 v0.20.4 をダウンロードするには、RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.4 と設定します。
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.20.5-slim エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.20.5-slim とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。例えば、完全版 v0.20.5 をダウンロードするには、RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5 と設定します。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -173,8 +173,8 @@
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.20.4 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.4-slim | ≈2 | ❌ | Stable release |
|
||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
|
||||
@ -301,12 +301,14 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```
|
||||
|
||||
5. オペレーティングシステムにjemallocがない場合は、次のようにインストールします:
|
||||
|
||||
|
||||
```bash
|
||||
# ubuntu
|
||||
sudo apt-get install libjemalloc-dev
|
||||
# centos
|
||||
sudo yum install jemalloc
|
||||
# mac
|
||||
sudo brew install jemalloc
|
||||
```
|
||||
|
||||
6. バックエンドサービスを起動する:
|
||||
|
||||
14
README_ko.md
14
README_ko.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.4">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -47,7 +47,7 @@
|
||||
|
||||
## 💡 RAGFlow란?
|
||||
|
||||
[RAGFlow](https://ragflow.io/)는 심층 문서 이해에 기반한 오픈소스 RAG (Retrieval-Augmented Generation) 엔진입니다. 이 엔진은 대규모 언어 모델(LLM)과 결합하여 정확한 질문 응답 기능을 제공하며, 다양한 복잡한 형식의 데이터에서 신뢰할 수 있는 출처를 바탕으로 한 인용을 통해 이를 뒷받침합니다. RAGFlow는 규모에 상관없이 모든 기업에 최적화된 RAG 워크플로우를 제공합니다.
|
||||
[RAGFlow](https://ragflow.io/) 는 최첨단 RAG(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 컨텍스트 엔진과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
|
||||
|
||||
## 🎮 데모
|
||||
|
||||
@ -160,7 +160,7 @@
|
||||
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
||||
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.20.4-slim 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.20.4-slim과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. 예를 들어, 전체 버전인 v0.20.4을 다운로드하려면 RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.4로 설정합니다.
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.20.5-slim 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.20.5-slim과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. 예를 들어, 전체 버전인 v0.20.5을 다운로드하려면 RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5로 설정합니다.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -173,8 +173,8 @@
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.20.4 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.4-slim | ≈2 | ❌ | Stable release |
|
||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
|
||||
@ -306,6 +306,8 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
sudo apt-get install libjemalloc-dev
|
||||
# centos
|
||||
sudo yum install jemalloc
|
||||
# mac
|
||||
sudo brew install jemalloc
|
||||
```
|
||||
|
||||
6. 백엔드 서비스를 시작합니다:
|
||||
@ -339,7 +341,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
pkill -f "ragflow_server.py|task_executor.py"
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 📚 문서
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.4">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
||||
@ -67,7 +67,7 @@
|
||||
|
||||
## 💡 O que é o RAGFlow?
|
||||
|
||||
[RAGFlow](https://ragflow.io/) é um mecanismo RAG (Geração Aumentada por Recuperação) de código aberto baseado em entendimento profundo de documentos. Ele oferece um fluxo de trabalho RAG simplificado para empresas de qualquer porte, combinando LLMs (Modelos de Linguagem de Grande Escala) para fornecer capacidades de perguntas e respostas verídicas, respaldadas por citações bem fundamentadas de diversos dados complexos formatados.
|
||||
[RAGFlow](https://ragflow.io/) é um mecanismo de RAG (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por um motor de contexto convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
|
||||
|
||||
## 🎮 Demo
|
||||
|
||||
@ -180,7 +180,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
||||
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
||||
|
||||
> O comando abaixo baixa a edição `v0.20.4-slim` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.20.4-slim`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. Por exemplo: defina `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.4` para a edição completa `v0.20.4`.
|
||||
> O comando abaixo baixa a edição `v0.20.5-slim` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.20.5-slim`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. Por exemplo: defina `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` para a edição completa `v0.20.5`.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -193,8 +193,8 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
| Tag da imagem RAGFlow | Tamanho da imagem (GB) | Possui modelos de incorporação? | Estável? |
|
||||
| --------------------- | ---------------------- | ------------------------------- | ------------------------ |
|
||||
| v0.20.4 | ~9 | :heavy_check_mark: | Lançamento estável |
|
||||
| v0.20.4-slim | ~2 | ❌ | Lançamento estável |
|
||||
| v0.20.5 | ~9 | :heavy_check_mark: | Lançamento estável |
|
||||
| v0.20.5-slim | ~2 | ❌ | Lançamento estável |
|
||||
| nightly | ~9 | :heavy_check_mark: | _Instável_ build noturno |
|
||||
| nightly-slim | ~2 | ❌ | _Instável_ build noturno |
|
||||
|
||||
@ -330,6 +330,8 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
sudo apt-get install libjemalloc-dev
|
||||
# centos
|
||||
sudo yum instalar jemalloc
|
||||
# mac
|
||||
sudo brew install jemalloc
|
||||
```
|
||||
|
||||
6. Lance o serviço de back-end:
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.4">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -70,7 +70,7 @@
|
||||
|
||||
## 💡 RAGFlow 是什麼?
|
||||
|
||||
[RAGFlow](https://ragflow.io/) 是一款基於深度文件理解所建構的開源 RAG(Retrieval-Augmented Generation)引擎。 RAGFlow 可以為各種規模的企業及個人提供一套精簡的 RAG 工作流程,結合大語言模型(LLM)針對用戶各類不同的複雜格式數據提供可靠的問答以及有理有據的引用。
|
||||
[RAGFlow](https://ragflow.io/) 是一款領先的開源 RAG(Retrieval-Augmented Generation)引擎,通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式上下文引擎與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
|
||||
|
||||
## 🎮 Demo 試用
|
||||
|
||||
@ -183,7 +183,7 @@
|
||||
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
||||
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
||||
|
||||
> 執行以下指令會自動下載 RAGFlow slim Docker 映像 `v0.20.4-slim`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.20.4-slim` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。例如,你可以透過設定 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.4` 來下載 RAGFlow 鏡像的 `v0.20.4` 完整發行版。
|
||||
> 執行以下指令會自動下載 RAGFlow slim Docker 映像 `v0.20.5-slim`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.20.5-slim` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。例如,你可以透過設定 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` 來下載 RAGFlow 鏡像的 `v0.20.5` 完整發行版。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -196,8 +196,8 @@
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.20.4 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.4-slim | ≈2 | ❌ | Stable release |
|
||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
|
||||
@ -343,6 +343,8 @@ docker build --platform linux/amd64 --build-arg NEED_MIRROR=1 -f Dockerfile -t i
|
||||
sudo apt-get install libjemalloc-dev
|
||||
# centos
|
||||
sudo yum install jemalloc
|
||||
# mac
|
||||
sudo brew install jemalloc
|
||||
```
|
||||
|
||||
6. 啟動後端服務:
|
||||
|
||||
12
README_zh.md
12
README_zh.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.4">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -70,7 +70,7 @@
|
||||
|
||||
## 💡 RAGFlow 是什么?
|
||||
|
||||
[RAGFlow](https://ragflow.io/) 是一款基于深度文档理解构建的开源 RAG(Retrieval-Augmented Generation)引擎。RAGFlow 可以为各种规模的企业及个人提供一套精简的 RAG 工作流程,结合大语言模型(LLM)针对用户各类不同的复杂格式数据提供可靠的问答以及有理有据的引用。
|
||||
[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成(RAG)引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式上下文引擎与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
|
||||
|
||||
## 🎮 Demo 试用
|
||||
|
||||
@ -183,7 +183,7 @@
|
||||
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
||||
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
||||
|
||||
> 运行以下命令会自动下载 RAGFlow slim Docker 镜像 `v0.20.4-slim`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.20.4-slim` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。比如,你可以通过设置 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.4` 来下载 RAGFlow 镜像的 `v0.20.4` 完整发行版。
|
||||
> 运行以下命令会自动下载 RAGFlow slim Docker 镜像 `v0.20.5-slim`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.20.5-slim` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。比如,你可以通过设置 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` 来下载 RAGFlow 镜像的 `v0.20.5` 完整发行版。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -196,8 +196,8 @@
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.20.4 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.4-slim | ≈2 | ❌ | Stable release |
|
||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
|
||||
@ -342,6 +342,8 @@ docker build --platform linux/amd64 --build-arg NEED_MIRROR=1 -f Dockerfile -t i
|
||||
sudo apt-get install libjemalloc-dev
|
||||
# centos
|
||||
sudo yum install jemalloc
|
||||
# mac
|
||||
sudo brew install jemalloc
|
||||
```
|
||||
|
||||
6. 启动后端服务:
|
||||
|
||||
101
admin/README.md
Normal file
101
admin/README.md
Normal file
@ -0,0 +1,101 @@
|
||||
# RAGFlow Admin Service & CLI
|
||||
|
||||
### Introduction
|
||||
|
||||
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
|
||||
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
|
||||
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
|
||||
|
||||
Built with scalability and reliability in mind, the Admin Service ensures smooth system operation and simplifies maintenance workflows.
|
||||
|
||||
It consists of a server-side Service and a command-line client (CLI), both implemented in Python. User commands are parsed using the Lark parsing toolkit.
|
||||
|
||||
- **Admin Service**: A backend service that interfaces with the RAGFlow system to execute administrative operations and monitor its status.
|
||||
- **Admin CLI**: A command-line interface that allows users to connect to the Admin Service and issue commands for system management.
|
||||
|
||||
### Starting the Admin Service
|
||||
|
||||
1. Before start Admin Service, please make sure RAGFlow system is already started.
|
||||
|
||||
2. Run the service script:
|
||||
```bash
|
||||
python admin/admin_server.py
|
||||
```
|
||||
The service will start and listen for incoming connections from the CLI on the configured port.
|
||||
|
||||
### Using the Admin CLI
|
||||
|
||||
1. Ensure the Admin Service is running.
|
||||
2. Launch the CLI client:
|
||||
```bash
|
||||
python admin/admin_client.py -h 0.0.0.0 -p 9381
|
||||
|
||||
## Supported Commands
|
||||
|
||||
Commands are case-insensitive and must be terminated with a semicolon (`;`).
|
||||
|
||||
### Service Management Commands
|
||||
|
||||
- `LIST SERVICES;`
|
||||
- Lists all available services within the RAGFlow system.
|
||||
- `SHOW SERVICE <id>;`
|
||||
- Shows detailed status information for the service identified by `<id>`.
|
||||
- `STARTUP SERVICE <id>;`
|
||||
- Attempts to start the service identified by `<id>`.
|
||||
- `SHUTDOWN SERVICE <id>;`
|
||||
- Attempts to gracefully shut down the service identified by `<id>`.
|
||||
- `RESTART SERVICE <id>;`
|
||||
- Attempts to restart the service identified by `<id>`.
|
||||
|
||||
### User Management Commands
|
||||
|
||||
- `LIST USERS;`
|
||||
- Lists all users known to the system.
|
||||
- `SHOW USER '<username>';`
|
||||
- Shows details and permissions for the specified user. The username must be enclosed in single or double quotes.
|
||||
- `DROP USER '<username>';`
|
||||
- Removes the specified user from the system. Use with caution.
|
||||
- `ALTER USER PASSWORD '<username>' '<new_password>';`
|
||||
- Changes the password for the specified user.
|
||||
|
||||
### Data and Agent Commands
|
||||
|
||||
- `LIST DATASETS OF '<username>';`
|
||||
- Lists the datasets associated with the specified user.
|
||||
- `LIST AGENTS OF '<username>';`
|
||||
- Lists the agents associated with the specified user.
|
||||
|
||||
### Meta-Commands
|
||||
|
||||
Meta-commands are prefixed with a backslash (`\`).
|
||||
|
||||
- `\?` or `\help`
|
||||
- Shows help information for the available commands.
|
||||
- `\q` or `\quit`
|
||||
- Exits the CLI application.
|
||||
|
||||
## Examples
|
||||
|
||||
```commandline
|
||||
admin> list users;
|
||||
+-------------------------------+------------------------+-----------+-------------+
|
||||
| create_date | email | is_active | nickname |
|
||||
+-------------------------------+------------------------+-----------+-------------+
|
||||
| Fri, 22 Nov 2024 16:03:41 GMT | jeffery@infiniflow.org | 1 | Jeffery |
|
||||
| Fri, 22 Nov 2024 16:10:55 GMT | aya@infiniflow.org | 1 | Waterdancer |
|
||||
+-------------------------------+------------------------+-----------+-------------+
|
||||
|
||||
admin> list services;
|
||||
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||
| extra | host | id | name | port | service_type |
|
||||
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||
| {} | 0.0.0.0 | 0 | ragflow_0 | 9380 | ragflow_server |
|
||||
| {'meta_type': 'mysql', 'password': 'infini_rag_flow', 'username': 'root'} | localhost | 1 | mysql | 5455 | meta_data |
|
||||
| {'password': 'infini_rag_flow', 'store_type': 'minio', 'user': 'rag_flow'} | localhost | 2 | minio | 9000 | file_store |
|
||||
| {'password': 'infini_rag_flow', 'retrieval_type': 'elasticsearch', 'username': 'elastic'} | localhost | 3 | elasticsearch | 1200 | retrieval |
|
||||
| {'db_name': 'default_db', 'retrieval_type': 'infinity'} | localhost | 4 | infinity | 23817 | retrieval |
|
||||
| {'database': 1, 'mq_type': 'redis', 'password': 'infini_rag_flow'} | localhost | 5 | redis | 6379 | message_queue |
|
||||
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||
```
|
||||
574
admin/admin_client.py
Normal file
574
admin/admin_client.py
Normal file
@ -0,0 +1,574 @@
|
||||
import argparse
|
||||
import base64
|
||||
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from typing import Dict, List, Any
|
||||
from lark import Lark, Transformer, Tree
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
GRAMMAR = r"""
|
||||
start: command
|
||||
|
||||
command: sql_command | meta_command
|
||||
|
||||
sql_command: list_services
|
||||
| show_service
|
||||
| startup_service
|
||||
| shutdown_service
|
||||
| restart_service
|
||||
| list_users
|
||||
| show_user
|
||||
| drop_user
|
||||
| alter_user
|
||||
| create_user
|
||||
| activate_user
|
||||
| list_datasets
|
||||
| list_agents
|
||||
|
||||
// meta command definition
|
||||
meta_command: "\\" meta_command_name [meta_args]
|
||||
|
||||
meta_command_name: /[a-zA-Z?]+/
|
||||
meta_args: (meta_arg)+
|
||||
|
||||
meta_arg: /[^\\s"']+/ | quoted_string
|
||||
|
||||
// command definition
|
||||
|
||||
LIST: "LIST"i
|
||||
SERVICES: "SERVICES"i
|
||||
SHOW: "SHOW"i
|
||||
CREATE: "CREATE"i
|
||||
SERVICE: "SERVICE"i
|
||||
SHUTDOWN: "SHUTDOWN"i
|
||||
STARTUP: "STARTUP"i
|
||||
RESTART: "RESTART"i
|
||||
USERS: "USERS"i
|
||||
DROP: "DROP"i
|
||||
USER: "USER"i
|
||||
ALTER: "ALTER"i
|
||||
ACTIVE: "ACTIVE"i
|
||||
PASSWORD: "PASSWORD"i
|
||||
DATASETS: "DATASETS"i
|
||||
OF: "OF"i
|
||||
AGENTS: "AGENTS"i
|
||||
|
||||
list_services: LIST SERVICES ";"
|
||||
show_service: SHOW SERVICE NUMBER ";"
|
||||
startup_service: STARTUP SERVICE NUMBER ";"
|
||||
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
|
||||
restart_service: RESTART SERVICE NUMBER ";"
|
||||
|
||||
list_users: LIST USERS ";"
|
||||
drop_user: DROP USER quoted_string ";"
|
||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||
show_user: SHOW USER quoted_string ";"
|
||||
create_user: CREATE USER quoted_string quoted_string ";"
|
||||
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||
|
||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||
list_agents: LIST AGENTS OF quoted_string ";"
|
||||
|
||||
identifier: WORD
|
||||
quoted_string: QUOTED_STRING
|
||||
status: WORD
|
||||
|
||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||
NUMBER: /[0-9]+/
|
||||
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
|
||||
class AdminTransformer(Transformer):
|
||||
|
||||
def start(self, items):
|
||||
return items[0]
|
||||
|
||||
def command(self, items):
|
||||
return items[0]
|
||||
|
||||
def list_services(self, items):
|
||||
result = {'type': 'list_services'}
|
||||
return result
|
||||
|
||||
def show_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "show_service", "number": service_id}
|
||||
|
||||
def startup_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "startup_service", "number": service_id}
|
||||
|
||||
def shutdown_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "shutdown_service", "number": service_id}
|
||||
|
||||
def restart_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "restart_service", "number": service_id}
|
||||
|
||||
def list_users(self, items):
|
||||
return {"type": "list_users"}
|
||||
|
||||
def show_user(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "show_user", "username": user_name}
|
||||
|
||||
def drop_user(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "drop_user", "username": user_name}
|
||||
|
||||
def alter_user(self, items):
|
||||
user_name = items[3]
|
||||
new_password = items[4]
|
||||
return {"type": "alter_user", "username": user_name, "password": new_password}
|
||||
|
||||
def create_user(self, items):
|
||||
user_name = items[2]
|
||||
password = items[3]
|
||||
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
|
||||
|
||||
def activate_user(self, items):
|
||||
user_name = items[3]
|
||||
activate_status = items[4]
|
||||
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
|
||||
|
||||
def list_datasets(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_datasets", "username": user_name}
|
||||
|
||||
def list_agents(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_agents", "username": user_name}
|
||||
|
||||
def meta_command(self, items):
|
||||
command_name = str(items[0]).lower()
|
||||
args = items[1:] if len(items) > 1 else []
|
||||
|
||||
# handle quoted parameter
|
||||
parsed_args = []
|
||||
for arg in args:
|
||||
if hasattr(arg, 'value'):
|
||||
parsed_args.append(arg.value)
|
||||
else:
|
||||
parsed_args.append(str(arg))
|
||||
|
||||
return {'type': 'meta', 'command': command_name, 'args': parsed_args}
|
||||
|
||||
def meta_command_name(self, items):
|
||||
return items[0]
|
||||
|
||||
def meta_args(self, items):
|
||||
return items
|
||||
|
||||
|
||||
def encrypt(input_string):
|
||||
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
||||
pub_key = RSA.importKey(pub)
|
||||
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
||||
return base64.b64encode(cipher_text).decode("utf-8")
|
||||
|
||||
|
||||
class AdminCommandParser:
|
||||
def __init__(self):
|
||||
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
||||
self.command_history = []
|
||||
|
||||
def parse_command(self, command_str: str) -> Dict[str, Any]:
|
||||
if not command_str.strip():
|
||||
return {'type': 'empty'}
|
||||
|
||||
self.command_history.append(command_str)
|
||||
|
||||
try:
|
||||
result = self.parser.parse(command_str)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
|
||||
|
||||
|
||||
class AdminCLI:
|
||||
def __init__(self):
|
||||
self.parser = AdminCommandParser()
|
||||
self.is_interactive = False
|
||||
self.admin_account = "admin@ragflow.io"
|
||||
self.admin_password: str = "admin"
|
||||
self.host: str = ""
|
||||
self.port: int = 0
|
||||
|
||||
def verify_admin(self, args):
|
||||
|
||||
conn_info = self._parse_connection_args(args)
|
||||
if 'error' in conn_info:
|
||||
print(f"Error: {conn_info['error']}")
|
||||
return
|
||||
|
||||
self.host = conn_info['host']
|
||||
self.port = conn_info['port']
|
||||
print(f"Attempt to access ip: {self.host}, port: {self.port}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/auth'
|
||||
|
||||
try_count = 0
|
||||
while True:
|
||||
try_count += 1
|
||||
if try_count > 3:
|
||||
return False
|
||||
|
||||
admin_passwd = input(f"password for {self.admin_account}: ").strip()
|
||||
try:
|
||||
self.admin_password = encode_to_base64(admin_passwd)
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
error_code = res_json.get('code', -1)
|
||||
if error_code == 0:
|
||||
print("Authentication successful.")
|
||||
return True
|
||||
else:
|
||||
error_message = res_json.get('message', 'Unknown error')
|
||||
print(f"Authentication failed: {error_message}, try again")
|
||||
continue
|
||||
else:
|
||||
print(f"Bad response,status: {response.status_code}, try again")
|
||||
except Exception:
|
||||
print(f"Can't access {self.host}, port: {self.port}")
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
if not data:
|
||||
print("No data to print")
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(data[0].keys())
|
||||
col_widths = {}
|
||||
|
||||
for col in columns:
|
||||
max_width = len(str(col))
|
||||
for item in data:
|
||||
value_len = len(str(item.get(col, '')))
|
||||
if value_len > max_width:
|
||||
max_width = value_len
|
||||
col_widths[col] = max(2, max_width)
|
||||
|
||||
# Generate delimiter
|
||||
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
|
||||
|
||||
# Print header
|
||||
print(separator)
|
||||
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
|
||||
print(header)
|
||||
print(separator)
|
||||
|
||||
# Print data
|
||||
for item in data:
|
||||
row = "|"
|
||||
for col in columns:
|
||||
value = str(item.get(col, ''))
|
||||
if len(value) > col_widths[col]:
|
||||
value = value[:col_widths[col] - 3] + "..."
|
||||
row += f" {value:<{col_widths[col]}} |"
|
||||
print(row)
|
||||
|
||||
print(separator)
|
||||
|
||||
def run_interactive(self):
|
||||
|
||||
self.is_interactive = True
|
||||
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = input("admin> ").strip()
|
||||
if not command:
|
||||
continue
|
||||
|
||||
print(f"command: {command}")
|
||||
result = self.parser.parse_command(command)
|
||||
self.execute_command(result)
|
||||
|
||||
if isinstance(result, Tree):
|
||||
continue
|
||||
|
||||
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nUse '\\q' to quit")
|
||||
except EOFError:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
def run_single_command(self, args):
|
||||
conn_info = self._parse_connection_args(args)
|
||||
if 'error' in conn_info:
|
||||
print(f"Error: {conn_info['error']}")
|
||||
return
|
||||
|
||||
def _parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
|
||||
parser = argparse.ArgumentParser(description='Admin CLI Client', add_help=False)
|
||||
parser.add_argument('-h', '--host', default='localhost', help='Admin service host')
|
||||
parser.add_argument('-p', '--port', type=int, default=8080, help='Admin service port')
|
||||
|
||||
try:
|
||||
parsed_args, remaining_args = parser.parse_known_args(args)
|
||||
return {
|
||||
'host': parsed_args.host,
|
||||
'port': parsed_args.port,
|
||||
}
|
||||
except SystemExit:
|
||||
return {'error': 'Invalid connection arguments'}
|
||||
|
||||
def execute_command(self, parsed_command: Dict[str, Any]):
|
||||
|
||||
command_dict: dict
|
||||
if isinstance(parsed_command, Tree):
|
||||
command_dict = parsed_command.children[0]
|
||||
else:
|
||||
if parsed_command['type'] == 'error':
|
||||
print(f"Error: {parsed_command['message']}")
|
||||
return
|
||||
else:
|
||||
command_dict = parsed_command
|
||||
|
||||
# print(f"Parsed command: {command_dict}")
|
||||
|
||||
command_type = command_dict['type']
|
||||
|
||||
match command_type:
|
||||
case 'list_services':
|
||||
self._handle_list_services(command_dict)
|
||||
case 'show_service':
|
||||
self._handle_show_service(command_dict)
|
||||
case 'restart_service':
|
||||
self._handle_restart_service(command_dict)
|
||||
case 'shutdown_service':
|
||||
self._handle_shutdown_service(command_dict)
|
||||
case 'startup_service':
|
||||
self._handle_startup_service(command_dict)
|
||||
case 'list_users':
|
||||
self._handle_list_users(command_dict)
|
||||
case 'show_user':
|
||||
self._handle_show_user(command_dict)
|
||||
case 'drop_user':
|
||||
self._handle_drop_user(command_dict)
|
||||
case 'alter_user':
|
||||
self._handle_alter_user(command_dict)
|
||||
case 'create_user':
|
||||
self._handle_create_user(command_dict)
|
||||
case 'activate_user':
|
||||
self._handle_activate_user(command_dict)
|
||||
case 'list_datasets':
|
||||
self._handle_list_datasets(command_dict)
|
||||
case 'list_agents':
|
||||
self._handle_list_agents(command_dict)
|
||||
case 'meta':
|
||||
self._handle_meta_command(command_dict)
|
||||
case _:
|
||||
print(f"Command '{command_type}' would be executed with API")
|
||||
|
||||
def _handle_list_services(self, command):
|
||||
print("Listing all services")
|
||||
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_show_service(self, command):
|
||||
service_id: int = command['number']
|
||||
print(f"Showing service: {service_id}")
|
||||
|
||||
def _handle_restart_service(self, command):
|
||||
service_id: int = command['number']
|
||||
print(f"Restart service {service_id}")
|
||||
|
||||
def _handle_shutdown_service(self, command):
|
||||
service_id: int = command['number']
|
||||
print(f"Shutdown service {service_id}")
|
||||
|
||||
def _handle_startup_service(self, command):
|
||||
service_id: int = command['number']
|
||||
print(f"Startup service {service_id}")
|
||||
|
||||
def _handle_list_users(self, command):
|
||||
print("Listing all users")
|
||||
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_show_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Showing user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_drop_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Drop user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||
response = requests.delete(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_alter_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
print(f"Alter user: {username}, password: {password}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
|
||||
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||
json={'new_password': encrypt(password)})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_create_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
role: str = command['role']
|
||||
print(f"Create user: {username}, password: {password}, role: {role}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||
response = requests.post(
|
||||
url,
|
||||
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||
json={'username': username, 'password': encrypt(password), 'role': role}
|
||||
)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_activate_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
activate_tree: Tree = command['activate_status']
|
||||
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||
if activate_status.lower() in ['on', 'off']:
|
||||
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
|
||||
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||
json={'activate_status': activate_status})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||
else:
|
||||
print(f"Unknown activate status: {activate_status}.")
|
||||
|
||||
def _handle_list_datasets(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all datasets of user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_list_agents(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all agents of user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_meta_command(self, command):
|
||||
meta_command = command['command']
|
||||
args = command.get('args', [])
|
||||
|
||||
if meta_command in ['?', 'h', 'help']:
|
||||
self.show_help()
|
||||
elif meta_command in ['q', 'quit', 'exit']:
|
||||
print("Goodbye!")
|
||||
else:
|
||||
print(f"Meta command '{meta_command}' with args {args}")
|
||||
|
||||
def show_help(self):
|
||||
"""Help info"""
|
||||
help_text = """
|
||||
Commands:
|
||||
LIST SERVICES
|
||||
SHOW SERVICE <service>
|
||||
STARTUP SERVICE <service>
|
||||
SHUTDOWN SERVICE <service>
|
||||
RESTART SERVICE <service>
|
||||
LIST USERS
|
||||
SHOW USER <user>
|
||||
DROP USER <user>
|
||||
CREATE USER <user> <password>
|
||||
ALTER USER PASSWORD <user> <new_password>
|
||||
ALTER USER ACTIVE <user> <on/off>
|
||||
LIST DATASETS OF <user>
|
||||
LIST AGENTS OF <user>
|
||||
|
||||
Meta Commands:
|
||||
\\?, \\h, \\help Show this help
|
||||
\\q, \\quit, \\exit Quit the CLI
|
||||
"""
|
||||
print(help_text)
|
||||
|
||||
|
||||
def main():
|
||||
import sys
|
||||
|
||||
cli = AdminCLI()
|
||||
|
||||
if len(sys.argv) == 1 or (len(sys.argv) > 1 and sys.argv[1] == '-'):
|
||||
print(r"""
|
||||
____ ___ ______________ ___ __ _
|
||||
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||
""")
|
||||
if cli.verify_admin(sys.argv):
|
||||
cli.run_interactive()
|
||||
else:
|
||||
if cli.verify_admin(sys.argv):
|
||||
cli.run_interactive()
|
||||
# cli.run_single_command(sys.argv[1:])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
47
admin/admin_server.py
Normal file
47
admin/admin_server.py
Normal file
@ -0,0 +1,47 @@
|
||||
|
||||
import os
|
||||
import signal
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
from werkzeug.serving import run_simple
|
||||
from flask import Flask
|
||||
from routes import admin_bp
|
||||
from api.utils.log_utils import init_root_logger
|
||||
from api.constants import SERVICE_CONF
|
||||
from api import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_root_logger("admin_service")
|
||||
logging.info(r"""
|
||||
____ ___ ______________ ___ __ _
|
||||
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||
""")
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(admin_bp)
|
||||
settings.init_settings()
|
||||
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||
|
||||
try:
|
||||
logging.info("RAGFlow Admin service start...")
|
||||
run_simple(
|
||||
hostname="0.0.0.0",
|
||||
port=9381,
|
||||
application=app,
|
||||
threaded=True,
|
||||
use_reloader=True,
|
||||
use_debugger=True,
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
os.kill(os.getpid(), signal.SIGKILL)
|
||||
57
admin/auth.py
Normal file
57
admin/auth.py
Normal file
@ -0,0 +1,57 @@
|
||||
import logging
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from flask import request, jsonify
|
||||
|
||||
from exceptions import AdminException
|
||||
from api.db.init_data import encode_to_base64
|
||||
from api.db.services import UserService
|
||||
|
||||
|
||||
def check_admin(username: str, password: str):
|
||||
users = UserService.query(email=username)
|
||||
if not users:
|
||||
logging.info(f"Username: {username} is not registered!")
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": encode_to_base64("admin"),
|
||||
"nickname": "admin",
|
||||
"is_superuser": True,
|
||||
"email": "admin@ragflow.io",
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
if not UserService.save(**user_info):
|
||||
raise AdminException("Can't init admin.", 500)
|
||||
|
||||
user = UserService.query_user(username, password)
|
||||
if user:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def login_verify(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
auth = request.authorization
|
||||
if not auth or 'username' not in auth.parameters or 'password' not in auth.parameters:
|
||||
return jsonify({
|
||||
"code": 401,
|
||||
"message": "Authentication required",
|
||||
"data": None
|
||||
}), 200
|
||||
|
||||
username = auth.parameters['username']
|
||||
password = auth.parameters['password']
|
||||
# TODO: to check the username and password from DB
|
||||
if check_admin(username, password) is False:
|
||||
return jsonify({
|
||||
"code": 403,
|
||||
"message": "Access denied",
|
||||
"data": None
|
||||
}), 200
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
280
admin/config.py
Normal file
280
admin/config.py
Normal file
@ -0,0 +1,280 @@
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
from api.utils.configs import read_config
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class ServiceConfigs:
|
||||
def __init__(self):
|
||||
self.configs = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
|
||||
SERVICE_CONFIGS = ServiceConfigs
|
||||
|
||||
|
||||
class ServiceType(Enum):
|
||||
METADATA = "metadata"
|
||||
RETRIEVAL = "retrieval"
|
||||
MESSAGE_QUEUE = "message_queue"
|
||||
RAGFLOW_SERVER = "ragflow_server"
|
||||
TASK_EXECUTOR = "task_executor"
|
||||
FILE_STORE = "file_store"
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, 'service_type': self.service_type}
|
||||
|
||||
|
||||
class MetaConfig(BaseConfig):
|
||||
meta_type: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['meta_type'] = self.meta_type
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class MySQLConfig(MetaConfig):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['username'] = self.username
|
||||
extra_dict['password'] = self.password
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class PostgresConfig(MetaConfig):
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
return result
|
||||
|
||||
|
||||
class RetrievalConfig(BaseConfig):
|
||||
retrieval_type: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['retrieval_type'] = self.retrieval_type
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class InfinityConfig(RetrievalConfig):
|
||||
db_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['db_name'] = self.db_name
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class ElasticsearchConfig(RetrievalConfig):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['username'] = self.username
|
||||
extra_dict['password'] = self.password
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class MessageQueueConfig(BaseConfig):
|
||||
mq_type: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['mq_type'] = self.mq_type
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class RedisConfig(MessageQueueConfig):
|
||||
database: int
|
||||
password: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['database'] = self.database
|
||||
extra_dict['password'] = self.password
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class RabbitMQConfig(MessageQueueConfig):
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
return result
|
||||
|
||||
|
||||
class RAGFlowServerConfig(BaseConfig):
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
return result
|
||||
|
||||
|
||||
class TaskExecutorConfig(BaseConfig):
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
return result
|
||||
|
||||
|
||||
class FileStoreConfig(BaseConfig):
|
||||
store_type: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['store_type'] = self.store_type
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
class MinioConfig(FileStoreConfig):
|
||||
user: str
|
||||
password: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
extra_dict = result['extra'].copy()
|
||||
extra_dict['user'] = self.user
|
||||
extra_dict['password'] = self.password
|
||||
result['extra'] = extra_dict
|
||||
return result
|
||||
|
||||
|
||||
def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||
raw_configs = read_config(config_path)
|
||||
configurations = []
|
||||
ragflow_count = 0
|
||||
id_count = 0
|
||||
for k, v in raw_configs.items():
|
||||
match (k):
|
||||
case "ragflow":
|
||||
name: str = f'ragflow_{ragflow_count}'
|
||||
host: str = v['host']
|
||||
http_port: int = v['http_port']
|
||||
config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port, service_type="ragflow_server")
|
||||
configurations.append(config)
|
||||
id_count += 1
|
||||
case "es":
|
||||
name: str = 'elasticsearch'
|
||||
url = v['hosts']
|
||||
parsed = urlparse(url)
|
||||
host: str = parsed.hostname
|
||||
port: int = parsed.port
|
||||
username: str = v.get('username')
|
||||
password: str = v.get('password')
|
||||
config = ElasticsearchConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval",
|
||||
retrieval_type="elasticsearch",
|
||||
username=username, password=password)
|
||||
configurations.append(config)
|
||||
id_count += 1
|
||||
|
||||
case "infinity":
|
||||
name: str = 'infinity'
|
||||
url = v['uri']
|
||||
parts = url.split(':', 1)
|
||||
host = parts[0]
|
||||
port = int(parts[1])
|
||||
database: str = v.get('db_name', 'default_db')
|
||||
config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", retrieval_type="infinity",
|
||||
db_name=database)
|
||||
configurations.append(config)
|
||||
id_count += 1
|
||||
case "minio":
|
||||
name: str = 'minio'
|
||||
url = v['host']
|
||||
parts = url.split(':', 1)
|
||||
host = parts[0]
|
||||
port = int(parts[1])
|
||||
user = v.get('user')
|
||||
password = v.get('password')
|
||||
config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store",
|
||||
store_type="minio")
|
||||
configurations.append(config)
|
||||
id_count += 1
|
||||
case "redis":
|
||||
name: str = 'redis'
|
||||
url = v['host']
|
||||
parts = url.split(':', 1)
|
||||
host = parts[0]
|
||||
port = int(parts[1])
|
||||
password = v.get('password')
|
||||
db: int = v.get('db')
|
||||
config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db,
|
||||
service_type="message_queue", mq_type="redis")
|
||||
configurations.append(config)
|
||||
id_count += 1
|
||||
case "mysql":
|
||||
name: str = 'mysql'
|
||||
host: str = v.get('host')
|
||||
port: int = v.get('port')
|
||||
username = v.get('user')
|
||||
password = v.get('password')
|
||||
config = MySQLConfig(id=id_count, name=name, host=host, port=port, username=username, password=password,
|
||||
service_type="meta_data", meta_type="mysql")
|
||||
configurations.append(config)
|
||||
id_count += 1
|
||||
case "admin":
|
||||
pass
|
||||
case _:
|
||||
logging.warning(f"Unknown configuration key: {k}")
|
||||
continue
|
||||
|
||||
return configurations
|
||||
17
admin/exceptions.py
Normal file
17
admin/exceptions.py
Normal file
@ -0,0 +1,17 @@
|
||||
class AdminException(Exception):
|
||||
def __init__(self, message, code=400):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
class UserNotFoundError(AdminException):
|
||||
def __init__(self, username):
|
||||
super().__init__(f"User '{username}' not found", 404)
|
||||
|
||||
class UserAlreadyExistsError(AdminException):
|
||||
def __init__(self, username):
|
||||
super().__init__(f"User '{username}' already exists", 409)
|
||||
|
||||
class CannotDeleteAdminError(AdminException):
|
||||
def __init__(self):
|
||||
super().__init__("Cannot delete admin account", 403)
|
||||
15
admin/responses.py
Normal file
15
admin/responses.py
Normal file
@ -0,0 +1,15 @@
|
||||
from flask import jsonify
|
||||
|
||||
def success_response(data=None, message="Success", code = 0):
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data
|
||||
}), 200
|
||||
|
||||
def error_response(message="Error", code=-1, data=None):
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data
|
||||
}), 400
|
||||
190
admin/routes.py
Normal file
190
admin/routes.py
Normal file
@ -0,0 +1,190 @@
|
||||
from flask import Blueprint, request
|
||||
|
||||
from auth import login_verify
|
||||
from responses import success_response, error_response
|
||||
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||
from exceptions import AdminException
|
||||
|
||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||
|
||||
|
||||
@admin_bp.route('/auth', methods=['GET'])
|
||||
@login_verify
|
||||
def auth_admin():
|
||||
try:
|
||||
return success_response(None, "Admin is authorized", 0)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users', methods=['GET'])
|
||||
@login_verify
|
||||
def list_users():
|
||||
try:
|
||||
users = UserMgr.get_all_users()
|
||||
return success_response(users, "Get all users", 0)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users', methods=['POST'])
|
||||
@login_verify
|
||||
def create_user():
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'username' not in data or 'password' not in data:
|
||||
return error_response("Username and password are required", 400)
|
||||
|
||||
username = data['username']
|
||||
password = data['password']
|
||||
role = data.get('role', 'user')
|
||||
|
||||
res = UserMgr.create_user(username, password, role)
|
||||
if res["success"]:
|
||||
user_info = res["user_info"]
|
||||
user_info.pop("password") # do not return password
|
||||
return success_response(user_info, "User created successfully")
|
||||
else:
|
||||
return error_response("create user failed")
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e))
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||
@login_verify
|
||||
def delete_user(username):
|
||||
try:
|
||||
res = UserMgr.delete_user(username)
|
||||
if res["success"]:
|
||||
return success_response(None, res["message"])
|
||||
else:
|
||||
return error_response(res["message"])
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/password', methods=['PUT'])
|
||||
@login_verify
|
||||
def change_password(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'new_password' not in data:
|
||||
return error_response("New password is required", 400)
|
||||
|
||||
new_password = data['new_password']
|
||||
msg = UserMgr.update_user_password(username, new_password)
|
||||
return success_response(None, msg)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||
@login_verify
|
||||
def alter_user_activate_status(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'activate_status' not in data:
|
||||
return error_response("Activation status is required", 400)
|
||||
activate_status = data['activate_status']
|
||||
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||
return success_response(None, msg)
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_details(username):
|
||||
try:
|
||||
user_details = UserMgr.get_user_details(username)
|
||||
return success_response(user_details)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_datasets(username):
|
||||
try:
|
||||
datasets_list = UserServiceMgr.get_user_datasets(username)
|
||||
return success_response(datasets_list)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_agents(username):
|
||||
try:
|
||||
agents_list = UserServiceMgr.get_user_agents(username)
|
||||
return success_response(agents_list)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services', methods=['GET'])
|
||||
@login_verify
|
||||
def get_services():
|
||||
try:
|
||||
services = ServiceMgr.get_all_services()
|
||||
return success_response(services, "Get all services", 0)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/service_types/<service_type>', methods=['GET'])
|
||||
@login_verify
|
||||
def get_services_by_type(service_type_str):
|
||||
try:
|
||||
services = ServiceMgr.get_services_by_type(service_type_str)
|
||||
return success_response(services)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services/<service_id>', methods=['GET'])
|
||||
@login_verify
|
||||
def get_service(service_id):
|
||||
try:
|
||||
services = ServiceMgr.get_service_details(service_id)
|
||||
return success_response(services)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services/<service_id>', methods=['DELETE'])
|
||||
@login_verify
|
||||
def shutdown_service(service_id):
|
||||
try:
|
||||
services = ServiceMgr.shutdown_service(service_id)
|
||||
return success_response(services)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services/<service_id>', methods=['PUT'])
|
||||
@login_verify
|
||||
def restart_service(service_id):
|
||||
try:
|
||||
services = ServiceMgr.restart_service(service_id)
|
||||
return success_response(services)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
175
admin/services.py
Normal file
175
admin/services.py
Normal file
@ -0,0 +1,175 @@
|
||||
import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from api.db import ActiveEnum
|
||||
from api.db.services import UserService
|
||||
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.crypt import decrypt
|
||||
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
|
||||
from config import SERVICE_CONFIGS
|
||||
|
||||
class UserMgr:
|
||||
@staticmethod
|
||||
def get_all_users():
|
||||
users = UserService.get_all_users()
|
||||
result = []
|
||||
for user in users:
|
||||
result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date, 'is_active': user.is_active})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_user_details(username):
|
||||
# use email to query
|
||||
users = UserService.query_user_by_email(username)
|
||||
result = []
|
||||
for user in users:
|
||||
result.append({
|
||||
'email': user.email,
|
||||
'language': user.language,
|
||||
'last_login_time': user.last_login_time,
|
||||
'is_authenticated': user.is_authenticated,
|
||||
'is_active': user.is_active,
|
||||
'is_anonymous': user.is_anonymous,
|
||||
'login_channel': user.login_channel,
|
||||
'status': user.status,
|
||||
'is_superuser': user.is_superuser,
|
||||
'create_date': user.create_date,
|
||||
'update_date': user.update_date
|
||||
})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_user(username, password, role="user") -> dict:
|
||||
# Validate the email address
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
|
||||
raise AdminException(f"Invalid email address: {username}!")
|
||||
# Check if the email address is already used
|
||||
if UserService.query(email=username):
|
||||
raise UserAlreadyExistsError(username)
|
||||
# Construct user info data
|
||||
user_info_dict = {
|
||||
"email": username,
|
||||
"nickname": "", # ask user to edit it manually in settings.
|
||||
"password": decrypt(password),
|
||||
"login_channel": "password",
|
||||
"is_superuser": role == "admin",
|
||||
}
|
||||
return create_new_user(user_info_dict)
|
||||
|
||||
@staticmethod
|
||||
def delete_user(username):
|
||||
# use email to delete
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
if len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
usr = user_list[0]
|
||||
return delete_user_data(usr.id)
|
||||
|
||||
@staticmethod
|
||||
def update_user_password(username, new_password) -> str:
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# check new_password different from old.
|
||||
usr = user_list[0]
|
||||
psw = decrypt(new_password)
|
||||
if check_password_hash(usr.password, psw):
|
||||
return "Same password, no need to update!"
|
||||
# update password
|
||||
UserService.update_user_password(usr.id, psw)
|
||||
return "Password updated successfully!"
|
||||
|
||||
@staticmethod
|
||||
def update_user_activate_status(username, activate_status: str):
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# check activate status different from new
|
||||
usr = user_list[0]
|
||||
# format activate_status before handle
|
||||
_activate_status = activate_status.lower()
|
||||
target_status = {
|
||||
'on': ActiveEnum.ACTIVE.value,
|
||||
'off': ActiveEnum.INACTIVE.value,
|
||||
}.get(_activate_status)
|
||||
if not target_status:
|
||||
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||
if target_status == usr.is_active:
|
||||
return f"User activate status is already {_activate_status}!"
|
||||
# update is_active
|
||||
UserService.update_user(usr.id, {"is_active": target_status})
|
||||
return f"Turn {_activate_status} user activate status successfully!"
|
||||
|
||||
class UserServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_user_datasets(username):
|
||||
# use email to find user.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# find tenants
|
||||
usr = user_list[0]
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted kb and owned kb
|
||||
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agents(username):
|
||||
# use email to find user.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# find tenants
|
||||
usr = user_list[0]
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted agents and owned agents
|
||||
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||
return [{
|
||||
'title': r['title'],
|
||||
'permission': r['permission'],
|
||||
'canvas_type': r['canvas_type'],
|
||||
'canvas_category': r['canvas_category']
|
||||
} for r in res]
|
||||
|
||||
class ServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_all_services():
|
||||
result = []
|
||||
configs = SERVICE_CONFIGS.configs
|
||||
for config in configs:
|
||||
result.append(config.to_dict())
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_services_by_type(service_type_str: str):
|
||||
raise AdminException("get_services_by_type: not implemented")
|
||||
|
||||
@staticmethod
|
||||
def get_service_details(service_id: int):
|
||||
raise AdminException("get_service_details: not implemented")
|
||||
|
||||
@staticmethod
|
||||
def shutdown_service(service_id: int):
|
||||
raise AdminException("shutdown_service: not implemented")
|
||||
|
||||
@staticmethod
|
||||
def restart_service(service_id: int):
|
||||
raise AdminException("restart_service: not implemented")
|
||||
@ -16,6 +16,7 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
@ -26,7 +27,7 @@ from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils import get_uuid, hash_str2int
|
||||
from rag.prompts.prompts import chunks_format
|
||||
from rag.prompts.generator import chunks_format
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
class Graph:
|
||||
@ -152,6 +153,16 @@ class Graph:
|
||||
def get_tenant_id(self):
|
||||
return self._tenant_id
|
||||
|
||||
def get_variable_value(self, exp: str) -> Any:
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
return self.globals[exp]
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
return cpn["obj"].output(var_nm)
|
||||
|
||||
|
||||
class Canvas(Graph):
|
||||
|
||||
@ -300,9 +311,11 @@ class Canvas(Graph):
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
cpn_obj.set_output("content", _m)
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||
else:
|
||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||
yield decorate("message_end", {"reference": self.get_reference()})
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
||||
|
||||
while partials:
|
||||
_cpn_obj = self.get_component_obj(partials[0])
|
||||
@ -403,16 +416,6 @@ class Canvas(Graph):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_variable_value(self, exp: str) -> Any:
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
return self.globals[exp]
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
return cpn["obj"].output(var_nm)
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
@ -481,13 +484,14 @@ class Canvas(Graph):
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
def add_refernce(self, chunks: list[object], doc_infos: list[object]):
|
||||
def add_reference(self, chunks: list[object], doc_infos: list[object]):
|
||||
if not self.retrieval:
|
||||
self.retrieval = [{"chunks": {}, "doc_aggs": {}}]
|
||||
|
||||
r = self.retrieval[-1]
|
||||
for ck in chunks_format({"chunks": chunks}):
|
||||
cid = hash_str2int(ck["id"], 100)
|
||||
cid = hash_str2int(ck["id"], 500)
|
||||
# cid = uuid.uuid5(uuid.NAMESPACE_DNS, ck["id"])
|
||||
if cid not in r:
|
||||
r["chunks"][cid] = ck
|
||||
|
||||
|
||||
@ -28,9 +28,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.prompts import message_fit_in
|
||||
from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
|
||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
@ -138,7 +137,7 @@ class Agent(LLM, ToolBase):
|
||||
res.update(cpn.get_input_form())
|
||||
return res
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if kwargs.get("user_prompt"):
|
||||
usr_pmt = ""
|
||||
@ -155,18 +154,18 @@ class Agent(LLM, ToolBase):
|
||||
if not self.tools:
|
||||
return LLM._invoke(self, **kwargs)
|
||||
|
||||
prompt, msg = self._prepare_prompt_variables()
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools):
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
ans += delta_ans
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
@ -182,11 +181,11 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg):
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools):
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
@ -209,7 +208,7 @@ class Agent(LLM, ToolBase):
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools):
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
@ -230,7 +229,7 @@ class Agent(LLM, ToolBase):
|
||||
# last_calling,
|
||||
# last_calling != name
|
||||
#]):
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
||||
last_calling = name
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
use_tools.append({
|
||||
@ -239,7 +238,7 @@ class Agent(LLM, ToolBase):
|
||||
"results": tool_response
|
||||
})
|
||||
# self.callback("add_memory", {}, "...")
|
||||
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
|
||||
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
|
||||
|
||||
return name, tool_response
|
||||
|
||||
@ -279,10 +278,10 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk
|
||||
hist.append({"role": "assistant", "content": response})
|
||||
@ -307,7 +306,7 @@ class Agent(LLM, ToolBase):
|
||||
thr.append(executor.submit(use_tool, name, args))
|
||||
|
||||
st = timer()
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
@ -334,10 +333,10 @@ Respond immediately with your final comprehensive answer.
|
||||
for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
||||
mems = self._canvas.get_memory()
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
||||
try:
|
||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
||||
mems = [mems[r] for r in rank]
|
||||
|
||||
@ -244,7 +244,7 @@ class ComponentParamBase(ABC):
|
||||
|
||||
if not value_legal:
|
||||
raise ValueError(
|
||||
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
|
||||
"Please check runtime conf, {} = {} does not match user-parameter restriction".format(
|
||||
variable, value
|
||||
)
|
||||
)
|
||||
@ -431,7 +431,7 @@ class ComponentBase(ABC):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from rag.llm.chat_model import ERROR_PREFIX
|
||||
class CategorizeParam(LLMParam):
|
||||
|
||||
"""
|
||||
Define the Categorize component parameters.
|
||||
Define the categorize component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -80,7 +80,7 @@ Here's description of each category:
|
||||
- Prioritize the most specific applicable category
|
||||
- Return only the category name without explanations
|
||||
- Use "Other" only when no other category fits
|
||||
|
||||
|
||||
""".format(
|
||||
"\n - ".join(list(self.category_description.keys())),
|
||||
"\n".join(descriptions)
|
||||
@ -96,7 +96,7 @@ Here's description of each category:
|
||||
class Categorize(LLM, ABC):
|
||||
component_name = "Categorize"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if not msg:
|
||||
@ -112,7 +112,7 @@ class Categorize(LLM, ABC):
|
||||
|
||||
user_prompt = """
|
||||
---- Real Data ----
|
||||
{} →
|
||||
{} →
|
||||
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||
@ -134,4 +134,4 @@ class Categorize(LLM, ABC):
|
||||
self.set_output("_next", cpn_ids)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||
|
||||
@ -53,7 +53,7 @@ class InvokeParam(ComponentParamBase):
|
||||
class Invoke(ComponentBase, ABC):
|
||||
component_name = "Invoke"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||
def _invoke(self, **kwargs):
|
||||
args = {}
|
||||
for para in self._param.variables:
|
||||
|
||||
@ -17,6 +17,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generator
|
||||
import json_repair
|
||||
from functools import partial
|
||||
@ -25,8 +26,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.prompts import message_fit_in, citation_prompt
|
||||
from rag.prompts.prompts import tool_call_summary
|
||||
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt
|
||||
|
||||
|
||||
class LLMParam(ComponentParamBase):
|
||||
@ -81,9 +81,9 @@ class LLMParam(ComponentParamBase):
|
||||
|
||||
class LLM(ComponentBase):
|
||||
component_name = "LLM"
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
super().__init__(canvas, id, param)
|
||||
|
||||
def __init__(self, canvas, component_id, param: ComponentParamBase):
|
||||
super().__init__(canvas, component_id, param)
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
||||
self._param.llm_id, max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error
|
||||
@ -101,6 +101,8 @@ class LLM(ComponentBase):
|
||||
|
||||
def get_input_elements(self) -> dict[str, Any]:
|
||||
res = self.get_input_elements_from_text(self._param.sys_prompt)
|
||||
if isinstance(self._param.prompts, str):
|
||||
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||
for prompt in self._param.prompts:
|
||||
d = self.get_input_elements_from_text(prompt["content"])
|
||||
res.update(d)
|
||||
@ -112,6 +114,17 @@ class LLM(ComponentBase):
|
||||
def add2system_prompt(self, txt):
|
||||
self._param.sys_prompt += txt
|
||||
|
||||
def _sys_prompt_and_msg(self, msg, args):
|
||||
if isinstance(self._param.prompts, str):
|
||||
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
p = deepcopy(p)
|
||||
p["content"] = self.string_format(p["content"], args)
|
||||
msg.append(p)
|
||||
return msg, self.string_format(self._param.sys_prompt, args)
|
||||
|
||||
def _prepare_prompt_variables(self):
|
||||
if self._param.visual_files_var:
|
||||
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
||||
@ -127,7 +140,6 @@ class LLM(ComponentBase):
|
||||
|
||||
args = {}
|
||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||
sys_prompt = self._param.sys_prompt
|
||||
for k, o in vars.items():
|
||||
args[k] = o["value"]
|
||||
if not isinstance(args[k], str):
|
||||
@ -137,19 +149,22 @@ class LLM(ComponentBase):
|
||||
args[k] = str(args[k])
|
||||
self.set_input_value(k, args[k])
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
msg.append(p)
|
||||
|
||||
sys_prompt = self.string_format(sys_prompt, args)
|
||||
for m in msg:
|
||||
m["content"] = self.string_format(m["content"], args)
|
||||
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
|
||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||
sys_prompt += citation_prompt()
|
||||
sys_prompt += citation_prompt(user_defined_prompt)
|
||||
|
||||
return sys_prompt, msg
|
||||
return sys_prompt, msg, user_defined_prompt
|
||||
|
||||
def _extract_prompts(self, sys_prompt):
|
||||
pts = {}
|
||||
for tag in ["TASK_ANALYSIS", "PLAN_GENERATION", "REFLECTION", "CONTEXT_SUMMARY", "CONTEXT_RANKING", "CITATION_GUIDELINES"]:
|
||||
r = re.search(rf"<{tag}>(.*?)</{tag}>", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
||||
if not r:
|
||||
continue
|
||||
pts[tag.lower()] = r.group(1)
|
||||
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
||||
return pts, sys_prompt
|
||||
|
||||
def _generate(self, msg:list[dict], **kwargs) -> str:
|
||||
if not self.imgs:
|
||||
@ -190,15 +205,15 @@ class LLM(ComponentBase):
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||
yield delta(txt)
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
|
||||
prompt, msg = self._prepare_prompt_variables()
|
||||
error = ""
|
||||
prompt, msg, _ = self._prepare_prompt_variables()
|
||||
error: str = ""
|
||||
|
||||
if self._param.output_structure:
|
||||
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
||||
@ -261,11 +276,11 @@ class LLM(ComponentBase):
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results)
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
logging.info(f"[MEMORY]: {summ}")
|
||||
self._canvas.add_memory(user, assist, summ)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
_, msg = self._prepare_prompt_variables()
|
||||
_, msg,_ = self._prepare_prompt_variables()
|
||||
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
|
||||
@ -49,7 +49,7 @@ class MessageParam(ComponentParamBase):
|
||||
class Message(ComponentBase):
|
||||
component_name = "Message"
|
||||
|
||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimeter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||
for k,v in self.get_input_elements_from_text(script).items():
|
||||
if k in kwargs:
|
||||
continue
|
||||
@ -60,8 +60,8 @@ class Message(ComponentBase):
|
||||
if isinstance(v, partial):
|
||||
for t in v():
|
||||
ans += t
|
||||
elif isinstance(v, list) and delimeter:
|
||||
ans = delimeter.join([str(vv) for vv in v])
|
||||
elif isinstance(v, list) and delimiter:
|
||||
ans = delimiter.join([str(vv) for vv in v])
|
||||
elif not isinstance(v, str):
|
||||
try:
|
||||
ans = json.dumps(v, ensure_ascii=False)
|
||||
@ -127,7 +127,7 @@ class Message(ComponentBase):
|
||||
]
|
||||
return any([re.search(p, content) for p in patt])
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
rand_cnt = random.choice(self._param.content)
|
||||
if self._param.stream and not self._is_jinjia2(rand_cnt):
|
||||
|
||||
@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
|
||||
"type": "line"
|
||||
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self._param.method == "split":
|
||||
self._split(kwargs.get("line"))
|
||||
@ -90,7 +90,7 @@ class StringTransform(Message, ABC):
|
||||
for k,v in kwargs.items():
|
||||
if not v:
|
||||
v = ""
|
||||
script = re.sub(k, v, script)
|
||||
script = re.sub(k, lambda match: v, script)
|
||||
|
||||
self.set_output("result", script)
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ class SwitchParam(ComponentParamBase):
|
||||
class Switch(ComponentBase, ABC):
|
||||
component_name = "Switch"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||
def _invoke(self, **kwargs):
|
||||
for cond in self._param.conditions:
|
||||
res = []
|
||||
|
||||
@ -83,7 +83,7 @@
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"port": 3306,
|
||||
"sql": "Agent:WickedGoatsDivide@content",
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
}
|
||||
},
|
||||
@ -114,9 +114,7 @@
|
||||
"params": {
|
||||
"cross_languages": [],
|
||||
"empty_response": "",
|
||||
"kb_ids": [
|
||||
"ed31364c727211f0bdb2bafe6e7908e6"
|
||||
],
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
@ -124,7 +122,7 @@
|
||||
"value": ""
|
||||
}
|
||||
},
|
||||
"query": "sys.query",
|
||||
"query": "{sys.query}",
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"top_k": 1024,
|
||||
@ -145,9 +143,7 @@
|
||||
"params": {
|
||||
"cross_languages": [],
|
||||
"empty_response": "",
|
||||
"kb_ids": [
|
||||
"0f968106727311f08357bafe6e7908e6"
|
||||
],
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
@ -155,7 +151,7 @@
|
||||
"value": ""
|
||||
}
|
||||
},
|
||||
"query": "sys.query",
|
||||
"query": "{sys.query}",
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"top_k": 1024,
|
||||
@ -176,9 +172,7 @@
|
||||
"params": {
|
||||
"cross_languages": [],
|
||||
"empty_response": "",
|
||||
"kb_ids": [
|
||||
"4ad1f9d0727311f0827dbafe6e7908e6"
|
||||
],
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
@ -186,7 +180,7 @@
|
||||
"value": ""
|
||||
}
|
||||
},
|
||||
"query": "sys.query",
|
||||
"query": "{sys.query}",
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"top_k": 1024,
|
||||
@ -347,9 +341,7 @@
|
||||
"form": {
|
||||
"cross_languages": [],
|
||||
"empty_response": "",
|
||||
"kb_ids": [
|
||||
"ed31364c727211f0bdb2bafe6e7908e6"
|
||||
],
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
@ -357,7 +349,7 @@
|
||||
"value": ""
|
||||
}
|
||||
},
|
||||
"query": "sys.query",
|
||||
"query": "{sys.query}",
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"top_k": 1024,
|
||||
@ -387,9 +379,7 @@
|
||||
"form": {
|
||||
"cross_languages": [],
|
||||
"empty_response": "",
|
||||
"kb_ids": [
|
||||
"0f968106727311f08357bafe6e7908e6"
|
||||
],
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
@ -397,7 +387,7 @@
|
||||
"value": ""
|
||||
}
|
||||
},
|
||||
"query": "sys.query",
|
||||
"query": "{sys.query}",
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"top_k": 1024,
|
||||
@ -427,9 +417,7 @@
|
||||
"form": {
|
||||
"cross_languages": [],
|
||||
"empty_response": "",
|
||||
"kb_ids": [
|
||||
"4ad1f9d0727311f0827dbafe6e7908e6"
|
||||
],
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
@ -437,7 +425,7 @@
|
||||
"value": ""
|
||||
}
|
||||
},
|
||||
"query": "sys.query",
|
||||
"query": "{sys.query}",
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"top_k": 1024,
|
||||
@ -539,7 +527,7 @@
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"port": 3306,
|
||||
"sql": "Agent:WickedGoatsDivide@content",
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
},
|
||||
"label": "ExeSQL",
|
||||
|
||||
@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
|
||||
class ArXiv(ToolBase, ABC):
|
||||
component_name = "ArXiv"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -97,6 +97,6 @@ class ArXiv(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -22,7 +22,7 @@ from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.utils import hash_str2int
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
from rag.prompts.prompts import kb_prompt
|
||||
from rag.prompts.generator import kb_prompt
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
@ -166,7 +166,7 @@ class ToolBase(ComponentBase):
|
||||
"count": 1,
|
||||
"url": url
|
||||
})
|
||||
self._canvas.add_refernce(chunks, aggs)
|
||||
self._canvas.add_reference(chunks, aggs)
|
||||
self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))
|
||||
|
||||
def thoughts(self) -> str:
|
||||
|
||||
@ -129,7 +129,7 @@ module.exports = { main };
|
||||
class CodeExec(ToolBase, ABC):
|
||||
component_name = "CodeExec"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
lang = kwargs.get("lang", self._param.lang)
|
||||
script = kwargs.get("script", self._param.script)
|
||||
|
||||
@ -64,5 +64,5 @@ class Crawler(ToolBase, ABC):
|
||||
elif self._param.extract_type == 'markdown':
|
||||
return result.markdown
|
||||
elif self._param.extract_type == 'content':
|
||||
result.extracted_content
|
||||
return result.extracted_content
|
||||
return result.markdown
|
||||
|
||||
@ -43,7 +43,7 @@ class DeepLParam(ComponentParamBase):
|
||||
|
||||
|
||||
class DeepL(ComponentBase, ABC):
|
||||
component_name = "GitHub"
|
||||
component_name = "DeepL"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
|
||||
@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
|
||||
class DuckDuckGo(ToolBase, ABC):
|
||||
component_name = "DuckDuckGo"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -115,6 +115,6 @@ class DuckDuckGo(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -98,8 +98,8 @@ class EmailParam(ToolParamBase):
|
||||
|
||||
class Email(ToolBase, ABC):
|
||||
component_name = "Email"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("to_email"):
|
||||
self.set_output("success", False)
|
||||
@ -212,4 +212,4 @@ class Email(ToolBase, ABC):
|
||||
To: {}
|
||||
Subject: {}
|
||||
Your email is on its way—sit tight!
|
||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
@ -52,7 +53,7 @@ class ExeSQLParam(ToolParamBase):
|
||||
self.max_records = 1024
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql'])
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
|
||||
self.check_empty(self.database, "Database name")
|
||||
self.check_empty(self.username, "database username")
|
||||
self.check_empty(self.host, "IP Address")
|
||||
@ -77,7 +78,7 @@ class ExeSQLParam(ToolParamBase):
|
||||
class ExeSQL(ToolBase, ABC):
|
||||
component_name = "ExeSQL"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
|
||||
def convert_decimals(obj):
|
||||
@ -93,12 +94,24 @@ class ExeSQL(ToolBase, ABC):
|
||||
sql = kwargs.get("sql")
|
||||
if not sql:
|
||||
raise Exception("SQL for `ExeSQL` MUST not be empty.")
|
||||
sqls = sql.split(";")
|
||||
|
||||
vars = self.get_input_elements_from_text(sql)
|
||||
args = {}
|
||||
for k, o in vars.items():
|
||||
args[k] = o["value"]
|
||||
if not isinstance(args[k], str):
|
||||
try:
|
||||
args[k] = json.dumps(args[k], ensure_ascii=False)
|
||||
except Exception:
|
||||
args[k] = str(args[k])
|
||||
self.set_input_value(k, args[k])
|
||||
sql = self.string_format(sql, args)
|
||||
|
||||
sqls = sql.split(";")
|
||||
if self._param.db_type in ["mysql", "mariadb"]:
|
||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||
port=self._param.port, password=self._param.password)
|
||||
elif self._param.db_type == 'postgresql':
|
||||
elif self._param.db_type == 'postgres':
|
||||
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
|
||||
port=self._param.port, password=self._param.password)
|
||||
elif self._param.db_type == 'mssql':
|
||||
@ -110,6 +123,55 @@ class ExeSQL(ToolBase, ABC):
|
||||
r'PWD=' + self._param.password
|
||||
)
|
||||
db = pyodbc.connect(conn_str)
|
||||
elif self._param.db_type == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={self._param.database};"
|
||||
f"HOSTNAME={self._param.host};"
|
||||
f"PORT={self._param.port};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={self._param.username};"
|
||||
f"PWD={self._param.password};"
|
||||
)
|
||||
try:
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
except Exception as e:
|
||||
raise Exception("Database Connection Failed! \n" + str(e))
|
||||
|
||||
sql_res = []
|
||||
formalized_content = []
|
||||
for single_sql in sqls:
|
||||
single_sql = single_sql.replace("```", "").strip()
|
||||
if not single_sql:
|
||||
continue
|
||||
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
|
||||
|
||||
stmt = ibm_db.exec_immediate(conn, single_sql)
|
||||
rows = []
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
while row and len(rows) < self._param.max_records:
|
||||
rows.append(row)
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
|
||||
if not rows:
|
||||
sql_res.append({"content": "No record in the database!"})
|
||||
continue
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
for col in df.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
||||
df[col] = df[col].dt.strftime("%Y-%m-%d")
|
||||
|
||||
df = df.where(pd.notnull(df), None)
|
||||
|
||||
sql_res.append(convert_decimals(df.to_dict(orient="records")))
|
||||
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
|
||||
|
||||
ibm_db.close(conn)
|
||||
|
||||
self.set_output("json", sql_res)
|
||||
self.set_output("formalized_content", "\n\n".join(formalized_content))
|
||||
return self.output("formalized_content")
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
except Exception as e:
|
||||
@ -137,6 +199,8 @@ class ExeSQL(ToolBase, ABC):
|
||||
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
||||
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
||||
|
||||
single_res = single_res.where(pd.notnull(single_res), None)
|
||||
|
||||
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
||||
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
|
||||
class GitHub(ToolBase, ABC):
|
||||
component_name = "GitHub"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -88,4 +88,4 @@ class GitHub(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
|
||||
class Google(ToolBase, ABC):
|
||||
component_name = "Google"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("q"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -154,6 +154,6 @@ class Google(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
|
||||
class GoogleScholar(ToolBase, ABC):
|
||||
component_name = "GoogleScholar"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -93,4 +93,4 @@ class GoogleScholar(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -33,7 +33,7 @@ class PubMedParam(ToolParamBase):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "pubmed_search",
|
||||
"description": """
|
||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||
In addition to MEDLINE, PubMed provides access to:
|
||||
- older references from the print version of Index Medicus, back to 1951 and earlier
|
||||
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
||||
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
|
||||
class PubMed(ToolBase, ABC):
|
||||
component_name = "PubMed"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -105,4 +105,4 @@ class PubMed(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -23,8 +23,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts import kb_prompt
|
||||
from rag.prompts.prompts import cross_languages
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@ -75,7 +74,7 @@ class RetrievalParam(ToolParamBase):
|
||||
class Retrieval(ToolBase, ABC):
|
||||
component_name = "Retrieval"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
@ -163,13 +162,20 @@ class Retrieval(ToolBase, ABC):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
self._canvas.add_refernce(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||
# Format the chunks for JSON output (similar to how other tools do it)
|
||||
json_output = kbinfos["chunks"].copy()
|
||||
|
||||
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
||||
|
||||
# Set both formalized content and JSON output
|
||||
self.set_output("formalized_content", form_cnt)
|
||||
self.set_output("json", json_output)
|
||||
|
||||
return form_cnt
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
|
||||
class SearXNG(ToolBase, ABC):
|
||||
component_name = "SearXNG"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
# Gracefully handle try-run without inputs
|
||||
query = kwargs.get("query")
|
||||
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
# 构建搜索参数
|
||||
search_params = {
|
||||
'q': query,
|
||||
'format': 'json',
|
||||
@ -104,33 +103,29 @@ class SearXNG(ToolBase, ABC):
|
||||
'pageno': 1
|
||||
}
|
||||
|
||||
# 发送搜索请求
|
||||
response = requests.get(
|
||||
f"{searxng_url}/search",
|
||||
params=search_params,
|
||||
timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 验证响应数据
|
||||
|
||||
if not data or not isinstance(data, dict):
|
||||
raise ValueError("Invalid response from SearXNG")
|
||||
|
||||
|
||||
results = data.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
raise ValueError("Invalid results format from SearXNG")
|
||||
|
||||
# 限制结果数量
|
||||
|
||||
results = results[:self._param.top_n]
|
||||
|
||||
# 处理搜索结果
|
||||
|
||||
self._retrieve_chunks(results,
|
||||
get_title=lambda r: r.get("title", ""),
|
||||
get_url=lambda r: r.get("url", ""),
|
||||
get_content=lambda r: r.get("content", ""))
|
||||
|
||||
|
||||
self.set_output("json", results)
|
||||
return self.output("formalized_content")
|
||||
|
||||
@ -151,6 +146,6 @@ class SearXNG(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Searching with SearXNG for relevant results...
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -31,7 +31,7 @@ class TavilySearchParam(ToolParamBase):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "tavily_search",
|
||||
"description": """
|
||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||
When searching:
|
||||
- Start with specific query which should focus on just a single aspect.
|
||||
- Number of keywords in query should be less than 5.
|
||||
@ -101,7 +101,7 @@ When searching:
|
||||
class TavilySearch(ToolBase, ABC):
|
||||
component_name = "TavilySearch"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -136,7 +136,7 @@ class TavilySearch(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
|
||||
class TavilyExtract(ToolBase, ABC):
|
||||
component_name = "TavilyExtract"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||
last_e = None
|
||||
@ -224,4 +224,4 @@ class TavilyExtract(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||
|
||||
@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
|
||||
class WenCai(ToolBase, ABC):
|
||||
component_name = "WenCai"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("report", "")
|
||||
@ -111,4 +111,4 @@ class WenCai(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
|
||||
class Wikipedia(ToolBase, ABC):
|
||||
component_name = "Wikipedia"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -99,6 +99,6 @@ class Wikipedia(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
|
||||
class YahooFinance(ToolBase, ABC):
|
||||
component_name = "YahooFinance"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("stock_code"):
|
||||
self.set_output("report", "")
|
||||
@ -111,4 +111,4 @@ class YahooFinance(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||
|
||||
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services import UserService
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
from api.utils.json import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
|
||||
@ -39,7 +39,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts import keyword_extraction
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
|
||||
@ -19,15 +19,19 @@ import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from agent.component import LLM
|
||||
from api.db import FileType
|
||||
from api import settings
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.settings import RetCode
|
||||
@ -35,25 +39,19 @@ from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.db_models import APIToken, Task
|
||||
import time
|
||||
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def templates():
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def canvas_list():
|
||||
return get_json_result(data=sorted([c.to_dict() for c in \
|
||||
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
|
||||
)
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@ -77,9 +75,10 @@ def save():
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
cate = req.get("canvas_category", CanvasCategory.Agent)
|
||||
if "id" not in req:
|
||||
req["user_id"] = current_user.id
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip()):
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
|
||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||
req["id"] = get_uuid()
|
||||
if not UserCanvasService.save(**req):
|
||||
@ -91,7 +90,7 @@ def save():
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.update_by_id(req["id"], req)
|
||||
# save version
|
||||
UserCanvasVersionService.insert( user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||
return get_json_result(data=req)
|
||||
|
||||
@ -101,7 +100,7 @@ def save():
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@ -148,6 +147,14 @@ def run():
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
||||
except Exception as e:
|
||||
@ -173,6 +180,44 @@ def run():
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
doc = doc[0]
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
doc["token_num"] = 0
|
||||
DocumentService.clear_chunk_num_when_rerun(doc["id"])
|
||||
DocumentService.update_by_id(id, doc)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
|
||||
dsl = req["dsl"]
|
||||
dsl["path"] = [req["component_id"]]
|
||||
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
def cancel(task_id):
|
||||
try:
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
@ -198,7 +243,7 @@ def reset():
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -332,7 +377,7 @@ def test_db_connect():
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
password=req["password"])
|
||||
elif req["db_type"] == 'postgresql':
|
||||
elif req["db_type"] == 'postgres':
|
||||
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
password=req["password"])
|
||||
elif req["db_type"] == 'mssql':
|
||||
@ -348,6 +393,22 @@ def test_db_connect():
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
elif req["db_type"] == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
logging.info(conn_str)
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
else:
|
||||
return server_error_response("Unsupported database type.")
|
||||
if req["db_type"] != 'mssql':
|
||||
@ -383,22 +444,32 @@ def getversion( version_id):
|
||||
return get_json_result(data=f"Error getting history file: {e}")
|
||||
|
||||
|
||||
@manager.route('/listteam', methods=['GET']) # noqa: F821
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_canvas():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 150))
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
canvas_category = request.args.get("canvas_category")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
|
||||
if not owner_ids:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
tenants = [m["tenant_id"] for m in tenants]
|
||||
tenants.append(current_user.id)
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
tenants, current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, canvas_category)
|
||||
else:
|
||||
tenants = owner_ids
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
tenants, current_user.id, 0,
|
||||
0, orderby, desc, keywords, canvas_category)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
|
||||
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@ -418,12 +489,10 @@ def setting():
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
flow = flow.to_dict()
|
||||
flow["title"] = req["title"]
|
||||
if req["description"]:
|
||||
flow["description"] = req["description"]
|
||||
if req["permission"]:
|
||||
flow["permission"] = req["permission"]
|
||||
if req["avatar"]:
|
||||
flow["avatar"] = req["avatar"]
|
||||
|
||||
for key in ["description", "permission", "avatar"]:
|
||||
if value := req.get(key):
|
||||
flow[key] = value
|
||||
|
||||
num= UserCanvasService.update_by_id(req["id"], flow)
|
||||
return get_json_result(data=num)
|
||||
@ -472,3 +541,24 @@ def sessions(canvas_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def prompts():
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
"plan_generation": NEXT_STEP,
|
||||
"reflection": REFLECT,
|
||||
#"context_summary": SUMMARY4MEMORY,
|
||||
#"context_ranking": RANK_MEMORY,
|
||||
"citation_guidelines": CITATION_PROMPT_TEMPLATE
|
||||
})
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
@ -33,8 +33,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts import cross_languages, keyword_extraction
|
||||
from rag.prompts.prompts import gen_meta_filter
|
||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
|
||||
@ -291,6 +290,10 @@ def retrieval_test():
|
||||
kb_ids = req["kb_id"]
|
||||
if isinstance(kb_ids, str):
|
||||
kb_ids = [kb_ids]
|
||||
if not kb_ids:
|
||||
return get_json_result(data=False, message='Please specify dataset firstly.',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
@ -29,8 +29,8 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from rag.prompts.prompt_template import load_prompt
|
||||
from rag.prompts.prompts import chunks_format
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@ -226,7 +226,7 @@ def completion():
|
||||
if not is_embedded:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
@ -400,6 +400,8 @@ def related_questions():
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
|
||||
|
||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||
if "parameter" in gen_conf:
|
||||
del gen_conf["parameter"]
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
prompt,
|
||||
|
||||
@ -32,7 +32,7 @@ from api.db.services.document_service import DocumentService, doc_upload_and_par
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
@ -182,6 +182,7 @@ def create():
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": FileType.VIRTUAL,
|
||||
@ -456,8 +457,7 @@ def run():
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value and str(doc.run) == TaskStatus.DONE.value:
|
||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
|
||||
DocumentService.update_by_id(id, info)
|
||||
@ -480,8 +480,11 @@ def run():
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -547,31 +550,22 @@ def get(doc_id):
|
||||
|
||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "parser_id")
|
||||
@validate_request("doc_id")
|
||||
def change_parser():
|
||||
req = request.json
|
||||
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||
if "parser_config" in req:
|
||||
if req["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
|
||||
return get_data_error_result(message="Not supported yet!")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
def reset_doc():
|
||||
nonlocal doc
|
||||
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if "parser_config" in req:
|
||||
DocumentService.update_parser_config(doc.id, req["parser_config"])
|
||||
if doc.token_num > 0:
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
|
||||
if not e:
|
||||
@ -582,6 +576,26 @@ def change_parser():
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req:
|
||||
if doc.pipeline_id == req["pipeline_id"]:
|
||||
return get_json_result(data=True)
|
||||
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
|
||||
reset_doc()
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||
if "parser_config" in req:
|
||||
if req["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
|
||||
return get_data_error_result(message="Not supported yet!")
|
||||
if "parser_config" in req:
|
||||
DocumentService.update_parser_config(doc.id, req["parser_config"])
|
||||
reset_doc()
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -683,7 +697,7 @@ def set_meta():
|
||||
meta = json.loads(req["meta"])
|
||||
if not isinstance(meta, dict):
|
||||
return get_json_result(data=False, message="Only dictionary type supported.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for k,v in meta.items():
|
||||
for k, v in meta.items():
|
||||
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float):
|
||||
return get_json_result(data=False, message=f"The type is not supported: {v}", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
except Exception as e:
|
||||
|
||||
@ -246,6 +246,8 @@ def rm():
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
@ -292,6 +294,8 @@ def rename():
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="File not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.type != FileType.FOLDER.value \
|
||||
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
@ -328,6 +332,8 @@ def get(file_id):
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
@ -367,6 +373,8 @@ def move():
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_data_error_result(message="Parent Folder not found!")
|
||||
|
||||
@ -14,18 +14,21 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils import get_uuid
|
||||
from api.db import StatusEnum, FileSource
|
||||
from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
from api.utils.api_utils import get_json_result
|
||||
@ -35,7 +38,6 @@ from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
@ -61,10 +63,39 @@ def create():
|
||||
req["name"] = dataset_name
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
if not req.get("parser_id"):
|
||||
req["parser_id"] = "naive"
|
||||
e, t = TenantService.get_by_id(current_user.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
req["embd_id"] = t.embd_id
|
||||
req["parser_config"] = {
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 512,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False,
|
||||
"topn_tags": 3,
|
||||
"raptor": {
|
||||
"use_raptor": True,
|
||||
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
||||
"max_token": 256,
|
||||
"threshold": 0.1,
|
||||
"max_cluster": 64,
|
||||
"random_seed": 0
|
||||
},
|
||||
"graphrag": {
|
||||
"use_graphrag": True,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category"
|
||||
],
|
||||
"method": "light"
|
||||
}
|
||||
}
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
@ -379,3 +410,368 @@ def get_meta():
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
|
||||
|
||||
|
||||
@manager.route("/basic_info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get_basic_info():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
basic_info = DocumentService.knowledgebase_basic_info(kb_id)
|
||||
|
||||
return get_json_result(data=basic_info)
|
||||
|
||||
|
||||
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
keywords = request.args.get("keywords", "")
|
||||
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_date_from = request.args.get("create_date_from", "")
|
||||
create_date_to = request.args.get("create_date_to", "")
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_date_from = request.args.get("create_date_from", "")
|
||||
create_date_to = request.args.get("create_date_to", "")
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
req = request.get_json()
|
||||
log_ids = req.get("log_ids", [])
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/pipeline_log_detail", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def pipeline_log_detail():
|
||||
log_id = request.args.get("log_id")
|
||||
if not log_id:
|
||||
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
ok, log = PipelineOperationLogService.get_by_id(log_id)
|
||||
if not ok:
|
||||
return get_data_error_result(message="Invalid pipeline log ID")
|
||||
|
||||
return get_json_result(data=log.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_graphrag():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_raptor():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_raptor", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_raptor():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_mindmap():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
|
||||
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"mindmap_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_mindmap():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Mindmap Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/unbind_task", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_kb_task():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_json_result(data=True)
|
||||
|
||||
pipeline_task_type = request.args.get("pipeline_task_type", "")
|
||||
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
return get_error_data_result(message="Invalid task type")
|
||||
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
kb_task_id = "graphrag_task_id"
|
||||
kb_task_finish_at = "graphrag_task_finish_at"
|
||||
case PipelineTaskType.RAPTOR:
|
||||
kb_task_id = "raptor_task_id"
|
||||
kb_task_finish_at = "raptor_task_finish_at"
|
||||
case PipelineTaskType.MINDMAP:
|
||||
kb_task_id = "mindmap_task_id"
|
||||
kb_task_finish_at = "mindmap_task_finish_at"
|
||||
case _:
|
||||
return get_error_data_result(message="Internal Error: Invalid task type")
|
||||
|
||||
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id: "", kb_task_finish_at: None})
|
||||
if not ok:
|
||||
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -24,7 +24,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
|
||||
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@ -101,19 +101,4 @@ def retrieval(tenant_id):
|
||||
logging.exception(e)
|
||||
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
def convert_conditions(metadata_condition):
|
||||
if metadata_condition is None:
|
||||
metadata_condition = {}
|
||||
op_mapping = {
|
||||
"is": "=",
|
||||
"not is": "≠"
|
||||
}
|
||||
return [
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
|
||||
|
||||
@ -35,11 +35,12 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts import cross_languages, keyword_extraction
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
from rag.utils import rmSpace
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
@ -1350,6 +1351,9 @@ def retrieval_test(tenant_id):
|
||||
highlight:
|
||||
type: boolean
|
||||
description: Whether to highlight matched content.
|
||||
metadata_condition:
|
||||
type: object
|
||||
description: metadata filter condition.
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
@ -1413,6 +1417,10 @@ def retrieval_test(tenant_id):
|
||||
for doc_id in doc_ids:
|
||||
if doc_id not in doc_ids_list:
|
||||
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
||||
if not doc_ids:
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
|
||||
@ -3,9 +3,11 @@ import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, token_required
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType
|
||||
@ -81,16 +83,16 @@ def upload(tenant_id):
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||
|
||||
for file_obj in file_objs:
|
||||
# 文件路径处理
|
||||
# Handle file path
|
||||
full_path = '/' + file_obj.filename
|
||||
file_obj_names = full_path.split('/')
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# 获取文件夹路径ID
|
||||
# Get folder path ID
|
||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# 创建文件夹结构
|
||||
# Crete file folder
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
@ -666,3 +668,71 @@ def move(tenant_id):
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
req = request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
|
||||
try:
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
files_set = dict({file.id: file for file in files})
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for id in file_ids_list:
|
||||
informs = File2DocumentService.get_by_file_id(id)
|
||||
# delete
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(
|
||||
message="Database error (Document removal)!", code=404)
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
for kb_id in kb_ids:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this knowledgebase!", code=404)
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this file!", code=404)
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": tenant_id,
|
||||
"type": file.type,
|
||||
"name": file.name,
|
||||
"suffix": Path(file.name).suffix.lstrip("."),
|
||||
"location": file.location,
|
||||
"size": file.size
|
||||
})
|
||||
file2document = File2DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"file_id": id,
|
||||
"document_id": doc.id,
|
||||
})
|
||||
|
||||
file2documents.append(file2document.to_json())
|
||||
return get_json_result(data=file2documents)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -38,9 +38,8 @@ from api.db.services.user_service import UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts import chunks_format
|
||||
from rag.prompts.prompt_template import load_prompt
|
||||
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@ -414,7 +413,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.get("session_id", req.get("id", "") or req.get("metadata", {}).get("id", "")),
|
||||
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||
stream=True,
|
||||
**req,
|
||||
),
|
||||
@ -432,7 +431,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.get("session_id", req.get("id", "") or req.get("metadata", {}).get("id", "")),
|
||||
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||
stream=False,
|
||||
**req,
|
||||
)
|
||||
@ -941,6 +940,9 @@ def retrieval_test_embedded():
|
||||
kb_ids = req["kb_id"]
|
||||
if isinstance(kb_ids, str):
|
||||
kb_ids = [kb_ids]
|
||||
if not kb_ids:
|
||||
return get_json_result(data=False, message='Please specify dataset firstly.',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
|
||||
@ -36,6 +36,9 @@ from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
|
||||
|
||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@ -169,6 +172,12 @@ def status():
|
||||
return get_json_result(data=res)
|
||||
|
||||
|
||||
@manager.route("/healthz", methods=["GET"]) # noqa: F821
|
||||
def healthz():
|
||||
result, all_ok = run_health_checks()
|
||||
return jsonify(result), (200 if all_ok else 500)
|
||||
|
||||
|
||||
@manager.route("/new_token", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def new_token():
|
||||
|
||||
@ -34,7 +34,6 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
|
||||
from api.utils import (
|
||||
current_timestamp,
|
||||
datetime_format,
|
||||
decrypt,
|
||||
download_img,
|
||||
get_format_time,
|
||||
get_uuid,
|
||||
@ -46,6 +45,7 @@ from api.utils.api_utils import (
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
@ -98,7 +98,14 @@ def login():
|
||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.FORBIDDEN,
|
||||
message="This account has been disabled, please contact the administrator!",
|
||||
)
|
||||
elif user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
@ -227,6 +234,9 @@ def oauth_callback(channel):
|
||||
# User exists, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect(f"/?auth={user.get_id()}")
|
||||
@ -317,6 +327,8 @@ def github_callback():
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
@ -418,6 +430,8 @@ def feishu_callback():
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.save()
|
||||
|
||||
2
api/common/README.md
Normal file
2
api/common/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
The python files in this directory are shared between service. They contain common utilities, models, and functions that can be used across various
|
||||
services to ensure consistency and reduce code duplication.
|
||||
21
api/common/base64.py
Normal file
21
api/common/base64.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||
return base64_encoded.decode('utf-8')
|
||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
||||
INVALID = "0"
|
||||
|
||||
|
||||
class ActiveEnum(Enum):
|
||||
ACTIVE = "1"
|
||||
INACTIVE = "0"
|
||||
|
||||
|
||||
class UserTenantRole(StrEnum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
@ -74,8 +79,10 @@ class TaskStatus(StrEnum):
|
||||
DONE = "3"
|
||||
FAIL = "4"
|
||||
|
||||
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL}
|
||||
|
||||
|
||||
class ParserType(StrEnum):
|
||||
PRESENTATION = "presentation"
|
||||
LAWS = "laws"
|
||||
@ -105,10 +112,30 @@ class CanvasType(StrEnum):
|
||||
DocBot = "docbot"
|
||||
|
||||
|
||||
class CanvasCategory(StrEnum):
|
||||
Agent = "agent_canvas"
|
||||
DataFlow = "dataflow_canvas"
|
||||
|
||||
VALID_CANVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
||||
|
||||
|
||||
class MCPServerType(StrEnum):
|
||||
SSE = "sse"
|
||||
STREAMABLE_HTTP = "streamable-http"
|
||||
|
||||
|
||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "Download"
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
MINDMAP = "Mindmap"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
|
||||
@ -26,12 +26,14 @@ from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
|
||||
from api import settings, utils
|
||||
from api.db import ParserType, SerializedType
|
||||
from api.utils.json import json_dumps, json_loads
|
||||
from api.utils.configs import deserialize_b64, serialize_b64
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
value = self.default_value
|
||||
return utils.json_dumps(value)
|
||||
return json_dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if not value:
|
||||
return self.default_value
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
|
||||
|
||||
class ListField(JSONField):
|
||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
||||
|
||||
def db_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.serialize_b64(value, to_str=True)
|
||||
return serialize_b64(value, to_str=True)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return None
|
||||
return utils.json_dumps(value, with_type=True)
|
||||
return json_dumps(value, with_type=True)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
def python_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.deserialize_b64(value)
|
||||
return deserialize_b64(value)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return {}
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
@ -245,19 +247,26 @@ class JsonSerializedField(SerializedField):
|
||||
|
||||
class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.max_retries = kwargs.pop('max_retries', 5)
|
||||
self.retry_delay = kwargs.pop('retry_delay', 1)
|
||||
self.max_retries = kwargs.pop("max_retries", 5)
|
||||
self.retry_delay = kwargs.pop("retry_delay", 1)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
from peewee import OperationalError
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().execute_sql(sql, params, commit)
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Lost connection (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
@ -267,16 +276,34 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
return None
|
||||
|
||||
def _handle_connection_loss(self):
|
||||
self.close_all()
|
||||
self.connect()
|
||||
# self.close_all()
|
||||
# self.connect()
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
|
||||
def begin(self):
|
||||
from peewee import OperationalError
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().begin()
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||
)
|
||||
@ -301,7 +328,16 @@ class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = settings.DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
|
||||
pool_config = {
|
||||
'max_retries': 5,
|
||||
'retry_delay': 1,
|
||||
}
|
||||
database_config.update(pool_config)
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||
db_name, **database_config
|
||||
)
|
||||
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
logging.info("init database on cluster mode successfully")
|
||||
|
||||
|
||||
@ -648,8 +684,17 @@ class Knowledgebase(DataBaseModel):
|
||||
vector_similarity_weight = FloatField(default=0.3, index=True)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
graphrag_task_finish_at = DateTimeField(null=True)
|
||||
raptor_task_id = CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)
|
||||
raptor_task_finish_at = DateTimeField(null=True)
|
||||
mindmap_task_id = CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)
|
||||
mindmap_task_finish_at = DateTimeField(null=True)
|
||||
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
@ -664,6 +709,7 @@ class Document(DataBaseModel):
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
@ -815,6 +861,7 @@ class UserCanvas(DataBaseModel):
|
||||
permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
|
||||
description = TextField(null=True, help_text="Canvas description")
|
||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
||||
canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True)
|
||||
dsl = JSONField(null=True, default={})
|
||||
|
||||
class Meta:
|
||||
@ -827,6 +874,7 @@ class CanvasTemplate(DataBaseModel):
|
||||
title = JSONField(null=True, default=dict, help_text="Canvas title")
|
||||
description = JSONField(null=True, default=dict, help_text="Canvas description")
|
||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
||||
canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True)
|
||||
dsl = JSONField(null=True, default={})
|
||||
|
||||
class Meta:
|
||||
@ -904,6 +952,32 @@ class Search(DataBaseModel):
|
||||
db_table = "search"
|
||||
|
||||
|
||||
class PipelineOperationLog(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
document_id = CharField(max_length=32, index=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
pipeline_title = CharField(max_length=32, null=True, help_text="Pipeline title", index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="Parser ID", index=True)
|
||||
document_name = CharField(max_length=255, null=False, help_text="File name")
|
||||
document_suffix = CharField(max_length=255, null=False, help_text="File suffix")
|
||||
document_type = CharField(max_length=255, null=False, help_text="Document type")
|
||||
source_from = CharField(max_length=255, null=False, help_text="Source")
|
||||
progress = FloatField(default=0, index=True)
|
||||
progress_msg = TextField(null=True, help_text="process message", default="")
|
||||
process_begin_at = DateTimeField(null=True, index=True)
|
||||
process_duration = FloatField(default=0)
|
||||
dsl = JSONField(null=True, default=dict)
|
||||
task_type = CharField(max_length=32, null=False, default="")
|
||||
operation_status = CharField(max_length=32, null=False, help_text="Operation status")
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "pipeline_operation_log"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1020,7 +1094,6 @@ def migrate_db():
|
||||
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
|
||||
except Exception:
|
||||
@ -1029,4 +1102,44 @@ def migrate_db():
|
||||
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@ -32,11 +31,7 @@ from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_l
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||
return base64_encoded.decode('utf-8')
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
|
||||
def init_superuser():
|
||||
@ -144,8 +139,9 @@ def init_llm_factory():
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
doc_count = DocumentService.get_all_kb_doc_count()
|
||||
for kb_id in KnowledgebaseService.get_all_ids():
|
||||
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=DocumentService.get_kb_doc_count(kb_id))
|
||||
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=doc_count.get(kb_id, 0))
|
||||
|
||||
|
||||
|
||||
|
||||
327
api/db/joint_services/user_account_service.py
Normal file
327
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,327 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from api import settings
|
||||
from api.utils.api_utils import group_by
|
||||
from api.db import FileType, UserTenantRole, ActiveEnum
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
def create_new_user(user_info: dict) -> dict:
|
||||
"""
|
||||
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||
:param user_info: {
|
||||
"email": <example@example.com>,
|
||||
"nickname": <str, "name">,
|
||||
"password": <decrypted password>,
|
||||
"login_channel": <enum, "password">,
|
||||
"is_superuser": <bool, role == "admin">,
|
||||
}
|
||||
:return: {
|
||||
"success": <bool>,
|
||||
"user_info": <dict>, # if true, return user_info
|
||||
}
|
||||
"""
|
||||
# generate user_id and access_token for user
|
||||
user_id = uuid.uuid1().hex
|
||||
user_info['id'] = user_id
|
||||
user_info['access_token'] = uuid.uuid1().hex
|
||||
# construct tenant info
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": settings.RERANK_MDL,
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER,
|
||||
}
|
||||
# construct file folder info
|
||||
file_id = uuid.uuid1().hex
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": user_id,
|
||||
"created_by": user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
try:
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
return {"success": False}
|
||||
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
FileService.insert(file)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"user_info": user_info,
|
||||
}
|
||||
|
||||
except Exception as create_error:
|
||||
logging.exception(create_error)
|
||||
# rollback
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
TenantLLMService.delete_by_tenant_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
FileService.delete_by_id(file["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# delete user row finally
|
||||
try:
|
||||
UserService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# reraise
|
||||
raise create_error
|
||||
|
||||
|
||||
def delete_user_data(user_id: str) -> dict:
|
||||
# use user_id to delete
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
if not usr:
|
||||
return {"success": False, "message": f"{user_id} can't be found."}
|
||||
# check is inactive and not admin
|
||||
if usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
|
||||
if usr.is_superuser:
|
||||
return {"success": False, "message": "Can't delete the super user."}
|
||||
# tenant info
|
||||
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
|
||||
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
|
||||
|
||||
done_msg = ''
|
||||
try:
|
||||
# step1. delete owned tenant info
|
||||
if owned_tenant:
|
||||
done_msg += "Start to delete owned tenant.\n"
|
||||
tenant_id = owned_tenant[0]["tenant_id"]
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||
# step1.1 delete knowledgebase related file and info
|
||||
if kb_ids:
|
||||
# step1.1.1 delete files in storage, remove bucket
|
||||
for kb_id in kb_ids:
|
||||
if STORAGE_IMPL.bucket_exists(kb_id):
|
||||
STORAGE_IMPL.remove_bucket(kb_id)
|
||||
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
|
||||
# step1.1.2 delete file and document info in db
|
||||
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
|
||||
if doc_ids:
|
||||
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {doc_delete_res} document records.\n"
|
||||
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {task_delete_res} task records.\n"
|
||||
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
|
||||
if file_ids:
|
||||
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
if doc_ids or file_ids:
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[i["id"] for i in doc_ids],
|
||||
[f["id"] for f in file_ids]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step1.1.3 delete chunk in es
|
||||
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||
search.index_name(tenant_id), kb_ids)
|
||||
done_msg += f"- Deleted {r} chunk records.\n"
|
||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
||||
# step1.1.4 delete agents
|
||||
agent_delete_res = delete_user_agents(usr.id)
|
||||
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||
# step1.1.5 delete dialogs
|
||||
dialog_delete_res = delete_user_dialogs(usr.id)
|
||||
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
|
||||
# step1.1.6 delete mcp server
|
||||
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
|
||||
# step1.1.7 delete search
|
||||
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {search_delete_res} search records.\n"
|
||||
# step1.2 delete tenant_llm and tenant_langfuse
|
||||
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||
# step1.3 delete own tenant
|
||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||
# step2 delete user-tenant relation
|
||||
if tenants:
|
||||
# step2.1 delete docs and files in joined team
|
||||
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
|
||||
if joined_tenants:
|
||||
done_msg += "Start to delete data in joined tenants.\n"
|
||||
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
|
||||
if created_documents:
|
||||
# step2.1.1 delete files
|
||||
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
|
||||
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
|
||||
if created_files:
|
||||
# step2.1.1.1 delete file in storage
|
||||
for f in created_files:
|
||||
STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
|
||||
# step2.1.1.2 delete file record
|
||||
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
# step2.1.2 delete document-file relation record
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[d['id'] for d in created_documents],
|
||||
[f.id for f in created_files]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step2.1.3 delete chunks
|
||||
doc_groups = group_by(created_documents, "tenant_id")
|
||||
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
|
||||
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
|
||||
chunk_delete_res = 0
|
||||
kb_doc_info = {}
|
||||
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||
for _kb_id, docs in kb_doc.items():
|
||||
chunk_delete_res += settings.docStoreConn.delete(
|
||||
{"doc_id": [d["id"] for d in docs]},
|
||||
search.index_name(_tenant_id), _kb_id
|
||||
)
|
||||
# record doc info
|
||||
if _kb_id in kb_doc_info.keys():
|
||||
kb_doc_info[_kb_id]['doc_num'] += 1
|
||||
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
|
||||
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
|
||||
else:
|
||||
kb_doc_info[_kb_id] = {
|
||||
'doc_num': 1,
|
||||
'token_num': sum([d["token_num"] for d in docs]),
|
||||
'chunk_num': sum([d["chunk_num"] for d in docs])
|
||||
}
|
||||
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
|
||||
# step2.1.4 delete tasks
|
||||
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {task_delete_res} tasks.\n"
|
||||
# step2.1.5 delete document record
|
||||
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||
# step2.1.6 update knowledge base doc&chunk&token cnt
|
||||
for kb_id, doc_num in kb_doc_info.items():
|
||||
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||
|
||||
# step2.2 delete relation
|
||||
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
|
||||
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
|
||||
# step3 finally delete user
|
||||
user_delete_res = UserService.delete_by_id(usr.id)
|
||||
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
|
||||
|
||||
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
||||
|
||||
|
||||
def delete_user_agents(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"agents_deleted_count": 1,
|
||||
"version_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
agents_deleted_count, agents_version_deleted_count = 0, 0
|
||||
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
|
||||
if user_agents:
|
||||
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
|
||||
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
|
||||
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
|
||||
return {
|
||||
"agents_deleted_count": agents_deleted_count,
|
||||
"version_deleted_count": agents_version_deleted_count
|
||||
}
|
||||
|
||||
|
||||
def delete_user_dialogs(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"dialogs_deleted_count": 1,
|
||||
"conversations_deleted_count": 1,
|
||||
"api_token_deleted_count": 2,
|
||||
"api4conversation_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
|
||||
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
|
||||
if user_dialogs:
|
||||
# delete conversation
|
||||
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
|
||||
# delete api token
|
||||
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
|
||||
# delete api for conversation
|
||||
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
# delete dialog at last
|
||||
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
|
||||
return {
|
||||
"dialogs_deleted_count": dialog_deleted_count,
|
||||
"conversations_deleted_count": conversations_deleted_count,
|
||||
"api_token_deleted_count": api_token_deleted_count,
|
||||
"api4conversation_deleted_count": api4conversation_deleted_count
|
||||
}
|
||||
@ -19,7 +19,7 @@ from pathlib import PurePath
|
||||
from .user_service import UserService as UserService
|
||||
|
||||
|
||||
def split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||
def _split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||
"""
|
||||
Splits a filename into main part and counter (if present in parentheses).
|
||||
|
||||
@ -87,7 +87,7 @@ def duplicate_name(query_func, **kwargs) -> str:
|
||||
stem = path.stem
|
||||
suffix = path.suffix
|
||||
|
||||
main_part, counter = split_name_counter(stem)
|
||||
main_part, counter = _split_name_counter(stem)
|
||||
counter = counter + 1 if counter else 1
|
||||
|
||||
new_name = f"{main_part}({counter}){suffix}"
|
||||
|
||||
@ -35,6 +35,11 @@ class APITokenService(CommonService):
|
||||
cls.model.token == token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
|
||||
class API4ConversationService(CommonService):
|
||||
model = API4Conversation
|
||||
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
|
||||
cls.model.create_date <= to_date,
|
||||
cls.model.source == source
|
||||
).group_by(cls.model.create_date.truncate("day")).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_dialog_ids(cls, dialog_ids):
|
||||
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()
|
||||
|
||||
@ -18,7 +18,7 @@ import logging
|
||||
import time
|
||||
from uuid import uuid4
|
||||
from agent.canvas import Canvas
|
||||
from api.db import TenantPermission
|
||||
from api.db import CanvasCategory, TenantPermission
|
||||
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -31,6 +31,12 @@ from peewee import fn
|
||||
class CanvasTemplateService(CommonService):
|
||||
model = CanvasTemplate
|
||||
|
||||
class DataFlowTemplateService(CommonService):
|
||||
"""
|
||||
Alias of CanvasTemplateService
|
||||
"""
|
||||
model = CanvasTemplate
|
||||
|
||||
|
||||
class UserCanvasService(CommonService):
|
||||
model = UserCanvas
|
||||
@ -38,13 +44,14 @@ class UserCanvasService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, tenant_id,
|
||||
page_number, items_per_page, orderby, desc, id, title):
|
||||
page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
|
||||
agents = cls.model.select()
|
||||
if id:
|
||||
agents = agents.where(cls.model.id == id)
|
||||
if title:
|
||||
agents = agents.where(cls.model.title == title)
|
||||
agents = agents.where(cls.model.user_id == tenant_id)
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
if desc:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
@ -56,7 +63,38 @@ class UserCanvasService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, pid):
|
||||
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
# find team agents and owned agents
|
||||
agents = cls.model.select(*fields).where(
|
||||
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time, asc
|
||||
agents.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
ag_batch = agents.offset(offset).limit(limit)
|
||||
_temp = list(ag_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_canvas_id(cls, pid):
|
||||
try:
|
||||
|
||||
fields = [
|
||||
@ -71,6 +109,7 @@ class UserCanvasService(CommonService):
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date,
|
||||
cls.model.canvas_category,
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
]
|
||||
@ -87,7 +126,7 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page,
|
||||
orderby, desc, keywords,
|
||||
orderby, desc, keywords, canvas_category=None
|
||||
):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
@ -96,36 +135,41 @@ class UserCanvasService(CommonService):
|
||||
cls.model.dsl,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.user_id.alias("tenant_id"),
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
cls.model.update_time
|
||||
cls.model.update_time,
|
||||
cls.model.canvas_category,
|
||||
]
|
||||
if keywords:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id)),
|
||||
(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
cls.model.user_id.in_(joined_tenant_ids),
|
||||
fn.LOWER(cls.model.title).contains(keywords.lower())
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||
#(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id))
|
||||
cls.model.user_id.in_(joined_tenant_ids)
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
)
|
||||
if canvas_category:
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
if desc:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
count = agents.count()
|
||||
agents = agents.paginate(page_number, items_per_page)
|
||||
if page_number and items_per_page:
|
||||
agents = agents.paginate(page_number, items_per_page)
|
||||
return list(agents.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, canvas_id, tenant_id):
|
||||
from api.db.services.user_service import UserTenantService
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return False
|
||||
|
||||
|
||||
@ -14,12 +14,24 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import peewee
|
||||
from peewee import InterfaceError, OperationalError
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
def retry_db_operation(func):
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||
reraise=True,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class CommonService:
|
||||
"""Base service class that provides common database operations.
|
||||
@ -202,6 +214,7 @@ class CommonService:
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@retry_db_operation
|
||||
def update_by_id(cls, pid, data):
|
||||
# Update a single record by ID
|
||||
# Args:
|
||||
|
||||
@ -23,7 +23,7 @@ from api.db.services.dialog_service import DialogService, chat
|
||||
from api.utils import get_uuid
|
||||
import json
|
||||
|
||||
from rag.prompts import chunks_format
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
@ -48,6 +48,21 @@ class ConversationService(CommonService):
|
||||
|
||||
return list(sessions.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
|
||||
sessions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
s_batch = sessions.offset(offset).limit(limit)
|
||||
_temp = list(s_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def structure_answer(conv, ans, message_id, session_id):
|
||||
reference = ans["reference"]
|
||||
|
||||
@ -21,11 +21,9 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import trio
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
|
||||
from agentic_reasoning import DeepResearcher
|
||||
from api import settings
|
||||
from api.db import LLMType, ParserType, StatusEnum
|
||||
@ -41,8 +39,8 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
||||
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from rag.utils import num_tokens_from_string, rmSpace
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
@ -161,6 +159,22 @@ class DialogService(CommonService):
|
||||
|
||||
return list(dialogs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
dialogs.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
d_batch = dialogs.offset(offset).limit(limit)
|
||||
_temp = list(d_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
@ -178,7 +192,7 @@ def chat_solo(dialog, messages, stream=True):
|
||||
delta_ans = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans) :]
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
@ -255,6 +269,23 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
return answer, idx
|
||||
|
||||
|
||||
def convert_conditions(metadata_condition):
|
||||
if metadata_condition is None:
|
||||
metadata_condition = {}
|
||||
op_mapping = {
|
||||
"is": "=",
|
||||
"not is": "≠"
|
||||
}
|
||||
return [
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict]):
|
||||
doc_ids = set([])
|
||||
|
||||
@ -269,19 +300,19 @@ def meta_filter(metas: dict, filters: list[dict]):
|
||||
value = str(value)
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
try:
|
||||
if all(conds):
|
||||
ids.extend(docids)
|
||||
@ -350,7 +381,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
@ -441,7 +472,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
@ -452,7 +484,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
retrieval_ts = timer()
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||
"audio_binary": tts(tts_mdl, empty_res)}
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
@ -550,7 +583,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
if langfuse_tracer:
|
||||
langfuse_generation = langfuse_tracer.start_generation(
|
||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
|
||||
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -560,12 +594,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if thought:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans) :]
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans) :]
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
@ -578,7 +612,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
yield res
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
@ -615,6 +649,13 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
@ -654,7 +695,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
# compose Markdown table
|
||||
columns = (
|
||||
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
"|" + "|".join(
|
||||
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and docid_idx else "|")
|
||||
)
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
@ -731,7 +774,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
doc_ids = None
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
question = question,
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
kb_ids=kb_ids,
|
||||
@ -753,7 +796,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal knowledges, kbinfos, sys_prompt
|
||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
|
||||
embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
@ -821,4 +865,4 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||
return mind_map.output
|
||||
return mind_map.output
|
||||
|
||||
@ -24,12 +24,13 @@ from io import BytesIO
|
||||
|
||||
import trio
|
||||
import xxhash
|
||||
from peewee import fn
|
||||
from peewee import fn, Case, JOIN
|
||||
|
||||
from api import settings
|
||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File
|
||||
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole, CanvasCategory
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||
User
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -51,6 +52,7 @@ class DocumentService(CommonService):
|
||||
cls.model.thumbnail,
|
||||
cls.model.kb_id,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.parser_config,
|
||||
cls.model.source_type,
|
||||
cls.model.type,
|
||||
@ -79,7 +81,10 @@ class DocumentService(CommonService):
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name):
|
||||
fields = cls.get_cls_model_fields()
|
||||
docs = cls.model.select(*fields).join(File2Document, on = (File2Document.document_id == cls.model.id)).join(File, on = (File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||
.join(File, on = (File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
if id:
|
||||
docs = docs.where(
|
||||
cls.model.id == id)
|
||||
@ -117,12 +122,22 @@ class DocumentService(CommonService):
|
||||
orderby, desc, keywords, run_status, types, suffix):
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
@ -228,6 +243,46 @@ class DocumentService(CommonService):
|
||||
|
||||
return int(query.scalar()) or 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
|
||||
fields = [cls.model.id]
|
||||
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_docs_by_creator_id(cls, creator_id):
|
||||
fields = [
|
||||
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||
]
|
||||
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.created_by == creator_id
|
||||
)
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, doc):
|
||||
@ -330,8 +385,7 @@ class DocumentService(CommonService):
|
||||
process_duration=cls.model.process_duration + duration).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
logging.warning("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num +
|
||||
token_num,
|
||||
@ -597,6 +651,22 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def update_progress(cls):
|
||||
docs = cls.get_unfinished_docs()
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress_immediately(cls, docs:list[dict]):
|
||||
if not docs:
|
||||
return
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def _sync_progress(cls, docs:list[dict]):
|
||||
for d in docs:
|
||||
try:
|
||||
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
|
||||
@ -606,8 +676,6 @@ class DocumentService(CommonService):
|
||||
prg = 0
|
||||
finished = True
|
||||
bad = 0
|
||||
has_raptor = False
|
||||
has_graphrag = False
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
priority = 0
|
||||
@ -619,24 +687,14 @@ class DocumentService(CommonService):
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
if t.progress_msg.strip():
|
||||
msg.append(t.progress_msg)
|
||||
if t.task_type == "raptor":
|
||||
has_raptor = True
|
||||
elif t.task_type == "graphrag":
|
||||
has_graphrag = True
|
||||
priority = max(priority, t.priority)
|
||||
prg /= len(tsks)
|
||||
if finished and bad:
|
||||
prg = -1
|
||||
status = TaskStatus.FAIL.value
|
||||
elif finished:
|
||||
if (d["parser_config"].get("raptor") or {}).get("use_raptor") and not has_raptor:
|
||||
queue_raptor_o_graphrag_tasks(d, "raptor", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif (d["parser_config"].get("graphrag") or {}).get("use_graphrag") and not has_graphrag:
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
else:
|
||||
status = TaskStatus.DONE.value
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
@ -648,7 +706,7 @@ class DocumentService(CommonService):
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
info["progress_msg"] = msg
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
@ -660,8 +718,16 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_doc_count(cls, kb_id):
|
||||
return len(cls.model.select(cls.model.id).where(
|
||||
cls.model.kb_id == kb_id).dicts())
|
||||
return cls.model.select().where(cls.model.kb_id == kb_id).count()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_doc_count(cls):
|
||||
result = {}
|
||||
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
|
||||
for row in rows:
|
||||
result[row.kb_id] = row.count
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -674,7 +740,58 @@ class DocumentService(CommonService):
|
||||
return False
|
||||
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
||||
# cancelled: run == "2" but progress can vary
|
||||
cancelled = (
|
||||
cls.model.select(fn.COUNT(1))
|
||||
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||
.scalar()
|
||||
)
|
||||
|
||||
row = (
|
||||
cls.model.select(
|
||||
# finished: progress == 1
|
||||
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
|
||||
|
||||
# failed: progress == -1
|
||||
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
|
||||
|
||||
# processing: 0 <= progress < 1
|
||||
fn.COALESCE(
|
||||
fn.SUM(
|
||||
Case(
|
||||
None,
|
||||
[
|
||||
(((cls.model.progress == 0) | ((cls.model.progress > 0) & (cls.model.progress < 1))), 1),
|
||||
],
|
||||
0,
|
||||
)
|
||||
),
|
||||
0,
|
||||
).alias("processing"),
|
||||
)
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id)
|
||||
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
|
||||
)
|
||||
.dicts()
|
||||
.get()
|
||||
)
|
||||
|
||||
return {
|
||||
"processing": int(row["processing"]),
|
||||
"finished": int(row["finished"]),
|
||||
"failed": int(row["failed"]),
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||
"""
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
@ -684,11 +801,12 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
@ -697,11 +815,18 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
hasher.update(ty.encode("utf-8"))
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
|
||||
if ty in ["graphrag", "raptor", "mindmap"]:
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
|
||||
def get_queue_length(priority):
|
||||
group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
|
||||
if not group_info:
|
||||
return 0
|
||||
return int(group_info.get("lag", 0) or 0)
|
||||
|
||||
|
||||
@ -847,3 +972,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
return [d["id"] for d, _ in files]
|
||||
|
||||
|
||||
@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
|
||||
objs = cls.model.select().where(cls.model.document_id == document_id)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_document_ids(cls, document_ids):
|
||||
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
|
||||
return list(objs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, obj):
|
||||
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
|
||||
def delete_by_file_id(cls, file_id):
|
||||
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
|
||||
if not document_ids:
|
||||
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
|
||||
elif not file_ids:
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_id(cls, doc_id):
|
||||
|
||||
@ -161,6 +161,23 @@ class FileService(CommonService):
|
||||
result_ids.append(folder_id)
|
||||
return result_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_file_ids_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
files.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
file_batch = files.offset(offset).limit(limit)
|
||||
_temp = list(file_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_folder(cls, file, parent_id, name, count):
|
||||
@ -440,6 +457,7 @@ class FileService(CommonService):
|
||||
"id": doc_id,
|
||||
"kb_id": kb.id,
|
||||
"parser_id": self.get_parser(filetype, filename, kb.parser_id),
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": user_id,
|
||||
"type": filetype,
|
||||
@ -495,7 +513,7 @@ class FileService(CommonService):
|
||||
return ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
return ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
if re.search(r"\.(msg|eml)$", filename):
|
||||
return ParserType.EMAIL.value
|
||||
return default
|
||||
|
||||
|
||||
@ -15,10 +15,10 @@
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import fn
|
||||
from peewee import fn, JOIN
|
||||
|
||||
from api.db import StatusEnum, TenantPermission
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
|
||||
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
return list(kbs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted kb, be cautious.
|
||||
fields = [
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.status,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date
|
||||
]
|
||||
# find team kb and owned kb
|
||||
kbs = cls.model.select(*fields).where(
|
||||
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time asc
|
||||
kbs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later.
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
kb_batch = kbs.offset(offset).limit(limit)
|
||||
_temp = list(kb_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_ids(cls, tenant_id):
|
||||
@ -225,20 +260,29 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
UserCanvas.title.alias("pipeline_name"),
|
||||
UserCanvas.avatar.alias("pipeline_avatar"),
|
||||
cls.model.parser_config,
|
||||
cls.model.pagerank,
|
||||
cls.model.graphrag_task_id,
|
||||
cls.model.graphrag_task_finish_at,
|
||||
cls.model.raptor_task_id,
|
||||
cls.model.raptor_task_finish_at,
|
||||
cls.model.mindmap_task_id,
|
||||
cls.model.mindmap_task_finish_at,
|
||||
cls.model.create_time,
|
||||
cls.model.update_time
|
||||
]
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=(
|
||||
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
||||
kbs = cls.model.select(*fields)\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
).dicts()
|
||||
if not kbs:
|
||||
return
|
||||
d = kbs[0].to_dict()
|
||||
return d
|
||||
return kbs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -436,3 +480,17 @@ class KnowledgebaseService(CommonService):
|
||||
else:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
|
||||
kb_row = cls.model.get_by_id(kb_id)
|
||||
if not kb_row:
|
||||
raise RuntimeError(f"kb_id {kb_id} does not exist")
|
||||
update_dict = {
|
||||
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
|
||||
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
|
||||
'token_num': kb_row.token_num - doc_num_info['token_num'],
|
||||
'update_time': current_timestamp(),
|
||||
'update_date': datetime_format(datetime.now())
|
||||
}
|
||||
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||
|
||||
@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_ty_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@classmethod
|
||||
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
||||
langfuse_keys["update_time"] = current_timestamp()
|
||||
|
||||
@ -84,3 +84,8 @@ class MCPServerService(CommonService):
|
||||
return bool(mcp_server), mcp_server
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
263
api/db/services/pipeline_operation_log_service.py
Normal file
263
api/db/services/pipeline_operation_log_service.py
Normal file
@ -0,0 +1,263 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType
|
||||
from api.db.db_models import DB, Document, PipelineOperationLog
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
|
||||
class PipelineOperationLogService(CommonService):
|
||||
model = PipelineOperationLog
|
||||
|
||||
@classmethod
|
||||
def get_file_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.document_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.pipeline_title,
|
||||
cls.model.parser_id,
|
||||
cls.model.document_name,
|
||||
cls.model.document_suffix,
|
||||
cls.model.document_type,
|
||||
cls.model.source_from,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.dsl,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_dataset_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def save(cls, **kwargs):
|
||||
"""
|
||||
wrap this function in a transaction
|
||||
"""
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: str = "{}"):
|
||||
referred_document_id = document_id
|
||||
|
||||
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
|
||||
referred_document_id = fake_document_ids[0]
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
DocumentService.update_progress_immediately([document.to_dict()])
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
if document.progress not in [1, -1]:
|
||||
return
|
||||
operation_status = document.run
|
||||
|
||||
if pipeline_id:
|
||||
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||
tenant_id = user_pipeline.user_id
|
||||
title = user_pipeline.title
|
||||
avatar = user_pipeline.avatar
|
||||
else:
|
||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
||||
|
||||
tenant_id = kb_info.tenant_id
|
||||
title = document.parser_id
|
||||
avatar = document.thumbnail
|
||||
|
||||
if task_type not in VALID_PIPELINE_TASK_TYPES:
|
||||
raise ValueError(f"Invalid task type: {task_type}")
|
||||
|
||||
if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
finish_at = document.process_begin_at + timedelta(seconds=document.process_duration)
|
||||
if task_type == PipelineTaskType.GRAPH_RAG:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"graphrag_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.RAPTOR:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"raptor_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.MINDMAP:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"mindmap_task_finish_at": finish_at},
|
||||
)
|
||||
|
||||
log = dict(
|
||||
id=get_uuid(),
|
||||
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
|
||||
tenant_id=tenant_id,
|
||||
kb_id=document.kb_id,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_title=title,
|
||||
parser_id=document.parser_id,
|
||||
document_name=document.name,
|
||||
document_suffix=document.suffix,
|
||||
document_type=document.type,
|
||||
source_from="", # TODO: add in the future
|
||||
progress=document.progress,
|
||||
progress_msg=document.progress_msg,
|
||||
process_begin_at=document.process_begin_at,
|
||||
process_duration=document.process_duration,
|
||||
dsl=json.loads(dsl),
|
||||
task_type=task_type,
|
||||
operation_status=operation_status,
|
||||
avatar=avatar,
|
||||
)
|
||||
log["create_time"] = current_timestamp()
|
||||
log["create_date"] = datetime_format(datetime.now())
|
||||
log["update_time"] = current_timestamp()
|
||||
log["update_date"] = datetime_format(datetime.now())
|
||||
|
||||
with DB.atomic():
|
||||
obj = cls.save(**log)
|
||||
|
||||
limit = int(os.getenv("PIPELINE_OPERATION_LOG_LIMIT", 1000))
|
||||
total = cls.model.select().where(cls.model.kb_id == document.kb_id).count()
|
||||
|
||||
if total > limit:
|
||||
keep_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == document.kb_id).order_by(cls.model.create_time.desc()).limit(limit)]
|
||||
|
||||
deleted = cls.model.delete().where(cls.model.kb_id == document.kb_id, cls.model.id.not_in(keep_ids)).execute()
|
||||
logging.info(f"[PipelineOperationLogService] Cleaned {deleted} old logs, kept latest {limit} for {document.kb_id}")
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_file_logs_fields()
|
||||
if keywords:
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||
else:
|
||||
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||
|
||||
logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if types:
|
||||
logs = logs.where(cls.model.document_type.in_(types))
|
||||
if suffix:
|
||||
logs = logs.where(cls.model.document_suffix.in_(suffix))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_documents_info(cls, id):
|
||||
fields = [Document.id, Document.name, Document.progress, Document.kb_id]
|
||||
return (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.document_id == Document.id))
|
||||
.where(
|
||||
cls.model.id == id
|
||||
)
|
||||
.dicts()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_dataset_logs_fields()
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
@ -110,3 +110,8 @@ class SearchService(CommonService):
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@ -35,6 +35,8 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
|
||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
|
||||
|
||||
def trim_header_by_lines(text: str, max_length) -> str:
|
||||
# Trim header text to maximum length while preserving line breaks
|
||||
@ -54,15 +56,15 @@ def trim_header_by_lines(text: str, max_length) -> str:
|
||||
|
||||
class TaskService(CommonService):
|
||||
"""Service class for managing document processing tasks.
|
||||
|
||||
|
||||
This class extends CommonService to provide specialized functionality for document
|
||||
processing task management, including task creation, progress tracking, and chunk
|
||||
management. It handles various document types (PDF, Excel, etc.) and manages their
|
||||
processing lifecycle.
|
||||
|
||||
|
||||
The class implements a robust task queue system with retry mechanisms and progress
|
||||
tracking, supporting both synchronous and asynchronous task execution.
|
||||
|
||||
|
||||
Attributes:
|
||||
model: The Task model class for database operations.
|
||||
"""
|
||||
@ -70,20 +72,24 @@ class TaskService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_task(cls, task_id):
|
||||
def get_task(cls, task_id, doc_ids=[]):
|
||||
"""Retrieve detailed task information by task ID.
|
||||
|
||||
|
||||
This method fetches comprehensive task details including associated document,
|
||||
knowledge base, and tenant information. It also handles task retry logic and
|
||||
progress updates.
|
||||
|
||||
|
||||
Args:
|
||||
task_id (str): The unique identifier of the task to retrieve.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Task details dictionary containing all task information and related metadata.
|
||||
Returns None if task is not found or has exceeded retry limit.
|
||||
"""
|
||||
doc_id = cls.model.doc_id
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
|
||||
doc_id = doc_ids[0]
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.doc_id,
|
||||
@ -109,7 +115,7 @@ class TaskService(CommonService):
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(Document, on=(doc_id == Document.id))
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == task_id)
|
||||
@ -139,13 +145,13 @@ class TaskService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_tasks(cls, doc_id: str):
|
||||
"""Retrieve all tasks associated with a document.
|
||||
|
||||
|
||||
This method fetches all processing tasks for a given document, ordered by page
|
||||
number and creation time. It includes task progress and chunk information.
|
||||
|
||||
|
||||
Args:
|
||||
doc_id (str): The unique identifier of the document.
|
||||
|
||||
|
||||
Returns:
|
||||
list[dict]: List of task dictionaries containing task details.
|
||||
Returns None if no tasks are found.
|
||||
@ -170,10 +176,10 @@ class TaskService(CommonService):
|
||||
@DB.connection_context()
|
||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||
"""Update the chunk IDs associated with a task.
|
||||
|
||||
|
||||
This method updates the chunk_ids field of a task, which stores the IDs of
|
||||
processed document chunks in a space-separated string format.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The unique identifier of the task.
|
||||
chunk_ids (str): Space-separated string of chunk identifiers.
|
||||
@ -184,11 +190,11 @@ class TaskService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_ongoing_doc_name(cls):
|
||||
"""Get names of documents that are currently being processed.
|
||||
|
||||
|
||||
This method retrieves information about documents that are in the processing state,
|
||||
including their locations and associated IDs. It uses database locking to ensure
|
||||
thread safety when accessing the task information.
|
||||
|
||||
|
||||
Returns:
|
||||
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
|
||||
for documents currently being processed. Returns empty list if
|
||||
@ -238,14 +244,14 @@ class TaskService(CommonService):
|
||||
@DB.connection_context()
|
||||
def do_cancel(cls, id):
|
||||
"""Check if a task should be cancelled based on its document status.
|
||||
|
||||
|
||||
This method determines whether a task should be cancelled by checking the
|
||||
associated document's run status and progress. A task should be cancelled
|
||||
if its document is marked for cancellation or has negative progress.
|
||||
|
||||
|
||||
Args:
|
||||
id (str): The unique identifier of the task to check.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the task should be cancelled, False otherwise.
|
||||
"""
|
||||
@ -292,37 +298,45 @@ class TaskService(CommonService):
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
return
|
||||
else:
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
||||
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_doc_ids(cls, doc_ids):
|
||||
"""Delete task associated with a document."""
|
||||
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
"""Create and queue document processing tasks.
|
||||
|
||||
|
||||
This function creates processing tasks for a document based on its type and configuration.
|
||||
It handles different document types (PDF, Excel, etc.) differently and manages task
|
||||
chunking and configuration. It also implements task reuse optimization by checking
|
||||
for previously completed tasks.
|
||||
|
||||
|
||||
Args:
|
||||
doc (dict): Document dictionary containing metadata and configuration.
|
||||
bucket (str): Storage bucket name where the document is stored.
|
||||
name (str): File name of the document.
|
||||
priority (int, optional): Priority level for task queueing (default is 0).
|
||||
|
||||
|
||||
Note:
|
||||
- For PDF documents, tasks are created per page range based on configuration
|
||||
- For Excel documents, tasks are created per row range
|
||||
@ -330,7 +344,14 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
- Previous task chunks may be reused if available
|
||||
"""
|
||||
def new_task():
|
||||
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"progress": 0.0,
|
||||
"from_page": 0,
|
||||
"to_page": 100000000,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
parse_task_array = []
|
||||
|
||||
@ -343,7 +364,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
page_size = doc["parser_config"].get("task_page_size") or 12
|
||||
if doc["parser_id"] == "paper":
|
||||
page_size = doc["parser_config"].get("task_page_size") or 22
|
||||
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
|
||||
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc", True):
|
||||
page_size = 10 ** 9
|
||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
||||
for s, e in page_ranges:
|
||||
@ -410,19 +431,19 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
|
||||
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
||||
"""Attempt to reuse chunks from previous tasks for optimization.
|
||||
|
||||
|
||||
This function checks if chunks from previously completed tasks can be reused for
|
||||
the current task, which can significantly improve processing efficiency. It matches
|
||||
tasks based on page ranges and configuration digests.
|
||||
|
||||
|
||||
Args:
|
||||
task (dict): Current task dictionary to potentially reuse chunks for.
|
||||
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
|
||||
chunking_config (dict): Configuration dictionary for chunk processing.
|
||||
|
||||
|
||||
Returns:
|
||||
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
|
||||
|
||||
|
||||
Note:
|
||||
Chunks can only be reused if:
|
||||
- A previous task exists with matching page range and configuration digest
|
||||
@ -470,3 +491,32 @@ def has_canceled(task_id):
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
|
||||
|
||||
task = dict(
|
||||
id=task_id,
|
||||
doc_id=doc_id,
|
||||
from_page=0,
|
||||
to_page=100000000,
|
||||
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||
priority=priority,
|
||||
begin_at=datetime.now(),
|
||||
)
|
||||
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
|
||||
TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
|
||||
DocumentService.begin2parse(doc_id)
|
||||
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||
|
||||
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
|
||||
task["tenant_id"] = tenant_id
|
||||
task["dataflow_id"] = flow_id
|
||||
task["file"] = file
|
||||
|
||||
if not REDIS_CONN.queue_product(
|
||||
get_svr_queue_name(priority), message=task
|
||||
):
|
||||
return False, "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
return True, ""
|
||||
|
||||
@ -209,6 +209,11 @@ class TenantLLMService(CommonService):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
|
||||
@ -24,7 +24,24 @@ class UserCanvasVersionService(CommonService):
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
|
||||
fields = [cls.model.id]
|
||||
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
|
||||
versions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
version_batch = versions.offset(offset).limit(limit)
|
||||
_temp = list(version_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_all_versions(cls, user_canvas_id):
|
||||
|
||||
@ -45,22 +45,22 @@ class UserService(CommonService):
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
if 'access_token' in kwargs:
|
||||
access_token = kwargs['access_token']
|
||||
|
||||
|
||||
# Reject empty, None, or whitespace-only access tokens
|
||||
if not access_token or not str(access_token).strip():
|
||||
logging.warning("UserService.query: Rejecting empty access_token query")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
|
||||
|
||||
|
||||
# Reject tokens that are too short (should be UUID, 32+ chars)
|
||||
if len(str(access_token).strip()) < 32:
|
||||
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
|
||||
|
||||
|
||||
# Reject tokens that start with "INVALID_" (from logout)
|
||||
if str(access_token).startswith("INVALID_"):
|
||||
logging.warning("UserService.query: Rejecting invalidated access_token")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
|
||||
|
||||
|
||||
# Call parent query method for valid requests
|
||||
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@ -100,6 +100,12 @@ class UserService(CommonService):
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user_by_email(cls, email):
|
||||
users = cls.model.select().where((cls.model.email == email))
|
||||
return list(users)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
@ -133,6 +139,17 @@ class UserService(CommonService):
|
||||
cls.model.update(user_dict).where(
|
||||
cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user_password(cls, user_id, new_password):
|
||||
with DB.atomic():
|
||||
update_dict = {
|
||||
"password": generate_password_hash(str(new_password)),
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": datetime_format(datetime.now())
|
||||
}
|
||||
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_admin(cls, user_id):
|
||||
@ -140,6 +157,12 @@ class UserService(CommonService):
|
||||
cls.model.id == user_id,
|
||||
cls.model.is_superuser == 1).count() > 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_users(cls):
|
||||
users = cls.model.select()
|
||||
return list(users)
|
||||
|
||||
|
||||
class TenantService(CommonService):
|
||||
"""Service class for managing tenant-related database operations.
|
||||
@ -265,6 +288,17 @@ class UserTenantService(CommonService):
|
||||
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_user_tenant_relation_by_user_id(cls, user_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.user_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.role
|
||||
]
|
||||
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_num_members(cls, user_id: str):
|
||||
|
||||
@ -41,7 +41,7 @@ from api import utils
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.init_data import init_web_data
|
||||
from api.versions import get_ragflow_version
|
||||
from api.utils import show_configs
|
||||
from api.utils.configs import show_configs
|
||||
from rag.settings import print_rag_settings
|
||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
@ -24,7 +24,7 @@ import rag.utils.es_conn
|
||||
import rag.utils.infinity_conn
|
||||
import rag.utils.opensearch_conn
|
||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||
from api.utils import decrypt_database_config, get_base_config
|
||||
from api.utils.configs import decrypt_database_config, get_base_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
@ -16,184 +16,15 @@
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
import requests
|
||||
import logging
|
||||
import copy
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
import importlib
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from filelock import FileLock
|
||||
from api.constants import SERVICE_CONF
|
||||
|
||||
from . import file_utils
|
||||
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
|
||||
def read_config(conf_name=SERVICE_CONF):
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
|
||||
# load local config file
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
global_config_path = conf_realpath(conf_name)
|
||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||
|
||||
if not isinstance(global_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||
|
||||
global_config.update(local_config)
|
||||
return global_config
|
||||
|
||||
|
||||
CONFIGS = read_config()
|
||||
|
||||
|
||||
def show_configs():
|
||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||
for k, v in CONFIGS.items():
|
||||
if isinstance(v, dict):
|
||||
if "password" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["password"] = "*" * 8
|
||||
if "access_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["access_key"] = "*" * 8
|
||||
if "secret_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret_key"] = "*" * 8
|
||||
if "secret" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret"] = "*" * 8
|
||||
if "sas_token" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["sas_token"] = "*" * 8
|
||||
if "oauth" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "client_secret" in val:
|
||||
val["client_secret"] = "*" * 8
|
||||
if "authentication" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "http_secret_key" in val:
|
||||
val["http_secret_key"] = "*" * 8
|
||||
msg += f"\n\t{k}: {v}"
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def get_base_config(key, default=None):
|
||||
if key is None:
|
||||
return None
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
return CONFIGS.get(key, default)
|
||||
|
||||
|
||||
use_deserialize_safe_module = get_base_config(
|
||||
'use_deserialize_safe_module', False)
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__,
|
||||
"data": data, "module": module}
|
||||
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def rag_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(
|
||||
src,
|
||||
indent=indent,
|
||||
cls=CustomJSONEncoder,
|
||||
with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook)
|
||||
from .common import string_to_bytes
|
||||
|
||||
|
||||
def current_timestamp():
|
||||
@ -215,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
return time_stamp
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(
|
||||
string_to_bytes(src) if isinstance(
|
||||
src, str) else src)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def get_lan_ip():
|
||||
if os.name != "nt":
|
||||
import fcntl
|
||||
@ -298,47 +90,6 @@ def from_dict_hook(in_dict: dict):
|
||||
return in_dict
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(
|
||||
importlib.import_module(
|
||||
module_fun[0]),
|
||||
module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(
|
||||
database=None, passwd_key="password", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(
|
||||
file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
@ -363,37 +114,6 @@ def elapsed2time(elapsed):
|
||||
return '%02d:%02d:%02d' % (hour, minuter, second)
|
||||
|
||||
|
||||
def decrypt(line):
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"private.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return cipher.decrypt(base64.b64decode(
|
||||
line), "Fail to decrypt password!").decode('utf-8')
|
||||
|
||||
|
||||
def decrypt2(crypt_text):
|
||||
from base64 import b64decode, b16decode
|
||||
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
||||
from Crypto.PublicKey import RSA
|
||||
decode_data = b64decode(crypt_text)
|
||||
if len(decode_data) == 127:
|
||||
hex_fixed = '00' + decode_data.hex()
|
||||
decode_data = b16decode(hex_fixed.upper())
|
||||
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"private.pem")
|
||||
pem = open(file_path).read()
|
||||
rsa_key = RSA.importKey(pem, "Welcome")
|
||||
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
||||
decrypt_text = cipher.decrypt(decode_data, None)
|
||||
return (b64decode(decrypt_text)).decode()
|
||||
|
||||
|
||||
def download_img(url):
|
||||
if not url:
|
||||
return ""
|
||||
@ -408,5 +128,5 @@ def delta_seconds(date_string: str):
|
||||
return (datetime.datetime.now() - dt).total_seconds()
|
||||
|
||||
|
||||
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
|
||||
@ -39,6 +39,7 @@ from flask import (
|
||||
make_response,
|
||||
send_file,
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
@ -48,14 +49,41 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api import settings
|
||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||
from api.db import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services import UserService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||
from api.utils import get_uuid
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
def serialize_for_json(obj):
|
||||
"""
|
||||
Recursively serialize objects to make them JSON serializable.
|
||||
Handles ModelMetaclass and other non-serializable objects.
|
||||
"""
|
||||
if hasattr(obj, '__dict__'):
|
||||
# For objects with __dict__, try to serialize their attributes
|
||||
try:
|
||||
return {key: serialize_for_json(value) for key, value in obj.__dict__.items()
|
||||
if not key.startswith('_')}
|
||||
except (AttributeError, TypeError):
|
||||
return str(obj)
|
||||
elif hasattr(obj, '__name__'):
|
||||
# For classes and metaclasses, return their name
|
||||
return f"<{obj.__module__}.{obj.__name__}>" if hasattr(obj, '__module__') else f"<{obj.__name__}>"
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [serialize_for_json(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {key: serialize_for_json(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, (str, int, float, bool)) or obj is None:
|
||||
return obj
|
||||
else:
|
||||
# Fallback: convert to string representation
|
||||
return str(obj)
|
||||
|
||||
def request(**kwargs):
|
||||
sess = requests.Session()
|
||||
@ -128,7 +156,11 @@ def server_error_response(e):
|
||||
except BaseException:
|
||||
pass
|
||||
if len(e.args) > 1:
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||
try:
|
||||
serialized_data = serialize_for_json(e.args[1])
|
||||
return get_json_result(code= settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
|
||||
except Exception:
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||
|
||||
@ -198,6 +230,18 @@ def not_allowed_parameters(*params):
|
||||
return decorator
|
||||
|
||||
|
||||
def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_localhost(ip):
|
||||
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||
|
||||
@ -292,6 +336,8 @@ def construct_error_response(e):
|
||||
def token_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
if os.environ.get("DISABLE_SDK"):
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_str = flask_request.headers.get("Authorization")
|
||||
if not authorization_str:
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
@ -613,6 +659,16 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
||||
return transformed_data
|
||||
|
||||
|
||||
def group_by(list_of_dict, key):
|
||||
res = {}
|
||||
for item in list_of_dict:
|
||||
if item[key] in res.keys():
|
||||
res[item[key]].append(item)
|
||||
else:
|
||||
res[item[key]] = [item]
|
||||
return res
|
||||
|
||||
|
||||
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
@ -649,7 +705,9 @@ TimeoutException = Union[Type[BaseException], BaseException]
|
||||
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
if isinstance(seconds, str):
|
||||
seconds = float(seconds)
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
@ -1,3 +1,56 @@
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import trio
|
||||
from rag.svr.task_executor import minio_limiter
|
||||
if not d.get("image"):
|
||||
return
|
||||
|
||||
with BytesIO() as output_buffer:
|
||||
if isinstance(d["image"], bytes):
|
||||
output_buffer.write(d["image"])
|
||||
output_buffer.seek(0)
|
||||
else:
|
||||
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
|
||||
if d["image"].mode in ("RGBA", "P"):
|
||||
converted_image = d["image"].convert("RGB")
|
||||
d["image"] = converted_image
|
||||
try:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
except OSError as e:
|
||||
logging.warning(
|
||||
"Saving image exception, ignore: {}".format(str(e)))
|
||||
|
||||
async with minio_limiter:
|
||||
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue()))
|
||||
d["img_id"] = f"{bucket}-{objname}"
|
||||
if not isinstance(d["image"], bytes):
|
||||
d["image"].close()
|
||||
del d["image"] # Remove image reference
|
||||
|
||||
|
||||
def id2image(image_id:str|None, storage_get_func: partial):
|
||||
if not image_id:
|
||||
return
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return
|
||||
bkt, nm = image_id.split("-")
|
||||
try:
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
23
api/utils/common.py
Normal file
23
api/utils/common.py
Normal file
@ -0,0 +1,23 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
179
api/utils/configs.py
Normal file
179
api/utils/configs.py
Normal file
@ -0,0 +1,179 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import io
|
||||
import copy
|
||||
import logging
|
||||
import base64
|
||||
import pickle
|
||||
import importlib
|
||||
|
||||
from api.utils import file_utils
|
||||
from filelock import FileLock
|
||||
from api.utils.common import bytes_to_string, string_to_bytes
|
||||
from api.constants import SERVICE_CONF
|
||||
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
|
||||
def read_config(conf_name=SERVICE_CONF):
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
|
||||
# load local config file
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
global_config_path = conf_realpath(conf_name)
|
||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||
|
||||
if not isinstance(global_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||
|
||||
global_config.update(local_config)
|
||||
return global_config
|
||||
|
||||
|
||||
CONFIGS = read_config()
|
||||
|
||||
|
||||
def show_configs():
|
||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||
for k, v in CONFIGS.items():
|
||||
if isinstance(v, dict):
|
||||
if "password" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["password"] = "*" * 8
|
||||
if "access_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["access_key"] = "*" * 8
|
||||
if "secret_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret_key"] = "*" * 8
|
||||
if "secret" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret"] = "*" * 8
|
||||
if "sas_token" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["sas_token"] = "*" * 8
|
||||
if "oauth" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "client_secret" in val:
|
||||
val["client_secret"] = "*" * 8
|
||||
if "authentication" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "http_secret_key" in val:
|
||||
val["http_secret_key"] = "*" * 8
|
||||
msg += f"\n\t{k}: {v}"
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def get_base_config(key, default=None):
|
||||
if key is None:
|
||||
return None
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
return CONFIGS.get(key, default)
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(
|
||||
importlib.import_module(
|
||||
module_fun[0]),
|
||||
module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(
|
||||
database=None, passwd_key="password", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(
|
||||
file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(
|
||||
string_to_bytes(src) if isinstance(
|
||||
src, str) else src)
|
||||
use_deserialize_safe_module = get_base_config(
|
||||
'use_deserialize_safe_module', False)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
64
api/utils/crypt.py
Normal file
64
api/utils/crypt.py
Normal file
@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from api.utils import file_utils
|
||||
|
||||
|
||||
def crypt(line):
|
||||
"""
|
||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||
"""
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||
encrypted_password = cipher.encrypt(password_base64.encode())
|
||||
return base64.b64encode(encrypted_password).decode('utf-8')
|
||||
|
||||
|
||||
def decrypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||
|
||||
|
||||
def decrypt2(crypt_text):
|
||||
from base64 import b64decode, b16decode
|
||||
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
||||
from Crypto.PublicKey import RSA
|
||||
decode_data = b64decode(crypt_text)
|
||||
if len(decode_data) == 127:
|
||||
hex_fixed = '00' + decode_data.hex()
|
||||
decode_data = b16decode(hex_fixed.upper())
|
||||
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||
pem = open(file_path).read()
|
||||
rsa_key = RSA.importKey(pem, "Welcome")
|
||||
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
||||
decrypt_text = cipher.decrypt(decode_data, None)
|
||||
return (b64decode(decrypt_text)).decode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
passwd = crypt(sys.argv[1])
|
||||
print(passwd)
|
||||
print(decrypt(passwd))
|
||||
@ -155,7 +155,7 @@ def filename_type(filename):
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
return FileType.PDF.value
|
||||
|
||||
if re.match(r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
|
||||
if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
|
||||
return FileType.DOC.value
|
||||
|
||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename):
|
||||
|
||||
104
api/utils/health.py
Normal file
104
api/utils/health.py
Normal file
@ -0,0 +1,104 @@
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api import settings
|
||||
from api.db.db_models import DB
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
def _ok_nok(ok: bool) -> str:
|
||||
return "ok" if ok else "nok"
|
||||
|
||||
|
||||
def check_db() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
# lightweight probe; works for MySQL/Postgres
|
||||
DB.execute_sql("SELECT 1")
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_redis() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
ok = bool(REDIS_CONN.health())
|
||||
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_doc_engine() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
meta = settings.docStoreConn.health()
|
||||
# treat any successful call as ok
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_storage() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
STORAGE_IMPL.health()
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_chat() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
cfg = getattr(settings, "CHAT_CFG", None)
|
||||
ok = bool(cfg and cfg.get("factory"))
|
||||
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def run_health_checks() -> tuple[dict, bool]:
|
||||
result: dict[str, str | dict] = {}
|
||||
|
||||
db_ok, db_meta = check_db()
|
||||
chat_ok, chat_meta = check_chat()
|
||||
|
||||
result["db"] = _ok_nok(db_ok)
|
||||
if not db_ok:
|
||||
result.setdefault("_meta", {})["db"] = db_meta
|
||||
|
||||
result["chat"] = _ok_nok(chat_ok)
|
||||
if not chat_ok:
|
||||
result.setdefault("_meta", {})["chat"] = chat_meta
|
||||
|
||||
# Optional probes (do not change minimal contract but exposed for observability)
|
||||
try:
|
||||
redis_ok, redis_meta = check_redis()
|
||||
result["redis"] = _ok_nok(redis_ok)
|
||||
if not redis_ok:
|
||||
result.setdefault("_meta", {})["redis"] = redis_meta
|
||||
except Exception:
|
||||
result["redis"] = "nok"
|
||||
|
||||
try:
|
||||
doc_ok, doc_meta = check_doc_engine()
|
||||
result["doc_engine"] = _ok_nok(doc_ok)
|
||||
if not doc_ok:
|
||||
result.setdefault("_meta", {})["doc_engine"] = doc_meta
|
||||
except Exception:
|
||||
result["doc_engine"] = "nok"
|
||||
|
||||
try:
|
||||
sto_ok, sto_meta = check_storage()
|
||||
result["storage"] = _ok_nok(sto_ok)
|
||||
if not sto_ok:
|
||||
result.setdefault("_meta", {})["storage"] = sto_meta
|
||||
except Exception:
|
||||
result["storage"] = "nok"
|
||||
|
||||
all_ok = (result.get("db") == "ok") and (result.get("chat") == "ok")
|
||||
result["status"] = "ok" if all_ok else "nok"
|
||||
return result, all_ok
|
||||
|
||||
|
||||
107
api/utils/health_utils.py
Normal file
107
api/utils/health_utils.py
Normal file
@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api import settings
|
||||
from api.db.db_models import DB
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
def _ok_nok(ok: bool) -> str:
|
||||
return "ok" if ok else "nok"
|
||||
|
||||
|
||||
def check_db() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
# lightweight probe; works for MySQL/Postgres
|
||||
DB.execute_sql("SELECT 1")
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_redis() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
ok = bool(REDIS_CONN.health())
|
||||
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_doc_engine() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
meta = settings.docStoreConn.health()
|
||||
# treat any successful call as ok
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_storage() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
STORAGE_IMPL.health()
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
|
||||
|
||||
def run_health_checks() -> tuple[dict, bool]:
|
||||
result: dict[str, str | dict] = {}
|
||||
|
||||
db_ok, db_meta = check_db()
|
||||
result["db"] = _ok_nok(db_ok)
|
||||
if not db_ok:
|
||||
result.setdefault("_meta", {})["db"] = db_meta
|
||||
|
||||
try:
|
||||
redis_ok, redis_meta = check_redis()
|
||||
result["redis"] = _ok_nok(redis_ok)
|
||||
if not redis_ok:
|
||||
result.setdefault("_meta", {})["redis"] = redis_meta
|
||||
except Exception:
|
||||
result["redis"] = "nok"
|
||||
|
||||
try:
|
||||
doc_ok, doc_meta = check_doc_engine()
|
||||
result["doc_engine"] = _ok_nok(doc_ok)
|
||||
if not doc_ok:
|
||||
result.setdefault("_meta", {})["doc_engine"] = doc_meta
|
||||
except Exception:
|
||||
result["doc_engine"] = "nok"
|
||||
|
||||
try:
|
||||
sto_ok, sto_meta = check_storage()
|
||||
result["storage"] = _ok_nok(sto_ok)
|
||||
if not sto_ok:
|
||||
result.setdefault("_meta", {})["storage"] = sto_meta
|
||||
except Exception:
|
||||
result["storage"] = "nok"
|
||||
|
||||
|
||||
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok")
|
||||
result["status"] = "ok" if all_ok else "nok"
|
||||
return result, all_ok
|
||||
|
||||
|
||||
78
api/utils/json.py
Normal file
78
api/utils/json.py
Normal file
@ -0,0 +1,78 @@
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum, IntEnum
|
||||
from api.utils.common import string_to_bytes, bytes_to_string
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__,
|
||||
"data": data, "module": module}
|
||||
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(
|
||||
src,
|
||||
indent=indent,
|
||||
cls=CustomJSONEncoder,
|
||||
with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook)
|
||||
@ -1,40 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from api.utils import decrypt, file_utils
|
||||
|
||||
|
||||
def crypt(line):
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||
encrypted_password = cipher.encrypt(password_base64.encode())
|
||||
return base64.b64encode(encrypted_password).decode('utf-8')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
passwd = crypt(sys.argv[1])
|
||||
print(passwd)
|
||||
print(decrypt(passwd))
|
||||
19
chat_demo/index.html
Normal file
19
chat_demo/index.html
Normal file
@ -0,0 +1,19 @@
|
||||
<iframe src="http://localhost:9222/next-chats/widget?shared_id=9dcfc68696c611f0bb789b9b8b765d12&from=chat&auth=U4MDU3NzkwOTZjNzExZjBiYjc4OWI5Yj&mode=master&streaming=false"
|
||||
style="position:fixed;bottom:0;right:0;width:100px;height:100px;border:none;background:transparent;z-index:9999"
|
||||
frameborder="0" allow="microphone;camera"></iframe>
|
||||
<script>
|
||||
window.addEventListener('message',e=>{
|
||||
if(e.origin!=='http://localhost:9222')return;
|
||||
if(e.data.type==='CREATE_CHAT_WINDOW'){
|
||||
if(document.getElementById('chat-win'))return;
|
||||
const i=document.createElement('iframe');
|
||||
i.id='chat-win';i.src=e.data.src;
|
||||
i.style.cssText='position:fixed;bottom:104px;right:24px;width:380px;height:500px;border:none;background:transparent;z-index:9998;display:none';
|
||||
i.frameBorder='0';i.allow='microphone;camera';
|
||||
document.body.appendChild(i);
|
||||
}else if(e.data.type==='TOGGLE_CHAT'){
|
||||
const w=document.getElementById('chat-win');
|
||||
if(w)w.style.display=e.data.isOpen?'block':'none';
|
||||
}else if(e.data.type==='SCROLL_PASSTHROUGH')window.scrollBy(0,e.data.deltaY);
|
||||
});
|
||||
</script>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user