mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Compare commits
192 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3fd7db40ea | |||
| 5650442b0b | |||
| 5b013da4d6 | |||
| fe797bcc66 | |||
| 9542f4484c | |||
| 2452c5624f | |||
| a5c03ccd4c | |||
| d2213141e0 | |||
| 3da3260eb5 | |||
| 07f283b73e | |||
| 29509ff69d | |||
| 216f6495c4 | |||
| f60a249fe1 | |||
| 152072f900 | |||
| 80032b1fc0 | |||
| 5d55e6a049 | |||
| 418700b455 | |||
| eea6565472 | |||
| 3f21603558 | |||
| 3a739e3dd7 | |||
| 4ba1ba973a | |||
| e8b9871fb9 | |||
| e37b0d217d | |||
| 50e9df4c76 | |||
| b9a50ef4b8 | |||
| da11a20c92 | |||
| 955619c8ac | |||
| ad2e116367 | |||
| ccbd4365be | |||
| 9169643157 | |||
| 5cff780ec4 | |||
| ceb0419fe5 | |||
| 74ebc497c1 | |||
| 161cb08bbd | |||
| ff8702f7de | |||
| a973b9e01f | |||
| 5e19423d82 | |||
| 29f7f8b81e | |||
| 6012f376ca | |||
| 8468031e39 | |||
| aac460ad29 | |||
| 753c13d76f | |||
| 0cb588f7bf | |||
| ebdd71ce68 | |||
| 013856b604 | |||
| 61096596bc | |||
| 549d67e281 | |||
| 79c873344b | |||
| 548f01850f | |||
| 3f495b2d22 | |||
| c943517932 | |||
| 935687998e | |||
| 375f621405 | |||
| a99d19bdea | |||
| 906c0c5c89 | |||
| c92d334b29 | |||
| d38f995ba6 | |||
| bc50f68127 | |||
| b24abee364 | |||
| 6fee2962cb | |||
| e67bfca552 | |||
| d5f87a5498 | |||
| d7426d86d5 | |||
| 7ca98848ac | |||
| 32d5885b68 | |||
| f4d182e4ee | |||
| 69b9581417 | |||
| 1e21056364 | |||
| fdfa5d0ad4 | |||
| d96348eb22 | |||
| 100b3165d8 | |||
| 7e60800c95 | |||
| 4b195cc14c | |||
| 7034dc8dea | |||
| 71f2ba1452 | |||
| 1ec84a589e | |||
| eb40377700 | |||
| bbf9d6d786 | |||
| 8c2b91d3db | |||
| 55028b2db7 | |||
| daf86dbf74 | |||
| b2ef6a05a1 | |||
| 6bc3a2d58a | |||
| d69f4ec829 | |||
| ef45526700 | |||
| 79034bd194 | |||
| 60356b52c6 | |||
| 80d703f9c2 | |||
| 022afbb39d | |||
| 792a1a9d91 | |||
| d2b70e73dd | |||
| 37b0829e28 | |||
| b4a281eca1 | |||
| 95821f6fb6 | |||
| bf2ea04d02 | |||
| ac7a0d4fbf | |||
| 9f109adf28 | |||
| cf12c3cc1f | |||
| 29a7b7a040 | |||
| eb42adc818 | |||
| a4d230f12b | |||
| 9352a09c53 | |||
| a0c1d83ddc | |||
| 657019a5a9 | |||
| 58df013722 | |||
| 347cb61f26 | |||
| 264303ba98 | |||
| 1c90c39897 | |||
| 3fcdba1683 | |||
| 915354bec9 | |||
| c0090a1b4f | |||
| be6d5b76c3 | |||
| fb21efd77d | |||
| cf4fff64f8 | |||
| 0b94376cd4 | |||
| 2b5812d0a9 | |||
| 4da3ee400b | |||
| f8602b5286 | |||
| fc8a752cd5 | |||
| 478cd006d6 | |||
| 4d10dbcf95 | |||
| 43cd455b52 | |||
| b54d5807f3 | |||
| 58e95f76c1 | |||
| 06fd35d420 | |||
| 4df75ca84e | |||
| 701e5be535 | |||
| 9ae57eb370 | |||
| fe5dd5b70a | |||
| 1015436691 | |||
| 83c9f1ed39 | |||
| e4f4b30ae3 | |||
| 9bf6f7c9a0 | |||
| b06957e561 | |||
| baeedc699d | |||
| 00943dc04a | |||
| f43cf7c2b0 | |||
| 9e1421b77c | |||
| 13389be3f4 | |||
| a5306e6345 | |||
| 99adeabc85 | |||
| 6a5e1d597c | |||
| 266119bf62 | |||
| 75086f41a9 | |||
| 3657b1f2a2 | |||
| 975798c643 | |||
| 607de74ace | |||
| 2a647162a8 | |||
| d4332643c4 | |||
| 2ea696934b | |||
| 5a6a34cef9 | |||
| 1daa0b4d46 | |||
| 60d406acaa | |||
| 1a6bd437f5 | |||
| 258a10fb74 | |||
| fdc21ec853 | |||
| c2693d2f46 | |||
| ca9c9c4e1e | |||
| bafe137502 | |||
| 2dea8448a6 | |||
| d9868d0229 | |||
| 38a90c32b2 | |||
| eecec7b119 | |||
| 4eeb535946 | |||
| 26de9adb41 | |||
| 0c9a7caa9d | |||
| a5a617b7a3 | |||
| d5618749c9 | |||
| de8267cfd7 | |||
| 740714b79d | |||
| 013db9410f | |||
| b96ba6f831 | |||
| d29fd52e14 | |||
| 99f7bbaaa2 | |||
| 575099df2d | |||
| ddeac9ab3d | |||
| 009e18f094 | |||
| 9c023b6d8c | |||
| 2c2b2e0779 | |||
| 8d7fb12305 | |||
| 7f4c63d102 | |||
| 3e9f444e6b | |||
| 2290c2a2f0 | |||
| dbb8f7b77b | |||
| 8964817d72 | |||
| 0b950da73f | |||
| 30b88e2b91 | |||
| fb66b1e726 | |||
| 198a8b6592 | |||
| 56e3fa2d6a | |||
| 24f9b17ff6 | |||
| 427fb97562 |
@ -10,7 +10,8 @@ ADD ./api ./api
|
||||
ADD ./conf ./conf
|
||||
ADD ./deepdoc ./deepdoc
|
||||
ADD ./rag ./rag
|
||||
ADD ./graph ./graph
|
||||
ADD ./agent ./agent
|
||||
ADD ./graphrag ./graphrag
|
||||
|
||||
ENV PYTHONPATH=/ragflow/
|
||||
ENV HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
@ -21,7 +21,8 @@ ADD ./api ./api
|
||||
ADD ./conf ./conf
|
||||
ADD ./deepdoc ./deepdoc
|
||||
ADD ./rag ./rag
|
||||
ADD ./graph ./graph
|
||||
ADD ./agent ./agent
|
||||
ADD ./graphrag ./graphrag
|
||||
|
||||
ENV PYTHONPATH=/ragflow/
|
||||
ENV HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
@ -4,8 +4,8 @@ USER root
|
||||
WORKDIR /ragflow
|
||||
|
||||
## for cuda > 12.0
|
||||
RUN /root/miniconda3/envs/py11/bin/pip uninstall -y onnxruntime-gpu
|
||||
RUN /root/miniconda3/envs/py11/bin/pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
||||
RUN pip uninstall -y onnxruntime-gpu
|
||||
RUN pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
||||
|
||||
|
||||
ADD ./web ./web
|
||||
@ -15,6 +15,8 @@ ADD ./api ./api
|
||||
ADD ./conf ./conf
|
||||
ADD ./deepdoc ./deepdoc
|
||||
ADD ./rag ./rag
|
||||
ADD ./agent ./agent
|
||||
ADD ./graphrag ./graphrag
|
||||
|
||||
ENV PYTHONPATH=/ragflow/
|
||||
ENV HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
@ -30,7 +30,8 @@ ADD ./conf ./conf
|
||||
ADD ./deepdoc ./deepdoc
|
||||
ADD ./rag ./rag
|
||||
ADD ./requirements.txt ./requirements.txt
|
||||
ADD ./graph ./graph
|
||||
ADD ./agent ./agent
|
||||
ADD ./graphrag ./graphrag
|
||||
|
||||
RUN apt install openmpi-bin openmpi-common libopenmpi-dev
|
||||
ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu/openmpi/lib:$LD_LIBRARY_PATH
|
||||
|
||||
@ -30,7 +30,8 @@ ADD ./conf ./conf
|
||||
ADD ./deepdoc ./deepdoc
|
||||
ADD ./rag ./rag
|
||||
ADD ./requirements.txt ./requirements.txt
|
||||
ADD ./graph ./graph
|
||||
ADD ./agent ./agent
|
||||
ADD ./graphrag ./graphrag
|
||||
|
||||
RUN dnf install -y openmpi openmpi-devel python3-openmpi
|
||||
ENV C_INCLUDE_PATH /usr/include/openmpi-x86_64:$C_INCLUDE_PATH
|
||||
|
||||
21
README.md
21
README.md
@ -17,7 +17,7 @@
|
||||
<a href="https://demo.ragflow.io" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"></a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docker_pull-ragflow:v0.8.0-brightgreen" alt="docker pull infiniflow/ragflow:v0.8.0"></a>
|
||||
<img src="https://img.shields.io/badge/docker_pull-ragflow:v0.9.0-brightgreen" alt="docker pull infiniflow/ragflow:v0.9.0"></a>
|
||||
<a href="https://github.com/infiniflow/ragflow/blob/main/LICENSE">
|
||||
<img height="21" src="https://img.shields.io/badge/License-Apache--2.0-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="license">
|
||||
</a>
|
||||
@ -59,20 +59,27 @@
|
||||
Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/7248/2f6baa3e-1092-4f11-866d-36f6a9d075e5" width="1200"/>
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b083d173-dadc-4ea9-bdeb-180d7df514eb" width="1200"/>
|
||||
</div>
|
||||
|
||||
|
||||
## 📌 Latest Updates
|
||||
## 🔥 Latest Updates
|
||||
|
||||
- 2024-07-08 Supports [Graph](./graph/README.md).
|
||||
- 2024-08-02 Supports GraphRAG inspired by [graphrag](https://github.com/microsoft/graphrag) , and mind map.
|
||||
|
||||
- 2024-07-23 Supports audio file parsing.
|
||||
|
||||
- 2024-06-27 Supports Markdown and Docx in the Q&A parsing method. Supports extracting images from Docx files. Supports extracting tables from Markdown files.
|
||||
- 2024-06-14 Supports PDF in the Q&A parsing method.
|
||||
- 2024-07-21 Supports more LLMs (LocalAI, OpenRouter, StepFun, and Nvidia).
|
||||
|
||||
- 2024-07-18 Adds more components (Wikipedia, PubMed, Baidu, and Duckduckgo) to the graph.
|
||||
|
||||
- 2024-07-08 Supports workflow based on [Graph](./graph/README.md).
|
||||
- 2024-06-27 Supports Markdown and Docx in the Q&A parsing method.
|
||||
- 2024-06-27 Supports extracting images from Docx files.
|
||||
- 2024-06-27 Supports extracting tables from Markdown files.
|
||||
- 2024-06-06 Supports [Self-RAG](https://huggingface.co/papers/2310.11511), which is enabled by default in dialog settings.
|
||||
- 2024-05-30 Integrates [BCE](https://github.com/netease-youdao/BCEmbedding) and [BGE](https://github.com/FlagOpen/FlagEmbedding) reranker models.
|
||||
- 2024-05-28 Supports LLM Baichuan and VolcanoArk.
|
||||
- 2024-05-23 Supports [RAPTOR](https://arxiv.org/html/2401.18059v1) for better text retrieval.
|
||||
- 2024-05-21 Supports streaming output and text chunk retrieval API.
|
||||
- 2024-05-15 Integrates OpenAI GPT-4o.
|
||||
|
||||
## 🌟 Key Features
|
||||
|
||||
23
README_ja.md
23
README_ja.md
@ -17,8 +17,8 @@
|
||||
<a href="https://demo.ragflow.io" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"></a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docker_pull-ragflow:v0.8.0-brightgreen"
|
||||
alt="docker pull infiniflow/ragflow:v0.8.0"></a>
|
||||
<img src="https://img.shields.io/badge/docker_pull-ragflow:v0.9.0-brightgreen"
|
||||
alt="docker pull infiniflow/ragflow:v0.9.0"></a>
|
||||
<a href="https://github.com/infiniflow/ragflow/blob/main/LICENSE">
|
||||
<img height="21" src="https://img.shields.io/badge/License-Apache--2.0-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="license">
|
||||
</a>
|
||||
@ -41,18 +41,23 @@
|
||||
デモをお試しください:[https://demo.ragflow.io](https://demo.ragflow.io)。
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/7248/2f6baa3e-1092-4f11-866d-36f6a9d075e5" width="1200"/>
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b083d173-dadc-4ea9-bdeb-180d7df514eb" width="1200"/>
|
||||
</div>
|
||||
|
||||
|
||||
## 📌 最新情報
|
||||
- 2024-07-08 [Graph](./graph/README.md) に対応しました。.
|
||||
- 2024-06-27 Q&A解析方式はMarkdownファイルとDocxファイルをサポートしています。Docxファイルからの画像の抽出をサポートします。Markdownファイルからテーブルを抽出することをサポートします。
|
||||
- 2024-06-14 Q&A 解析メソッドは PDF ファイルをサポートしています。
|
||||
## 🔥 最新情報
|
||||
|
||||
- 2024-08-02 [graphrag](https://github.com/microsoft/graphrag) からインスピレーションを得た GraphRAG とマインド マップをサポートします。
|
||||
- 2024-07-23 音声ファイルの解析をサポートしました。
|
||||
- 2024-07-21 より多くの LLM サプライヤー (LocalAI/OpenRouter/StepFun/Nvidia) をサポートします。
|
||||
- 2024-07-18 グラフにコンポーネント(Wikipedia/PubMed/Baidu/Duckduckgo)を追加しました。
|
||||
- 2024-07-08 [Graph](./graph/README.md) ベースのワークフローをサポート
|
||||
- 2024-06-27 Q&A解析方式はMarkdownファイルとDocxファイルをサポートしています。
|
||||
- 2024-06-27 Docxファイルからの画像の抽出をサポートします。
|
||||
- 2024-06-27 Markdownファイルからテーブルを抽出することをサポートします。
|
||||
- 2024-06-06 会話設定でデフォルトでチェックされている [Self-RAG](https://huggingface.co/papers/2310.11511) をサポートします。
|
||||
- 2024-05-30 [BCE](https://github.com/netease-youdao/BCEmbedding) 、[BGE](https://github.com/FlagOpen/FlagEmbedding) reranker を統合。
|
||||
- 2024-05-28 LLM BaichuanとVolcanoArkを統合しました。
|
||||
- 2024-05-23 より良いテキスト検索のために [RAPTOR](https://arxiv.org/html/2401.18059v1) をサポート。
|
||||
- 2024-05-21 ストリーミング出力とテキストチャンク取得APIをサポート。
|
||||
- 2024-05-15 OpenAI GPT-4oを統合しました。
|
||||
|
||||
## 🌟 主な特徴
|
||||
@ -136,7 +141,7 @@
|
||||
$ docker compose up -d
|
||||
```
|
||||
|
||||
> 上記のコマンドを実行すると、RAGFlowの開発版dockerイメージが自動的にダウンロードされます。 特定のバージョンのDockerイメージをダウンロードして実行したい場合は、docker/.envファイルのRAGFLOW_VERSION変数を見つけて、対応するバージョンに変更してください。 例えば、RAGFLOW_VERSION=v0.8.0として、上記のコマンドを実行してください。
|
||||
> 上記のコマンドを実行すると、RAGFlowの開発版dockerイメージが自動的にダウンロードされます。 特定のバージョンのDockerイメージをダウンロードして実行したい場合は、docker/.envファイルのRAGFLOW_VERSION変数を見つけて、対応するバージョンに変更してください。 例えば、RAGFLOW_VERSION=v0.9.0として、上記のコマンドを実行してください。
|
||||
|
||||
> コアイメージのサイズは約 9 GB で、ロードに時間がかかる場合があります。
|
||||
|
||||
|
||||
26
README_zh.md
26
README_zh.md
@ -17,7 +17,7 @@
|
||||
<a href="https://demo.ragflow.io" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"></a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docker_pull-ragflow:v0.8.0-brightgreen" alt="docker pull infiniflow/ragflow:v0.8.0"></a>
|
||||
<img src="https://img.shields.io/badge/docker_pull-ragflow:v0.9.0-brightgreen" alt="docker pull infiniflow/ragflow:v0.9.0"></a>
|
||||
<a href="https://github.com/infiniflow/ragflow/blob/main/LICENSE">
|
||||
<img height="21" src="https://img.shields.io/badge/License-Apache--2.0-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="license">
|
||||
</a>
|
||||
@ -40,19 +40,23 @@
|
||||
请登录网址 [https://demo.ragflow.io](https://demo.ragflow.io) 试用 demo。
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/7248/2f6baa3e-1092-4f11-866d-36f6a9d075e5" width="1200"/>
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b083d173-dadc-4ea9-bdeb-180d7df514eb" width="1200"/>
|
||||
</div>
|
||||
|
||||
|
||||
## 📌 近期更新
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2024-07-08 支持 [Graph](./graph/README.md)。
|
||||
- 2024-06-27 Q&A 解析方式支持 Markdown 文件和 Docx 文件。支持提取出 Docx 文件中的图片。支持提取出 Markdown 文件中的表格。
|
||||
- 2024-06-14 Q&A 解析方式支持 PDF 文件。
|
||||
- 2024-08-02 支持 GraphRAG 启发于 [graphrag](https://github.com/microsoft/graphrag) 和思维导图。
|
||||
- 2024-07-23 支持解析音频文件。
|
||||
- 2024-07-21 支持更多的大模型供应商(LocalAI/OpenRouter/StepFun/Nvidia)。
|
||||
- 2024-07-18 在Graph中支持算子:Wikipedia、PubMed、Baidu和Duckduckgo。
|
||||
- 2024-07-08 支持 Agentic RAG: 基于 [Graph](./graph/README.md) 的工作流。
|
||||
- 2024-06-27 Q&A 解析方式支持 Markdown 文件和 Docx 文件。
|
||||
- 2024-06-27 支持提取出 Docx 文件中的图片。
|
||||
- 2024-06-27 支持提取出 Markdown 文件中的表格。
|
||||
- 2024-06-06 支持 [Self-RAG](https://huggingface.co/papers/2310.11511) ,在对话设置里面默认勾选。
|
||||
- 2024-05-30 集成 [BCE](https://github.com/netease-youdao/BCEmbedding) 和 [BGE](https://github.com/FlagOpen/FlagEmbedding) 重排序模型。
|
||||
- 2024-05-28 集成大模型 Baichuan 和火山方舟。
|
||||
- 2024-05-23 实现 [RAPTOR](https://arxiv.org/html/2401.18059v1) 提供更好的文本检索。
|
||||
- 2024-05-21 支持流式结果输出和文本块获取API。
|
||||
- 2024-05-15 集成大模型 OpenAI GPT-4o。
|
||||
|
||||
## 🌟 主要功能
|
||||
@ -136,7 +140,7 @@
|
||||
$ docker compose -f docker-compose-CN.yml up -d
|
||||
```
|
||||
|
||||
> 请注意,运行上述命令会自动下载 RAGFlow 的开发版本 docker 镜像。如果你想下载并运行特定版本的 docker 镜像,请在 docker/.env 文件中找到 RAGFLOW_VERSION 变量,将其改为对应版本。例如 RAGFLOW_VERSION=v0.8.0,然后运行上述命令。
|
||||
> 请注意,运行上述命令会自动下载 RAGFlow 的开发版本 docker 镜像。如果你想下载并运行特定版本的 docker 镜像,请在 docker/.env 文件中找到 RAGFLOW_VERSION 变量,将其改为对应版本。例如 RAGFLOW_VERSION=v0.9.0,然后运行上述命令。
|
||||
|
||||
> 核心镜像文件大约 9 GB,可能需要一定时间拉取。请耐心等待。
|
||||
|
||||
@ -198,7 +202,7 @@
|
||||
```bash
|
||||
$ git clone https://github.com/infiniflow/ragflow.git
|
||||
$ cd ragflow/
|
||||
$ docker build -t infiniflow/ragflow:v0.8.0 .
|
||||
$ docker build -t infiniflow/ragflow:v0.9.0 .
|
||||
$ cd ragflow/docker
|
||||
$ chmod +x ./entrypoint.sh
|
||||
$ docker compose up -d
|
||||
@ -306,6 +310,10 @@ $ systemctl start nginx
|
||||
|
||||
RAGFlow 只有通过开源协作才能蓬勃发展。秉持这一精神,我们欢迎来自社区的各种贡献。如果您有意参与其中,请查阅我们的[贡献者指南](./docs/references/CONTRIBUTING.md) 。
|
||||
|
||||
## 🤝 商务合作
|
||||
|
||||
- [预约咨询](https://aao615odquw.feishu.cn/share/base/form/shrcnjw7QleretCLqh1nuPo1xxh)
|
||||
|
||||
## 👥 加入社区
|
||||
|
||||
扫二维码添加 RAGFlow 小助手,进 RAGFlow 交流群。
|
||||
|
||||
@ -22,9 +22,9 @@ from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graph.component import component_class
|
||||
from graph.component.base import ComponentBase
|
||||
from graph.settings import flow_logger, DEBUG
|
||||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from agent.settings import flow_logger, DEBUG
|
||||
|
||||
|
||||
class Canvas(ABC):
|
||||
@ -188,14 +188,19 @@ class Canvas(ABC):
|
||||
def prepare2run(cpns):
|
||||
nonlocal ran, ans
|
||||
for c in cpns:
|
||||
if self.path[-1] and c == self.path[-1][-1]: continue
|
||||
cpn = self.components[c]["obj"]
|
||||
if cpn.component_name == "Answer":
|
||||
self.answer.append(c)
|
||||
else:
|
||||
if DEBUG: print("RUN: ", c)
|
||||
if cpn.component_name == "Generate":
|
||||
cpids = cpn.get_dependent_components()
|
||||
if any([c not in self.path[-1] for c in cpids]):
|
||||
continue
|
||||
ans = cpn.run(self.history, **kwargs)
|
||||
self.path[-1].append(c)
|
||||
ran += 1
|
||||
ran += 1
|
||||
|
||||
prepare2run(self.components[self.path[-2][-1]]["downstream"])
|
||||
while 0 <= ran < len(self.path[-1]):
|
||||
@ -220,6 +225,7 @@ class Canvas(ABC):
|
||||
prepare2run([p])
|
||||
break
|
||||
traceback.print_exc()
|
||||
break
|
||||
continue
|
||||
|
||||
try:
|
||||
@ -231,6 +237,7 @@ class Canvas(ABC):
|
||||
prepare2run([p])
|
||||
break
|
||||
traceback.print_exc()
|
||||
break
|
||||
|
||||
if self.answer:
|
||||
cpn_id = self.answer[0]
|
||||
@ -10,7 +10,13 @@ from .message import Message, MessageParam
|
||||
from .rewrite import RewriteQuestion, RewriteQuestionParam
|
||||
from .keyword import KeywordExtract, KeywordExtractParam
|
||||
from .baidu import Baidu, BaiduParam
|
||||
from .duckduckgosearch import DuckDuckGoSearch, DuckDuckGoSearchParam
|
||||
from .duckduckgo import DuckDuckGo, DuckDuckGoParam
|
||||
from .wikipedia import Wikipedia, WikipediaParam
|
||||
from .pubmed import PubMed, PubMedParam
|
||||
from .arxiv import ArXiv, ArXivParam
|
||||
from .google import Google, GoogleParam
|
||||
from .bing import Bing, BingParam
|
||||
from .googlescholar import GoogleScholar, GoogleScholarParam
|
||||
|
||||
|
||||
def component_class(class_name):
|
||||
@ -19,7 +19,7 @@ from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class AnswerParam(ComponentParamBase):
|
||||
@ -59,8 +59,10 @@ class Answer(ComponentBase, ABC):
|
||||
stream = self.get_stream_input()
|
||||
if isinstance(stream, pd.DataFrame):
|
||||
res = stream
|
||||
answer = ""
|
||||
for ii, row in stream.iterrows():
|
||||
yield row.to_dict()
|
||||
answer += row.to_dict()["content"]
|
||||
yield {"content": answer}
|
||||
else:
|
||||
for st in stream():
|
||||
res = st
|
||||
69
agent/component/arxiv.py
Normal file
69
agent/component/arxiv.py
Normal file
@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import arxiv
|
||||
import pandas as pd
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class ArXivParam(ComponentParamBase):
|
||||
"""
|
||||
Define the ArXiv component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 6
|
||||
self.sort_by = 'submittedDate'
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.sort_by, "ArXiv Search Sort_by",
|
||||
['submittedDate', 'lastUpdatedDate', 'relevance'])
|
||||
|
||||
|
||||
class ArXiv(ComponentBase, ABC):
|
||||
component_name = "ArXiv"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return ArXiv.be_output("")
|
||||
|
||||
try:
|
||||
sort_choices = {"relevance": arxiv.SortCriterion.Relevance,
|
||||
"lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate,
|
||||
'submittedDate': arxiv.SortCriterion.SubmittedDate}
|
||||
arxiv_client = arxiv.Client()
|
||||
search = arxiv.Search(
|
||||
query=ans,
|
||||
max_results=self._param.top_n,
|
||||
sort_by=sort_choices[self._param.sort_by]
|
||||
)
|
||||
arxiv_res = [
|
||||
{"content": 'Title: ' + i.title + '\nPdf_Url: <a href="' + i.pdf_url + '"></a> \nSummary: ' + i.summary} for
|
||||
i in list(arxiv_client.results(search))]
|
||||
except Exception as e:
|
||||
return ArXiv.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not arxiv_res:
|
||||
return ArXiv.be_output("")
|
||||
|
||||
df = pd.DataFrame(arxiv_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
69
agent/component/baidu.py
Normal file
69
agent/component/baidu.py
Normal file
@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
import pandas as pd
|
||||
import requests
|
||||
import re
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class BaiduParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Baidu component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
|
||||
class Baidu(ComponentBase, ABC):
|
||||
component_name = "Baidu"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return Baidu.be_output("")
|
||||
|
||||
try:
|
||||
url = 'https://www.baidu.com/s?wd=' + ans + '&rn=' + str(self._param.top_n)
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.104 Safari/537.36'}
|
||||
response = requests.get(url=url, headers=headers)
|
||||
|
||||
url_res = re.findall(r"'url': \\\"(.*?)\\\"}", response.text)
|
||||
title_res = re.findall(r"'title': \\\"(.*?)\\\",\\n", response.text)
|
||||
body_res = re.findall(r"\"contentText\":\"(.*?)\"", response.text)
|
||||
baidu_res = [{"content": re.sub('<em>|</em>', '', '<a href="' + url + '">' + title + '</a> ' + body)} for
|
||||
url, title, body in zip(url_res, title_res, body_res)]
|
||||
del body_res, url_res, title_res
|
||||
except Exception as e:
|
||||
return Baidu.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not baidu_res:
|
||||
return Baidu.be_output("")
|
||||
|
||||
df = pd.DataFrame(baidu_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
|
||||
@ -23,8 +23,8 @@ from typing import List, Dict, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graph import settings
|
||||
from graph.settings import flow_logger, DEBUG
|
||||
from agent import settings
|
||||
from agent.settings import flow_logger, DEBUG
|
||||
|
||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||
@ -35,7 +35,7 @@ _IS_RAW_CONF = "_is_raw_conf"
|
||||
class ComponentParamBase(ABC):
|
||||
def __init__(self):
|
||||
self.output_var_name = "output"
|
||||
self.message_history_window_size = 4
|
||||
self.message_history_window_size = 22
|
||||
|
||||
def set_name(self, name: str):
|
||||
self._name = name
|
||||
@ -445,6 +445,12 @@ class ComponentBase(ABC):
|
||||
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
|
||||
for u in reversed_cpnts[::-1]:
|
||||
if self.get_component_name(u) in ["switch"]: continue
|
||||
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
||||
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
||||
if o is not None:
|
||||
upstream_outs.append(o)
|
||||
continue
|
||||
if u not in self._canvas.get_component(self._id)["upstream"]: continue
|
||||
if self.component_name.lower().find("switch") < 0 \
|
||||
and self.get_component_name(u) in ["relevant", "categorize"]:
|
||||
continue
|
||||
@ -455,12 +461,20 @@ class ComponentBase(ABC):
|
||||
break
|
||||
break
|
||||
if self.component_name.lower().find("answer") >= 0:
|
||||
if self.get_component_name(u) in ["relevant"]: continue
|
||||
|
||||
else: upstream_outs.append(self._canvas.get_component(u)["obj"].output(allow_partial=False)[1])
|
||||
if self.get_component_name(u) in ["relevant"]:
|
||||
continue
|
||||
else:
|
||||
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
||||
if o is not None:
|
||||
upstream_outs.append(o)
|
||||
break
|
||||
|
||||
return pd.concat(upstream_outs, ignore_index=False)
|
||||
if upstream_outs:
|
||||
df = pd.concat(upstream_outs, ignore_index=True)
|
||||
if "content" in df:
|
||||
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
||||
return df
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_stream_input(self):
|
||||
reversed_cpnts = []
|
||||
@ -13,11 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class BeginParam(ComponentParamBase):
|
||||
|
||||
85
agent/component/bing.py
Normal file
85
agent/component/bing.py
Normal file
@ -0,0 +1,85 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import requests
|
||||
import pandas as pd
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class BingParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Bing component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.channel = "Webpages"
|
||||
self.api_key = "YOUR_ACCESS_KEY"
|
||||
self.country = "CN"
|
||||
self.language = "en"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.channel, "Bing Web Search or Bing News", ["Webpages", "News"])
|
||||
self.check_empty(self.api_key, "Bing subscription key")
|
||||
self.check_valid_value(self.country, "Bing Country",
|
||||
['AR', 'AU', 'AT', 'BE', 'BR', 'CA', 'CL', 'DK', 'FI', 'FR', 'DE', 'HK', 'IN', 'ID',
|
||||
'IT', 'JP', 'KR', 'MY', 'MX', 'NL', 'NZ', 'NO', 'CN', 'PL', 'PT', 'PH', 'RU', 'SA',
|
||||
'ZA', 'ES', 'SE', 'CH', 'TW', 'TR', 'GB', 'US'])
|
||||
self.check_valid_value(self.language, "Bing Languages",
|
||||
['ar', 'eu', 'bn', 'bg', 'ca', 'ns', 'nt', 'hr', 'cs', 'da', 'nl', 'en', 'gb', 'et',
|
||||
'fi', 'fr', 'gl', 'de', 'gu', 'he', 'hi', 'hu', 'is', 'it', 'jp', 'kn', 'ko', 'lv',
|
||||
'lt', 'ms', 'ml', 'mr', 'nb', 'pl', 'br', 'pt', 'pa', 'ro', 'ru', 'sr', 'sk', 'sl',
|
||||
'es', 'sv', 'ta', 'te', 'th', 'tr', 'uk', 'vi'])
|
||||
|
||||
|
||||
class Bing(ComponentBase, ABC):
|
||||
component_name = "Bing"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return Bing.be_output("")
|
||||
|
||||
try:
|
||||
headers = {"Ocp-Apim-Subscription-Key": self._param.api_key, 'Accept-Language': self._param.language}
|
||||
params = {"q": ans, "textDecorations": True, "textFormat": "HTML", "cc": self._param.country,
|
||||
"answerCount": 1, "promote": self._param.channel}
|
||||
if self._param.channel == "Webpages":
|
||||
response = requests.get("https://api.bing.microsoft.com/v7.0/search", headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["snippet"]} for i in
|
||||
search_results["webPages"]["value"]]
|
||||
elif self._param.channel == "News":
|
||||
response = requests.get("https://api.bing.microsoft.com/v7.0/news/search", headers=headers,
|
||||
params=params)
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["description"]} for i
|
||||
in search_results['news']['value']]
|
||||
except Exception as e:
|
||||
return Bing.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not bing_res:
|
||||
return Bing.be_output("")
|
||||
|
||||
df = pd.DataFrame(bing_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
@ -14,13 +14,10 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from graph.component import GenerateParam, Generate
|
||||
from graph.settings import DEBUG
|
||||
from agent.component import GenerateParam, Generate
|
||||
from agent.settings import DEBUG
|
||||
|
||||
|
||||
class CategorizeParam(GenerateParam):
|
||||
@ -21,7 +21,7 @@ from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class CiteParam(ComponentParamBase):
|
||||
66
agent/component/duckduckgo.py
Normal file
66
agent/component/duckduckgo.py
Normal file
@ -0,0 +1,66 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from duckduckgo_search import DDGS
|
||||
import pandas as pd
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class DuckDuckGoParam(ComponentParamBase):
|
||||
"""
|
||||
Define the DuckDuckGo component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.channel = "text"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.channel, "Web Search or News", ["text", "news"])
|
||||
|
||||
|
||||
class DuckDuckGo(ComponentBase, ABC):
|
||||
component_name = "DuckDuckGo"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return DuckDuckGo.be_output("")
|
||||
|
||||
try:
|
||||
if self._param.channel == "text":
|
||||
with DDGS() as ddgs:
|
||||
# {'title': '', 'href': '', 'body': ''}
|
||||
duck_res = [{"content": '<a href="' + i["href"] + '">' + i["title"] + '</a> ' + i["body"]} for i
|
||||
in ddgs.text(ans, max_results=self._param.top_n)]
|
||||
elif self._param.channel == "news":
|
||||
with DDGS() as ddgs:
|
||||
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
|
||||
duck_res = [{"content": '<a href="' + i["url"] + '">' + i["title"] + '</a> ' + i["body"]} for i
|
||||
in ddgs.news(ans, max_results=self._param.top_n)]
|
||||
except Exception as e:
|
||||
return DuckDuckGo.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not duck_res:
|
||||
return DuckDuckGo.be_output("")
|
||||
|
||||
df = pd.DataFrame(duck_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
@ -15,13 +15,11 @@
|
||||
#
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class GenerateParam(ComponentParamBase):
|
||||
@ -63,26 +61,60 @@ class GenerateParam(ComponentParamBase):
|
||||
class Generate(ComponentBase):
|
||||
component_name = "Generate"
|
||||
|
||||
def get_dependent_components(self):
|
||||
cpnts = [para["component_id"] for para in self._param.parameters]
|
||||
return cpnts
|
||||
|
||||
def set_cite(self, retrieval_res, answer):
|
||||
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
||||
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
||||
self._canvas.get_embedding_model()), tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
doc_ids = set([])
|
||||
recall_docs = []
|
||||
for i in idx:
|
||||
did = retrieval_res.loc[int(i), "doc_id"]
|
||||
if did in doc_ids: continue
|
||||
doc_ids.add(did)
|
||||
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
||||
|
||||
del retrieval_res["vector"]
|
||||
del retrieval_res["content_ltks"]
|
||||
|
||||
reference = {
|
||||
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
|
||||
"doc_aggs": recall_docs
|
||||
}
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
res = {"content": answer, "reference": reference}
|
||||
|
||||
return res
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
prompt = self._param.prompt
|
||||
|
||||
retrieval_res = self.get_input()
|
||||
input = "\n- ".join(retrieval_res["content"])
|
||||
input = (" - " + "\n - ".join(retrieval_res["content"])) if "content" in retrieval_res else ""
|
||||
for para in self._param.parameters:
|
||||
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
||||
_, out = cpn.output(allow_partial=False)
|
||||
if "content" not in out.columns:
|
||||
kwargs[para["key"]] = "Nothing"
|
||||
else:
|
||||
kwargs[para["key"]] = "\n - ".join(out["content"])
|
||||
kwargs[para["key"]] = " - " + "\n - ".join(out["content"])
|
||||
|
||||
kwargs["input"] = input
|
||||
for n, v in kwargs.items():
|
||||
# prompt = re.sub(r"\{%s\}"%n, re.escape(str(v)), prompt)
|
||||
prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"]
|
||||
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
||||
"obj"].component_name.lower() == "answer":
|
||||
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
|
||||
|
||||
if "empty_response" in retrieval_res.columns:
|
||||
@ -90,27 +122,8 @@ class Generate(ComponentBase):
|
||||
|
||||
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
|
||||
self._param.gen_conf())
|
||||
|
||||
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
||||
ans, idx = retrievaler.insert_citations(ans,
|
||||
[ck["content_ltks"]
|
||||
for _, ck in retrieval_res.iterrows()],
|
||||
[ck["vector"]
|
||||
for _, ck in retrieval_res.iterrows()],
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
||||
self._canvas.get_embedding_model()),
|
||||
tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
del retrieval_res["vector"]
|
||||
retrieval_res = retrieval_res.to_dict("records")
|
||||
df = []
|
||||
for i in idx:
|
||||
df.append(retrieval_res[int(i)])
|
||||
r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans)
|
||||
assert r, f"{i} => {ans}"
|
||||
df[-1]["content"] = r.group(1)
|
||||
ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans)
|
||||
if ans: df.append({"content": ans})
|
||||
df = self.set_cite(retrieval_res, ans)
|
||||
return pd.DataFrame(df)
|
||||
|
||||
return Generate.be_output(ans)
|
||||
@ -131,34 +144,7 @@ class Generate(ComponentBase):
|
||||
yield res
|
||||
|
||||
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
||||
answer, idx = retrievaler.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for _, ck in retrieval_res.iterrows()],
|
||||
[ck["vector"]
|
||||
for _, ck in retrieval_res.iterrows()],
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
||||
self._canvas.get_embedding_model()),
|
||||
tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
doc_ids = set([])
|
||||
recall_docs = []
|
||||
for i in idx:
|
||||
did = retrieval_res.loc[int(i), "doc_id"]
|
||||
if did in doc_ids: continue
|
||||
doc_ids.add(did)
|
||||
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
||||
|
||||
del retrieval_res["vector"]
|
||||
del retrieval_res["content_ltks"]
|
||||
|
||||
reference = {
|
||||
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
|
||||
"doc_aggs": recall_docs
|
||||
}
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
res = {"content": answer, "reference": reference}
|
||||
res = self.set_cite(retrieval_res, answer)
|
||||
yield res
|
||||
|
||||
self.set_output(res)
|
||||
96
agent/component/google.py
Normal file
96
agent/component/google.py
Normal file
@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from serpapi import GoogleSearch
|
||||
import pandas as pd
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class GoogleParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Google component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.api_key = "xxx"
|
||||
self.country = "cn"
|
||||
self.language = "en"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_empty(self.api_key, "SerpApi API key")
|
||||
self.check_valid_value(self.country, "Google Country",
|
||||
['af', 'al', 'dz', 'as', 'ad', 'ao', 'ai', 'aq', 'ag', 'ar', 'am', 'aw', 'au', 'at',
|
||||
'az', 'bs', 'bh', 'bd', 'bb', 'by', 'be', 'bz', 'bj', 'bm', 'bt', 'bo', 'ba', 'bw',
|
||||
'bv', 'br', 'io', 'bn', 'bg', 'bf', 'bi', 'kh', 'cm', 'ca', 'cv', 'ky', 'cf', 'td',
|
||||
'cl', 'cn', 'cx', 'cc', 'co', 'km', 'cg', 'cd', 'ck', 'cr', 'ci', 'hr', 'cu', 'cy',
|
||||
'cz', 'dk', 'dj', 'dm', 'do', 'ec', 'eg', 'sv', 'gq', 'er', 'ee', 'et', 'fk', 'fo',
|
||||
'fj', 'fi', 'fr', 'gf', 'pf', 'tf', 'ga', 'gm', 'ge', 'de', 'gh', 'gi', 'gr', 'gl',
|
||||
'gd', 'gp', 'gu', 'gt', 'gn', 'gw', 'gy', 'ht', 'hm', 'va', 'hn', 'hk', 'hu', 'is',
|
||||
'in', 'id', 'ir', 'iq', 'ie', 'il', 'it', 'jm', 'jp', 'jo', 'kz', 'ke', 'ki', 'kp',
|
||||
'kr', 'kw', 'kg', 'la', 'lv', 'lb', 'ls', 'lr', 'ly', 'li', 'lt', 'lu', 'mo', 'mk',
|
||||
'mg', 'mw', 'my', 'mv', 'ml', 'mt', 'mh', 'mq', 'mr', 'mu', 'yt', 'mx', 'fm', 'md',
|
||||
'mc', 'mn', 'ms', 'ma', 'mz', 'mm', 'na', 'nr', 'np', 'nl', 'an', 'nc', 'nz', 'ni',
|
||||
'ne', 'ng', 'nu', 'nf', 'mp', 'no', 'om', 'pk', 'pw', 'ps', 'pa', 'pg', 'py', 'pe',
|
||||
'ph', 'pn', 'pl', 'pt', 'pr', 'qa', 're', 'ro', 'ru', 'rw', 'sh', 'kn', 'lc', 'pm',
|
||||
'vc', 'ws', 'sm', 'st', 'sa', 'sn', 'rs', 'sc', 'sl', 'sg', 'sk', 'si', 'sb', 'so',
|
||||
'za', 'gs', 'es', 'lk', 'sd', 'sr', 'sj', 'sz', 'se', 'ch', 'sy', 'tw', 'tj', 'tz',
|
||||
'th', 'tl', 'tg', 'tk', 'to', 'tt', 'tn', 'tr', 'tm', 'tc', 'tv', 'ug', 'ua', 'ae',
|
||||
'uk', 'gb', 'us', 'um', 'uy', 'uz', 'vu', 've', 'vn', 'vg', 'vi', 'wf', 'eh', 'ye',
|
||||
'zm', 'zw'])
|
||||
self.check_valid_value(self.language, "Google languages",
|
||||
['af', 'ak', 'sq', 'ws', 'am', 'ar', 'hy', 'az', 'eu', 'be', 'bem', 'bn', 'bh',
|
||||
'xx-bork', 'bs', 'br', 'bg', 'bt', 'km', 'ca', 'chr', 'ny', 'zh-cn', 'zh-tw', 'co',
|
||||
'hr', 'cs', 'da', 'nl', 'xx-elmer', 'en', 'eo', 'et', 'ee', 'fo', 'tl', 'fi', 'fr',
|
||||
'fy', 'gaa', 'gl', 'ka', 'de', 'el', 'kl', 'gn', 'gu', 'xx-hacker', 'ht', 'ha', 'haw',
|
||||
'iw', 'hi', 'hu', 'is', 'ig', 'id', 'ia', 'ga', 'it', 'ja', 'jw', 'kn', 'kk', 'rw',
|
||||
'rn', 'xx-klingon', 'kg', 'ko', 'kri', 'ku', 'ckb', 'ky', 'lo', 'la', 'lv', 'ln', 'lt',
|
||||
'loz', 'lg', 'ach', 'mk', 'mg', 'ms', 'ml', 'mt', 'mv', 'mi', 'mr', 'mfe', 'mo', 'mn',
|
||||
'sr-me', 'my', 'ne', 'pcm', 'nso', 'no', 'nn', 'oc', 'or', 'om', 'ps', 'fa',
|
||||
'xx-pirate', 'pl', 'pt', 'pt-br', 'pt-pt', 'pa', 'qu', 'ro', 'rm', 'nyn', 'ru', 'gd',
|
||||
'sr', 'sh', 'st', 'tn', 'crs', 'sn', 'sd', 'si', 'sk', 'sl', 'so', 'es', 'es-419', 'su',
|
||||
'sw', 'sv', 'tg', 'ta', 'tt', 'te', 'th', 'ti', 'to', 'lua', 'tum', 'tr', 'tk', 'tw',
|
||||
'ug', 'uk', 'ur', 'uz', 'vu', 'vi', 'cy', 'wo', 'xh', 'yi', 'yo', 'zu']
|
||||
)
|
||||
|
||||
|
||||
class Google(ComponentBase, ABC):
|
||||
component_name = "Google"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return Google.be_output("")
|
||||
|
||||
try:
|
||||
client = GoogleSearch(
|
||||
{"engine": "google", "q": ans, "api_key": self._param.api_key, "gl": self._param.country,
|
||||
"hl": self._param.language, "num": self._param.top_n})
|
||||
google_res = [{"content": '<a href="' + i["link"] + '">' + i["title"] + '</a> ' + i["snippet"]} for i in
|
||||
client.get_dict()["organic_results"]]
|
||||
except Exception as e:
|
||||
return Google.be_output("**ERROR**: Existing Unavailable Parameters!")
|
||||
|
||||
if not google_res:
|
||||
return Google.be_output("")
|
||||
|
||||
df = pd.DataFrame(google_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
70
agent/component/googlescholar.py
Normal file
70
agent/component/googlescholar.py
Normal file
@ -0,0 +1,70 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from scholarly import scholarly
|
||||
|
||||
|
||||
class GoogleScholarParam(ComponentParamBase):
|
||||
"""
|
||||
Define the GoogleScholar component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 6
|
||||
self.sort_by = 'relevance'
|
||||
self.year_low = None
|
||||
self.year_high = None
|
||||
self.patents = True
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ['date', 'relevance'])
|
||||
self.check_boolean(self.patents, "Whether or not to include patents, defaults to True")
|
||||
|
||||
|
||||
class GoogleScholar(ComponentBase, ABC):
|
||||
component_name = "GoogleScholar"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return GoogleScholar.be_output("")
|
||||
|
||||
scholar_client = scholarly.search_pubs(ans, patents=self._param.patents, year_low=self._param.year_low,
|
||||
year_high=self._param.year_high, sort_by=self._param.sort_by)
|
||||
scholar_res = []
|
||||
for i in range(self._param.top_n):
|
||||
try:
|
||||
pub = next(scholar_client)
|
||||
scholar_res.append({"content": 'Title: ' + pub['bib']['title'] + '\n_Url: <a href="' + pub[
|
||||
'pub_url'] + '"></a> ' + "\n author: " + ",".join(pub['bib']['author']) + '\n Abstract: ' + pub[
|
||||
'bib'].get('abstract', 'no abstract')})
|
||||
|
||||
except StopIteration or Exception as e:
|
||||
print("**ERROR** " + str(e))
|
||||
break
|
||||
|
||||
if not scholar_res:
|
||||
return GoogleScholar.be_output("")
|
||||
|
||||
df = pd.DataFrame(scholar_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
@ -17,8 +17,8 @@ import re
|
||||
from abc import ABC
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from graph.component import GenerateParam, Generate
|
||||
from graph.settings import DEBUG
|
||||
from agent.component import GenerateParam, Generate
|
||||
from agent.settings import DEBUG
|
||||
|
||||
|
||||
class KeywordExtractParam(GenerateParam):
|
||||
@ -16,10 +16,7 @@
|
||||
import random
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
@ -46,7 +43,11 @@ class Message(ComponentBase, ABC):
|
||||
return Message.be_output(random.choice(self._param.messages))
|
||||
|
||||
def stream_output(self):
|
||||
res = None
|
||||
if self._param.messages:
|
||||
yield {"content": random.choice(self._param.messages)}
|
||||
res = {"content": random.choice(self._param.messages)}
|
||||
yield res
|
||||
|
||||
self.set_output(res)
|
||||
|
||||
|
||||
65
agent/component/pubmed.py
Normal file
65
agent/component/pubmed.py
Normal file
@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from Bio import Entrez
|
||||
import pandas as pd
|
||||
import xml.etree.ElementTree as ET
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class PubMedParam(ComponentParamBase):
|
||||
"""
|
||||
Define the PubMed component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 5
|
||||
self.email = "A.N.Other@example.com"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
|
||||
class PubMed(ComponentBase, ABC):
|
||||
component_name = "PubMed"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return PubMed.be_output("")
|
||||
|
||||
try:
|
||||
Entrez.email = self._param.email
|
||||
pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=ans))['IdList']
|
||||
pubmedcnt = ET.fromstring(
|
||||
Entrez.efetch(db='pubmed', id=",".join(pubmedids), retmode="xml").read().decode("utf-8"))
|
||||
pubmed_res = [{"content": 'Title:' + child.find("MedlineCitation").find("Article").find(
|
||||
"ArticleTitle").text + '\nUrl:<a href=" https://pubmed.ncbi.nlm.nih.gov/' + child.find(
|
||||
"MedlineCitation").find("PMID").text + '">' + '</a>\n' + 'Abstract:' + child.find(
|
||||
"MedlineCitation").find("Article").find("Abstract").find("AbstractText").text} for child in
|
||||
pubmedcnt.findall("PubmedArticle")]
|
||||
except Exception as e:
|
||||
return PubMed.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not pubmed_res:
|
||||
return PubMed.be_output("")
|
||||
|
||||
df = pd.DataFrame(pubmed_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
@ -16,7 +16,7 @@
|
||||
from abc import ABC
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from graph.component import GenerateParam, Generate
|
||||
from agent.component import GenerateParam, Generate
|
||||
from rag.utils import num_tokens_from_string, encoder
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class RetrievalParam(ComponentParamBase):
|
||||
@ -16,7 +16,7 @@
|
||||
from abc import ABC
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from graph.component import GenerateParam, Generate
|
||||
from agent.component import GenerateParam, Generate
|
||||
|
||||
|
||||
class RewriteQuestionParam(GenerateParam):
|
||||
@ -16,12 +16,7 @@
|
||||
from abc import ABC
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class SwitchParam(ComponentParamBase):
|
||||
69
agent/component/wikipedia.py
Normal file
69
agent/component/wikipedia.py
Normal file
@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
import wikipedia
|
||||
import pandas as pd
|
||||
from agent.settings import DEBUG
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class WikipediaParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Wikipedia component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.language = "en"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.language, "Wikipedia languages",
|
||||
['af', 'pl', 'ar', 'ast', 'az', 'bg', 'nan', 'bn', 'be', 'ca', 'cs', 'cy', 'da', 'de',
|
||||
'et', 'el', 'en', 'es', 'eo', 'eu', 'fa', 'fr', 'gl', 'ko', 'hy', 'hi', 'hr', 'id',
|
||||
'it', 'he', 'ka', 'lld', 'la', 'lv', 'lt', 'hu', 'mk', 'arz', 'ms', 'min', 'my', 'nl',
|
||||
'ja', 'nb', 'nn', 'ce', 'uz', 'pt', 'kk', 'ro', 'ru', 'ceb', 'sk', 'sl', 'sr', 'sh',
|
||||
'fi', 'sv', 'ta', 'tt', 'th', 'tg', 'azb', 'tr', 'uk', 'ur', 'vi', 'war', 'zh', 'yue'])
|
||||
|
||||
|
||||
class Wikipedia(ComponentBase, ABC):
|
||||
component_name = "Wikipedia"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return Wikipedia.be_output("")
|
||||
|
||||
try:
|
||||
wiki_res = []
|
||||
wikipedia.set_lang(self._param.language)
|
||||
wiki_engine = wikipedia
|
||||
for wiki_key in wiki_engine.search(ans, results=self._param.top_n):
|
||||
page = wiki_engine.page(title=wiki_key, auto_suggest=False)
|
||||
wiki_res.append({"content": '<a href="' + page.url + '">' + page.title + '</a> ' + page.summary})
|
||||
except Exception as e:
|
||||
return Wikipedia.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not wiki_res:
|
||||
return Wikipedia.be_output("")
|
||||
|
||||
df = pd.DataFrame(wiki_res)
|
||||
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
|
||||
return df
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
572
agent/templates/websearch_assistant.json
Normal file
572
agent/templates/websearch_assistant.json
Normal file
File diff suppressed because one or more lines are too long
@ -16,9 +16,8 @@
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
import readline
|
||||
from graph.canvas import Canvas
|
||||
from graph.settings import DEBUG
|
||||
from agent.canvas import Canvas
|
||||
from agent.settings import DEBUG
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
62
agent/test/dsl_examples/keyword_wikipedia_and_generate.json
Normal file
62
agent/test/dsl_examples/keyword_wikipedia_and_generate.json
Normal file
@ -0,0 +1,62 @@
|
||||
{
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj":{
|
||||
"component_name": "Begin",
|
||||
"params": {
|
||||
"prologue": "Hi there!"
|
||||
}
|
||||
},
|
||||
"downstream": ["answer:0"],
|
||||
"upstream": []
|
||||
},
|
||||
"answer:0": {
|
||||
"obj": {
|
||||
"component_name": "Answer",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["keyword:0"],
|
||||
"upstream": ["begin"]
|
||||
},
|
||||
"keyword:0": {
|
||||
"obj": {
|
||||
"component_name": "KeywordExtract",
|
||||
"params": {
|
||||
"llm_id": "deepseek-chat",
|
||||
"prompt": "- Role: You're a question analyzer.\n - Requirements:\n - Summarize user's question, and give top %s important keyword/phrase.\n - Use comma as a delimiter to separate keywords/phrases.\n - Answer format: (in language of user's question)\n - keyword: ",
|
||||
"temperature": 0.2,
|
||||
"top_n": 1
|
||||
}
|
||||
},
|
||||
"downstream": ["wikipedia:0"],
|
||||
"upstream": ["answer:0"]
|
||||
},
|
||||
"wikipedia:0": {
|
||||
"obj":{
|
||||
"component_name": "Wikipedia",
|
||||
"params": {
|
||||
"top_n": 10
|
||||
}
|
||||
},
|
||||
"downstream": ["generate:0"],
|
||||
"upstream": ["keyword:0"]
|
||||
},
|
||||
"generate:1": {
|
||||
"obj": {
|
||||
"component_name": "Generate",
|
||||
"params": {
|
||||
"llm_id": "deepseek-chat",
|
||||
"prompt": "You are an intelligent assistant. Please answer the question based on content from Wikipedia. When the answer from Wikipedia is incomplete, you need to output the URL link of the corresponding content as well. When all the content searched from Wikipedia is irrelevant to the question, your answer must include the sentence, \"The answer you are looking for is not found in the Wikipedia!\". Answers need to consider chat history.\n The content of Wikipedia is as follows:\n {input}\n The above is the content of Wikipedia.",
|
||||
"temperature": 0.2
|
||||
}
|
||||
},
|
||||
"downstream": ["answer:0"],
|
||||
"upstream": ["wikipedia:0"]
|
||||
}
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"messages": [],
|
||||
"reference": {},
|
||||
"answer": []
|
||||
}
|
||||
@ -25,7 +25,7 @@ from flask_cors import CORS
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services import UserService
|
||||
from api.utils import CustomJSONEncoder
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
@ -60,6 +60,7 @@ Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
commands.register_commands(app)
|
||||
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
|
||||
@ -199,9 +199,13 @@ def completion():
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
||||
|
||||
def rename_field(ans):
|
||||
for chunk_i in ans['reference'].get('chunks', []):
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
reference = ans['reference']
|
||||
if not isinstance(reference, dict):
|
||||
return
|
||||
for chunk_i in reference.get('chunks', []):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
@ -232,10 +236,7 @@ def completion():
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
|
||||
for chunk_i in answer['reference'].get('chunks',[]):
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
rename_field(answer)
|
||||
return get_json_result(data=answer)
|
||||
|
||||
except Exception as e:
|
||||
@ -252,6 +253,8 @@ def get(conversation_id):
|
||||
|
||||
conv = conv.to_dict()
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
for chunk_i in referenct_i['chunks']:
|
||||
if 'docnm_kwd' in chunk_i.keys():
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
@ -335,6 +338,8 @@ def upload():
|
||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
|
||||
@ -573,6 +578,7 @@ def completion_faq():
|
||||
response = MINIO.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -580,4 +586,56 @@ def completion_faq():
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval', methods=['POST'])
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_id = req.get("kb_id")
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Knowledgebase not found!")
|
||||
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl)
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -15,15 +15,12 @@
|
||||
#
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.db_models import UserCanvas
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request
|
||||
from graph.canvas import Canvas
|
||||
from agent.canvas import Canvas
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET'])
|
||||
@ -103,10 +100,10 @@ def run():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
assert answer, "Nothing. Is it over?"
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial)
|
||||
assert isinstance(answer, partial), "Nothing. Is it over?"
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs
|
||||
@ -135,12 +132,13 @@ def run():
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
return get_json_result(data=req["dsl"])
|
||||
return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])})
|
||||
|
||||
|
||||
@manager.route('/reset', methods=['POST'])
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
@ -29,7 +31,7 @@ from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.settings import RetCode, retrievaler
|
||||
from api.settings import RetCode, retrievaler, kg_retrievaler
|
||||
from api.utils.api_utils import get_json_result
|
||||
import hashlib
|
||||
import re
|
||||
@ -61,7 +63,8 @@ def list_chunk():
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get(
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
|
||||
id].get(
|
||||
"content_with_weight", ""),
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
@ -136,11 +139,11 @@ def set():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
@ -185,13 +188,19 @@ def switch():
|
||||
|
||||
@manager.route('/rm', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("chunk_ids")
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not ELASTICSEARCH.deleteByQuery(
|
||||
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
||||
return get_data_error_result(retmsg="Index updating failure")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -224,13 +233,12 @@ def create():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
@ -272,9 +280,10 @@ def retrieval_test():
|
||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl)
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl)
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
@ -285,3 +294,25 @@ def retrieval_test():
|
||||
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/knowledge_graph', methods=['GET'])
|
||||
@login_required
|
||||
def knowledge_graph():
|
||||
doc_id = request.args["doc_id"]
|
||||
req = {
|
||||
"doc_ids":[doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
sres = retrievaler.search(req, search.index_name(tenant_id))
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
try:
|
||||
obj[ty] = json.loads(sres.field[id]["content_with_weight"])
|
||||
except Exception as e:
|
||||
print(traceback.format_exc(), flush=True)
|
||||
|
||||
return get_json_result(data=obj)
|
||||
|
||||
|
||||
@ -16,15 +16,16 @@ import os
|
||||
import pathlib
|
||||
import re
|
||||
import warnings
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request, send_file
|
||||
from flask_login import login_required, current_user
|
||||
from httpx import HTTPError
|
||||
from minio import S3Error
|
||||
|
||||
from api.contants import NAME_LENGTH_LIMIT
|
||||
from api.db import FileType, ParserType, FileSource
|
||||
from api.db import FileType, ParserType, FileSource, TaskStatus
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import File
|
||||
from api.db.services import duplicate_name
|
||||
@ -38,10 +39,14 @@ from api.utils import get_uuid
|
||||
from api.utils.api_utils import construct_json_result, construct_error_response
|
||||
from api.utils.api_utils import construct_result, validate_request
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app import book, laws, manual, naive, one, paper, presentation, qa, resume, table, picture, audio
|
||||
from rag.nlp import search
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from rag.utils.minio_conn import MINIO
|
||||
|
||||
MAXIMUM_OF_UPLOADING_FILES = 256
|
||||
|
||||
|
||||
# ------------------------------ create a dataset ---------------------------------------
|
||||
|
||||
@manager.route("/", methods=["POST"])
|
||||
@ -116,6 +121,7 @@ def create_dataset():
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# -----------------------------list datasets-------------------------------------------------------
|
||||
|
||||
@manager.route("/", methods=["GET"])
|
||||
@ -135,6 +141,7 @@ def list_datasets():
|
||||
except HTTPError as http_err:
|
||||
return construct_json_result(http_err)
|
||||
|
||||
|
||||
# ---------------------------------delete a dataset ----------------------------
|
||||
|
||||
@manager.route("/<dataset_id>", methods=["DELETE"])
|
||||
@ -162,13 +169,15 @@ def remove_dataset(dataset_id):
|
||||
|
||||
# delete the dataset
|
||||
if not KnowledgebaseService.delete_by_id(dataset_id):
|
||||
return construct_json_result(code=RetCode.DATA_ERROR, message="There was an error during the dataset removal process. "
|
||||
"Please check the status of the RAGFlow server and try the removal again.")
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message="There was an error during the dataset removal process. "
|
||||
"Please check the status of the RAGFlow server and try the removal again.")
|
||||
# success
|
||||
return construct_json_result(code=RetCode.SUCCESS, message=f"Remove dataset: {dataset_id} successfully")
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# ------------------------------ get details of a dataset ----------------------------------------
|
||||
|
||||
@manager.route("/<dataset_id>", methods=["GET"])
|
||||
@ -182,6 +191,7 @@ def get_dataset(dataset_id):
|
||||
except Exception as e:
|
||||
return construct_json_result(e)
|
||||
|
||||
|
||||
# ------------------------------ update a dataset --------------------------------------------
|
||||
|
||||
@manager.route("/<dataset_id>", methods=["PUT"])
|
||||
@ -209,8 +219,9 @@ def update_dataset(dataset_id):
|
||||
if name.lower() != dataset.name.lower() \
|
||||
and len(KnowledgebaseService.query(name=name, tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)) > 1:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR, message=f"The name: {name.lower()} is already used by other "
|
||||
f"datasets. Please choose a different name.")
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"The name: {name.lower()} is already used by other "
|
||||
f"datasets. Please choose a different name.")
|
||||
|
||||
dataset_updating_data = {}
|
||||
chunk_num = req.get("chunk_num")
|
||||
@ -222,17 +233,22 @@ def update_dataset(dataset_id):
|
||||
if chunk_num == 0:
|
||||
dataset_updating_data["embd_id"] = req["embedding_model_id"]
|
||||
else:
|
||||
construct_json_result(code=RetCode.DATA_ERROR, message="You have already parsed the document in this "
|
||||
"dataset, so you cannot change the embedding "
|
||||
"model.")
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message="You have already parsed the document in this "
|
||||
"dataset, so you cannot change the embedding "
|
||||
"model.")
|
||||
# only if chunk_num is 0, the user can update the chunk_method
|
||||
if req.get("chunk_method"):
|
||||
if chunk_num == 0:
|
||||
dataset_updating_data['parser_id'] = req["chunk_method"]
|
||||
else:
|
||||
if "chunk_method" in req:
|
||||
type_value = req["chunk_method"]
|
||||
if is_illegal_value_for_enum(type_value, ParserType):
|
||||
return construct_json_result(message=f"Illegal value {type_value} for 'chunk_method' field.",
|
||||
code=RetCode.DATA_ERROR)
|
||||
if chunk_num != 0:
|
||||
construct_json_result(code=RetCode.DATA_ERROR, message="You have already parsed the document "
|
||||
"in this dataset, so you cannot "
|
||||
"change the chunk method.")
|
||||
dataset_updating_data["parser_id"] = req["template_type"]
|
||||
|
||||
# convert the photo parameter to avatar
|
||||
if req.get("photo"):
|
||||
dataset_updating_data["avatar"] = req["photo"]
|
||||
@ -265,6 +281,7 @@ def update_dataset(dataset_id):
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# --------------------------------content management ----------------------------------------------
|
||||
|
||||
# ----------------------------upload files-----------------------------------------------------
|
||||
@ -339,9 +356,10 @@ def upload_documents(dataset_id):
|
||||
location += "_"
|
||||
|
||||
blob = file.read()
|
||||
|
||||
# the content is empty, raising a warning
|
||||
if blob == b'':
|
||||
warnings.warn(f"[WARNING]: The file {filename} is empty.")
|
||||
warnings.warn(f"[WARNING]: The content of the file {filename} is empty.")
|
||||
|
||||
MINIO.put(dataset_id, location, blob)
|
||||
|
||||
@ -359,6 +377,8 @@ def upload_documents(dataset_id):
|
||||
}
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
DocumentService.insert(doc)
|
||||
@ -453,6 +473,7 @@ def list_documents(dataset_id):
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# ----------------------------update: enable rename-----------------------------------------------------
|
||||
@manager.route("/<dataset_id>/documents/<document_id>", methods=["PUT"])
|
||||
@login_required
|
||||
@ -555,6 +576,7 @@ def update_document(dataset_id, document_id):
|
||||
def is_illegal_value_for_enum(value, enum_class):
|
||||
return value not in enum_class.__members__.values()
|
||||
|
||||
|
||||
# ----------------------------download a file-----------------------------------------------------
|
||||
@manager.route("/<dataset_id>/documents/<document_id>", methods=["GET"])
|
||||
@login_required
|
||||
@ -563,7 +585,8 @@ def download_document(dataset_id, document_id):
|
||||
# Check whether there is this dataset
|
||||
exist, _ = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not exist:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR, message=f"This dataset '{dataset_id}' cannot be found!")
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"This dataset '{dataset_id}' cannot be found!")
|
||||
|
||||
# Check whether there is this document
|
||||
exist, document = DocumentService.get_by_id(document_id)
|
||||
@ -591,11 +614,252 @@ def download_document(dataset_id, document_id):
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
# ----------------------------start parsing-----------------------------------------------------
|
||||
|
||||
# ----------------------------stop parsing-----------------------------------------------------
|
||||
# ----------------------------start parsing a document-----------------------------------------------------
|
||||
# helper method for parsing
|
||||
# callback method
|
||||
def doc_parse_callback(doc_id, prog=None, msg=""):
|
||||
cancel = DocumentService.do_cancel(doc_id)
|
||||
if cancel:
|
||||
raise Exception("The parsing process has been cancelled!")
|
||||
|
||||
"""
|
||||
def doc_parse(binary, doc_name, parser_name, tenant_id, doc_id):
|
||||
match parser_name:
|
||||
case "book":
|
||||
book.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "laws":
|
||||
laws.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "manual":
|
||||
manual.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "naive":
|
||||
# It's the mode by default, which is general in the front-end
|
||||
naive.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "one":
|
||||
one.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "paper":
|
||||
paper.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "picture":
|
||||
picture.chunk(doc_name, binary=binary, tenant_id=tenant_id, lang="Chinese",
|
||||
callback=partial(doc_parse_callback, doc_id))
|
||||
case "presentation":
|
||||
presentation.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "qa":
|
||||
qa.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "resume":
|
||||
resume.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "table":
|
||||
table.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case "audio":
|
||||
audio.chunk(doc_name, binary=binary, callback=partial(doc_parse_callback, doc_id))
|
||||
case _:
|
||||
return False
|
||||
|
||||
return True
|
||||
"""
|
||||
|
||||
|
||||
@manager.route("/<dataset_id>/documents/<document_id>/status", methods=["POST"])
|
||||
@login_required
|
||||
def parse_document(dataset_id, document_id):
|
||||
try:
|
||||
# valid dataset
|
||||
exist, _ = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not exist:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"This dataset '{dataset_id}' cannot be found!")
|
||||
|
||||
return parsing_document_internal(document_id)
|
||||
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# ----------------------------start parsing documents-----------------------------------------------------
|
||||
@manager.route("/<dataset_id>/documents/status", methods=["POST"])
|
||||
@login_required
|
||||
def parse_documents(dataset_id):
|
||||
doc_ids = request.json["doc_ids"]
|
||||
try:
|
||||
exist, _ = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not exist:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"This dataset '{dataset_id}' cannot be found!")
|
||||
# two conditions
|
||||
if not doc_ids:
|
||||
# documents inside the dataset
|
||||
docs, total = DocumentService.list_documents_in_dataset(dataset_id, 0, -1, "create_time",
|
||||
True, "")
|
||||
doc_ids = [doc["id"] for doc in docs]
|
||||
|
||||
message = ""
|
||||
# for loop
|
||||
for id in doc_ids:
|
||||
res = parsing_document_internal(id)
|
||||
res_body = res.json
|
||||
if res_body["code"] == RetCode.SUCCESS:
|
||||
message += res_body["message"]
|
||||
else:
|
||||
return res
|
||||
return construct_json_result(data=True, code=RetCode.SUCCESS, message=message)
|
||||
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# helper method for parsing the document
|
||||
def parsing_document_internal(id):
|
||||
message = ""
|
||||
try:
|
||||
# Check whether there is this document
|
||||
exist, document = DocumentService.get_by_id(id)
|
||||
if not exist:
|
||||
return construct_json_result(message=f"This document '{id}' cannot be found!",
|
||||
code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return construct_json_result(message="Tenant not found!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
info = {"run": "1", "progress": 0}
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
|
||||
DocumentService.update_by_id(id, info)
|
||||
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
_, doc_attributes = DocumentService.get_by_id(id)
|
||||
doc_attributes = doc_attributes.to_dict()
|
||||
doc_id = doc_attributes["id"]
|
||||
|
||||
bucket, doc_name = File2DocumentService.get_minio_address(doc_id=doc_id)
|
||||
binary = MINIO.get(bucket, doc_name)
|
||||
parser_name = doc_attributes["parser_id"]
|
||||
if binary:
|
||||
res = doc_parse(binary, doc_name, parser_name, tenant_id, doc_id)
|
||||
if res is False:
|
||||
message += f"The parser id: {parser_name} of the document {doc_id} is not supported; "
|
||||
else:
|
||||
message += f"Empty data in the document: {doc_name}; "
|
||||
# failed in parsing
|
||||
if doc_attributes["status"] == TaskStatus.FAIL.value:
|
||||
message += f"Failed in parsing the document: {doc_id}; "
|
||||
return construct_json_result(code=RetCode.SUCCESS, message=message)
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# ----------------------------stop parsing a doc-----------------------------------------------------
|
||||
@manager.route("<dataset_id>/documents/<document_id>/status", methods=["DELETE"])
|
||||
@login_required
|
||||
def stop_parsing_document(dataset_id, document_id):
|
||||
try:
|
||||
# valid dataset
|
||||
exist, _ = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not exist:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"This dataset '{dataset_id}' cannot be found!")
|
||||
|
||||
return stop_parsing_document_internal(document_id)
|
||||
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# ----------------------------stop parsing docs-----------------------------------------------------
|
||||
@manager.route("<dataset_id>/documents/status", methods=["DELETE"])
|
||||
@login_required
|
||||
def stop_parsing_documents(dataset_id):
|
||||
doc_ids = request.json["doc_ids"]
|
||||
try:
|
||||
# valid dataset?
|
||||
exist, _ = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not exist:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"This dataset '{dataset_id}' cannot be found!")
|
||||
if not doc_ids:
|
||||
# documents inside the dataset
|
||||
docs, total = DocumentService.list_documents_in_dataset(dataset_id, 0, -1, "create_time",
|
||||
True, "")
|
||||
doc_ids = [doc["id"] for doc in docs]
|
||||
|
||||
message = ""
|
||||
# for loop
|
||||
for id in doc_ids:
|
||||
res = stop_parsing_document_internal(id)
|
||||
res_body = res.json
|
||||
if res_body["code"] == RetCode.SUCCESS:
|
||||
message += res_body["message"]
|
||||
else:
|
||||
return res
|
||||
return construct_json_result(data=True, code=RetCode.SUCCESS, message=message)
|
||||
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# Helper method
|
||||
def stop_parsing_document_internal(document_id):
|
||||
try:
|
||||
# valid doc?
|
||||
exist, doc = DocumentService.get_by_id(document_id)
|
||||
if not exist:
|
||||
return construct_json_result(message=f"This document '{document_id}' cannot be found!",
|
||||
code=RetCode.ARGUMENT_ERROR)
|
||||
doc_attributes = doc.to_dict()
|
||||
|
||||
# only when the status is parsing, we need to stop it
|
||||
if doc_attributes["status"] == TaskStatus.RUNNING.value:
|
||||
tenant_id = DocumentService.get_tenant_id(document_id)
|
||||
if not tenant_id:
|
||||
return construct_json_result(message="Tenant not found!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
# update successfully?
|
||||
if not DocumentService.update_by_id(document_id, {"status": "2"}): # cancel
|
||||
return construct_json_result(
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
message="There was an error during the stopping parsing the document process. "
|
||||
"Please check the status of the RAGFlow server and try the update again."
|
||||
)
|
||||
|
||||
_, doc_attributes = DocumentService.get_by_id(document_id)
|
||||
doc_attributes = doc_attributes.to_dict()
|
||||
|
||||
# failed in stop parsing
|
||||
if doc_attributes["status"] == TaskStatus.RUNNING.value:
|
||||
return construct_json_result(message=f"Failed in parsing the document: {document_id}; ", code=RetCode.SUCCESS)
|
||||
return construct_json_result(code=RetCode.SUCCESS, message="")
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
|
||||
# ----------------------------show the status of the file-----------------------------------------------------
|
||||
@manager.route("/<dataset_id>/documents/<document_id>/status", methods=["GET"])
|
||||
@login_required
|
||||
def show_parsing_status(dataset_id, document_id):
|
||||
try:
|
||||
# valid dataset
|
||||
exist, _ = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not exist:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"This dataset: '{dataset_id}' cannot be found!")
|
||||
# valid document
|
||||
exist, _ = DocumentService.get_by_id(document_id)
|
||||
if not exist:
|
||||
return construct_json_result(code=RetCode.DATA_ERROR,
|
||||
message=f"This document: '{document_id}' is not a valid document.")
|
||||
|
||||
_, doc = DocumentService.get_by_id(document_id) # get doc object
|
||||
doc_attributes = doc.to_dict()
|
||||
|
||||
return construct_json_result(
|
||||
data={"progress": doc_attributes["progress"], "status": TaskStatus(doc_attributes["status"]).name},
|
||||
code=RetCode.SUCCESS
|
||||
)
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
# ----------------------------list the chunks of the file-----------------------------------------------------
|
||||
|
||||
@ -610,6 +874,3 @@ def download_document(dataset_id, document_id):
|
||||
# ----------------------------get a specific chunk-----------------------------------------------------
|
||||
|
||||
# ----------------------------retrieval test-----------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
@ -105,6 +105,8 @@ def upload():
|
||||
}
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
DocumentService.insert(doc)
|
||||
@ -171,6 +173,8 @@ def web_crawl():
|
||||
}
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
DocumentService.insert(doc)
|
||||
|
||||
@ -20,8 +20,8 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel
|
||||
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
|
||||
import requests
|
||||
|
||||
@manager.route('/factories', methods=['GET'])
|
||||
@login_required
|
||||
@ -57,8 +57,8 @@ def set_api_key():
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
"temperature": 0.9})
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9,'max_tokens':50})
|
||||
if not tc:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
@ -126,6 +126,9 @@ def add_llm():
|
||||
api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
|
||||
f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
|
||||
f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
|
||||
elif factory == "LocalAI":
|
||||
llm_name = req["llm_name"]+"___LocalAI"
|
||||
api_key = "xxxxxxxxxxxxxxx"
|
||||
else:
|
||||
llm_name = req["llm_name"]
|
||||
api_key = "xxxxxxxxxxxxxxx"
|
||||
@ -165,6 +168,36 @@ def add_llm():
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
||||
e)
|
||||
elif llm["model_type"] == LLMType.RERANK:
|
||||
mdl = RerankModel[factory](
|
||||
key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
|
||||
)
|
||||
try:
|
||||
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
|
||||
if len(arr) == 0 or tc == 0:
|
||||
raise Exception("Not known.")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
||||
e)
|
||||
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
|
||||
mdl = CvModel[factory](
|
||||
key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
|
||||
)
|
||||
try:
|
||||
img_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/comm"
|
||||
"ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
|
||||
"0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
)
|
||||
res = requests.get(img_url)
|
||||
if res.status_code == 200:
|
||||
m, tc = mdl.describe(res.content)
|
||||
if not tc:
|
||||
raise Exception(m)
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
|
||||
@ -59,9 +59,9 @@ def status():
|
||||
|
||||
st = timer()
|
||||
try:
|
||||
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
||||
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.),
|
||||
"pending": qinfo.get("pending", 0)}
|
||||
if not REDIS_CONN.health():
|
||||
raise Exception("Lost connection!")
|
||||
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
|
||||
except Exception as e:
|
||||
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
||||
|
||||
|
||||
@ -84,6 +84,8 @@ class ParserType(StrEnum):
|
||||
NAIVE = "naive"
|
||||
PICTURE = "picture"
|
||||
ONE = "one"
|
||||
AUDIO = "audio"
|
||||
KG = "knowledge_graph"
|
||||
|
||||
|
||||
class FileSource(StrEnum):
|
||||
@ -96,4 +98,4 @@ class CanvasType(StrEnum):
|
||||
ChatBot = "chatbot"
|
||||
DocBot = "docbot"
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
|
||||
@ -144,10 +144,10 @@ def remove_field_name_prefix(field_name):
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
create_time = BigIntegerField(null=True)
|
||||
create_date = DateTimeField(null=True)
|
||||
update_time = BigIntegerField(null=True)
|
||||
update_date = DateTimeField(null=True)
|
||||
create_time = BigIntegerField(null=True, index=True)
|
||||
create_date = DateTimeField(null=True, index=True)
|
||||
update_time = BigIntegerField(null=True, index=True)
|
||||
update_date = DateTimeField(null=True, index=True)
|
||||
|
||||
def to_json(self):
|
||||
# This function is obsolete
|
||||
@ -234,7 +234,7 @@ class BaseModel(Model):
|
||||
def insert(cls, __data=None, **insert):
|
||||
if isinstance(__data, dict) and __data:
|
||||
__data[cls._meta.combined["create_time"]
|
||||
] = utils.current_timestamp()
|
||||
] = utils.current_timestamp()
|
||||
if insert:
|
||||
insert["create_time"] = utils.current_timestamp()
|
||||
|
||||
@ -248,7 +248,7 @@ class BaseModel(Model):
|
||||
return {}
|
||||
|
||||
normalized[cls._meta.combined["update_time"]
|
||||
] = utils.current_timestamp()
|
||||
] = utils.current_timestamp()
|
||||
|
||||
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
||||
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
||||
@ -332,7 +332,7 @@ DB.lock = DatabaseLock
|
||||
def close_connection():
|
||||
try:
|
||||
if DB:
|
||||
DB.close()
|
||||
DB.close_stale(age=30)
|
||||
except Exception as e:
|
||||
LOGGER.exception(e)
|
||||
|
||||
@ -373,9 +373,9 @@ def fill_db_model_object(model_object, human_model_dict):
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
||||
password = CharField(max_length=255, null=True, help_text="password")
|
||||
access_token = CharField(max_length=255, null=True, index=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
|
||||
password = CharField(max_length=255, null=True, help_text="password", index=True)
|
||||
email = CharField(
|
||||
max_length=255,
|
||||
null=False,
|
||||
@ -386,28 +386,32 @@ class User(DataBaseModel, UserMixin):
|
||||
max_length=32,
|
||||
null=True,
|
||||
help_text="English|Chinese",
|
||||
default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English")
|
||||
default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
|
||||
index=True)
|
||||
color_schema = CharField(
|
||||
max_length=32,
|
||||
null=True,
|
||||
help_text="Bright|Dark",
|
||||
default="Bright")
|
||||
default="Bright",
|
||||
index=True)
|
||||
timezone = CharField(
|
||||
max_length=64,
|
||||
null=True,
|
||||
help_text="Timezone",
|
||||
default="UTC+8\tAsia/Shanghai")
|
||||
last_login_time = DateTimeField(null=True)
|
||||
is_authenticated = CharField(max_length=1, null=False, default="1")
|
||||
is_active = CharField(max_length=1, null=False, default="1")
|
||||
is_anonymous = CharField(max_length=1, null=False, default="0")
|
||||
login_channel = CharField(null=True, help_text="from which user login")
|
||||
default="UTC+8\tAsia/Shanghai",
|
||||
index=True)
|
||||
last_login_time = DateTimeField(null=True, index=True)
|
||||
is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
|
||||
is_active = CharField(max_length=1, null=False, default="1", index=True)
|
||||
is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
|
||||
login_channel = CharField(null=True, help_text="from which user login", index=True)
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
||||
default="1",
|
||||
index=True)
|
||||
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.email
|
||||
@ -422,35 +426,41 @@ class User(DataBaseModel, UserMixin):
|
||||
|
||||
class Tenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
||||
public_key = CharField(max_length=255, null=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||
name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
|
||||
public_key = CharField(max_length=255, null=True, index=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
|
||||
embd_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default embedding model ID")
|
||||
help_text="default embedding model ID",
|
||||
index=True)
|
||||
asr_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default ASR model ID")
|
||||
help_text="default ASR model ID",
|
||||
index=True)
|
||||
img2txt_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default image to text model ID")
|
||||
help_text="default image to text model ID",
|
||||
index=True)
|
||||
rerank_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default rerank model ID")
|
||||
help_text="default rerank model ID",
|
||||
index=True)
|
||||
parser_ids = CharField(
|
||||
max_length=256,
|
||||
null=False,
|
||||
help_text="document processors")
|
||||
credit = IntegerField(default=512)
|
||||
help_text="document processors",
|
||||
index=True)
|
||||
credit = IntegerField(default=512, index=True)
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant"
|
||||
@ -458,15 +468,16 @@ class Tenant(DataBaseModel):
|
||||
|
||||
class UserTenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
user_id = CharField(max_length=32, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
||||
invited_by = CharField(max_length=32, null=False)
|
||||
user_id = CharField(max_length=32, null=False, index=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
|
||||
invited_by = CharField(max_length=32, null=False, index=True)
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "user_tenant"
|
||||
@ -474,15 +485,16 @@ class UserTenant(DataBaseModel):
|
||||
|
||||
class InvitationCode(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
code = CharField(max_length=32, null=False)
|
||||
visit_time = DateTimeField(null=True)
|
||||
user_id = CharField(max_length=32, null=True)
|
||||
tenant_id = CharField(max_length=32, null=True)
|
||||
code = CharField(max_length=32, null=False, index=True)
|
||||
visit_time = DateTimeField(null=True, index=True)
|
||||
user_id = CharField(max_length=32, null=True, index=True)
|
||||
tenant_id = CharField(max_length=32, null=True, index=True)
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "invitation_code"
|
||||
@ -498,12 +510,14 @@ class LLMFactories(DataBaseModel):
|
||||
tags = CharField(
|
||||
max_length=255,
|
||||
null=False,
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR",
|
||||
index=True)
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@ -523,18 +537,22 @@ class LLM(DataBaseModel):
|
||||
model_type = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR",
|
||||
index=True)
|
||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
|
||||
max_tokens = IntegerField(default=0)
|
||||
|
||||
tags = CharField(
|
||||
max_length=255,
|
||||
null=False,
|
||||
help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||
help_text="LLM, Text Embedding, Image2Text, Chat, 32k...",
|
||||
index=True)
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
@ -544,23 +562,27 @@ class LLM(DataBaseModel):
|
||||
|
||||
|
||||
class TenantLLM(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
llm_factory = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="LLM factory name")
|
||||
help_text="LLM factory name",
|
||||
index=True)
|
||||
model_type = CharField(
|
||||
max_length=128,
|
||||
null=True,
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
help_text="LLM, Text Embedding, Image2Text, ASR",
|
||||
index=True)
|
||||
llm_name = CharField(
|
||||
max_length=128,
|
||||
null=True,
|
||||
help_text="LLM name",
|
||||
default="")
|
||||
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||
default="",
|
||||
index=True)
|
||||
api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
used_tokens = IntegerField(default=0)
|
||||
|
||||
used_tokens = IntegerField(default=0, index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
@ -573,7 +595,7 @@ class TenantLLM(DataBaseModel):
|
||||
class Knowledgebase(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
@ -583,35 +605,40 @@ class Knowledgebase(DataBaseModel):
|
||||
max_length=32,
|
||||
null=True,
|
||||
default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
|
||||
help_text="English|Chinese")
|
||||
help_text="English|Chinese",
|
||||
index=True)
|
||||
description = TextField(null=True, help_text="KB description")
|
||||
embd_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default embedding model ID")
|
||||
help_text="default embedding model ID",
|
||||
index=True)
|
||||
permission = CharField(
|
||||
max_length=16,
|
||||
null=False,
|
||||
help_text="me|team",
|
||||
default="me")
|
||||
created_by = CharField(max_length=32, null=False)
|
||||
doc_num = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
similarity_threshold = FloatField(default=0.2)
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
default="me",
|
||||
index=True)
|
||||
created_by = CharField(max_length=32, null=False, index=True)
|
||||
doc_num = IntegerField(default=0, index=True)
|
||||
token_num = IntegerField(default=0, index=True)
|
||||
chunk_num = IntegerField(default=0, index=True)
|
||||
similarity_threshold = FloatField(default=0.2, index=True)
|
||||
vector_similarity_weight = FloatField(default=0.3, index=True)
|
||||
|
||||
parser_id = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
help_text="default parser ID",
|
||||
default=ParserType.NAIVE.value)
|
||||
default=ParserType.NAIVE.value,
|
||||
index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@ -627,18 +654,22 @@ class Document(DataBaseModel):
|
||||
parser_id = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
help_text="default parser ID")
|
||||
help_text="default parser ID",
|
||||
index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
source_type = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
default="local",
|
||||
help_text="where dose this document come from")
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
help_text="where dose this document come from",
|
||||
index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension",
|
||||
index=True)
|
||||
created_by = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
help_text="who created it")
|
||||
help_text="who created it",
|
||||
index=True)
|
||||
name = CharField(
|
||||
max_length=255,
|
||||
null=True,
|
||||
@ -647,27 +678,31 @@ class Document(DataBaseModel):
|
||||
location = CharField(
|
||||
max_length=255,
|
||||
null=True,
|
||||
help_text="where dose it store")
|
||||
size = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
help_text="where dose it store",
|
||||
index=True)
|
||||
size = IntegerField(default=0, index=True)
|
||||
token_num = IntegerField(default=0, index=True)
|
||||
chunk_num = IntegerField(default=0, index=True)
|
||||
progress = FloatField(default=0, index=True)
|
||||
progress_msg = TextField(
|
||||
null=True,
|
||||
help_text="process message",
|
||||
default="")
|
||||
process_begin_at = DateTimeField(null=True)
|
||||
process_begin_at = DateTimeField(null=True, index=True)
|
||||
process_duation = FloatField(default=0)
|
||||
|
||||
run = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="start to run processing or cancel.(1: run it; 2: cancel)",
|
||||
default="0")
|
||||
default="0",
|
||||
index=True)
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "document"
|
||||
@ -676,8 +711,7 @@ class Document(DataBaseModel):
|
||||
class File(DataBaseModel):
|
||||
id = CharField(
|
||||
max_length=32,
|
||||
primary_key=True,
|
||||
)
|
||||
primary_key=True)
|
||||
parent_id = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
@ -691,7 +725,8 @@ class File(DataBaseModel):
|
||||
created_by = CharField(
|
||||
max_length=32,
|
||||
null=False,
|
||||
help_text="who created it")
|
||||
help_text="who created it",
|
||||
index=True)
|
||||
name = CharField(
|
||||
max_length=255,
|
||||
null=False,
|
||||
@ -700,14 +735,15 @@ class File(DataBaseModel):
|
||||
location = CharField(
|
||||
max_length=255,
|
||||
null=True,
|
||||
help_text="where dose it store")
|
||||
size = IntegerField(default=0)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
help_text="where dose it store",
|
||||
index=True)
|
||||
size = IntegerField(default=0, index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
source_type = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
default="",
|
||||
help_text="where dose this document come from")
|
||||
help_text="where dose this document come from", index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "file"
|
||||
@ -716,8 +752,7 @@ class File(DataBaseModel):
|
||||
class File2Document(DataBaseModel):
|
||||
id = CharField(
|
||||
max_length=32,
|
||||
primary_key=True,
|
||||
)
|
||||
primary_key=True)
|
||||
file_id = CharField(
|
||||
max_length=32,
|
||||
null=True,
|
||||
@ -737,10 +772,13 @@ class Task(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
doc_id = CharField(max_length=32, null=False, index=True)
|
||||
from_page = IntegerField(default=0)
|
||||
|
||||
to_page = IntegerField(default=-1)
|
||||
begin_at = DateTimeField(null=True)
|
||||
|
||||
begin_at = DateTimeField(null=True, index=True)
|
||||
process_duation = FloatField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
|
||||
progress = FloatField(default=0, index=True)
|
||||
progress_msg = TextField(
|
||||
null=True,
|
||||
help_text="process message",
|
||||
@ -749,38 +787,45 @@ class Task(DataBaseModel):
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(
|
||||
max_length=255,
|
||||
null=True,
|
||||
help_text="dialog application name")
|
||||
help_text="dialog application name",
|
||||
index=True)
|
||||
description = TextField(null=True, help_text="Dialog description")
|
||||
icon = TextField(null=True, help_text="icon base64 string")
|
||||
language = CharField(
|
||||
max_length=32,
|
||||
null=True,
|
||||
default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
|
||||
help_text="English|Chinese")
|
||||
help_text="English|Chinese",
|
||||
index=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||
|
||||
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
||||
"presence_penalty": 0.4, "max_tokens": 512})
|
||||
prompt_type = CharField(
|
||||
max_length=16,
|
||||
null=False,
|
||||
default="simple",
|
||||
help_text="simple|advanced")
|
||||
help_text="simple|advanced",
|
||||
index=True)
|
||||
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
||||
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
||||
|
||||
similarity_threshold = FloatField(default=0.2)
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
|
||||
top_n = IntegerField(default=6)
|
||||
|
||||
top_k = IntegerField(default=1024)
|
||||
|
||||
do_refer = CharField(
|
||||
max_length=1,
|
||||
null=False,
|
||||
help_text="it needs to insert reference index into answer or not",
|
||||
default="1")
|
||||
help_text="it needs to insert reference index into answer or not")
|
||||
|
||||
rerank_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
@ -791,7 +836,8 @@ class Dialog(DataBaseModel):
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
default="1")
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog"
|
||||
@ -800,7 +846,7 @@ class Dialog(DataBaseModel):
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name")
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
|
||||
message = JSONField(null=True)
|
||||
reference = JSONField(null=True, default=[])
|
||||
|
||||
@ -809,8 +855,8 @@ class Conversation(DataBaseModel):
|
||||
|
||||
|
||||
class APIToken(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
token = CharField(max_length=255, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
token = CharField(max_length=255, null=False, index=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
|
||||
class Meta:
|
||||
@ -821,13 +867,14 @@ class APIToken(DataBaseModel):
|
||||
class API4Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
user_id = CharField(max_length=255, null=False, help_text="user_id")
|
||||
user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
|
||||
message = JSONField(null=True)
|
||||
reference = JSONField(null=True, default=[])
|
||||
tokens = IntegerField(default=0)
|
||||
duration = FloatField(default=0)
|
||||
round = IntegerField(default=0)
|
||||
thumb_up = IntegerField(default=0)
|
||||
|
||||
duration = FloatField(default=0, index=True)
|
||||
round = IntegerField(default=0, index=True)
|
||||
thumb_up = IntegerField(default=0, index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "api_4_conversation"
|
||||
@ -836,10 +883,11 @@ class API4Conversation(DataBaseModel):
|
||||
class UserCanvas(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
user_id = CharField(max_length=255, null=False, help_text="user_id")
|
||||
user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
|
||||
title = CharField(max_length=255, null=True, help_text="Canvas title")
|
||||
|
||||
description = TextField(null=True, help_text="Canvas description")
|
||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type")
|
||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
||||
dsl = JSONField(null=True, default={})
|
||||
|
||||
class Meta:
|
||||
@ -850,8 +898,9 @@ class CanvasTemplate(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
title = CharField(max_length=255, null=True, help_text="Canvas title")
|
||||
|
||||
description = TextField(null=True, help_text="Canvas description")
|
||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type")
|
||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
||||
dsl = JSONField(null=True, default={})
|
||||
|
||||
class Meta:
|
||||
@ -859,29 +908,44 @@ class CanvasTemplate(DataBaseModel):
|
||||
|
||||
|
||||
def migrate_db():
|
||||
with DB.transaction():
|
||||
migrator = MySQLMigrator(DB)
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
with DB.transaction():
|
||||
migrator = MySQLMigrator(DB)
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
||||
help_text="where dose this document come from",
|
||||
index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('tenant', 'rerank_id',
|
||||
CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3",
|
||||
help_text="default rerank model ID"))
|
||||
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="",
|
||||
help_text="default rerank model ID"))
|
||||
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
||||
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.alter_column_type('tenant_llm', 'api_key',
|
||||
CharField(max_length=1024, null=True, help_text="API KEY", index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -89,827 +89,30 @@ def init_superuser():
|
||||
tenant["embd_id"]))
|
||||
|
||||
|
||||
factory_infos = [{
|
||||
"name": "OpenAI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "Tongyi-Qianwen",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "ZHIPU-AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},
|
||||
{
|
||||
"name": "Ollama",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "Moonshot",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "FastEmbed",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
}, {
|
||||
"name": "Xinference",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "Youdao",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "DeepSeek",
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "VolcEngine",
|
||||
"logo": "",
|
||||
"tags": "LLM, TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "BaiChuan",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "Jina",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "BAAI",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "MiniMax",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "Mistral",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "Azure-OpenAI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "Bedrock",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
}
|
||||
# {
|
||||
# "name": "文心一言",
|
||||
# "logo": "",
|
||||
# "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
# "status": "1",
|
||||
# },
|
||||
]
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
llm_infos = [
|
||||
# ---------------------- OpenAI ------------------------
|
||||
{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4o",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 128000,
|
||||
"model_type": LLMType.CHAT.value + "," + LLMType.IMAGE2TEXT.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo-16k-0613",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16385,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "text-embedding-ada-002",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "text-embedding-3-small",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "text-embedding-3-large",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "whisper-1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"max_tokens": 25 * 1024 * 1024,
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-turbo",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-32k",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-vision-preview",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"max_tokens": 765,
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},
|
||||
# ----------------------- Qwen -----------------------
|
||||
{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-turbo",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-plus",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-max-1201",
|
||||
"tags": "LLM,CHAT,6K",
|
||||
"max_tokens": 5899,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "text-embedding-v2",
|
||||
"tags": "TEXT EMBEDDING,2K",
|
||||
"max_tokens": 2048,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "paraformer-realtime-8k-v1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"max_tokens": 25 * 1024 * 1024,
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
}, {
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-vl-max",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"max_tokens": 765,
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},
|
||||
# ---------------------- ZhipuAI ----------------------
|
||||
{
|
||||
"fid": factory_infos[2]["name"],
|
||||
"llm_name": "glm-3-turbo",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[2]["name"],
|
||||
"llm_name": "glm-4",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[2]["name"],
|
||||
"llm_name": "glm-4v",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"max_tokens": 2000,
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[2]["name"],
|
||||
"llm_name": "embedding-2",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ Moonshot -----------------------
|
||||
{
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-8k",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 7900,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-32k",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-128k",
|
||||
"tags": "LLM,CHAT",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
# ------------------------ FastEmbed -----------------------
|
||||
{
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "BAAI/bge-small-en-v1.5",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "BAAI/bge-small-zh-v1.5",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
}, {
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "BAAI/bge-base-en-v1.5",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
}, {
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "BAAI/bge-large-en-v1.5",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "nomic-ai/nomic-embed-text-v1.5",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "jinaai/jina-embeddings-v2-small-en",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 2147483648,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[5]["name"],
|
||||
"llm_name": "jinaai/jina-embeddings-v2-base-en",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 2147483648,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ Youdao -----------------------
|
||||
{
|
||||
"fid": factory_infos[7]["name"],
|
||||
"llm_name": "maidalun1020/bce-embedding-base_v1",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[7]["name"],
|
||||
"llm_name": "maidalun1020/bce-reranker-base_v1",
|
||||
"tags": "RE-RANK, 512",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
# ------------------------ DeepSeek -----------------------
|
||||
{
|
||||
"fid": factory_infos[8]["name"],
|
||||
"llm_name": "deepseek-chat",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[8]["name"],
|
||||
"llm_name": "deepseek-coder",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 16385,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
# ------------------------ VolcEngine -----------------------
|
||||
{
|
||||
"fid": factory_infos[9]["name"],
|
||||
"llm_name": "Skylark2-pro-32k",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[9]["name"],
|
||||
"llm_name": "Skylark2-pro-4k",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
# ------------------------ BaiChuan -----------------------
|
||||
{
|
||||
"fid": factory_infos[10]["name"],
|
||||
"llm_name": "Baichuan2-Turbo",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[10]["name"],
|
||||
"llm_name": "Baichuan2-Turbo-192k",
|
||||
"tags": "LLM,CHAT,192K",
|
||||
"max_tokens": 196608,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[10]["name"],
|
||||
"llm_name": "Baichuan3-Turbo",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[10]["name"],
|
||||
"llm_name": "Baichuan3-Turbo-128k",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[10]["name"],
|
||||
"llm_name": "Baichuan4",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[10]["name"],
|
||||
"llm_name": "Baichuan-Text-Embedding",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ Jina -----------------------
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-reranker-v1-base-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-reranker-v1-turbo-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-reranker-v1-tiny-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-colbert-v1-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-en",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-de",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-es",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-code",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-zh",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ BAAI -----------------------
|
||||
{
|
||||
"fid": factory_infos[12]["name"],
|
||||
"llm_name": "BAAI/bge-large-zh-v1.5",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 1024,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[12]["name"],
|
||||
"llm_name": "BAAI/bge-reranker-v2-m3",
|
||||
"tags": "RE-RANK,2k",
|
||||
"max_tokens": 2048,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
# ------------------------ Minimax -----------------------
|
||||
{
|
||||
"fid": factory_infos[13]["name"],
|
||||
"llm_name": "abab6.5-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[13]["name"],
|
||||
"llm_name": "abab6.5s-chat",
|
||||
"tags": "LLM,CHAT,245k",
|
||||
"max_tokens": 245760,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[13]["name"],
|
||||
"llm_name": "abab6.5t-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[13]["name"],
|
||||
"llm_name": "abab6.5g-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[13]["name"],
|
||||
"llm_name": "abab5.5-chat",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[13]["name"],
|
||||
"llm_name": "abab5.5s-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
# ------------------------ Mistral -----------------------
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "open-mixtral-8x22b",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "open-mixtral-8x7b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "open-mistral-7b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "mistral-large-latest",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "mistral-small-latest",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "mistral-medium-latest",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "codestral-latest",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[14]["name"],
|
||||
"llm_name": "mistral-embed",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.EMBEDDING
|
||||
},
|
||||
# ------------------------ Azure OpenAI -----------------------
|
||||
# Please ensure the llm_name is the same as the name in Azure
|
||||
# OpenAI deployment name (e.g., azure-gpt-4o). And the llm_name
|
||||
# must different from the OpenAI llm_name
|
||||
#
|
||||
# Each model must be deployed in the Azure OpenAI service, otherwise,
|
||||
# you will receive an error message 'The API deployment for
|
||||
# this resource does not exist'
|
||||
{
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-gpt-4o",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 128000,
|
||||
"model_type": LLMType.CHAT.value + "," + LLMType.IMAGE2TEXT.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-gpt-35-turbo",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-gpt-35-turbo-16k",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16385,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-text-embedding-ada-002",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-text-embedding-3-small",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-text-embedding-3-large",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-whisper-1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"max_tokens": 25 * 1024 * 1024,
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-gpt-4",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-gpt-4-turbo",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-gpt-4-32k",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[15]["name"],
|
||||
"llm_name": "azure-gpt-4-vision-preview",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"max_tokens": 765,
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},
|
||||
# ------------------------ Bedrock -----------------------
|
||||
{
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "ai21.j2-ultra-v1",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "ai21.j2-mid-v1",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "cohere.command-text-v14",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "cohere.command-light-text-v14",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "cohere.command-r-v1:0",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "cohere.command-r-plus-v1:0",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "anthropic.claude-v2",
|
||||
"tags": "LLM,CHAT,100k",
|
||||
"max_tokens": 100 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "anthropic.claude-v2:1",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "anthropic.claude-instant-v1",
|
||||
"tags": "LLM,CHAT,100k",
|
||||
"max_tokens": 100 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "amazon.titan-text-express-v1",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "amazon.titan-text-premier-v1:0",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32 * 1024,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "amazon.titan-text-lite-v1",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "meta.llama2-13b-chat-v1",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "meta.llama2-70b-chat-v1",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "meta.llama3-8b-instruct-v1:0",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "meta.llama3-70b-instruct-v1:0",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "mistral.mistral-7b-instruct-v0:2",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "mistral.mistral-large-2402-v1:0",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "mistral.mistral-small-2402-v1:0",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "amazon.titan-embed-text-v2:0",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8192,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "cohere.embed-english-v3",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 2048,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
}, {
|
||||
"fid": factory_infos[16]["name"],
|
||||
"llm_name": "cohere.embed-multilingual-v3",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 2048,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
]
|
||||
for info in factory_infos:
|
||||
try:
|
||||
LLMService.filter_delete([(LLM.fid == "MiniMax" or LLM.fid == "Minimax")])
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
factory_llm_infos = json.load(
|
||||
open(
|
||||
os.path.join(get_project_base_directory(), "conf", "llm_factories.json"),
|
||||
"r",
|
||||
)
|
||||
)
|
||||
for factory_llm_info in factory_llm_infos["factory_llm_infos"]:
|
||||
llm_infos = factory_llm_info.pop("llm")
|
||||
try:
|
||||
LLMFactoriesService.save(**info)
|
||||
except Exception as e:
|
||||
pass
|
||||
for info in llm_infos:
|
||||
try:
|
||||
LLMService.save(**info)
|
||||
LLMFactoriesService.save(**factory_llm_info)
|
||||
except Exception as e:
|
||||
pass
|
||||
for llm_info in llm_infos:
|
||||
llm_info["fid"] = factory_llm_info["name"]
|
||||
try:
|
||||
LLMService.save(**llm_info)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
|
||||
LLMService.filter_delete([LLM.fid == "Local"])
|
||||
@ -918,6 +121,8 @@ def init_llm_factory():
|
||||
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
|
||||
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
|
||||
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
||||
TenantService.filter_update([1 == 1], {
|
||||
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph"})
|
||||
## insert openai two embedding models to the current openai user.
|
||||
print("Start to insert 2 OpenAI embedding models...")
|
||||
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
|
||||
@ -940,7 +145,7 @@ def init_llm_factory():
|
||||
"""
|
||||
drop table llm;
|
||||
drop table llm_factories;
|
||||
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One';
|
||||
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
|
||||
alter table knowledgebase modify avatar longtext;
|
||||
alter table user modify avatar longtext;
|
||||
alter table dialog modify icon longtext;
|
||||
@ -948,7 +153,7 @@ def init_llm_factory():
|
||||
|
||||
|
||||
def add_graph_templates():
|
||||
dir = os.path.join(get_project_base_directory(), "graph", "templates")
|
||||
dir = os.path.join(get_project_base_directory(), "agent", "templates")
|
||||
for fnm in os.listdir(dir):
|
||||
try:
|
||||
cnvs = json.load(open(os.path.join(dir, fnm), "r"))
|
||||
|
||||
@ -13,19 +13,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.db_models import Dialog, Conversation
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api.settings import chat_logger, retrievaler
|
||||
from api.settings import chat_logger, retrievaler, kg_retrievaler
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.nlp import keyword_extraction
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
@ -73,6 +76,15 @@ def message_fit_in(msg, max_length=4000):
|
||||
return max_length, msg
|
||||
|
||||
|
||||
def llm_id2llm_type(llm_id):
|
||||
fnm = os.path.join(get_project_base_directory(), "conf")
|
||||
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||||
for llm_factory in llm_factories["factory_llm_infos"]:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].strip(",")[-1]
|
||||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
llm = LLMService.query(llm_name=dialog.llm_id)
|
||||
@ -80,7 +92,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
|
||||
if not llm:
|
||||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||||
max_tokens = 1024
|
||||
max_tokens = 8192
|
||||
else:
|
||||
max_tokens = llm[0].max_tokens
|
||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||
@ -89,9 +101,15 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retr = retrievaler if not is_kg else kg_retrievaler
|
||||
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"]
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
@ -123,7 +141,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
else:
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
||||
@ -132,7 +150,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
#self-rag
|
||||
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
|
||||
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
|
||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
||||
@ -162,8 +180,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal prompt_config, knowledges, kwargs, kbinfos
|
||||
refs = []
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
answer, idx = retrievaler.insert_citations(answer,
|
||||
answer, idx = retr.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
@ -177,10 +196,11 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
|
||||
refs = deepcopy(kbinfos)
|
||||
for c in refs["chunks"]:
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
refs = deepcopy(kbinfos)
|
||||
for c in refs["chunks"]:
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
return {"answer": answer, "reference": refs}
|
||||
@ -326,7 +346,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
|
||||
|
||||
def relevant(tenant_id, llm_id, question, contents: list):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
if llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
prompt = """
|
||||
You are a grader assessing relevance of a retrieved document to a user question.
|
||||
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
||||
@ -345,7 +368,10 @@ def relevant(tenant_id, llm_id, question, contents: list):
|
||||
|
||||
|
||||
def rewrite(tenant_id, llm_id, question):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
if llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
prompt = """
|
||||
You are an expert at query expansion to generate a paraphrasing of a question.
|
||||
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
||||
|
||||
@ -142,7 +142,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_unfinished_docs(cls):
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg]
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
|
||||
docs = cls.model.select(*fields) \
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
@ -168,7 +168,26 @@ class DocumentService(CommonService):
|
||||
chunk_num).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num - token_num,
|
||||
chunk_num=cls.model.chunk_num - chunk_num,
|
||||
process_duation=cls.model.process_duation + duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
chunk_num
|
||||
).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def clear_chunk_num(cls, doc_id):
|
||||
@ -292,7 +311,8 @@ class DocumentService(CommonService):
|
||||
prg = 0
|
||||
finished = True
|
||||
bad = 0
|
||||
status = TaskStatus.RUNNING.value
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run#TaskStatus.RUNNING.value
|
||||
for t in tsks:
|
||||
if 0 <= t.progress < 1:
|
||||
finished = False
|
||||
@ -333,6 +353,17 @@ class DocumentService(CommonService):
|
||||
cls.model.kb_id == kb_id).dicts())
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def do_cancel(cls, doc_id):
|
||||
try:
|
||||
_, doc = DocumentService.get_by_id(doc_id)
|
||||
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
except Exception as e:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def queue_raptor_tasks(doc):
|
||||
def new_task():
|
||||
nonlocal doc
|
||||
@ -347,4 +378,4 @@ def queue_raptor_tasks(doc):
|
||||
task = new_task()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
task["type"] = "raptor"
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import database_logger
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
||||
@ -70,7 +70,7 @@ class TenantLLMService(CommonService):
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id
|
||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
@ -120,6 +120,14 @@ class TenantLLMService(CommonService):
|
||||
return ChatModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], lang,
|
||||
base_url=model_config["api_base"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||
@ -207,6 +215,14 @@ class LLMBundle(object):
|
||||
"Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
|
||||
return txt
|
||||
|
||||
def transcription(self, audio):
|
||||
txt, used_tokens = self.mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error(
|
||||
"Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id))
|
||||
return txt
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
||||
if not TenantLLMService.increase_usage(
|
||||
|
||||
@ -139,6 +139,8 @@ def queue_tasks(doc, bucket, name):
|
||||
page_size = doc["parser_config"].get("task_page_size", 22)
|
||||
if doc["parser_id"] == "one":
|
||||
page_size = 1000000000
|
||||
if doc["parser_id"] == "knowledge_graph":
|
||||
page_size = 1000000000
|
||||
if not do_layout:
|
||||
page_size = 1000000000
|
||||
page_ranges = doc["parser_config"].get("pages")
|
||||
|
||||
@ -34,6 +34,7 @@ chat_logger = getLogger("chat")
|
||||
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from rag.nlp import search
|
||||
from graphrag import search as kg_search
|
||||
from api.utils import get_base_config, decrypt_database_config
|
||||
|
||||
API_VERSION = "v1"
|
||||
@ -131,7 +132,7 @@ IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||
API_KEY = LLM.get("api_key", "")
|
||||
PARSERS = LLM.get(
|
||||
"parsers",
|
||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
|
||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph")
|
||||
|
||||
# distribution
|
||||
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
||||
@ -204,6 +205,7 @@ PRIVILEGE_COMMAND_WHITELIST = []
|
||||
CHECK_NODES_IDENTITY = False
|
||||
|
||||
retrievaler = search.Dealer(ELASTICSEARCH)
|
||||
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
|
||||
|
||||
|
||||
class CustomEnum(Enum):
|
||||
|
||||
78
api/utils/commands.py
Normal file
78
api/utils/commands.py
Normal file
@ -0,0 +1,78 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import click
|
||||
import re
|
||||
|
||||
from flask import Flask
|
||||
from werkzeug.security import generate_password_hash
|
||||
|
||||
from api.db.services import UserService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
|
||||
@click.option('--new-password', prompt=True, help='the new password.')
|
||||
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
|
||||
def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
|
||||
return
|
||||
user = UserService.query(email=email)
|
||||
if not user:
|
||||
click.echo(click.style('sorry. The Email is not registered!.', fg='red'))
|
||||
return
|
||||
encode_password = base64.b64encode(new_password.encode('utf-8')).decode('utf-8')
|
||||
password_hash = generate_password_hash(encode_password)
|
||||
user_dict = {
|
||||
'password': password_hash
|
||||
}
|
||||
UserService.update_user(user[0].id,user_dict)
|
||||
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
|
||||
|
||||
|
||||
@click.command('reset-email', help='Reset the account email.')
|
||||
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
|
||||
@click.option('--new-email', prompt=True, help='the new email.')
|
||||
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
|
||||
def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
|
||||
return
|
||||
if str(new_email).strip() == str(email).strip():
|
||||
click.echo(click.style('Sorry, new email and old email are the same.', fg='red'))
|
||||
return
|
||||
user = UserService.query(email=email)
|
||||
if not user:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
return
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", new_email):
|
||||
click.echo(click.style('sorry. {} is not a valid email. '.format(new_email), fg='red'))
|
||||
return
|
||||
new_user = UserService.query(email=new_email)
|
||||
if new_user:
|
||||
click.echo(click.style('sorry. the account: [{}] is exist .'.format(new_email), fg='red'))
|
||||
return
|
||||
user_dict = {
|
||||
'email': new_email
|
||||
}
|
||||
UserService.update_user(user[0].id,user_dict)
|
||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
||||
|
||||
def register_commands(app: Flask):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@ -11,10 +11,11 @@ def crypt(line):
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read())
|
||||
rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return base64.b64encode(cipher.encrypt(
|
||||
line.encode('utf-8'))).decode("utf-8")
|
||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||
encrypted_password = cipher.encrypt(password_base64.encode())
|
||||
return base64.b64encode(encrypted_password).decode('utf-8')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
2214
conf/llm_factories.json
Normal file
2214
conf/llm_factories.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -127,7 +127,7 @@ class RAGFlowDocxParser:
|
||||
runs_within_single_paragraph.append(run.text) # append run.text first
|
||||
|
||||
# wrap page break checker into a static method
|
||||
if RAGFlowDocxParser.has_page_break(run._element.xml):
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
|
||||
secs.append(("".join(runs_within_single_paragraph), p.style.name)) # then concat run.text as part of the paragraph
|
||||
|
||||
@ -23,7 +23,7 @@ import logging
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
from timeit import default_timer as timer
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
from pypdf import PdfReader as pdf2_read
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
||||
@ -285,10 +285,10 @@ class RAGFlowPdfParser:
|
||||
"page_number": pagenum} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
||||
self.mean_height[-1] / 3
|
||||
)
|
||||
|
||||
|
||||
# merge chars in the same rect
|
||||
for c in Recognizer.sort_X_firstly(
|
||||
chars, self.mean_width[pagenum - 1] // 4):
|
||||
for c in Recognizer.sort_Y_firstly(
|
||||
chars, self.mean_height[pagenum - 1] // 4):
|
||||
ii = Recognizer.find_overlapped(c, bxs)
|
||||
if ii is None:
|
||||
self.lefted_chars.append(c)
|
||||
@ -952,7 +952,7 @@ class RAGFlowPdfParser:
|
||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(self.pdf.pages[page_from:page_to])]
|
||||
self.page_chars = [[{**c, 'top': max(0, c['top'] - 10), 'bottom': max(0, c['bottom'] - 10)} for c in page.chars if self._has_color(c)] for page in
|
||||
self.page_chars = [[{**c, 'top': c['top'], 'bottom': c['bottom']} for c in page.dedupe_chars().chars if self._has_color(c)] for page in
|
||||
self.pdf.pages[page_from:page_to]]
|
||||
self.total_page = len(self.pdf.pages)
|
||||
except Exception as e:
|
||||
@ -985,7 +985,6 @@ class RAGFlowPdfParser:
|
||||
self.is_english = True
|
||||
else:
|
||||
self.is_english = False
|
||||
self.is_english = False
|
||||
|
||||
st = timer()
|
||||
for i, img in enumerate(self.page_images):
|
||||
|
||||
@ -52,7 +52,7 @@ class RAGFlowPptParser(object):
|
||||
break
|
||||
texts = []
|
||||
for shape in sorted(
|
||||
slide.shapes, key=lambda x: (x.top // 10, x.left)):
|
||||
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left)):
|
||||
txt = self.__extract(shape)
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
|
||||
@ -75,7 +75,7 @@ class LayoutRecognizer(Recognizer):
|
||||
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
|
||||
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
||||
"page_number": pn,
|
||||
} for b in lts]
|
||||
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
|
||||
lts = self.sort_Y_firstly(lts, np.mean(
|
||||
[l["bottom"] - l["top"] for l in lts]) / 2)
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
|
||||
@ -23,7 +23,6 @@ import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from .postprocess import build_post_process
|
||||
from rag.settings import cron_logger
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
@ -566,9 +565,6 @@ class OCR(object):
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
else:
|
||||
cron_logger.debug("dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
|
||||
return zip(self.sorted_boxes(dt_boxes), [
|
||||
("", 0) for _ in range(len(dt_boxes))])
|
||||
@ -597,9 +593,7 @@ class OCR(object):
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
else:
|
||||
cron_logger.debug("dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = self.sorted_boxes(dt_boxes)
|
||||
@ -612,8 +606,6 @@ class OCR(object):
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
|
||||
time_dict['rec'] = elapse
|
||||
cron_logger.debug("rec_res num : {}, elapsed : {}".format(
|
||||
len(rec_res), elapse))
|
||||
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
|
||||
@ -19,7 +19,6 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .operators import *
|
||||
from rag.settings import cron_logger
|
||||
|
||||
|
||||
class Recognizer(object):
|
||||
@ -340,7 +339,6 @@ class Recognizer(object):
|
||||
if score < thr:
|
||||
continue
|
||||
if clsid >= len(self.label_list):
|
||||
cron_logger.warning(f"bad category id")
|
||||
continue
|
||||
bb.append({
|
||||
"type": self.label_list[clsid].lower(),
|
||||
|
||||
@ -10,6 +10,8 @@ ELASTIC_PASSWORD=infini_rag_flow
|
||||
|
||||
# Port to expose Kibana to the host
|
||||
KIBANA_PORT=6601
|
||||
KIBANA_USER=rag_flow
|
||||
KIBANA_PASSWORD=infini_rag_flow
|
||||
|
||||
# Increase or decrease based on the available host memory (in bytes)
|
||||
|
||||
|
||||
33
docker/docker-compose-admin-tool.yml
Normal file
33
docker/docker-compose-admin-tool.yml
Normal file
@ -0,0 +1,33 @@
|
||||
services:
|
||||
kibana:
|
||||
image: kibana:${STACK_VERSION}
|
||||
container_name: ragflow-kibana
|
||||
environment:
|
||||
ELASTICSEARCH_USERNAME: ${KIBANA_USER}
|
||||
ELASTICSEARCH_PASSWORD: ${KIBANA_PASSWORD}
|
||||
ELASTICSEARCH_HOSTS: "http://es01:9200"
|
||||
ports:
|
||||
- ${KIBANA_PORT}:5601
|
||||
depends_on:
|
||||
es01:
|
||||
condition: service_healthy
|
||||
kibana-user-init:
|
||||
condition: service_completed_successfully
|
||||
|
||||
networks:
|
||||
- ragflow
|
||||
kibana-user-init:
|
||||
image: appropriate/curl
|
||||
depends_on:
|
||||
es01:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./init-kibana.sh:/app/init-kibana.sh
|
||||
environment:
|
||||
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
||||
- KIBANA_USER=${KIBANA_USER}
|
||||
- KIBANA_PASSWORD=${KIBANA_PASSWORD}
|
||||
command: /bin/sh -c "sh /app/init-kibana.sh"
|
||||
networks:
|
||||
- ragflow
|
||||
restart: 'no'
|
||||
37
docker/docker-compose-gpu-CN-oc9.yml
Normal file
37
docker/docker-compose-gpu-CN-oc9.yml
Normal file
@ -0,0 +1,37 @@
|
||||
include:
|
||||
- path: ./docker-compose-base.yml
|
||||
env_file: ./.env
|
||||
|
||||
services:
|
||||
ragflow:
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
es01:
|
||||
condition: service_healthy
|
||||
image: edwardelric233/ragflow:oc9
|
||||
container_name: ragflow-server
|
||||
ports:
|
||||
- ${SVR_HTTP_PORT}:9380
|
||||
- 80:80
|
||||
- 443:443
|
||||
volumes:
|
||||
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
||||
- ./ragflow-logs:/ragflow/logs
|
||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
environment:
|
||||
- TZ=${TIMEZONE}
|
||||
- HF_ENDPOINT=https://hf-mirror.com
|
||||
- MACOS=${MACOS}
|
||||
networks:
|
||||
- ragflow
|
||||
restart: always
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
37
docker/docker-compose-gpu-CN.yml
Normal file
37
docker/docker-compose-gpu-CN.yml
Normal file
@ -0,0 +1,37 @@
|
||||
include:
|
||||
- path: ./docker-compose-base.yml
|
||||
env_file: ./.env
|
||||
|
||||
services:
|
||||
ragflow:
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
es01:
|
||||
condition: service_healthy
|
||||
image: swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:${RAGFLOW_VERSION}
|
||||
container_name: ragflow-server
|
||||
ports:
|
||||
- ${SVR_HTTP_PORT}:9380
|
||||
- 80:80
|
||||
- 443:443
|
||||
volumes:
|
||||
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
||||
- ./ragflow-logs:/ragflow/logs
|
||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
environment:
|
||||
- TZ=${TIMEZONE}
|
||||
- HF_ENDPOINT=https://hf-mirror.com
|
||||
- MACOS=${MACOS}
|
||||
networks:
|
||||
- ragflow
|
||||
restart: always
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
37
docker/docker-compose-gpu.yml
Normal file
37
docker/docker-compose-gpu.yml
Normal file
@ -0,0 +1,37 @@
|
||||
include:
|
||||
- path: ./docker-compose-base.yml
|
||||
env_file: ./.env
|
||||
|
||||
services:
|
||||
ragflow:
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
es01:
|
||||
condition: service_healthy
|
||||
image: infiniflow/ragflow:${RAGFLOW_VERSION}
|
||||
container_name: ragflow-server
|
||||
ports:
|
||||
- ${SVR_HTTP_PORT}:9380
|
||||
- 80:80
|
||||
- 443:443
|
||||
volumes:
|
||||
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
||||
- ./ragflow-logs:/ragflow/logs
|
||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
environment:
|
||||
- TZ=${TIMEZONE}
|
||||
- HF_ENDPOINT=https://huggingface.co
|
||||
- MACOS=${MACOS}
|
||||
networks:
|
||||
- ragflow
|
||||
restart: always
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
30
docker/init-kibana.sh
Executable file
30
docker/init-kibana.sh
Executable file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 等待 Elasticsearch 啟動
|
||||
until curl -u "elastic:${ELASTIC_PASSWORD}" -s http://es01:9200 >/dev/null; do
|
||||
echo "等待 Elasticsearch 啟動..."
|
||||
sleep 5
|
||||
done
|
||||
|
||||
|
||||
echo "使用者: elastic:${ELASTIC_PASSWORD}"
|
||||
|
||||
|
||||
|
||||
PAYLOAD="{
|
||||
\"password\" : \"${KIBANA_PASSWORD}\",
|
||||
\"roles\" : [ \"kibana_admin\",\"kibana_system\" ],
|
||||
\"full_name\" : \"${KIBANA_USER}\",
|
||||
\"email\" : \"${KIBANA_USER}@example.com\"
|
||||
}"
|
||||
echo "新用戶帳戶: $PAYLOAD"
|
||||
|
||||
# 創建新用戶帳戶
|
||||
curl -X POST "http://es01:9200/_security/user/${KIBANA_USER}" \
|
||||
-u "elastic:${ELASTIC_PASSWORD}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$PAYLOAD"s
|
||||
|
||||
echo "新用戶帳戶已創建"
|
||||
|
||||
exit 0
|
||||
@ -107,6 +107,10 @@ RAGFlow features visibility and explainability, allowing you to view the chunkin
|
||||
|
||||

|
||||
|
||||
:::caution NOTE
|
||||
You can add keywords to a file chunk to increase its relevance. This action increases its keyword weight and can improve its position in search list.
|
||||
:::
|
||||
|
||||
4. In Retrieval testing, ask a quick question in **Test text** to double check if your configurations work:
|
||||
|
||||
_As you can tell from the following, RAGFlow responds with truthful citations._
|
||||
@ -124,7 +128,7 @@ RAGFlow uses multiple recall of both full-text search and vector search in its c
|
||||
|
||||
## Search for knowledge base
|
||||
|
||||
As of RAGFlow v0.8.0, the search feature is still in a rudimentary form, supporting only knowledge base search by name.
|
||||
As of RAGFlow v0.9.0, the search feature is still in a rudimentary form, supporting only knowledge base search by name.
|
||||
|
||||

|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ slug: /deploy_local_llm
|
||||
---
|
||||
|
||||
# Deploy a local LLM
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
RAGFlow supports deploying models locally using Ollama or Xinference. If you have locally deployed models to leverage or wish to enable GPU or CUDA for inference acceleration, you can bind Ollama or Xinference into RAGFlow and use either of them as a local "server" for interacting with your local models.
|
||||
|
||||
@ -108,7 +110,7 @@ Update your chat model accordingly in **Chat Configuration**:
|
||||
|
||||
## Deploy a local model using Xinference
|
||||
|
||||
Xorbits Inference([Xinference](https://github.com/xorbitsai/inference)) enables you to unleash the full potential of cutting-edge AI models.
|
||||
Xorbits Inference ([Xinference](https://github.com/xorbitsai/inference)) enables you to unleash the full potential of cutting-edge AI models.
|
||||
|
||||
:::note
|
||||
- For information about installing Xinference Ollama, see [here](https://inference.readthedocs.io/en/latest/getting_started/).
|
||||
@ -129,8 +131,8 @@ $ xinference-local --host 0.0.0.0 --port 9997
|
||||
|
||||
### 3. Launch your local model
|
||||
|
||||
Launch your local model (**Mistral**), ensuring that you replace `${quantization}` with your chosen quantization method
|
||||
:
|
||||
Launch your local model (**Mistral**), ensuring that you replace `${quantization}` with your chosen quantization method:
|
||||
|
||||
```bash
|
||||
$ xinference launch -u mistral --model-name mistral-v0.1 --size-in-billions 7 --model-format pytorch --quantization ${quantization}
|
||||
```
|
||||
@ -143,6 +145,7 @@ In RAGFlow, click on your logo on the top right of the page **>** **Model Provid
|
||||
### 5. Complete basic Xinference settings
|
||||
|
||||
Enter an accessible base URL, such as `http://<your-xinference-endpoint-domain>:9997/v1`.
|
||||
> For rerank model, please use the `http://<your-xinference-endpoint-domain>:9997/v1/rerank` as the base URL.
|
||||
|
||||
### 6. Update System Model Settings
|
||||
|
||||
@ -160,9 +163,9 @@ Update your chat model accordingly in **Chat Configuration**:
|
||||
|
||||
## Deploy a local model using IPEX-LLM
|
||||
|
||||
IPEX-LLM([IPEX-LLM](https://github.com/intel-analytics/ipex-llm)) is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency
|
||||
[IPEX-LLM](https://github.com/intel-analytics/ipex-llm) is a PyTorch library for running LLMs on local Intel CPUs or GPUs (including iGPU or discrete GPUs like Arc, Flex, and Max) with low latency. It supports Ollama on Linux and Windows systems.
|
||||
|
||||
To deploy a local model, eg., **Qwen2**, using IPEX-LLM, follow the steps below:
|
||||
To deploy a local model, e.g., **Qwen2**, using IPEX-LLM-accelerated Ollama:
|
||||
|
||||
### 1. Check firewall settings
|
||||
|
||||
@ -172,46 +175,69 @@ Ensure that your host machine's firewall allows inbound connections on port 1143
|
||||
sudo ufw allow 11434/tcp
|
||||
```
|
||||
|
||||
### 2. Install and Start Ollama serve using IPEX-LLM
|
||||
### 2. Launch Ollama service using IPEX-LLM
|
||||
|
||||
#### 2.1 Install IPEX-LLM for Ollama
|
||||
|
||||
IPEX-LLM's support for `ollama` now is available for Linux system and Windows system.
|
||||
:::tip NOTE
|
||||
IPEX-LLM's supports Ollama on Linux and Windows systems.
|
||||
:::
|
||||
|
||||
Visit [Run llama.cpp with IPEX-LLM on Intel GPU Guide](https://github.com/intel-analytics/ipex-llm/blob/main/docs/mddocs/Quickstart/llama_cpp_quickstart.md), and follow the instructions in section [Prerequisites](https://github.com/intel-analytics/ipex-llm/blob/main/docs/mddocs/Quickstart/llama_cpp_quickstart.md#0-prerequisites) to setup and section [Install IPEX-LLM cpp](https://github.com/intel-analytics/ipex-llm/blob/main/docs/mddocs/Quickstart/llama_cpp_quickstart.md#1-install-ipex-llm-for-llamacpp) to install the IPEX-LLM with Ollama binaries.
|
||||
For detailed information about installing IPEX-LLM for Ollama, see [Run llama.cpp with IPEX-LLM on Intel GPU Guide](https://github.com/intel-analytics/ipex-llm/blob/main/docs/mddocs/Quickstart/llama_cpp_quickstart.md):
|
||||
- [Prerequisites](https://github.com/intel-analytics/ipex-llm/blob/main/docs/mddocs/Quickstart/llama_cpp_quickstart.md#0-prerequisites)
|
||||
- [Install IPEX-LLM cpp with Ollama binaries](https://github.com/intel-analytics/ipex-llm/blob/main/docs/mddocs/Quickstart/llama_cpp_quickstart.md#1-install-ipex-llm-for-llamacpp)
|
||||
|
||||
**After the installation, you should have created a conda environment, named `llm-cpp` for instance, for running `ollama` commands with IPEX-LLM.**
|
||||
*After the installation, you should have created a Conda environment, e.g., `llm-cpp`, for running Ollama commands with IPEX-LLM.*
|
||||
|
||||
#### 2.2 Initialize Ollama
|
||||
|
||||
Activate the `llm-cpp` conda environment and initialize Ollama by executing the commands below. A symbolic link to `ollama` will appear in your current directory.
|
||||
1. Activate the `llm-cpp` Conda environment and initialize Ollama:
|
||||
|
||||
- For **Linux users**:
|
||||
<Tabs
|
||||
defaultValue="linux"
|
||||
values={[
|
||||
{label: 'Linux', value: 'linux'},
|
||||
{label: 'Windows', value: 'windows'},
|
||||
]}>
|
||||
<TabItem value="linux">
|
||||
|
||||
```bash
|
||||
conda activate llm-cpp
|
||||
init-ollama
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="windows">
|
||||
|
||||
- For **Windows users**:
|
||||
|
||||
Please run the following command with **administrator privilege in Miniforge Prompt**.
|
||||
Run these commands with *administrator privileges in Miniforge Prompt*:
|
||||
|
||||
```cmd
|
||||
conda activate llm-cpp
|
||||
init-ollama.bat
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
> [!NOTE]
|
||||
> If you have installed higher version `ipex-llm[cpp]` and want to upgrade your ollama binary file, don't forget to remove old binary files first and initialize again with `init-ollama` or `init-ollama.bat`.
|
||||
2. If the installed `ipex-llm[cpp]` requires an upgrade to the Ollama binary files, remove the old binary files and reinitialize Ollama using `init-ollama` (Linux) or `init-ollama.bat` (Windows).
|
||||
|
||||
*A symbolic link to Ollama appears in your current directory, and you can use this executable file following standard Ollama commands.*
|
||||
|
||||
**Now you can use this executable file by standard ollama's usage.**
|
||||
#### 2.3 Launch Ollama service
|
||||
|
||||
#### 2.3 Run Ollama Serve
|
||||
1. Set the environment variable `OLLAMA_NUM_GPU` to `999` to ensure that all layers of your model run on the Intel GPU; otherwise, some layers may default to CPU.
|
||||
2. For optimal performance on Intel Arc™ A-Series Graphics with Linux OS (Kernel 6.2), set the following environment variable before launching the Ollama service:
|
||||
|
||||
You may launch the Ollama service as below:
|
||||
```bash
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
```
|
||||
3. Launch the Ollama service:
|
||||
|
||||
- For **Linux users**:
|
||||
<Tabs
|
||||
defaultValue="linux"
|
||||
values={[
|
||||
{label: 'Linux', value: 'linux'},
|
||||
{label: 'Windows', value: 'windows'},
|
||||
]}>
|
||||
<TabItem value="linux">
|
||||
|
||||
```bash
|
||||
export OLLAMA_NUM_GPU=999
|
||||
@ -223,9 +249,10 @@ You may launch the Ollama service as below:
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
- For **Windows users**:
|
||||
</TabItem>
|
||||
<TabItem value="windows">
|
||||
|
||||
Please run the following command in Miniforge Prompt.
|
||||
Run the following command *in Miniforge Prompt*:
|
||||
|
||||
```cmd
|
||||
set OLLAMA_NUM_GPU=999
|
||||
@ -235,49 +262,54 @@ You may launch the Ollama service as below:
|
||||
|
||||
ollama serve
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
> Please set environment variable `OLLAMA_NUM_GPU` to `999` to make sure all layers of your model are running on Intel GPU, otherwise, some layers may run on CPU.
|
||||
:::tip NOTE
|
||||
To enable the Ollama service to accept connections from all IP addresses, use `OLLAMA_HOST=0.0.0.0 ./ollama serve` rather than simply `./ollama serve`.
|
||||
:::
|
||||
|
||||
|
||||
> If your local LLM is running on Intel Arc™ A-Series Graphics with Linux OS (Kernel 6.2), it is recommended to additionaly set the following environment variable for optimal performance before executing `ollama serve`:
|
||||
>
|
||||
> ```bash
|
||||
> export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
> ```
|
||||
|
||||
|
||||
> To allow the service to accept connections from all IP addresses, use `OLLAMA_HOST=0.0.0.0 ./ollama serve` instead of just `./ollama serve`.
|
||||
|
||||
The console will display messages similar to the following:
|
||||
*The console displays messages similar to the following:*
|
||||
|
||||

|
||||
|
||||
### 3. Pull and Run Ollama Model
|
||||
### 3. Pull and Run Ollama model
|
||||
|
||||
Keep the Ollama service on and open another terminal and run `./ollama pull <model_name>` in Linux (`ollama.exe pull <model_name>` in Windows) to automatically pull a model. e.g. `qwen2:latest`:
|
||||
#### 3.1 Pull Ollama model
|
||||
|
||||
With the Ollama service running, open a new terminal and run `./ollama pull <model_name>` (Linux) or `ollama.exe pull <model_name>` (Windows) to pull the desired model. e.g., `qwen2:latest`:
|
||||
|
||||

|
||||
|
||||
#### Run Ollama Model
|
||||
#### 3.2 Run Ollama model
|
||||
|
||||
<Tabs
|
||||
defaultValue="linux"
|
||||
values={[
|
||||
{label: 'Linux', value: 'linux'},
|
||||
{label: 'Windows', value: 'windows'},
|
||||
]}>
|
||||
<TabItem value="linux">
|
||||
|
||||
- For **Linux users**:
|
||||
```bash
|
||||
./ollama run qwen2:latest
|
||||
```
|
||||
|
||||
- For **Windows users**:
|
||||
</TabItem>
|
||||
<TabItem value="windows">
|
||||
|
||||
```cmd
|
||||
ollama run qwen2:latest
|
||||
```
|
||||
### 4. Configure RAGflow to use IPEX-LLM accelerated Ollama
|
||||
|
||||
The confiugraiton follows the steps in
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
Ollama Section 4 [Add Ollama](#4-add-ollama),
|
||||
### 4. Configure RAGflow
|
||||
|
||||
Section 5 [Complete basic Ollama settings](#5-complete-basic-ollama-settings),
|
||||
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:
|
||||
|
||||
Section 6 [Update System Model Settings](#6-update-system-model-settings),
|
||||
|
||||
Section 7 [Update Chat Configuration](#7-update-chat-configuration)
|
||||
1. [Add Ollama](#4-add-ollama)
|
||||
2. [Complete basic Ollama settings](#5-complete-basic-ollama-settings)
|
||||
3. [Update System Model Settings](#6-update-system-model-settings)
|
||||
4. [Update Chat Configuration](#7-update-chat-configuration)
|
||||
@ -11,16 +11,25 @@ An API key is required for RAGFlow to interact with an online AI model. This gui
|
||||
|
||||
For now, RAGFlow supports the following online LLMs. Click the corresponding link to apply for your API key. Most LLM providers grant newly-created accounts trial credit, which will expire in a couple of months, or a promotional amount of free quota.
|
||||
|
||||
- [OpenAI](https://platform.openai.com/login?launch),
|
||||
- [Tongyi-Qianwen](https://dashscope.console.aliyun.com/model),
|
||||
- [ZHIPU-AI](https://open.bigmodel.cn/),
|
||||
- [Moonshot](https://platform.moonshot.cn/docs),
|
||||
- [DeepSeek](https://platform.deepseek.com/api-docs/),
|
||||
- [Baichuan](https://www.baichuan-ai.com/home),
|
||||
- [VolcEngine](https://www.volcengine.com/docs/82379).
|
||||
- [OpenAI](https://platform.openai.com/login?launch)
|
||||
- [Azure-OpenAI](https://ai.azure.com/)
|
||||
- [Gemini](https://aistudio.google.com/)
|
||||
- [Groq](https://console.groq.com/)
|
||||
- [Mistral](https://mistral.ai/)
|
||||
- [Bedrock](https://aws.amazon.com/cn/bedrock/)
|
||||
- [Tongyi-Qianwen](https://dashscope.console.aliyun.com/model)
|
||||
- [ZHIPU-AI](https://open.bigmodel.cn/)
|
||||
- [MiniMax](https://platform.minimaxi.com/)
|
||||
- [Moonshot](https://platform.moonshot.cn/docs)
|
||||
- [DeepSeek](https://platform.deepseek.com/api-docs/)
|
||||
- [Baichuan](https://www.baichuan-ai.com/home)
|
||||
- [VolcEngine](https://www.volcengine.com/docs/82379)
|
||||
- [Jina](https://jina.ai/reader/)
|
||||
- [OpenRouter](https://openrouter.ai/)
|
||||
- [StepFun](https://platform.stepfun.com/)
|
||||
|
||||
:::note
|
||||
If you find your online LLM is not on the list, don't feel disheartened. The list is expanding, and you can [file a feature request](https://github.com/infiniflow/ragflow/issues/new?assignees=&labels=feature+request&projects=&template=feature_request.yml&title=%5BFeature+Request%5D%3A+) with us! Alternatively, if you have customized or locally-deployed models, you can [bind them to RAGFlow using Ollama or Xinference](./deploy_local_llm.md).
|
||||
If you find your online LLM is not on the list, don't feel disheartened. The list is expanding, and you can [file a feature request](https://github.com/infiniflow/ragflow/issues/new?assignees=&labels=feature+request&projects=&template=feature_request.yml&title=%5BFeature+Request%5D%3A+) with us! Alternatively, if you have customized or locally-deployed models, you can [bind them to RAGFlow using Ollama, Xinferenc, or LocalAI](./deploy_local_llm.md).
|
||||
:::
|
||||
|
||||
## Configure your API key
|
||||
|
||||
@ -43,13 +43,13 @@ You can link your file to one knowledge base or multiple knowledge bases at one
|
||||
|
||||

|
||||
|
||||
## Move file to specified folder
|
||||
## Move file to a specific folder
|
||||
|
||||
As of RAGFlow v0.8.0, this feature is *not* available.
|
||||
As of RAGFlow v0.9.0, this feature is *not* available.
|
||||
|
||||
## Search files or folders
|
||||
|
||||
As of RAGFlow v0.8.0, the search feature is still in a rudimentary form, supporting only file and folder search in the current directory by name (files or folders in the child directory will not be retrieved).
|
||||
As of RAGFlow v0.9.0, the search feature is still in a rudimentary form, supporting only file and folder search in the current directory by name (files or folders in the child directory will not be retrieved).
|
||||
|
||||

|
||||
|
||||
@ -81,4 +81,4 @@ RAGFlow's file management allows you to download an uploaded file:
|
||||
|
||||

|
||||
|
||||
> As of RAGFlow v0.8.0, bulk download is not supported, nor can you download an entire folder.
|
||||
> As of RAGFlow v0.9.0, bulk download is not supported, nor can you download an entire folder.
|
||||
|
||||
@ -34,7 +34,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
|
||||
|
||||
`vm.max_map_count`. This value sets the maximum number of memory map areas a process may have. Its default value is 65530. While most applications require fewer than a thousand maps, reducing this value can result in abmornal behaviors, and the system will throw out-of-memory errors when a process reaches the limitation.
|
||||
|
||||
RAGFlow v0.8.0 uses Elasticsearch for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component.
|
||||
RAGFlow v0.9.0 uses Elasticsearch for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component.
|
||||
|
||||
<Tabs
|
||||
defaultValue="linux"
|
||||
@ -132,7 +132,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
|
||||
|
||||
3. Build the pre-built Docker images and start up the server:
|
||||
|
||||
> Running the following commands automatically downloads the *dev* version RAGFlow Docker image. To download and run a specified Docker version, update `RAGFLOW_VERSION` in **docker/.env** to the intended version, for example `RAGFLOW_VERSION=v0.8.0`, before running the following commands.
|
||||
> Running the following commands automatically downloads the *dev* version RAGFlow Docker image. To download and run a specified Docker version, update `RAGFLOW_VERSION` in **docker/.env** to the intended version, for example `RAGFLOW_VERSION=v0.9.0`, before running the following commands.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
@ -176,15 +176,26 @@ With the default settings, you only need to enter `http://IP_OF_YOUR_MACHINE` (*
|
||||
|
||||
RAGFlow is a RAG engine, and it needs to work with an LLM to offer grounded, hallucination-free question-answering capabilities. For now, RAGFlow supports the following LLMs, and the list is expanding:
|
||||
|
||||
- OpenAI
|
||||
- Tongyi-Qianwen
|
||||
- ZHIPU-AI
|
||||
- Moonshot
|
||||
- DeepSeek-V2
|
||||
- Baichuan
|
||||
- VolcEngine
|
||||
- [OpenAI](https://platform.openai.com/login?launch)
|
||||
- [Azure-OpenAI](https://ai.azure.com/)
|
||||
- [Gemini](https://aistudio.google.com/)
|
||||
- [Groq](https://console.groq.com/)
|
||||
- [Mistral](https://mistral.ai/)
|
||||
- [Bedrock](https://aws.amazon.com/cn/bedrock/)
|
||||
- [Tongyi-Qianwen](https://dashscope.console.aliyun.com/model)
|
||||
- [ZHIPU-AI](https://open.bigmodel.cn/)
|
||||
- [MiniMax](https://platform.minimaxi.com/)
|
||||
- [Moonshot](https://platform.moonshot.cn/docs)
|
||||
- [DeepSeek](https://platform.deepseek.com/api-docs/)
|
||||
- [Baichuan](https://www.baichuan-ai.com/home)
|
||||
- [VolcEngine](https://www.volcengine.com/docs/82379)
|
||||
- [Jina](https://jina.ai/reader/)
|
||||
- [OpenRouter](https://openrouter.ai/)
|
||||
- [StepFun](https://platform.stepfun.com/)
|
||||
|
||||
> RAGFlow also supports deploying LLMs locally using Ollama or Xinference, but this part is not covered in this quick start guide.
|
||||
:::note
|
||||
RAGFlow also supports deploying LLMs locally using Ollama, Xinference, or LocalAI, but this part is not covered in this quick start guide.
|
||||
:::
|
||||
|
||||
To add and configure an LLM:
|
||||
|
||||
@ -192,7 +203,7 @@ To add and configure an LLM:
|
||||
|
||||

|
||||
|
||||
> Each RAGFlow account is able to use **text-embedding-v2** for free, a embedding model of Tongyi-Qianwen. This is why you can see Tongyi-Qianwen in the **Added models** list. And you may need to update your Tongyi-Qianwen API key at a later point.
|
||||
> Each RAGFlow account is able to use **text-embedding-v2** for free, an embedding model of Tongyi-Qianwen. This is why you can see Tongyi-Qianwen in the **Added models** list. And you may need to update your Tongyi-Qianwen API key at a later point.
|
||||
|
||||
2. Click on the desired LLM and update the API key accordingly (DeepSeek-V2 in this case):
|
||||
|
||||
@ -228,7 +239,9 @@ To create your first knowledge base:
|
||||
|
||||
3. RAGFlow offers multiple chunk templates that cater to different document layouts and file formats. Select the embedding model and chunk method (template) for your knowledge base.
|
||||
|
||||
> IMPORTANT: Once you have selected an embedding model and used it to parse a file, you are no longer allowed to change it. The obvious reason is that we must ensure that all files in a specific knowledge base are parsed using the *same* embedding model (ensure that they are being compared in the same embedding space).
|
||||
:::danger IMPORTANT
|
||||
Once you have selected an embedding model and used it to parse a file, you are no longer allowed to change it. The obvious reason is that we must ensure that all files in a specific knowledge base are parsed using the *same* embedding model (ensure that they are being compared in the same embedding space).
|
||||
:::
|
||||
|
||||
_You are taken to the **Dataset** page of your knowledge base._
|
||||
|
||||
@ -240,6 +253,11 @@ To create your first knowledge base:
|
||||
|
||||
_When the file parsing completes, its parsing status changes to **SUCCESS**._
|
||||
|
||||
:::alert NOTE
|
||||
- If your file parsing gets stuck at below 1%, see [FAQ 4.3](https://ragflow.io/docs/dev/faq#43-why-does-my-document-parsing-stall-at-under-one-percent).
|
||||
- If your file parsing gets stuck at near completion, see [FAQ 4.4](https://ragflow.io/docs/dev/faq#44-why-does-my-pdf-parsing-stall-near-completion-while-the-log-does-not-show-any-error)
|
||||
:::
|
||||
|
||||
## Intervene with file parsing
|
||||
|
||||
RAGFlow features visibility and explainability, allowing you to view the chunking results and intervene where necessary. To do so:
|
||||
@ -256,6 +274,10 @@ RAGFlow features visibility and explainability, allowing you to view the chunkin
|
||||
|
||||

|
||||
|
||||
:::caution NOTE
|
||||
You can add keywords to a file chunk to increase its relevance. This action increases its keyword weight and can improve its position in search list.
|
||||
:::
|
||||
|
||||
4. In Retrieval testing, ask a quick question in **Test text** to double check if your configurations work:
|
||||
|
||||
_As you can tell from the following, RAGFlow responds with truthful citations._
|
||||
|
||||
@ -198,7 +198,7 @@ Ignore this warning and continue. All system warnings can be ignored.
|
||||
|
||||

|
||||
|
||||
If your RAGFlow is deployed *locally*, try the following:
|
||||
Click the red cross beside the 'parsing status' bar, then restart the parsing process to see if the issue remains. If the issue persists and your RAGFlow is deployed locally, try the following:
|
||||
|
||||
1. Check the log of your RAGFlow server to see if it is running properly:
|
||||
```bash
|
||||
@ -209,15 +209,17 @@ docker logs -f ragflow-server
|
||||
|
||||
#### 4.4 Why does my pdf parsing stall near completion, while the log does not show any error?
|
||||
|
||||
If your RAGFlow is deployed *locally*, the parsing process is likely killed due to insufficient RAM. Try increasing your memory allocation by increasing the `MEM_LIMIT` value in **docker/.env**.
|
||||
Click the red cross beside the 'parsing status' bar, then restart the parsing process to see if the issue remains. If the issue persists and your RAGFlow is deployed locally, the parsing process is likely killed due to insufficient RAM. Try increasing your memory allocation by increasing the `MEM_LIMIT` value in **docker/.env**.
|
||||
|
||||
> Ensure that you restart up your RAGFlow server for your changes to take effect!
|
||||
> ```bash
|
||||
> docker compose stop
|
||||
> ```
|
||||
> ```bash
|
||||
> docker compose up -d
|
||||
> ```
|
||||
:::note
|
||||
Ensure that you restart up your RAGFlow server for your changes to take effect!
|
||||
```bash
|
||||
docker compose stop
|
||||
```
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
:::
|
||||
|
||||

|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ RAGFlow offers RESTful APIs for you to integrate its capabilities into third-par
|
||||
|
||||
## Base URL
|
||||
```
|
||||
http://<host_address>/api/v1/
|
||||
http://<host_address>/v1/api/
|
||||
```
|
||||
|
||||
## Dataset URL
|
||||
@ -428,7 +428,7 @@ This method deletes documents for a specific user.
|
||||
|
||||
## List documents
|
||||
|
||||
This method deletes documents for a specific user.
|
||||
This method lists documents for a specific user.
|
||||
|
||||
### Request
|
||||
|
||||
@ -532,4 +532,350 @@ This method deletes documents for a specific user.
|
||||
"message": "IndexError('Offset is out of the valid range.')"
|
||||
}
|
||||
```
|
||||
## Update the details of the document
|
||||
|
||||
This method updates the details, including the name, enable and template type of a specific document for a specific user.
|
||||
|
||||
### Request
|
||||
|
||||
#### Request URI
|
||||
|
||||
| Method | Request URI |
|
||||
|--------|-------------------------------------------------|
|
||||
| PUT | `/dataset/{dataset_id}/documents/{document_id}` |
|
||||
|
||||
|
||||
#### Request parameter
|
||||
|
||||
| Name | Type | Required | Description |
|
||||
|--------------|--------|----------|------------------------------------------------------------------------------------------------------------|
|
||||
| `dataset_id` | string | Yes | The ID of the dataset. Call ['GET' /dataset](#create-dataset) to retrieve the ID. |
|
||||
| `document_id` | string | Yes | The ID of the document. Call ['GET' /document](#list-documents) to retrieve the ID. |
|
||||
|
||||
### Response
|
||||
|
||||
### Successful Response
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"chunk_num": 0,
|
||||
"create_date": "Mon, 15 Jul 2024 16:55:03 GMT",
|
||||
"create_time": 1721033703914,
|
||||
"created_by": "b48110a0286411ef994a3043d7ee537e",
|
||||
"id": "ed30167a428711efab193043d7ee537e",
|
||||
"kb_id": "ed2d8770428711efaf583043d7ee537e",
|
||||
"location": "test.txt",
|
||||
"name": "new_name.txt",
|
||||
"parser_config": {
|
||||
"pages": [
|
||||
[1, 1000000]
|
||||
]
|
||||
},
|
||||
"parser_id": "naive",
|
||||
"process_begin_at": null,
|
||||
"process_duration": 0.0,
|
||||
"progress": 0.0,
|
||||
"progress_msg": "",
|
||||
"run": "0",
|
||||
"size": 14,
|
||||
"source_type": "local",
|
||||
"status": "1",
|
||||
"thumbnail": null,
|
||||
"token_num": 0,
|
||||
"type": "doc",
|
||||
"update_date": "Mon, 15 Jul 2024 16:55:03 GMT",
|
||||
"update_time": 1721033703934
|
||||
},
|
||||
"message": "Success"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document which does not exist.
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 101,
|
||||
"message": "This document weird_doc_id cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document without giving parameters.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Please input at least one parameter that you want to update!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document in the nonexistent dataset.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "This dataset fake_dataset_id cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document with an extension name that differs from its original.
|
||||
```json
|
||||
{
|
||||
"code": 101,
|
||||
"data": false,
|
||||
"message": "The extension of file cannot be changed"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document with a duplicate name.
|
||||
```json
|
||||
{
|
||||
"code": 101,
|
||||
"message": "Duplicated document name in the same dataset."
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document's illegal parameter.
|
||||
```json
|
||||
{
|
||||
"code": 101,
|
||||
"message": "illegal_parameter is an illegal parameter."
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document's name without its name value.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "There is no new name."
|
||||
}
|
||||
```
|
||||
|
||||
### Response for updating a document's with giving illegal enable's value.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Illegal value '?' for 'enable' field."
|
||||
}
|
||||
```
|
||||
|
||||
## Download the document
|
||||
|
||||
This method downloads a specific document for a specific user.
|
||||
|
||||
### Request
|
||||
|
||||
#### Request URI
|
||||
|
||||
| Method | Request URI |
|
||||
|--------|-------------------------------------------------|
|
||||
| GET | `/dataset/{dataset_id}/documents/{document_id}` |
|
||||
|
||||
|
||||
#### Request parameter
|
||||
|
||||
| Name | Type | Required | Description |
|
||||
|--------------|--------|----------|------------------------------------------------------------------------------------------------------------|
|
||||
| `dataset_id` | string | Yes | The ID of the dataset. Call ['GET' /dataset](#create-dataset) to retrieve the ID. |
|
||||
| `document_id` | string | Yes | The ID of the document. Call ['GET' /document](#list-documents) to retrieve the ID. |
|
||||
|
||||
### Response
|
||||
|
||||
### Successful Response
|
||||
|
||||
```json
|
||||
{
|
||||
"code": "0",
|
||||
"data": "b'test\\ntest\\ntest'"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for downloading a document which does not exist.
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 101,
|
||||
"message": "This document 'imagination.txt' cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for downloading a document in the nonexistent dataset.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "This dataset 'imagination' cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for downloading an empty document.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "This file is empty."
|
||||
}
|
||||
```
|
||||
|
||||
## Start parsing a document
|
||||
|
||||
This method enables a specific document to start parsing for a specific user.
|
||||
|
||||
### Request
|
||||
|
||||
#### Request URI
|
||||
|
||||
| Method | Request URI |
|
||||
|--------|--------------------------------------------------------|
|
||||
| POST | `/dataset/{dataset_id}/documents/{document_id}/status` |
|
||||
|
||||
|
||||
#### Request parameter
|
||||
|
||||
| Name | Type | Required | Description |
|
||||
|--------------|--------|----------|------------------------------------------------------------------------------------------------------------|
|
||||
| `dataset_id` | string | Yes | The ID of the dataset. Call ['GET' /dataset](#create-dataset) to retrieve the ID. |
|
||||
| `document_id` | string | Yes | The ID of the document. Call ['GET' /document](#list-documents) to retrieve the ID. |
|
||||
|
||||
### Response
|
||||
|
||||
### Successful Response
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"message": ""
|
||||
}
|
||||
```
|
||||
|
||||
### Response for parsing a document which does not exist.
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 101,
|
||||
"message": "This document 'imagination.txt' cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for parsing a document in the nonexistent dataset.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "This dataset 'imagination' cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for parsing an empty document.
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"message": "Empty data in the document: empty.txt;"
|
||||
}
|
||||
```
|
||||
|
||||
## Start parsing multiple documents
|
||||
|
||||
This method enables multiple documents, including all documents in the specific dataset or specified documents, to start parsing for a specific user.
|
||||
|
||||
### Request
|
||||
|
||||
#### Request URI
|
||||
|
||||
| Method | Request URI |
|
||||
|--------|-------------------------------------------------------|
|
||||
| POST | `/dataset/{dataset_id}/documents/status` |
|
||||
|
||||
|
||||
#### Request parameter
|
||||
|
||||
| Name | Type | Required | Description |
|
||||
|--------------|--------|----------|-----------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `dataset_id` | string | Yes | The ID of the dataset. Call ['GET' /dataset](#create-dataset) to retrieve the ID. |
|
||||
| `document_id` | string | Yes | The ID of the document. Call ['GET' /document](#list-documents) to retrieve the ID. |
|
||||
| `doc_ids` | list | No | The document IDs of the documents that the user would like to parse. Default: None, means all documents in the specified dataset. |
|
||||
### Response
|
||||
|
||||
### Successful Response
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"data": true,
|
||||
"message": ""
|
||||
}
|
||||
```
|
||||
|
||||
### Response for parsing documents which does not exist.
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 101,
|
||||
"message": "This document 'imagination.txt' cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for parsing documents in the nonexistent dataset.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "This dataset 'imagination' cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for parsing documents, one of which is empty.
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"data": true,
|
||||
"message": "Empty data in the document: empty.txt; "
|
||||
}
|
||||
```
|
||||
|
||||
## Show the parsing status of the document
|
||||
|
||||
This method shows the parsing status of the document for a specific user.
|
||||
|
||||
### Request
|
||||
|
||||
#### Request URI
|
||||
|
||||
| Method | Request URI |
|
||||
|--------|-------------------------------------------------------|
|
||||
| GET | `/dataset/{dataset_id}/documents/status` |
|
||||
|
||||
|
||||
#### Request parameter
|
||||
|
||||
| Name | Type | Required | Description |
|
||||
|--------------|--------|----------|-----------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `dataset_id` | string | Yes | The ID of the dataset. Call ['GET' /dataset](#create-dataset) to retrieve the ID. |
|
||||
| `document_id` | string | Yes | The ID of the document. Call ['GET' /document](#list-documents) to retrieve the ID. |
|
||||
|
||||
### Response
|
||||
|
||||
### Successful Response
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"progress": 0.0,
|
||||
"status": "RUNNING"
|
||||
},
|
||||
"message": "success"
|
||||
}
|
||||
```
|
||||
|
||||
### Response for showing the parsing status of a document which does not exist.
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "This document: 'imagination.txt' is not a valid document."
|
||||
}
|
||||
```
|
||||
|
||||
### Response for showing the parsing status of a document in the nonexistent dataset.
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "This dataset 'imagination' cannot be found!"
|
||||
}
|
||||
```
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
import pandas as pd
|
||||
import requests
|
||||
import re
|
||||
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class BaiduParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Baidu component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
|
||||
class Baidu(ComponentBase, ABC):
|
||||
component_name = "Baidu"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return Baidu.be_output(self._param.no)
|
||||
|
||||
url = 'https://www.baidu.com/s?wd=' + ans + '&rn=' + str(self._param.top_n)
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.104 Safari/537.36'}
|
||||
response = requests.get(url=url, headers=headers)
|
||||
|
||||
url_res = re.findall(r"'url': \\\"(.*?)\\\"}", response.text)
|
||||
title_res = re.findall(r"'title': \\\"(.*?)\\\",\\n", response.text)
|
||||
body_res = re.findall(r"\"contentText\":\"(.*?)\"", response.text)
|
||||
baidu_res = [re.sub('<em>|</em>', '', '<a href="' + url + '">' + title + '</a> ' + body) for url, title, body
|
||||
in zip(url_res, title_res, body_res)]
|
||||
del body_res, url_res, title_res
|
||||
|
||||
br = pd.DataFrame(baidu_res, columns=['content'])
|
||||
print(">>>>>>>>>>>>>>>>>>>>>>>>>>\n", br)
|
||||
return br
|
||||
@ -1,62 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
from duckduckgo_search import DDGS
|
||||
import pandas as pd
|
||||
|
||||
from graph.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class DuckDuckGoSearchParam(ComponentParamBase):
|
||||
"""
|
||||
Define the DuckDuckGoSearch component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.channel = "text"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.channel, "Web Search or News", ["text", "news"])
|
||||
|
||||
|
||||
class DuckDuckGoSearch(ComponentBase, ABC):
|
||||
component_name = "DuckDuckGoSearch"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return Baidu.be_output(self._param.no)
|
||||
|
||||
if self.channel == "text":
|
||||
with DDGS() as ddgs:
|
||||
# {'title': '', 'href': '', 'body': ''}
|
||||
duck_res = ['<a href="' + i["href"] + '">' + i["title"] + '</a> ' + i["body"] for i in
|
||||
ddgs.text(ans, max_results=self._param.top_n)]
|
||||
elif self.channel == "news":
|
||||
with DDGS() as ddgs:
|
||||
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
|
||||
duck_res = ['<a href="' + i["url"] + '">' + i["title"] + '</a> ' + i["body"] for i in
|
||||
ddgs.news(ans, max_results=self._param.top_n)]
|
||||
|
||||
dr = pd.DataFrame(duck_res, columns=['content'])
|
||||
print(">>>>>>>>>>>>>>>>>>>>>>>>>>\n", dr)
|
||||
return dr
|
||||
278
graphrag/claim_extractor.py
Normal file
278
graphrag/claim_extractor.py
Normal file
@ -0,0 +1,278 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
CLAIM_MAX_GLEANINGS = 1
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimExtractorResult:
|
||||
"""Claim extractor result class definition."""
|
||||
|
||||
output: list[dict]
|
||||
source_docs: dict[str, Any]
|
||||
|
||||
|
||||
class ClaimExtractor:
|
||||
"""Claim extractor class definition."""
|
||||
|
||||
_llm: CompletionLLM
|
||||
_extraction_prompt: str
|
||||
_summary_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_input_text_key: str
|
||||
_input_entity_spec_key: str
|
||||
_input_claim_description_key: str
|
||||
_tuple_delimiter_key: str
|
||||
_record_delimiter_key: str
|
||||
_completion_delimiter_key: str
|
||||
_max_gleanings: int
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
extraction_prompt: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
input_entity_spec_key: str | None = None,
|
||||
input_claim_description_key: str | None = None,
|
||||
input_resolved_entities_key: str | None = None,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
completion_delimiter_key: str | None = None,
|
||||
encoding_model: str | None = None,
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
|
||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._completion_delimiter_key = (
|
||||
completion_delimiter_key or "completion_delimiter"
|
||||
)
|
||||
self._input_claim_description_key = (
|
||||
input_claim_description_key or "claim_description"
|
||||
)
|
||||
self._input_resolved_entities_key = (
|
||||
input_resolved_entities_key or "resolved_entities"
|
||||
)
|
||||
self._max_gleanings = (
|
||||
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
|
||||
)
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
|
||||
# Construct the looping arguments
|
||||
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
||||
yes = encoding.encode("YES")
|
||||
no = encoding.encode("NO")
|
||||
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
||||
|
||||
def __call__(
|
||||
self, inputs: dict[str, Any], prompt_variables: dict | None = None
|
||||
) -> ClaimExtractorResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
texts = inputs[self._input_text_key]
|
||||
entity_spec = str(inputs[self._input_entity_spec_key])
|
||||
claim_description = inputs[self._input_claim_description_key]
|
||||
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
|
||||
source_doc_map = {}
|
||||
|
||||
prompt_args = {
|
||||
self._input_entity_spec_key: entity_spec,
|
||||
self._input_claim_description_key: claim_description,
|
||||
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
||||
or DEFAULT_TUPLE_DELIMITER,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
self._completion_delimiter_key: prompt_variables.get(
|
||||
self._completion_delimiter_key
|
||||
)
|
||||
or DEFAULT_COMPLETION_DELIMITER,
|
||||
}
|
||||
|
||||
all_claims: list[dict] = []
|
||||
for doc_index, text in enumerate(texts):
|
||||
document_id = f"d{doc_index}"
|
||||
try:
|
||||
claims = self._process_document(prompt_args, text, doc_index)
|
||||
all_claims += [
|
||||
self._clean_claim(c, document_id, resolved_entities) for c in claims
|
||||
]
|
||||
source_doc_map[document_id] = text
|
||||
except Exception as e:
|
||||
log.exception("error extracting claim")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
{"doc_index": doc_index, "text": text},
|
||||
)
|
||||
continue
|
||||
|
||||
return ClaimExtractorResult(
|
||||
output=all_claims,
|
||||
source_docs=source_doc_map,
|
||||
)
|
||||
|
||||
def _clean_claim(
|
||||
self, claim: dict, document_id: str, resolved_entities: dict
|
||||
) -> dict:
|
||||
# clean the parsed claims to remove any claims with status = False
|
||||
obj = claim.get("object_id", claim.get("object"))
|
||||
subject = claim.get("subject_id", claim.get("subject"))
|
||||
|
||||
# If subject or object in resolved entities, then replace with resolved entity
|
||||
obj = resolved_entities.get(obj, obj)
|
||||
subject = resolved_entities.get(subject, subject)
|
||||
claim["object_id"] = obj
|
||||
claim["subject_id"] = subject
|
||||
claim["doc_id"] = document_id
|
||||
return claim
|
||||
|
||||
def _process_document(
|
||||
self, prompt_args: dict, doc, doc_index: int
|
||||
) -> list[dict]:
|
||||
record_delimiter = prompt_args.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_args.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
variables = {
|
||||
self._input_text_key: doc,
|
||||
**prompt_args,
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
results = self._llm.chat(text, [], gen_conf)
|
||||
claims = results.strip().removesuffix(completion_delimiter)
|
||||
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
extension = self._llm.chat("", history, gen_conf)
|
||||
claims += record_delimiter + extension.strip().removesuffix(
|
||||
completion_delimiter
|
||||
)
|
||||
|
||||
# If this isn't the last loop, check to see if we should continue
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
history.append({"role": "assistant", "content": extension})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
continuation = self._llm.chat("", history, self._loop_args)
|
||||
if continuation != "YES":
|
||||
break
|
||||
|
||||
result = self._parse_claim_tuples(claims, prompt_args)
|
||||
for r in result:
|
||||
r["doc_id"] = f"{doc_index}"
|
||||
return result
|
||||
|
||||
def _parse_claim_tuples(
|
||||
self, claims: str, prompt_variables: dict
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse claim tuples."""
|
||||
record_delimiter = prompt_variables.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_variables.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
tuple_delimiter = prompt_variables.get(
|
||||
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
|
||||
)
|
||||
|
||||
def pull_field(index: int, fields: list[str]) -> str | None:
|
||||
return fields[index].strip() if len(fields) > index else None
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
claims_values = (
|
||||
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
|
||||
)
|
||||
for claim in claims_values:
|
||||
claim = claim.strip().removeprefix("(").removesuffix(")")
|
||||
claim = re.sub(r".*Output:", "", claim)
|
||||
|
||||
# Ignore the completion delimiter
|
||||
if claim == completion_delimiter:
|
||||
continue
|
||||
|
||||
claim_fields = claim.split(tuple_delimiter)
|
||||
o = {
|
||||
"subject_id": pull_field(0, claim_fields),
|
||||
"object_id": pull_field(1, claim_fields),
|
||||
"type": pull_field(2, claim_fields),
|
||||
"status": pull_field(3, claim_fields),
|
||||
"start_date": pull_field(4, claim_fields),
|
||||
"end_date": pull_field(5, claim_fields),
|
||||
"description": pull_field(6, claim_fields),
|
||||
"source_text": pull_field(7, claim_fields),
|
||||
"doc_id": pull_field(8, claim_fields),
|
||||
}
|
||||
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
|
||||
continue
|
||||
result.append(o)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
|
||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])]
|
||||
info = {
|
||||
"input_text": docs,
|
||||
"entity_specs": "organization, person",
|
||||
"claim_description": ""
|
||||
}
|
||||
claim = ex(info)
|
||||
print(json.dumps(claim.output, ensure_ascii=False, indent=2))
|
||||
84
graphrag/claim_prompt.py
Normal file
84
graphrag/claim_prompt.py
Normal file
@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
CLAIM_EXTRACTION_PROMPT = """
|
||||
################
|
||||
-Target activity-
|
||||
################
|
||||
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.
|
||||
|
||||
################
|
||||
-Goal-
|
||||
################
|
||||
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.
|
||||
|
||||
################
|
||||
-Steps-
|
||||
################
|
||||
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
|
||||
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
|
||||
For each claim, extract the following information:
|
||||
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
|
||||
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
|
||||
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
|
||||
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
|
||||
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
|
||||
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
|
||||
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.
|
||||
|
||||
- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
|
||||
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
||||
- 5. If there's nothing satisfy the above requirements, just keep output empty.
|
||||
- 6. When finished, output {completion_delimiter}
|
||||
|
||||
################
|
||||
-Examples-
|
||||
################
|
||||
Example 1:
|
||||
Entity specification: organization
|
||||
Claim description: red flags associated with an entity
|
||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
||||
Output:
|
||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
||||
{completion_delimiter}
|
||||
|
||||
###########################
|
||||
Example 2:
|
||||
Entity specification: Company A, Person C
|
||||
Claim description: red flags associated with an entity
|
||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
||||
Output:
|
||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
||||
{record_delimiter}
|
||||
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
|
||||
{completion_delimiter}
|
||||
|
||||
################
|
||||
-Real Data-
|
||||
################
|
||||
Use the following input for your answer.
|
||||
Entity specification: {entity_specs}
|
||||
Claim description: {claim_description}
|
||||
Text: {input_text}
|
||||
Output:"""
|
||||
|
||||
|
||||
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
|
||||
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"
|
||||
171
graphrag/community_report_prompt.py
Normal file
171
graphrag/community_report_prompt.py
Normal file
@ -0,0 +1,171 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
COMMUNITY_REPORT_PROMPT = """
|
||||
You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
|
||||
|
||||
# Goal
|
||||
Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims.
|
||||
|
||||
# Report Structure
|
||||
|
||||
The report should include the following sections:
|
||||
|
||||
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
|
||||
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
|
||||
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
|
||||
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
|
||||
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
|
||||
|
||||
Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
|
||||
{{
|
||||
"title": <report_title>,
|
||||
"summary": <executive_summary>,
|
||||
"rating": <impact_severity_rating>,
|
||||
"rating_explanation": <rating_explanation>,
|
||||
"findings": [
|
||||
{{
|
||||
"summary":<insight_1_summary>,
|
||||
"explanation": <insight_1_explanation>
|
||||
}},
|
||||
{{
|
||||
"summary":<insight_2_summary>,
|
||||
"explanation": <insight_2_explanation>
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
# Grounding Rules
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."
|
||||
|
||||
where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.
|
||||
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
|
||||
# Example Input
|
||||
-----------
|
||||
Text:
|
||||
|
||||
-Entities-
|
||||
|
||||
id,entity,description
|
||||
5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March
|
||||
6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza
|
||||
|
||||
-Relationships-
|
||||
|
||||
id,source,target,description
|
||||
37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March
|
||||
38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza
|
||||
39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza
|
||||
40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza
|
||||
41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march
|
||||
43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March
|
||||
|
||||
Output:
|
||||
{{
|
||||
"title": "Verdant Oasis Plaza and Unity March",
|
||||
"summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.",
|
||||
"rating": 5.0,
|
||||
"rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.",
|
||||
"findings": [
|
||||
{{
|
||||
"summary": "Verdant Oasis Plaza as the central location",
|
||||
"explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]"
|
||||
}},
|
||||
{{
|
||||
"summary": "Harmony Assembly's role in the community",
|
||||
"explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]"
|
||||
}},
|
||||
{{
|
||||
"summary": "Unity March as a significant event",
|
||||
"explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]"
|
||||
}},
|
||||
{{
|
||||
"summary": "Role of Tribune Spotlight",
|
||||
"explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
# Real Data
|
||||
|
||||
Use the following text for your answer. Do not make anything up in your answer.
|
||||
|
||||
Text:
|
||||
|
||||
-Entities-
|
||||
{entity_df}
|
||||
|
||||
-Relationships-
|
||||
{relation_df}
|
||||
|
||||
The report should include the following sections:
|
||||
|
||||
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
|
||||
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
|
||||
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
|
||||
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
|
||||
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
|
||||
|
||||
Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
|
||||
{{
|
||||
"title": <report_title>,
|
||||
"summary": <executive_summary>,
|
||||
"rating": <impact_severity_rating>,
|
||||
"rating_explanation": <rating_explanation>,
|
||||
"findings": [
|
||||
{{
|
||||
"summary":<insight_1_summary>,
|
||||
"explanation": <insight_1_explanation>
|
||||
}},
|
||||
{{
|
||||
"summary":<insight_2_summary>,
|
||||
"explanation": <insight_2_explanation>
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
# Grounding Rules
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."
|
||||
|
||||
where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.
|
||||
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
Output:"""
|
||||
135
graphrag/community_reports_extractor.py
Normal file
135
graphrag/community_reports_extractor.py
Normal file
@ -0,0 +1,135 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from graphrag import leiden
|
||||
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from graphrag.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityReportsResult:
|
||||
"""Community reports result class definition."""
|
||||
|
||||
output: List[str]
|
||||
structured_output: List[dict]
|
||||
|
||||
|
||||
class CommunityReportsExtractor:
|
||||
"""Community reports extractor class definition."""
|
||||
|
||||
_llm: CompletionLLM
|
||||
_extraction_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
_max_report_length: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
extraction_prompt: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
max_report_length: int | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
def __call__(self, graph: nx.Graph):
|
||||
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
||||
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
||||
res_str = []
|
||||
res_dict = []
|
||||
for level, comm in communities.items():
|
||||
for cm_id, ents in comm.items():
|
||||
weight = ents["weight"]
|
||||
ents = ents["nodes"]
|
||||
ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
|
||||
rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
|
||||
|
||||
prompt_variables = {
|
||||
"entity_df": ent_df.to_csv(index_label="id"),
|
||||
"relation_df": rela_df.to_csv(index_label="id")
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
try:
|
||||
response = self._llm.chat(text, [], gen_conf)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
print(response)
|
||||
response = json.loads(response)
|
||||
if not dict_has_keys_with_types(response, [
|
||||
("title", str),
|
||||
("summary", str),
|
||||
("findings", list),
|
||||
("rating", float),
|
||||
("rating_explanation", str),
|
||||
]): continue
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
except Exception as e:
|
||||
print("ERROR: ", traceback.format_exc())
|
||||
self._on_error(e, traceback.format_exc(), None)
|
||||
continue
|
||||
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
res_dict.append(response)
|
||||
|
||||
return CommunityReportsResult(
|
||||
structured_output=res_dict,
|
||||
output=res_str,
|
||||
)
|
||||
|
||||
def _get_text_output(self, parsed_output: dict) -> str:
|
||||
title = parsed_output.get("title", "Report")
|
||||
summary = parsed_output.get("summary", "")
|
||||
findings = parsed_output.get("findings", [])
|
||||
|
||||
def finding_summary(finding: dict):
|
||||
if isinstance(finding, str):
|
||||
return finding
|
||||
return finding.get("summary")
|
||||
|
||||
def finding_explanation(finding: dict):
|
||||
if isinstance(finding, str):
|
||||
return ""
|
||||
return finding.get("explanation")
|
||||
|
||||
report_sections = "\n\n".join(
|
||||
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
|
||||
)
|
||||
return f"# {title}\n\n{summary}\n\n{report_sections}"
|
||||
167
graphrag/description_summary.py
Normal file
167
graphrag/description_summary.py
Normal file
@ -0,0 +1,167 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
import numbers
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
SUMMARIZE_PROMPT = """
|
||||
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
|
||||
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
|
||||
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
||||
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
||||
Make sure it is written in third person, and include the entity names so we the have full context.
|
||||
|
||||
#######
|
||||
-Data-
|
||||
Entities: {entity_name}
|
||||
Description List: {description_list}
|
||||
#######
|
||||
Output:
|
||||
"""
|
||||
|
||||
# Max token size for input prompts
|
||||
DEFAULT_MAX_INPUT_TOKENS = 4_000
|
||||
# Max token count for LLM answers
|
||||
DEFAULT_MAX_SUMMARY_LENGTH = 128
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
|
||||
items: str | tuple[str, str]
|
||||
description: str
|
||||
|
||||
|
||||
class SummarizeExtractor:
|
||||
"""Unipartite graph extractor class definition."""
|
||||
|
||||
_llm: CompletionLLM
|
||||
_entity_name_key: str
|
||||
_input_descriptions_key: str
|
||||
_summarization_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
_max_summary_length: int
|
||||
_max_input_tokens: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
entity_name_key: str | None = None,
|
||||
input_descriptions_key: str | None = None,
|
||||
summarization_prompt: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
max_summary_length: int | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
self._entity_name_key = entity_name_key or "entity_name"
|
||||
self._input_descriptions_key = input_descriptions_key or "description_list"
|
||||
|
||||
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
|
||||
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
items: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
) -> SummarizationResult:
|
||||
"""Call method definition."""
|
||||
result = ""
|
||||
if len(descriptions) == 0:
|
||||
result = ""
|
||||
if len(descriptions) == 1:
|
||||
result = descriptions[0]
|
||||
else:
|
||||
result = self._summarize_descriptions(items, descriptions)
|
||||
|
||||
return SummarizationResult(
|
||||
items=items,
|
||||
description=result or "",
|
||||
)
|
||||
|
||||
def _summarize_descriptions(
|
||||
self, items: str | tuple[str, str], descriptions: list[str]
|
||||
) -> str:
|
||||
"""Summarize descriptions into a single description."""
|
||||
sorted_items = sorted(items) if isinstance(items, list) else items
|
||||
|
||||
# Safety check, should always be a list
|
||||
if not isinstance(descriptions, list):
|
||||
descriptions = [descriptions]
|
||||
|
||||
# Iterate over descriptions, adding all until the max input tokens is reached
|
||||
usable_tokens = self._max_input_tokens - num_tokens_from_string(
|
||||
self._summarization_prompt
|
||||
)
|
||||
descriptions_collected = []
|
||||
result = ""
|
||||
|
||||
for i, description in enumerate(descriptions):
|
||||
usable_tokens -= num_tokens_from_string(description)
|
||||
descriptions_collected.append(description)
|
||||
|
||||
# If buffer is full, or all descriptions have been added, summarize
|
||||
if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
|
||||
i == len(descriptions) - 1
|
||||
):
|
||||
# Calculate result (final or partial)
|
||||
result = await self._summarize_descriptions_with_llm(
|
||||
sorted_items, descriptions_collected
|
||||
)
|
||||
|
||||
# If we go for another loop, reset values to new
|
||||
if i != len(descriptions) - 1:
|
||||
descriptions_collected = [result]
|
||||
usable_tokens = (
|
||||
self._max_input_tokens
|
||||
- num_tokens_from_string(self._summarization_prompt)
|
||||
- num_tokens_from_string(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _summarize_descriptions_with_llm(
|
||||
self, items: str | tuple[str, str] | list[str], descriptions: list[str]
|
||||
):
|
||||
"""Summarize descriptions using the LLM."""
|
||||
variables = {
|
||||
self._entity_name_key: json.dumps(items),
|
||||
self._input_descriptions_key: json.dumps(sorted(descriptions)),
|
||||
}
|
||||
text = perform_variable_replacements(self._summarization_prompt, variables=variables)
|
||||
return self._llm.chat("", [{"role": "user", "content": text}])
|
||||
78
graphrag/entity_embedding.py
Normal file
78
graphrag/entity_embedding.py
Normal file
@ -0,0 +1,78 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from graphrag.leiden import stable_largest_connected_component
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeEmbeddings:
|
||||
"""Node embeddings class definition."""
|
||||
|
||||
nodes: list[str]
|
||||
embeddings: np.ndarray
|
||||
|
||||
|
||||
def embed_nod2vec(
|
||||
graph: nx.Graph | nx.DiGraph,
|
||||
dimensions: int = 1536,
|
||||
num_walks: int = 10,
|
||||
walk_length: int = 40,
|
||||
window_size: int = 2,
|
||||
iterations: int = 3,
|
||||
random_seed: int = 86,
|
||||
) -> NodeEmbeddings:
|
||||
"""Generate node embeddings using Node2Vec."""
|
||||
# generate embedding
|
||||
lcc_tensors = gc.embed.node2vec_embed( # type: ignore
|
||||
graph=graph,
|
||||
dimensions=dimensions,
|
||||
window_size=window_size,
|
||||
iterations=iterations,
|
||||
num_walks=num_walks,
|
||||
walk_length=walk_length,
|
||||
random_seed=random_seed,
|
||||
)
|
||||
return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
|
||||
|
||||
|
||||
def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
|
||||
"""Run method definition."""
|
||||
if args.get("use_lcc", True):
|
||||
graph = stable_largest_connected_component(graph)
|
||||
|
||||
# create graph embedding using node2vec
|
||||
embeddings = embed_nod2vec(
|
||||
graph=graph,
|
||||
dimensions=args.get("dimensions", 1536),
|
||||
num_walks=args.get("num_walks", 10),
|
||||
walk_length=args.get("walk_length", 40),
|
||||
window_size=args.get("window_size", 2),
|
||||
iterations=args.get("iterations", 3),
|
||||
random_seed=args.get("random_seed", 86),
|
||||
)
|
||||
|
||||
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: x[0])
|
||||
|
||||
return dict(sorted_pairs)
|
||||
212
graphrag/entity_resolution.py
Normal file
212
graphrag/entity_resolution.py
Normal file
@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
from rag.nlp import is_english
|
||||
import editdistance
|
||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityResolutionResult:
|
||||
"""Entity resolution result class definition."""
|
||||
|
||||
output: nx.Graph
|
||||
|
||||
|
||||
class EntityResolution:
|
||||
"""Entity resolution class definition."""
|
||||
|
||||
_llm: CompletionLLM
|
||||
_resolution_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
_record_delimiter_key: str
|
||||
_entity_index_delimiter_key: str
|
||||
_resolution_result_delimiter_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
resolution_prompt: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
entity_index_delimiter_key: str | None = None,
|
||||
resolution_result_delimiter_key: str | None = None,
|
||||
input_text_key: str | None = None
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
|
||||
self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
|
||||
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
# Wire defaults into the prompt variables
|
||||
prompt_variables = {
|
||||
**prompt_variables,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key)
|
||||
or DEFAULT_ENTITY_INDEX_DELIMITER,
|
||||
self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
|
||||
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
|
||||
}
|
||||
|
||||
nodes = graph.nodes
|
||||
entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
|
||||
node_clusters = {entity_type: [] for entity_type in entity_types}
|
||||
|
||||
for node in nodes:
|
||||
node_clusters[graph.nodes[node]['entity_type']].append(node)
|
||||
|
||||
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
||||
for node_cluster in node_clusters.items():
|
||||
candidate_resolution_tmp = []
|
||||
for a in node_cluster[1]:
|
||||
for b in node_cluster[1]:
|
||||
if a == b:
|
||||
continue
|
||||
if self.is_similarity(a, b) and (b, a) not in candidate_resolution_tmp:
|
||||
candidate_resolution_tmp.append((a, b))
|
||||
if candidate_resolution_tmp:
|
||||
candidate_resolution[node_cluster[0]] = candidate_resolution_tmp
|
||||
|
||||
gen_conf = {"temperature": 0.5}
|
||||
resolution_result = set()
|
||||
for candidate_resolution_i in candidate_resolution.items():
|
||||
if candidate_resolution_i[1]:
|
||||
try:
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||
pair_txt.append(
|
||||
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
||||
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
||||
pair_txt.append(
|
||||
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
||||
pair_prompt = '\n'.join(pair_txt)
|
||||
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
self._input_text_key: pair_prompt
|
||||
}
|
||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||
|
||||
response = self._llm.chat(text, [], gen_conf)
|
||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||
prompt_variables.get(self._record_delimiter_key,
|
||||
DEFAULT_RECORD_DELIMITER),
|
||||
prompt_variables.get(self._entity_index_dilimiter_key,
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER),
|
||||
prompt_variables.get(self._resolution_result_delimiter_key,
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||
for result_i in result:
|
||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||
except Exception as e:
|
||||
logging.exception("error entity resolution")
|
||||
self._on_error(e, traceback.format_exc(), None)
|
||||
|
||||
connect_graph = nx.Graph()
|
||||
connect_graph.add_edges_from(resolution_result)
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||
remove_nodes = list(sub_connect_graph.nodes)
|
||||
keep_node = remove_nodes.pop()
|
||||
for remove_node in remove_nodes:
|
||||
remove_node_neighbors = graph[remove_node]
|
||||
graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
|
||||
graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
|
||||
remove_node_neighbors = list(remove_node_neighbors)
|
||||
for remove_node_neighbor in remove_node_neighbors:
|
||||
if remove_node_neighbor == keep_node:
|
||||
graph.remove_edge(keep_node, remove_node)
|
||||
continue
|
||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||
graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
|
||||
'weight']
|
||||
graph[keep_node][remove_node_neighbor]['description'] += \
|
||||
graph[remove_node][remove_node_neighbor]['description']
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
else:
|
||||
graph.add_edge(keep_node, remove_node_neighbor,
|
||||
weight=graph[remove_node][remove_node_neighbor]['weight'],
|
||||
description=graph[remove_node][remove_node_neighbor]['description'],
|
||||
source_id="")
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
graph.remove_node(remove_node)
|
||||
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
|
||||
return EntityResolutionResult(
|
||||
output=graph,
|
||||
)
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
records_length: int,
|
||||
results: str,
|
||||
record_delimiter: str,
|
||||
entity_index_delimiter: str,
|
||||
resolution_result_delimiter: str
|
||||
) -> list:
|
||||
ans_list = []
|
||||
records = [r.strip() for r in results.split(record_delimiter)]
|
||||
for record in records:
|
||||
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
|
||||
match_int = re.search(pattern_int, record)
|
||||
res_int = int(str(match_int.group(1) if match_int else '0'))
|
||||
if res_int > records_length:
|
||||
continue
|
||||
|
||||
pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}"
|
||||
match_bool = re.search(pattern_bool, record)
|
||||
res_bool = str(match_bool.group(1) if match_bool else '')
|
||||
|
||||
if res_int and res_bool:
|
||||
if res_bool.lower() == 'yes':
|
||||
ans_list.append((res_int, "yes"))
|
||||
|
||||
return ans_list
|
||||
|
||||
def is_similarity(self, a, b):
|
||||
if is_english(a) and is_english(b):
|
||||
if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
|
||||
return True
|
||||
|
||||
if len(set(a) & set(b)) > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
74
graphrag/entity_resolution_prompt.py
Normal file
74
graphrag/entity_resolution_prompt.py
Normal file
@ -0,0 +1,74 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
ENTITY_RESOLUTION_PROMPT = """
|
||||
-Goal-
|
||||
Please answer the following Question as required
|
||||
|
||||
-Steps-
|
||||
1. Identify each line of questioning as required
|
||||
|
||||
2. Return output in English as a single list of each line answer in steps 1. Use **{record_delimiter}** as the list delimiter.
|
||||
|
||||
######################
|
||||
-Examples-
|
||||
######################
|
||||
Example 1:
|
||||
|
||||
Question:
|
||||
When determining whether two Products are the same, you should only focus on critical properties and overlook noisy factors.
|
||||
|
||||
Demonstration 1: name of Product A is : "computer", name of Product B is :"phone" No, Product A and Product B are different products.
|
||||
Question 1: name of Product A is : "television", name of Product B is :"TV"
|
||||
Question 2: name of Product A is : "cup", name of Product B is :"mug"
|
||||
Question 3: name of Product A is : "soccer", name of Product B is :"football"
|
||||
Question 4: name of Product A is : "pen", name of Product B is :"eraser"
|
||||
|
||||
Use domain knowledge of Products to help understand the text and answer the above 4 questions in the format: For Question i, Yes, Product A and Product B are the same product. or No, Product A and Product B are different products. For Question i+1, (repeat the above procedures)
|
||||
################
|
||||
Output:
|
||||
(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
|
||||
(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
|
||||
(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, Product A and Product B are the same product.){record_delimiter}
|
||||
(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
|
||||
#############################
|
||||
|
||||
Example 2:
|
||||
|
||||
Question:
|
||||
When determining whether two toponym are the same, you should only focus on critical properties and overlook noisy factors.
|
||||
|
||||
Demonstration 1: name of toponym A is : "nanjing", name of toponym B is :"nanjing city" No, toponym A and toponym B are same toponym.
|
||||
Question 1: name of toponym A is : "Chicago", name of toponym B is :"ChiTown"
|
||||
Question 2: name of toponym A is : "Shanghai", name of toponym B is :"Zhengzhou"
|
||||
Question 3: name of toponym A is : "Beijing", name of toponym B is :"Peking"
|
||||
Question 4: name of toponym A is : "Los Angeles", name of toponym B is :"Cleveland"
|
||||
|
||||
Use domain knowledge of toponym to help understand the text and answer the above 4 questions in the format: For Question i, Yes, toponym A and toponym B are the same toponym. or No, toponym A and toponym B are different toponym. For Question i+1, (repeat the above procedures)
|
||||
################
|
||||
Output:
|
||||
(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are same toponym.){record_delimiter}
|
||||
(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter}
|
||||
(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are the same toponym.){record_delimiter}
|
||||
(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter}
|
||||
#############################
|
||||
|
||||
-Real Data-
|
||||
######################
|
||||
Question:{input_text}
|
||||
######################
|
||||
Output:
|
||||
"""
|
||||
319
graphrag/graph_extractor.py
Normal file
319
graphrag/graph_extractor.py
Normal file
@ -0,0 +1,319 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
import logging
|
||||
import numbers
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping
|
||||
import tiktoken
|
||||
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
|
||||
ENTITY_EXTRACTION_MAX_GLEANINGS = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExtractionResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
|
||||
output: nx.Graph
|
||||
source_docs: dict[Any, Any]
|
||||
|
||||
|
||||
class GraphExtractor:
|
||||
"""Unipartite graph extractor class definition."""
|
||||
|
||||
_llm: CompletionLLM
|
||||
_join_descriptions: bool
|
||||
_tuple_delimiter_key: str
|
||||
_record_delimiter_key: str
|
||||
_entity_types_key: str
|
||||
_input_text_key: str
|
||||
_completion_delimiter_key: str
|
||||
_entity_name_key: str
|
||||
_input_descriptions_key: str
|
||||
_extraction_prompt: str
|
||||
_summarization_prompt: str
|
||||
_loop_args: dict[str, Any]
|
||||
_max_gleanings: int
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
prompt: str | None = None,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
entity_types_key: str | None = None,
|
||||
completion_delimiter_key: str | None = None,
|
||||
join_descriptions=True,
|
||||
encoding_model: str | None = None,
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
self._join_descriptions = join_descriptions
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._completion_delimiter_key = (
|
||||
completion_delimiter_key or "completion_delimiter"
|
||||
)
|
||||
self._entity_types_key = entity_types_key or "entity_types"
|
||||
self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
|
||||
self._max_gleanings = (
|
||||
max_gleanings
|
||||
if max_gleanings is not None
|
||||
else ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
)
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
|
||||
|
||||
# Construct the looping arguments
|
||||
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
||||
yes = encoding.encode("YES")
|
||||
no = encoding.encode("NO")
|
||||
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
||||
|
||||
def __call__(
|
||||
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
|
||||
) -> GraphExtractionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
all_records: dict[int, str] = {}
|
||||
source_doc_map: dict[int, str] = {}
|
||||
|
||||
# Wire defaults into the prompt variables
|
||||
prompt_variables = {
|
||||
**prompt_variables,
|
||||
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
||||
or DEFAULT_TUPLE_DELIMITER,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
self._completion_delimiter_key: prompt_variables.get(
|
||||
self._completion_delimiter_key
|
||||
)
|
||||
or DEFAULT_COMPLETION_DELIMITER,
|
||||
self._entity_types_key: ",".join(
|
||||
prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
|
||||
),
|
||||
}
|
||||
|
||||
for doc_index, text in enumerate(texts):
|
||||
try:
|
||||
# Invoke the entity extraction
|
||||
result = self._process_document(text, prompt_variables)
|
||||
source_doc_map[doc_index] = text
|
||||
all_records[doc_index] = result
|
||||
except Exception as e:
|
||||
logging.exception("error extracting graph")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
{
|
||||
"doc_index": doc_index,
|
||||
"text": text,
|
||||
},
|
||||
)
|
||||
|
||||
output = self._process_results(
|
||||
all_records,
|
||||
prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
|
||||
prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
|
||||
)
|
||||
|
||||
return GraphExtractionResult(
|
||||
output=output,
|
||||
source_docs=source_doc_map,
|
||||
)
|
||||
|
||||
def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str]
|
||||
) -> str:
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
self._input_text_key: text,
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
response = self._llm.chat(text, [], gen_conf)
|
||||
|
||||
results = response or ""
|
||||
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
response = self._llm.chat("", history, gen_conf)
|
||||
results += response or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
continuation = self._llm.chat("", history, self._loop_args)
|
||||
if continuation != "YES":
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
results: dict[int, str],
|
||||
tuple_delimiter: str,
|
||||
record_delimiter: str,
|
||||
) -> nx.Graph:
|
||||
"""Parse the result string to create an undirected unipartite graph.
|
||||
|
||||
Args:
|
||||
- results - dict of results from the extraction chain
|
||||
- tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
|
||||
- record_delimiter - delimiter between records, default is '##'
|
||||
Returns:
|
||||
- output - unipartite graph in graphML format
|
||||
"""
|
||||
graph = nx.Graph()
|
||||
for source_doc_id, extracted_data in results.items():
|
||||
records = [r.strip() for r in extracted_data.split(record_delimiter)]
|
||||
|
||||
for record in records:
|
||||
record = re.sub(r"^\(|\)$", "", record.strip())
|
||||
record_attributes = record.split(tuple_delimiter)
|
||||
|
||||
if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
|
||||
# add this record as a node in the G
|
||||
entity_name = clean_str(record_attributes[1].upper())
|
||||
entity_type = clean_str(record_attributes[2].upper())
|
||||
entity_description = clean_str(record_attributes[3])
|
||||
|
||||
if entity_name in graph.nodes():
|
||||
node = graph.nodes[entity_name]
|
||||
if self._join_descriptions:
|
||||
node["description"] = "\n".join(
|
||||
list({
|
||||
*_unpack_descriptions(node),
|
||||
entity_description,
|
||||
})
|
||||
)
|
||||
else:
|
||||
if len(entity_description) > len(node["description"]):
|
||||
node["description"] = entity_description
|
||||
node["source_id"] = ", ".join(
|
||||
list({
|
||||
*_unpack_source_ids(node),
|
||||
str(source_doc_id),
|
||||
})
|
||||
)
|
||||
node["entity_type"] = (
|
||||
entity_type if entity_type != "" else node["entity_type"]
|
||||
)
|
||||
else:
|
||||
graph.add_node(
|
||||
entity_name,
|
||||
entity_type=entity_type,
|
||||
description=entity_description,
|
||||
source_id=str(source_doc_id),
|
||||
weight=1
|
||||
)
|
||||
|
||||
if (
|
||||
record_attributes[0] == '"relationship"'
|
||||
and len(record_attributes) >= 5
|
||||
):
|
||||
# add this record as edge
|
||||
source = clean_str(record_attributes[1].upper())
|
||||
target = clean_str(record_attributes[2].upper())
|
||||
edge_description = clean_str(record_attributes[3])
|
||||
edge_source_id = clean_str(str(source_doc_id))
|
||||
weight = (
|
||||
float(record_attributes[-1])
|
||||
if isinstance(record_attributes[-1], numbers.Number)
|
||||
else 1.0
|
||||
)
|
||||
if source not in graph.nodes():
|
||||
graph.add_node(
|
||||
source,
|
||||
entity_type="",
|
||||
description="",
|
||||
source_id=edge_source_id,
|
||||
weight=1
|
||||
)
|
||||
if target not in graph.nodes():
|
||||
graph.add_node(
|
||||
target,
|
||||
entity_type="",
|
||||
description="",
|
||||
source_id=edge_source_id,
|
||||
weight=1
|
||||
)
|
||||
if graph.has_edge(source, target):
|
||||
edge_data = graph.get_edge_data(source, target)
|
||||
if edge_data is not None:
|
||||
weight += edge_data["weight"]
|
||||
if self._join_descriptions:
|
||||
edge_description = "\n".join(
|
||||
list({
|
||||
*_unpack_descriptions(edge_data),
|
||||
edge_description,
|
||||
})
|
||||
)
|
||||
edge_source_id = ", ".join(
|
||||
list({
|
||||
*_unpack_source_ids(edge_data),
|
||||
str(source_doc_id),
|
||||
})
|
||||
)
|
||||
graph.add_edge(
|
||||
source,
|
||||
target,
|
||||
weight=weight,
|
||||
description=edge_description,
|
||||
source_id=edge_source_id,
|
||||
)
|
||||
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
return graph
|
||||
|
||||
|
||||
def _unpack_descriptions(data: Mapping) -> list[str]:
|
||||
value = data.get("description", None)
|
||||
return [] if value is None else value.split("\n")
|
||||
|
||||
|
||||
def _unpack_source_ids(data: Mapping) -> list[str]:
|
||||
value = data.get("source_id", None)
|
||||
return [] if value is None else value.split(", ")
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user