Compare commits

..

33 Commits

Author SHA1 Message Date
ff2c70608d Fix: judge index exist before delete memory message. (#12318)
### What problem does this PR solve?

Judge index exist before delete memory message.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-30 15:54:07 +08:00
5903d1c8f1 Feat: GitHub connector (#12314)
### What problem does this PR solve?

Feat: GitHub connector

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-30 15:09:52 +08:00
f0392e7501 Fix IDE warnings (#12315)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 15:04:09 +08:00
4037788e0c Fix: Dataset parse error (#12310)
### What problem does this PR solve?

Fix: Dataset parse error

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-30 13:08:20 +08:00
59884ab0fb Fix TypeError in meta_filter when using numeric metadata (#12286)
The filter_out function in metadata_utils.py was using a list of tuples
to evaluate conditions. Python eagerly evaluates all tuple elements when
constructing the list, causing "input in value" to be evaluated even
when the operator is "=". When input and value are floats (after numeric
conversion), this causes TypeError: "argument of type 'float' is not
iterable".

This change replaces the tuple list with if-elif chain, ensuring only
the matching condition is evaluated.

### What problem does this PR solve?

Fixes #12285

When using comparison operators like `=`, `>`, `<` with numeric
metadata, the `filter_out` function throws `TypeError("argument of type
'float' is not iterable")`. This is because Python eagerly evaluates all
tuple elements when constructing a list, causing `input in value` to be
evaluated even when the operator is `=`.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
2025-12-30 11:56:48 +08:00
4a6d37f0e8 Fix: use async task to save memory (#12308)
### What problem does this PR solve?

Use async task to save memory.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 11:41:38 +08:00
731e2d5f26 api key delete bug - Bug #3045 (#12299)
Description:
Fixed an issue where deleting an API token would fail because it was
incorrectly using current_user.id as the tenant_id instead of querying
the actual tenant ID from UserTenantService.

Changes:

Updated rm() endpoint to fetch the correct tenant_id from
UserTenantService before deleting the API token
Added proper error handling with try/except block
Code style cleanup: consistent quote usage and formatting
Related Issue: #3045

https://github.com/infiniflow/ragflow/issues/3045

Co-authored-by: Mardani, Ramin <ramin.mardani@sscinc.com>
2025-12-30 11:27:04 +08:00
df3cbb9b9e Refactor code (#12305)
### What problem does this PR solve?

as title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 11:09:18 +08:00
5402666b19 docs: fix typos (#12301)
### What problem does this PR solve?

fix typos

### Type of change

- [x] Documentation Update
2025-12-30 09:39:28 +08:00
4ec6a4e493 Feat: Remove the code that outputs jsonschema from the webhook.#10427 (#12297)
### What problem does this PR solve?

Feat: Remove the code that outputs jsonschema from the webhook.#10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-29 17:46:05 +08:00
2d5ad42128 docs: add optional proxy arguments for Docker build instructions (#12272)
### What problem does this PR solve?

Adds instructions for passing optional HTTP/HTTPS proxy arguments when
building the Docker image.

This helps users behind a proxy to successfully build the RAGFlow Docker
image without modifying the Dockerfile itself.

### Type of change

- [x] Documentation Update
2025-12-29 17:43:55 +08:00
dccda35f65 Fix: S3 parameter error (#12290)
### What problem does this PR solve?

Fix: S3 parameter error

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-29 17:38:01 +08:00
d142b9095e Fix: pick message to delete (#12295)
### What problem does this PR solve?

Pick unforgotten message when not found forgotten message to delete.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-29 17:10:46 +08:00
c2c079886f Revert "Feat: github connector" (#12296)
Reverts infiniflow/ragflow#12292
2025-12-29 17:06:40 +08:00
c3ae1aaecd Feat: Gitlab connector (#12248)
### What problem does this PR solve?

Feat: Gitlab connector
Fix: submit button in darkmode

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-29 17:05:20 +08:00
f099bc1236 Feat: github connector (#12292)
### What problem does this PR solve?

Feat: github connector

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-29 16:57:20 +08:00
0b5d1ebefa refactor: docling parser will close bytes io (#12280)
### What problem does this PR solve?

docling parser will close bytes io

### Type of change

- [x] Refactoring
2025-12-29 13:33:27 +08:00
082c2ed11c helm: improvements (#10976)
- fix(ingress): use root context ($) for fullname inside range
- fix(statefulset): use updateStrategy instead of strategy for
mysql/infinity/elasticsearch/opensearch
- feat(mysql): add external mode via mysql.enabled=false with env
MYSQL_HOST/PORT and MYSQL_USER (default root)
- feat(minio/redis): add external mode via *.enabled=false with env
*_HOST/PORT
- feat(global): add global.repo for image registry prefix and
global.imagePullSecrets for all pods
- feat: helper template ragflow.imageRepo to render image with global
repo
- chore(env): allow optional MINIO_HOST, MINIO_PASSWORD, REDIS_PASSWORD
(remove required); keep MYSQL_PASSWORD required
- docs(helm): add helm/README.md and update usage
- refactor(images): apply global repo to all components and init
containers
- test: align test busybox image with global repo helper

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
2025-12-29 13:29:47 +08:00
a764f0a5b2 Feat: Add Asana data source integration and configuration options (#12239)
### What problem does this PR solve?

change: Add Asana data source integration and configuration options

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-29 13:28:37 +08:00
Rin
651d9fff9f security: replace unsafe eval with ast.literal_eval in vision operators (#12236)
Addresses a potential RCE vulnerability in NormalizeImage by using
ast.literal_eval for safer string parsing.

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-29 13:28:09 +08:00
fddfce303c Fix (sdk): ensure variables defined in rm_chunk API (#12274)
### What problem does this PR solve?

Fixes a bug in the `rm_chunk` SDK interface where an `UnboundLocalError`
could
occur if `chunk_ids` is not provided in the request. 

- `unique_chunk_ids` and `duplicate_messages` are now always initialized
  in the `else` branch when `chunk_ids` is missing.
- API behavior remains unchanged when `chunk_ids` is present.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-29 13:18:23 +08:00
a24fc8291b Fix: If there is an error message on the chat page, the subsequent message references will not display correctly. #12252 (#12283)
### What problem does this PR solve?

Fix: If there is an error message on the chat page, the subsequent
message references will not display correctly. #12252

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-29 12:58:12 +08:00
37e4485415 feat: add MDX file support (#12261)
Feat: add MDX file support  #12057 
### What problem does this PR solve?

<img width="1055" height="270" alt="image"
src="https://github.com/user-attachments/assets/a0ab49f9-7806-41cd-8a96-f593591ab36b"
/>

The page states that MDX files are supported, but uploading fails with
the error: "x.mdx: This type of file has not been supported yet!"
<img width="381" height="110" alt="image"
src="https://github.com/user-attachments/assets/4bbb7d08-cb47-416a-95fc-bc90b90fcc39"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-29 12:54:31 +08:00
8d3f9d61da Fix: Delete chunk images on document parser config change. (#12262)
### What problem does this PR solve?

Modifying a document’s parser config previously left behind obsolete
chunk images. If the dataset isn’t manually deleted, these images
accumulate and waste storage. This PR fixes the issue by automatically
removing associated images when the parser config changes.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-29 12:54:11 +08:00
27c55f6514 Fix the consistency of ts and datetime (#12288)
### What problem does this PR solve?

#12279
#11942 

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-29 12:37:13 +08:00
9883c572cd Refactor: keep timestamp consistency (#12279)
### What problem does this PR solve?

keep timestamp consistency

### Type of change

- [x] Refactoring
2025-12-29 12:02:43 +08:00
f9619defcc Fix: init memory size from es (#12282)
### What problem does this PR solve?

Handle return when none exist index.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-29 12:01:45 +08:00
01f0ced1e6 Fix IDE warnings (#12281)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-29 12:01:18 +08:00
647fb115a0 Fix: Data-source S3 page style (#12255)
### What problem does this PR solve?

Fix: Data-source S3 page style

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-29 09:46:35 +08:00
2114b9e3ad Update deploy_local_llm.mdx (#12276)
### Type of change

- [x] Documentation Update
2025-12-28 19:46:50 +08:00
45b96acf6b Update deploy_local_llm.mdx vllm guide picture (#12275)
### Type of change
- [x] Documentation Update
2025-12-28 19:29:33 +08:00
Rin
3305215144 docs: add security warnings for default passwords in .env (#12250)
Enhances security by adding explicit warnings in the environment
template about changing default passwords for MySQL, Elasticsearch, and
MinIO before deployment.
2025-12-28 14:02:17 +08:00
86b03f399a Fix error in docs (#12269)
### What problem does this PR solve?

As title

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-28 11:55:52 +08:00
149 changed files with 4659 additions and 1249 deletions

View File

@ -233,7 +233,7 @@ releases! 🌟
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anormal` > If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network abnormal`
> error because, at that moment, your RAGFlow may not be fully initialized. > error because, at that moment, your RAGFlow may not be fully initialized.
> >
5. In your web browser, enter the IP address of your server and log in to RAGFlow. 5. In your web browser, enter the IP address of your server and log in to RAGFlow.
@ -303,6 +303,15 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
``` ```
Or if you are behind a proxy, you can pass proxy arguments:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 Launch service from source for development ## 🔨 Launch service from source for development
1. Install `uv` and `pre-commit`, or skip this step if they are already installed: 1. Install `uv` and `pre-commit`, or skip this step if they are already installed:

View File

@ -233,7 +233,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network anormal` > Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network abnormal`
> karena RAGFlow mungkin belum sepenuhnya siap. > karena RAGFlow mungkin belum sepenuhnya siap.
> >
2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow. 2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow.
@ -277,6 +277,15 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
``` ```
Jika berada di belakang proxy, Anda dapat melewatkan argumen proxy:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 Menjalankan Aplikasi dari untuk Pengembangan ## 🔨 Menjalankan Aplikasi dari untuk Pengembangan
1. Instal `uv` dan `pre-commit`, atau lewati langkah ini jika sudah terinstal: 1. Instal `uv` dan `pre-commit`, atau lewati langkah ini jika sudah terinstal:

View File

@ -277,6 +277,15 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
``` ```
プロキシ環境下にいる場合は、プロキシ引数を指定できます:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 ソースコードからサービスを起動する方法 ## 🔨 ソースコードからサービスを起動する方法
1. `uv` と `pre-commit` をインストールする。すでにインストールされている場合は、このステップをスキップしてください: 1. `uv` と `pre-commit` をインストールする。すでにインストールされている場合は、このステップをスキップしてください:

View File

@ -214,7 +214,7 @@
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network anormal` 오류가 발생할 수 있습니다. > 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network abnormal` 오류가 발생할 수 있습니다.
2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요. 2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요.
> 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다. > 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다.
@ -271,6 +271,15 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
``` ```
프록시 환경인 경우, 프록시 인수를 전달할 수 있습니다:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 소스 코드로 서비스를 시작합니다. ## 🔨 소스 코드로 서비스를 시작합니다.
1. `uv` 와 `pre-commit` 을 설치하거나, 이미 설치된 경우 이 단계를 건너뜁니다: 1. `uv` 와 `pre-commit` 을 설치하거나, 이미 설치된 경우 이 단계를 건너뜁니다:

View File

@ -232,7 +232,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
* Rodando em todos os endereços (0.0.0.0) * Rodando em todos os endereços (0.0.0.0)
``` ```
> Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network anormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado. > Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network abnormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado.
> >
5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow. 5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow.
@ -294,6 +294,15 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
``` ```
Se você estiver atrás de um proxy, pode passar argumentos de proxy:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 Lançar o serviço a partir do código-fonte para desenvolvimento ## 🔨 Lançar o serviço a partir do código-fonte para desenvolvimento
1. Instale o `uv` e o `pre-commit`, ou pule esta etapa se eles já estiverem instalados: 1. Instale o `uv` e o `pre-commit`, ou pule esta etapa se eles já estiverem instalados:

View File

@ -125,7 +125,7 @@
### 🍔 **相容各類異質資料來源** ### 🍔 **相容各類異質資料來源**
- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、印件、結構化資料、網頁等。 - 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、印件、結構化資料、網頁等。
### 🛀 **全程無憂、自動化的 RAG 工作流程** ### 🛀 **全程無憂、自動化的 RAG 工作流程**
@ -237,7 +237,7 @@
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> 如果您跳過這一步驟系統確認步驟就登入 RAGFlow你的瀏覽器有可能會提示 `network anormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。 > 如果您跳過這一步驟系統確認步驟就登入 RAGFlow你的瀏覽器有可能會提示 `network abnormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。
> >
5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。 5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。
@ -303,6 +303,15 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
``` ```
若您位於代理環境,可傳遞代理參數:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 以原始碼啟動服務 ## 🔨 以原始碼啟動服務
1. 安裝 `uv` 和 `pre-commit`。如已安裝,可跳過此步驟: 1. 安裝 `uv` 和 `pre-commit`。如已安裝,可跳過此步驟:

View File

@ -238,7 +238,7 @@
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow你的浏览器有可能会提示 `network anormal` 或 `网络异常`。 > 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow你的浏览器有可能会提示 `network abnormal` 或 `网络异常`。
5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。 5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
> 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80 > 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80
@ -302,6 +302,15 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
``` ```
如果您处在代理环境下,可以传递代理参数:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 以源代码启动服务 ## 🔨 以源代码启动服务
1. 安装 `uv` 和 `pre-commit`。如已经安装,可跳过本步骤: 1. 安装 `uv` 和 `pre-commit`。如已经安装,可跳过本步骤:

View File

@ -33,7 +33,7 @@ from common.connection_utils import timeout
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common import settings from common import settings
from api.db.joint_services.memory_message_service import save_to_memory from api.db.joint_services.memory_message_service import queue_save_to_memory_task
class MessageParam(ComponentParamBase): class MessageParam(ComponentParamBase):
@ -437,17 +437,4 @@ class Message(ComponentBase):
"user_input": self._canvas.get_sys_query(), "user_input": self._canvas.get_sys_query(),
"agent_response": content "agent_response": content
} }
res = [] return await queue_save_to_memory_task(self._param.memory_ids, message_dict)
for memory_id in self._param.memory_ids:
success, msg = await save_to_memory(memory_id, message_dict)
res.append({
"memory_id": memory_id,
"success": success,
"msg": msg
})
if all([r["success"] for r in res]):
return True, "Successfully added to memories."
error_text = "Some messages failed to add. " + " ".join([f"Add to memory {r['memory_id']} failed, detail: {r['msg']}" for r in res if not r["success"]])
logging.error(error_text)
return False, error_text

View File

@ -746,6 +746,7 @@ async def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
DocumentService.delete_chunk_images(doc, tenant_id)
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None return None

View File

@ -159,7 +159,8 @@ async def delete_memory(memory_id):
return get_json_result(message=True, code=RetCode.NOT_FOUND) return get_json_result(message=True, code=RetCode.NOT_FOUND)
try: try:
MemoryService.delete_memory(memory_id) MemoryService.delete_memory(memory_id)
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) if MessageService.has_index(memory.tenant_id, memory_id):
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True) return get_json_result(message=True)
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)

View File

@ -1286,6 +1286,9 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
if "chunk_ids" in req: if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
condition["id"] = unique_chunk_ids condition["id"] = unique_chunk_ids
else:
unique_chunk_ids = []
duplicate_messages = []
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0: if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)

View File

@ -177,7 +177,7 @@ def healthz():
return jsonify(result), (200 if all_ok else 500) return jsonify(result), (200 if all_ok else 500)
@manager.route("/ping", methods=["GET"]) # noqa: F821 @manager.route("/ping", methods=["GET"]) # noqa: F821
def ping(): def ping():
return "pong", 200 return "pong", 200
@ -213,7 +213,7 @@ def new_token():
if not tenants: if not tenants:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
obj = { obj = {
"tenant_id": tenant_id, "tenant_id": tenant_id,
"token": generate_confirmation_token(), "token": generate_confirmation_token(),
@ -268,13 +268,12 @@ def token_list():
if not tenants: if not tenants:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
objs = APITokenService.query(tenant_id=tenant_id) objs = APITokenService.query(tenant_id=tenant_id)
objs = [o.to_dict() for o in objs] objs = [o.to_dict() for o in objs]
for o in objs: for o in objs:
if not o["beta"]: if not o["beta"]:
o["beta"] = generate_confirmation_token().replace( o["beta"] = generate_confirmation_token().replace("ragflow-", "")[:32]
"ragflow-", "")[:32]
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o) APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
return get_json_result(data=objs) return get_json_result(data=objs)
except Exception as e: except Exception as e:
@ -307,13 +306,19 @@ def rm(token):
type: boolean type: boolean
description: Deletion status. description: Deletion status.
""" """
APITokenService.filter_delete( try:
[APIToken.tenant_id == current_user.id, APIToken.token == token] tenants = UserTenantService.query(user_id=current_user.id)
) if not tenants:
return get_json_result(data=True) return get_data_error_result(message="Tenant not found!")
tenant_id = tenants[0].tenant_id
APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == token])
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/config', methods=['GET']) # noqa: F821 @manager.route("/config", methods=["GET"]) # noqa: F821
def get_config(): def get_config():
""" """
Get system configuration. Get system configuration.
@ -330,6 +335,4 @@ def get_config():
type: integer 0 means disabled, 1 means enabled type: integer 0 means disabled, 1 means enabled
description: Whether user registration is enabled description: Whether user registration is enabled
""" """
return get_json_result(data={ return get_json_result(data={"registerEnabled": settings.REGISTER_ENABLED})
"registerEnabled": settings.REGISTER_ENABLED
})

View File

@ -16,9 +16,14 @@
import logging import logging
from typing import List from typing import List
from api.db.services.task_service import TaskService
from common import settings
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType from common.constants import MemoryType, LLMType
from common.doc_store.doc_store_base import FusionExpr from common.doc_store.doc_store_base import FusionExpr
from common.misc_utils import get_uuid
from api.db.db_utils import bulk_insert_into_db
from api.db.db_models import Task
from api.db.services.memory_service import MemoryService from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
@ -82,32 +87,44 @@ async def save_to_memory(memory_id: str, message_dict: dict):
"forget_at": None, "forget_at": None,
"status": True "status": True
} for content in extracted_content]] } for content in extracted_content]]
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) return await embed_and_save(memory, message_list)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(tenant_id, memory_id):
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "FIFO":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
decrease_memory_size_cache(memory_id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
increase_memory_size_cache(memory_id, new_msg_size) async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int):
return True, "Message saved successfully." memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return False, f"Memory '{memory_id}' not found."
if memory.memory_type == MemoryType.RAW.value:
return True, f"Memory '{memory_id}' don't need to extract."
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", "")
)
message_list = [{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
"message_type": content["message_type"],
"source_id": source_message_id,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": content["content"],
"valid_at": content["valid_at"],
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
"forget_at": None,
"status": True
} for content in extracted_content]
if not message_list:
return True, "No memory extracted from raw message."
return await embed_and_save(memory, message_list)
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str, async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
@ -136,6 +153,36 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list] } for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
async def embed_and_save(memory, message_list: list[dict]):
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(memory.tenant_id, memory.id):
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "FIFO":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id,
size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
decrease_memory_size_cache(memory.id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
increase_memory_size_cache(memory.id, new_msg_size)
return True, "Message saved successfully."
def query_message(filter_dict: dict, params: dict): def query_message(filter_dict: dict, params: dict):
""" """
:param filter_dict: { :param filter_dict: {
@ -231,3 +278,112 @@ def init_memory_size_cache():
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]): def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type) memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list}) return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict):
"""
:param memory_ids:
:param message_dict: {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
"""
def new_task(_memory_id: str, _source_id: int):
return {
"id": get_uuid(),
"doc_id": _memory_id,
"task_type": "memory",
"progress": 0.0,
"digest": str(_source_id)
}
not_found_memory = []
failed_memory = []
for memory_id in memory_ids:
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
not_found_memory.append(memory_id)
continue
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
raw_message = {
"message_id": raw_message_id,
"message_type": MemoryType.RAW.name.lower(),
"source_id": 0,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
"valid_at": timestamp_to_date(current_timestamp()),
"invalid_at": None,
"forget_at": None,
"status": True
}
res, msg = await embed_and_save(memory, [raw_message])
if not res:
failed_memory.append({"memory_id": memory_id, "fail_msg": msg})
continue
task = new_task(memory_id, raw_message_id)
bulk_insert_into_db(Task, [task], replace_on_conflict=True)
task_message = {
"id": task["id"],
"task_id": task["id"],
"task_type": task["task_type"],
"memory_id": memory_id,
"source_id": raw_message_id,
"message_dict": message_dict
}
if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message):
failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."})
error_msg = ""
if not_found_memory:
error_msg = f"Memory {not_found_memory} not found."
if failed_memory:
error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory])
if error_msg:
return False, error_msg
return True, "All add to task."
async def handle_save_to_memory_task(task_param: dict):
"""
:param task_param: {
"id": task_id
"memory_id": id
"source_id": id
"message_dict": {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
}
"""
_, task = TaskService.get_by_id(task_param["id"])
if not task:
return False, f"Task {task_param['id']} is not found."
if task.progress == -1:
return False, f"Task {task_param['id']} is already failed."
now_time = current_timestamp()
TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)})
memory_id = task_param["memory_id"]
source_id = task_param["source_id"]
message_dict = task_param["message_dict"]
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id)
if success:
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg})
return True, msg
logging.error(msg)
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None})
return False, msg

View File

@ -342,21 +342,7 @@ class DocumentService(CommonService):
cls.clear_chunk_num(doc.id) cls.clear_chunk_num(doc.id)
try: try:
TaskService.filter_delete([Task.doc_id == doc.id]) TaskService.filter_delete([Task.doc_id == doc.id])
page = 0 cls.delete_chunk_images(doc, tenant_id)
page_size = 1000
all_chunk_ids = []
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
page += 1
for cid in all_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
@ -378,6 +364,23 @@ class DocumentService(CommonService):
pass pass
return cls.delete_by_id(doc.id) return cls.delete_by_id(doc.id)
@classmethod
@DB.connection_context()
def delete_chunk_images(cls, doc, tenant_id):
page = 0
page_size = 1000
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
for cid in chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
page += 1
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_newly_uploaded(cls): def get_newly_uploaded(cls):

View File

@ -65,6 +65,7 @@ class EvaluationService(CommonService):
(success, dataset_id or error_message) (success, dataset_id or error_message)
""" """
try: try:
timestamp= current_timestamp()
dataset_id = get_uuid() dataset_id = get_uuid()
dataset = { dataset = {
"id": dataset_id, "id": dataset_id,
@ -73,8 +74,8 @@ class EvaluationService(CommonService):
"description": description, "description": description,
"kb_ids": kb_ids, "kb_ids": kb_ids,
"created_by": user_id, "created_by": user_id,
"create_time": current_timestamp(), "create_time": timestamp,
"update_time": current_timestamp(), "update_time": timestamp,
"status": StatusEnum.VALID.value "status": StatusEnum.VALID.value
} }

View File

@ -64,10 +64,13 @@ class TenantLangfuseService(CommonService):
@classmethod @classmethod
def save(cls, **kwargs): def save(cls, **kwargs):
kwargs["create_time"] = current_timestamp() current_ts = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now()) current_date = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now()) kwargs["create_time"] = current_ts
kwargs["create_date"] = current_date
kwargs["update_time"] = current_ts
kwargs["update_date"] = current_date
obj = cls.model.create(**kwargs) obj = cls.model.create(**kwargs)
return obj return obj

View File

@ -169,11 +169,12 @@ class PipelineOperationLogService(CommonService):
operation_status=operation_status, operation_status=operation_status,
avatar=avatar, avatar=avatar,
) )
log["create_time"] = current_timestamp() timestamp = current_timestamp()
log["create_date"] = datetime_format(datetime.now()) datetime_now = datetime_format(datetime.now())
log["update_time"] = current_timestamp() log["create_time"] = timestamp
log["update_date"] = datetime_format(datetime.now()) log["create_date"] = datetime_now
log["update_time"] = timestamp
log["update_date"] = datetime_now
with DB.atomic(): with DB.atomic():
obj = cls.save(**log) obj = cls.save(**log)

View File

@ -28,10 +28,13 @@ class SearchService(CommonService):
@classmethod @classmethod
def save(cls, **kwargs): def save(cls, **kwargs):
kwargs["create_time"] = current_timestamp() current_ts = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now()) current_date = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now()) kwargs["create_time"] = current_ts
kwargs["create_date"] = current_date
kwargs["update_time"] = current_ts
kwargs["update_date"] = current_date
obj = cls.model.create(**kwargs) obj = cls.model.create(**kwargs)
return obj return obj

View File

@ -116,10 +116,13 @@ class UserService(CommonService):
kwargs["password"] = generate_password_hash( kwargs["password"] = generate_password_hash(
str(kwargs["password"])) str(kwargs["password"]))
kwargs["create_time"] = current_timestamp() current_ts = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now()) current_date = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now()) kwargs["create_time"] = current_ts
kwargs["create_date"] = current_date
kwargs["update_time"] = current_ts
kwargs["update_date"] = current_date
obj = cls.model(**kwargs).save(force_insert=True) obj = cls.model(**kwargs).save(force_insert=True)
return obj return obj

View File

@ -42,7 +42,7 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|mdx|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename): if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename):

View File

@ -69,6 +69,7 @@ CONTENT_TYPE_MAP = {
# Web # Web
"md": "text/markdown", "md": "text/markdown",
"markdown": "text/markdown", "markdown": "text/markdown",
"mdx": "text/markdown",
"htm": "text/html", "htm": "text/html",
"html": "text/html", "html": "text/html",
"json": "application/json", "json": "application/json",

View File

@ -129,7 +129,9 @@ class FileSource(StrEnum):
OCI_STORAGE = "oci_storage" OCI_STORAGE = "oci_storage"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage" GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
AIRTABLE = "airtable" AIRTABLE = "airtable"
ASANA = "asana"
GITHUB = "github"
GITLAB = "gitlab"
class PipelineTaskType(StrEnum): class PipelineTaskType(StrEnum):
PARSE = "Parse" PARSE = "Parse"
@ -137,6 +139,7 @@ class PipelineTaskType(StrEnum):
RAPTOR = "RAPTOR" RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG" GRAPH_RAG = "GraphRAG"
MINDMAP = "Mindmap" MINDMAP = "Mindmap"
MEMORY = "Memory"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,

View File

@ -37,6 +37,7 @@ from .teams_connector import TeamsConnector
from .webdav_connector import WebDAVConnector from .webdav_connector import WebDAVConnector
from .moodle_connector import MoodleConnector from .moodle_connector import MoodleConnector
from .airtable_connector import AirtableConnector from .airtable_connector import AirtableConnector
from .asana_connector import AsanaConnector
from .config import BlobType, DocumentSource from .config import BlobType, DocumentSource
from .models import Document, TextSection, ImageSection, BasicExpertInfo from .models import Document, TextSection, ImageSection, BasicExpertInfo
from .exceptions import ( from .exceptions import (
@ -73,4 +74,5 @@ __all__ = [
"InsufficientPermissionsError", "InsufficientPermissionsError",
"UnexpectedValidationError", "UnexpectedValidationError",
"AirtableConnector", "AirtableConnector",
"AsanaConnector",
] ]

View File

@ -0,0 +1,454 @@
from collections.abc import Iterator
import time
from datetime import datetime
import logging
from typing import Any, Dict
import asana
import requests
from common.data_source.config import CONTINUE_ON_CONNECTOR_FAILURE, INDEX_BATCH_SIZE, DocumentSource
from common.data_source.interfaces import LoadConnector, PollConnector
from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
from common.data_source.utils import extract_size_bytes, get_file_ext
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
class AsanaTask:
def __init__(
self,
id: str,
title: str,
text: str,
link: str,
last_modified: datetime,
project_gid: str,
project_name: str,
) -> None:
self.id = id
self.title = title
self.text = text
self.link = link
self.last_modified = last_modified
self.project_gid = project_gid
self.project_name = project_name
def __str__(self) -> str:
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
class AsanaAPI:
def __init__(
self, api_token: str, workspace_gid: str, team_gid: str | None
) -> None:
self._user = None
self.workspace_gid = workspace_gid
self.team_gid = team_gid
self.configuration = asana.Configuration()
self.api_client = asana.ApiClient(self.configuration)
self.tasks_api = asana.TasksApi(self.api_client)
self.attachments_api = asana.AttachmentsApi(self.api_client)
self.stories_api = asana.StoriesApi(self.api_client)
self.users_api = asana.UsersApi(self.api_client)
self.project_api = asana.ProjectsApi(self.api_client)
self.project_memberships_api = asana.ProjectMembershipsApi(self.api_client)
self.workspaces_api = asana.WorkspacesApi(self.api_client)
self.api_error_count = 0
self.configuration.access_token = api_token
self.task_count = 0
def get_tasks(
self, project_gids: list[str] | None, start_date: str
) -> Iterator[AsanaTask]:
"""Get all tasks from the projects with the given gids that were modified since the given date.
If project_gids is None, get all tasks from all projects in the workspace."""
logging.info("Starting to fetch Asana projects")
projects = self.project_api.get_projects(
opts={
"workspace": self.workspace_gid,
"opt_fields": "gid,name,archived,modified_at",
}
)
start_seconds = int(time.mktime(datetime.now().timetuple()))
projects_list = []
project_count = 0
for project_info in projects:
project_gid = project_info["gid"]
if project_gids is None or project_gid in project_gids:
projects_list.append(project_gid)
else:
logging.debug(
f"Skipping project: {project_gid} - not in accepted project_gids"
)
project_count += 1
if project_count % 100 == 0:
logging.info(f"Processed {project_count} projects")
logging.info(f"Found {len(projects_list)} projects to process")
for project_gid in projects_list:
for task in self._get_tasks_for_project(
project_gid, start_date, start_seconds
):
yield task
logging.info(f"Completed fetching {self.task_count} tasks from Asana")
if self.api_error_count > 0:
logging.warning(
f"Encountered {self.api_error_count} API errors during task fetching"
)
def _get_tasks_for_project(
self, project_gid: str, start_date: str, start_seconds: int
) -> Iterator[AsanaTask]:
project = self.project_api.get_project(project_gid, opts={})
project_name = project.get("name", project_gid)
team = project.get("team") or {}
team_gid = team.get("gid")
if project.get("archived"):
logging.info(f"Skipping archived project: {project_name} ({project_gid})")
return
if not team_gid:
logging.info(
f"Skipping project without a team: {project_name} ({project_gid})"
)
return
if project.get("privacy_setting") == "private":
if self.team_gid and team_gid != self.team_gid:
logging.info(
f"Skipping private project not in configured team: {project_name} ({project_gid})"
)
return
logging.info(
f"Processing private project in configured team: {project_name} ({project_gid})"
)
simple_start_date = start_date.split(".")[0].split("+")[0]
logging.info(
f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})"
)
opts = {
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
"workspace,permalink_url",
"modified_since": start_date,
}
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
for data in tasks_from_api:
self.task_count += 1
if self.task_count % 10 == 0:
end_seconds = time.mktime(datetime.now().timetuple())
runtime_seconds = end_seconds - start_seconds
if runtime_seconds > 0:
logging.info(
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
)
logging.debug(f"Processing Asana task: {data['name']}")
text = self._construct_task_text(data)
try:
text += self._fetch_and_add_comments(data["gid"])
last_modified_date = self.format_date(data["modified_at"])
text += f"Last modified: {last_modified_date}\n"
task = AsanaTask(
id=data["gid"],
title=data["name"],
text=text,
link=data["permalink_url"],
last_modified=datetime.fromisoformat(data["modified_at"]),
project_gid=project_gid,
project_name=project_name,
)
yield task
except Exception:
logging.error(
f"Error processing task {data['gid']} in project {project_gid}",
exc_info=True,
)
self.api_error_count += 1
def _construct_task_text(self, data: Dict) -> str:
text = f"{data['name']}\n\n"
if data["notes"]:
text += f"{data['notes']}\n\n"
if data["created_by"] and data["created_by"]["gid"]:
creator = self.get_user(data["created_by"]["gid"])["name"]
created_date = self.format_date(data["created_at"])
text += f"Created by: {creator} on {created_date}\n"
if data["due_on"]:
due_date = self.format_date(data["due_on"])
text += f"Due date: {due_date}\n"
if data["completed_at"]:
completed_date = self.format_date(data["completed_at"])
text += f"Completed on: {completed_date}\n"
text += "\n"
return text
def _fetch_and_add_comments(self, task_gid: str) -> str:
text = ""
stories_opts: Dict[str, str] = {}
story_start = time.time()
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
story_count = 0
comment_count = 0
for story in stories:
story_count += 1
if story["resource_subtype"] == "comment_added":
comment = self.stories_api.get_story(
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
)
commenter = self.get_user(comment["created_by"]["gid"])["name"]
text += f"Comment by {commenter}: {comment['text']}\n\n"
comment_count += 1
story_duration = time.time() - story_start
logging.debug(
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
)
return text
def get_attachments(self, task_gid: str) -> list[dict]:
"""
Fetch full attachment info (including download_url) for a task.
"""
attachments: list[dict] = []
try:
# Step 1: list attachment compact records
for att in self.attachments_api.get_attachments_for_object(
parent=task_gid,
opts={}
):
gid = att.get("gid")
if not gid:
continue
try:
# Step 2: expand to full attachment
full = self.attachments_api.get_attachment(
attachment_gid=gid,
opts={
"opt_fields": "name,download_url,size,created_at"
}
)
if full.get("download_url"):
attachments.append(full)
except Exception:
logging.exception(
f"Failed to fetch attachment detail {gid} for task {task_gid}"
)
self.api_error_count += 1
except Exception:
logging.exception(f"Failed to list attachments for task {task_gid}")
self.api_error_count += 1
return attachments
def get_accessible_emails(
self,
workspace_id: str,
project_ids: list[str] | None,
team_id: str | None,
):
ws_users = self.users_api.get_users(
opts={
"workspace": workspace_id,
"opt_fields": "gid,name,email"
}
)
workspace_users = {
u["gid"]: u.get("email")
for u in ws_users
if u.get("email")
}
if not project_ids:
return set(workspace_users.values())
project_emails = set()
for pid in project_ids:
project = self.project_api.get_project(
pid,
opts={"opt_fields": "team,privacy_setting"}
)
if project["privacy_setting"] == "private":
if team_id and project.get("team", {}).get("gid") != team_id:
continue
memberships = self.project_memberships_api.get_project_membership(
pid,
opts={"opt_fields": "user.gid,user.email"}
)
for m in memberships:
email = m["user"].get("email")
if email:
project_emails.add(email)
return project_emails
def get_user(self, user_gid: str) -> Dict:
if self._user is not None:
return self._user
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
if not self._user:
logging.warning(f"Unable to fetch user information for user_gid: {user_gid}")
return {"name": "Unknown"}
return self._user
def format_date(self, date_str: str) -> str:
date = datetime.fromisoformat(date_str)
return time.strftime("%Y-%m-%d", date.timetuple())
def get_time(self) -> str:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
class AsanaConnector(LoadConnector, PollConnector):
def __init__(
self,
asana_workspace_id: str,
asana_project_ids: str | None = None,
asana_team_id: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids else None
)
self.asana_team_id = asana_team_id if asana_team_id else None
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.size_threshold = None
logging.info(
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.api_token = credentials["asana_api_token_secret"]
self.asana_client = AsanaAPI(
api_token=self.api_token,
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
self.workspace_users_email = self.asana_client.get_accessible_emails(self.workspace_id, self.project_ids_to_index, self.asana_team_id)
logging.info("Asana credentials loaded and API client initialized")
return None
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
start_time = datetime.fromtimestamp(start).isoformat()
logging.info(f"Starting Asana poll from {start_time}")
docs_batch: list[Document] = []
tasks = self.asana_client.get_tasks(self.project_ids_to_index, start_time)
for task in tasks:
docs = self._task_to_documents(task)
docs_batch.extend(docs)
if len(docs_batch) >= self.batch_size:
logging.info(f"Yielding batch of {len(docs_batch)} documents")
yield docs_batch
docs_batch = []
if docs_batch:
logging.info(f"Yielding final batch of {len(docs_batch)} documents")
yield docs_batch
logging.info("Asana poll completed")
def load_from_state(self) -> GenerateDocumentsOutput:
logging.info("Starting full index of all Asana tasks")
return self.poll_source(start=0, end=None)
def _task_to_documents(self, task: AsanaTask) -> list[Document]:
docs: list[Document] = []
attachments = self.asana_client.get_attachments(task.id)
for att in attachments:
try:
resp = requests.get(att["download_url"], timeout=30)
resp.raise_for_status()
file_blob = resp.content
filename = att.get("name", "attachment")
size_bytes = extract_size_bytes(att)
if (
self.size_threshold is not None
and isinstance(size_bytes, int)
and size_bytes > self.size_threshold
):
logging.warning(
f"{filename} exceeds size threshold of {self.size_threshold}. Skipping."
)
continue
docs.append(
Document(
id=f"asana:{task.id}:{att['gid']}",
blob=file_blob,
extension=get_file_ext(filename) or "",
size_bytes=size_bytes,
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=filename,
primary_owners=list(self.workspace_users_email),
)
)
except Exception:
logging.exception(
f"Failed to download attachment {att.get('gid')} for task {task.id}"
)
return docs
if __name__ == "__main__":
import time
import os
logging.info("Starting Asana connector test")
connector = AsanaConnector(
os.environ["WORKSPACE_ID"],
os.environ["PROJECT_IDS"],
os.environ["TEAM_ID"],
)
connector.load_credentials(
{
"asana_api_token_secret": os.environ["API_TOKEN"],
}
)
logging.info("Loading all documents from Asana")
all_docs = connector.load_from_state()
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
logging.info("Polling for documents updated in the last 24 hours")
latest_docs = connector.poll_source(one_day_ago, current)
for docs in all_docs:
for doc in docs:
print(doc.id)
logging.info("Asana connector test completed")

View File

@ -54,6 +54,10 @@ class DocumentSource(str, Enum):
DROPBOX = "dropbox" DROPBOX = "dropbox"
BOX = "box" BOX = "box"
AIRTABLE = "airtable" AIRTABLE = "airtable"
ASANA = "asana"
GITHUB = "github"
GITLAB = "gitlab"
class FileOrigin(str, Enum): class FileOrigin(str, Enum):
"""File origins""" """File origins"""
@ -231,6 +235,8 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback") BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback")
GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None
class HtmlBasedConnectorTransformLinksStrategy(str, Enum): class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
# remove links entirely # remove links entirely
STRIP = "strip" STRIP = "strip"
@ -256,6 +262,10 @@ AIRTABLE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("AIRTABLE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) os.environ.get("AIRTABLE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
) )
ASANA_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
_USER_NOT_FOUND = "Unknown Confluence User" _USER_NOT_FOUND = "Unknown Confluence User"
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"] _COMMENT_EXPANSION_FIELDS = ["body.storage.value"]

View File

@ -0,0 +1,217 @@
import sys
import time
import logging
from collections.abc import Generator
from datetime import datetime
from typing import Generic
from typing import TypeVar
from common.data_source.interfaces import (
BaseConnector,
CheckpointedConnector,
CheckpointedConnectorWithPermSync,
CheckpointOutput,
LoadConnector,
PollConnector,
)
from common.data_source.models import ConnectorCheckpoint, ConnectorFailure, Document
TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,
specifically for Document outputs.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: CT | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[CT],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, CT | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[CT],
) -> CheckpointOutput[CT]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
class ConnectorRunner(Generic[CT]):
"""
Handles:
- Batching
- Additional exception logging
- Combining different connector types to a single interface
"""
def __init__(
self,
connector: BaseConnector,
batch_size: int,
# cannot be True for non-checkpointed connectors
include_permissions: bool,
time_range: TimeRange | None = None,
):
if not isinstance(connector, CheckpointedConnector) and include_permissions:
raise ValueError(
"include_permissions cannot be True for non-checkpointed connectors"
)
self.connector = connector
self.time_range = time_range
self.batch_size = batch_size
self.include_permissions = include_permissions
self.doc_batch: list[Document] = []
def run(self, checkpoint: CT) -> Generator[
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
None,
None,
]:
"""Adds additional exception logging to the connector."""
try:
if isinstance(self.connector, CheckpointedConnector):
if self.time_range is None:
raise ValueError("time_range is required for CheckpointedConnector")
start = time.monotonic()
if self.include_permissions:
if not isinstance(
self.connector, CheckpointedConnectorWithPermSync
):
raise ValueError(
"Connector does not support permission syncing"
)
load_from_checkpoint = (
self.connector.load_from_checkpoint_with_perm_sync
)
else:
load_from_checkpoint = self.connector.load_from_checkpoint
checkpoint_connector_generator = load_from_checkpoint(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
checkpoint=checkpoint,
)
next_checkpoint: CT | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None and isinstance(document, Document):
self.doc_batch.append(document)
if failure is not None:
yield None, failure, None
if len(self.doc_batch) >= self.batch_size:
yield self.doc_batch, None, None
self.doc_batch = []
# yield remaining documents
if len(self.doc_batch) > 0:
yield self.doc_batch, None, None
self.doc_batch = []
yield None, None, next_checkpoint
logging.debug(
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
)
else:
finished_checkpoint = self.connector.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):
if self.time_range is None:
raise ValueError("time_range is required for PollConnector")
for document_batch in self.connector.poll_source(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
):
yield document_batch, None, None
yield None, None, finished_checkpoint
elif isinstance(self.connector, LoadConnector):
for document_batch in self.connector.load_from_state():
yield document_batch, None, None
yield None, None, finished_checkpoint
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
except Exception:
exc_type, _, exc_traceback = sys.exc_info()
# Traverse the traceback to find the last frame where the exception was raised
tb = exc_traceback
if tb is None:
logging.error("No traceback found for exception")
raise
while tb.tb_next:
tb = tb.tb_next # Move to the next frame in the traceback
# Get the local variables from the frame where the exception occurred
local_vars = tb.tb_frame.f_locals
local_vars_str = "\n".join(
f"{key}: {value}" for key, value in local_vars.items()
)
logging.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str[:1024]}"
)
raise

View File

@ -18,6 +18,7 @@ class UploadMimeTypes:
"text/plain", "text/plain",
"text/markdown", "text/markdown",
"text/x-markdown", "text/x-markdown",
"text/mdx",
"text/x-config", "text/x-config",
"text/tab-separated-values", "text/tab-separated-values",
"application/json", "application/json",

View File

View File

@ -0,0 +1,973 @@
import copy
import logging
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
from github import Github, Auth
from github import RateLimitExceededException
from github import Repository
from github.GithubException import GithubException
from github.Issue import Issue
from github.NamedUser import NamedUser
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from pydantic import BaseModel
from typing_extensions import override
from common.data_source.google_util.util import sanitize_filename
from common.data_source.config import DocumentSource, GITHUB_CONNECTOR_BASE_URL
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError,
)
from common.data_source.interfaces import CheckpointedConnectorWithPermSyncGH, CheckpointOutput
from common.data_source.models import (
ConnectorCheckpoint,
ConnectorFailure,
Document,
DocumentFailure,
ExternalAccess,
SecondsSinceUnixEpoch,
)
from common.data_source.connector_runner import ConnectorRunner
from .models import SerializedRepository
from .rate_limit_utils import sleep_after_rate_limit_exception
from .utils import deserialize_repository
from .utils import get_external_access_permission
ITEMS_PER_PAGE = 100
CURSOR_LOG_FREQUENCY = 50
_MAX_NUM_RATE_LIMIT_RETRIES = 5
ONE_DAY = timedelta(days=1)
SLIM_BATCH_SIZE = 100
# Cases
# X (from start) standard run, no fallback to cursor-based pagination
# X (from start) standard run errors, fallback to cursor-based pagination
# X error in the middle of a page
# X no errors: run to completion
# X (from checkpoint) standard run, no fallback to cursor-based pagination
# X (from checkpoint) continue from cursor-based pagination
# - retrying
# - no retrying
# things to check:
# checkpoint state on return
# checkpoint progress (no infinite loop)
class DocMetadata(BaseModel):
repo: str
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
if "_PaginatedList__nextUrl" in pag_list.__dict__:
return "_PaginatedList__nextUrl"
for key in pag_list.__dict__:
if "__nextUrl" in key:
return key
for key in pag_list.__dict__:
if "nextUrl" in key:
return key
return ""
def get_nextUrl(
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str
) -> str | None:
return getattr(pag_list, nextUrl_key) if nextUrl_key else None
def set_nextUrl(
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str
) -> None:
if nextUrl_key:
setattr(pag_list, nextUrl_key, nextUrl)
elif nextUrl:
raise ValueError("Next URL key not found: " + str(pag_list.__dict__))
def _paginate_until_error(
git_objs: Callable[[], PaginatedList[PullRequest | Issue]],
cursor_url: str | None,
prev_num_objs: int,
cursor_url_callback: Callable[[str | None, int], None],
retrying: bool = False,
) -> Generator[PullRequest | Issue, None, None]:
num_objs = prev_num_objs
pag_list = git_objs()
nextUrl_key = get_nextUrl_key(pag_list)
if cursor_url:
set_nextUrl(pag_list, nextUrl_key, cursor_url)
elif retrying:
# if we are retrying, we want to skip the objects retrieved
# over previous calls. Unfortunately, this WILL retrieve all
# pages before the one we are resuming from, so we really
# don't want this case to be hit often
logging.warning(
"Retrying from a previous cursor-based pagination call. "
"This will retrieve all pages before the one we are resuming from, "
"which may take a while and consume many API calls."
)
pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:])
num_objs = 0
try:
# this for loop handles cursor-based pagination
for issue_or_pr in pag_list:
num_objs += 1
yield issue_or_pr
# used to store the current cursor url in the checkpoint. This value
# is updated during iteration over pag_list.
cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs)
if num_objs % CURSOR_LOG_FREQUENCY == 0:
logging.info(
f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}"
)
except Exception as e:
logging.exception(f"Error during cursor-based pagination: {e}")
if num_objs - prev_num_objs > 0:
raise
if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying:
logging.info(
"Assuming that this error is due to cursor "
"expiration because no objects were retrieved. "
"Retrying from the first page."
)
yield from _paginate_until_error(
git_objs, None, prev_num_objs, cursor_url_callback, retrying=True
)
return
# for no cursor url or if we reach this point after a retry, raise the error
raise
def _get_batch_rate_limited(
# We pass in a callable because we want git_objs to produce a fresh
# PaginatedList each time it's called to avoid using the same object for cursor-based pagination
# from a partial offset-based pagination call.
git_objs: Callable[[], PaginatedList],
page_num: int,
cursor_url: str | None,
prev_num_objs: int,
cursor_url_callback: Callable[[str | None, int], None],
github_client: Github,
attempt_num: int = 0,
) -> Generator[PullRequest | Issue, None, None]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
)
try:
if cursor_url:
# when this is set, we are resuming from an earlier
# cursor-based pagination call.
yield from _paginate_until_error(
git_objs, cursor_url, prev_num_objs, cursor_url_callback
)
return
objs = list(git_objs().get_page(page_num))
# fetch all data here to disable lazy loading later
# this is needed to capture the rate limit exception here (if one occurs)
for obj in objs:
if hasattr(obj, "raw_data"):
getattr(obj, "raw_data")
yield from objs
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
yield from _get_batch_rate_limited(
git_objs,
page_num,
cursor_url,
prev_num_objs,
cursor_url_callback,
github_client,
attempt_num + 1,
)
except GithubException as e:
if not (
e.status == 422
and (
"cursor" in (e.message or "")
or "cursor" in (e.data or {}).get("message", "")
)
):
raise
# Fallback to a cursor-based pagination strategy
# This can happen for "large datasets," but there's no documentation
# On the error on the web as far as we can tell.
# Error message:
# "Pagination with the page parameter is not supported for large datasets,
# please use cursor based pagination (after/before)"
yield from _paginate_until_error(
git_objs, cursor_url, prev_num_objs, cursor_url_callback
)
def _get_userinfo(user: NamedUser) -> dict[str, str]:
def _safe_get(attr_name: str) -> str | None:
try:
return cast(str | None, getattr(user, attr_name))
except GithubException:
logging.debug(f"Error getting {attr_name} for user")
return None
return {
k: v
for k, v in {
"login": _safe_get("login"),
"name": _safe_get("name"),
"email": _safe_get("email"),
}.items()
if v is not None
}
def _convert_pr_to_document(
pull_request: PullRequest, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
doc_metadata = DocMetadata(repo=repo_name)
file_content_byte = pull_request.body.encode('utf-8') if pull_request.body else b""
name = sanitize_filename(pull_request.title, "md")
return Document(
id=pull_request.html_url,
blob= file_content_byte,
source=DocumentSource.GITHUB,
external_access=repo_external_access,
semantic_identifier=f"{pull_request.number}:{name}",
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=(
pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None
),
extension=".md",
# this metadata is used in perm sync
size_bytes=len(file_content_byte) if file_content_byte else 0,
primary_owners=[],
doc_metadata=doc_metadata.model_dump(),
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
"object_type": "PullRequest",
"id": pull_request.number,
"merged": pull_request.merged,
"state": pull_request.state,
"user": _get_userinfo(pull_request.user) if pull_request.user else None,
"assignees": [
_get_userinfo(assignee) for assignee in pull_request.assignees
],
"repo": (
pull_request.base.repo.full_name if pull_request.base else None
),
"num_commits": str(pull_request.commits),
"num_files_changed": str(pull_request.changed_files),
"labels": [label.name for label in pull_request.labels],
"created_at": (
pull_request.created_at.replace(tzinfo=timezone.utc)
if pull_request.created_at
else None
),
"updated_at": (
pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None
),
"closed_at": (
pull_request.closed_at.replace(tzinfo=timezone.utc)
if pull_request.closed_at
else None
),
"merged_at": (
pull_request.merged_at.replace(tzinfo=timezone.utc)
if pull_request.merged_at
else None
),
"merged_by": (
_get_userinfo(pull_request.merged_by)
if pull_request.merged_by
else None
),
}.items()
if v is not None
},
)
def _fetch_issue_comments(issue: Issue) -> str:
comments = issue.get_comments()
return "\nComment: ".join(comment.body for comment in comments)
def _convert_issue_to_document(
issue: Issue, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = issue.repository.full_name if issue.repository else ""
doc_metadata = DocMetadata(repo=repo_name)
file_content_byte = issue.body.encode('utf-8') if issue.body else b""
name = sanitize_filename(issue.title, "md")
return Document(
id=issue.html_url,
blob=file_content_byte,
source=DocumentSource.GITHUB,
extension=".md",
external_access=repo_external_access,
semantic_identifier=f"{issue.number}:{name}",
# updated_at is UTC time but is timezone unaware
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
# this metadata is used in perm sync
doc_metadata=doc_metadata.model_dump(),
size_bytes=len(file_content_byte) if file_content_byte else 0,
primary_owners=[_get_userinfo(issue.user) if issue.user else None],
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
"object_type": "Issue",
"id": issue.number,
"state": issue.state,
"user": _get_userinfo(issue.user) if issue.user else None,
"assignees": [_get_userinfo(assignee) for assignee in issue.assignees],
"repo": issue.repository.full_name if issue.repository else None,
"labels": [label.name for label in issue.labels],
"created_at": (
issue.created_at.replace(tzinfo=timezone.utc)
if issue.created_at
else None
),
"updated_at": (
issue.updated_at.replace(tzinfo=timezone.utc)
if issue.updated_at
else None
),
"closed_at": (
issue.closed_at.replace(tzinfo=timezone.utc)
if issue.closed_at
else None
),
"closed_by": (
_get_userinfo(issue.closed_by) if issue.closed_by else None
),
}.items()
if v is not None
},
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
class GithubConnectorCheckpoint(ConnectorCheckpoint):
stage: GithubConnectorStage
curr_page: int
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
# Used for the fallback cursor-based pagination strategy
num_retrieved: int
cursor_url: str | None = None
def reset(self) -> None:
"""
Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None)
"""
self.curr_page = 0
self.num_retrieved = 0
self.cursor_url = None
def make_cursor_url_callback(
checkpoint: GithubConnectorCheckpoint,
) -> Callable[[str | None, int], None]:
def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
# we want to maintain the old cursor url so code after retrieval
# can determine that we are using the fallback cursor-based pagination strategy
if cursor_url:
checkpoint.cursor_url = cursor_url
checkpoint.num_retrieved = num_objs
return cursor_url_callback
class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
token = credentials["github_access_token"]
auth = Auth.Token(token)
if GITHUB_CONNECTOR_BASE_URL:
self.github_client = Github(
auth=auth,
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
else:
self.github_client = Github(
auth=auth,
per_page=ITEMS_PER_PAGE,
)
return None
def get_github_repo(
self, github_client: Github, attempt_num: int = 0
) -> Repository.Repository:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github"
)
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
return self.get_github_repo(github_client, attempt_num + 1)
def get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
repos = []
# Split repo_name by comma and strip whitespace
repo_names = [
name.strip() for name in (cast(str, self.repositories)).split(",")
]
for repo_name in repo_names:
if repo_name: # Skip empty strings
try:
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
repos.append(repo)
except GithubException as e:
logging.warning(
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
)
return repos
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
return self.get_github_repos(github_client, attempt_num + 1)
def get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
# Try to get organization first
try:
org = github_client.get_organization(self.repo_owner)
return list(org.get_repos())
except GithubException:
# If not an org, try as a user
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
return self.get_all_repos(github_client, attempt_num + 1)
def _pull_requests_func(
self, repo: Repository.Repository
) -> Callable[[], PaginatedList[PullRequest]]:
return lambda: repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
def _issues_func(
self, repo: Repository.Repository
) -> Callable[[], PaginatedList[Issue]]:
return lambda: repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
def _fetch_from_github(
self,
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
checkpoint = copy.deepcopy(checkpoint)
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self.get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self.get_github_repo(self.github_client)]
else:
# All repositories
repos = self.get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
curr_repo = repos.pop()
checkpoint.cached_repo_ids = [repo.id for repo in repos]
checkpoint.cached_repo = SerializedRepository(
id=curr_repo.id,
headers=curr_repo.raw_headers,
raw_data=curr_repo.raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
if checkpoint.cached_repo is None:
raise ValueError("No repo saved in checkpoint")
# Deserialize the repository from the checkpoint
repo = deserialize_repository(checkpoint.cached_repo, self.github_client)
cursor_url_callback = make_cursor_url_callback(checkpoint)
repo_external_access: ExternalAccess | None = None
if include_permissions:
repo_external_access = get_external_access_permission(
repo, self.github_client
)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logging.info(f"Fetching PRs for repo: {repo.name}")
pr_batch = _get_batch_rate_limited(
self._pull_requests_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
done_with_prs = False
num_prs = 0
pr = None
print("start: ", start)
for pr in pr_batch:
num_prs += 1
print("-"*40)
print("PR name", pr.title)
print("updated at", pr.updated_at)
print("-"*40)
print("\n")
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) <= start
):
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logging.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
# If we reach this point with a cursor url in the checkpoint, we were using
# the fallback cursor-based pagination strategy. That strategy tries to get all
# PRs, so having curosr_url set means we are done with prs. However, we need to
# return AFTER the checkpoint reset to avoid infinite loops.
# if we found any PRs on the page and there are more PRs to get, return the checkpoint.
# In offset mode, while indexing without time constraints, the pr batch
# will be empty when we're done.
used_cursor = checkpoint.cursor_url is not None
if num_prs > 0 and not done_with_prs and not used_cursor:
return checkpoint
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.reset()
if used_cursor:
# save the checkpoint after changing stage; next run will continue from issues
return checkpoint
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logging.info(f"Fetching issues for repo: {repo.name}")
issue_batch = list(
_get_batch_rate_limited(
self._issues_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
)
checkpoint.curr_page += 1
done_with_issues = False
num_issues = 0
for issue in issue_batch:
num_issues += 1
issue = cast(Issue, issue)
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) <= start
):
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logging.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, and we're not done, return the checkpoint.
# don't return if we're using cursor-based pagination to avoid infinite loops
if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url:
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.reset()
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
checkpoint.cached_repo = SerializedRepository(
id=next_id,
headers=next_repo.raw_headers,
raw_data=next_repo.raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.reset()
if checkpoint.cached_repo_ids:
logging.info(
f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})"
)
else:
logging.info("No more repos remaining")
return checkpoint
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
include_permissions: bool = False,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
# add a day for timezone safety
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + ONE_DAY
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
# Could be due to delayed processing on GitHub side
# The non-updated issues since last poll will be shortcut-ed and not embedded
# adjusted_start_datetime = start_datetime - timedelta(hours=3)
adjusted_start_datetime = start_datetime
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
if adjusted_start_datetime < epoch:
adjusted_start_datetime = epoch
return self._fetch_from_github(
checkpoint,
start=adjusted_start_datetime,
end=end_datetime,
include_permissions=include_permissions,
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=False
)
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=True
)
def validate_connector_settings(self) -> None:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
if not self.repo_owner:
raise ConnectorValidationError(
"Invalid connector settings: 'repo_owner' must be provided."
)
try:
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repo_names = [name.strip() for name in self.repositories.split(",")]
if not repo_names:
raise ConnectorValidationError(
"Invalid connector settings: No valid repository names provided."
)
# Validate at least one repository exists and is accessible
valid_repos = False
validation_errors = []
for repo_name in repo_names:
if not repo_name:
continue
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
logging.info(
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
break
except GithubException as e:
validation_errors.append(
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
)
if not valid_repos:
error_msg = (
"None of the specified repositories could be accessed: "
)
error_msg += ", ".join(validation_errors)
raise ConnectorValidationError(error_msg)
else:
# Single repository (backward compatibility)
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repositories}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
org = self.github_client.get_organization(self.repo_owner)
total_count = org.get_repos().totalCount
if total_count == 0:
raise ConnectorValidationError(
f"Found no repos for organization: {self.repo_owner}. "
"Does the credential have the right scopes?"
)
except GithubException as e:
# Check for missing SSO
MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower()
if MISSING_SSO_ERROR_MESSAGE in str(e).lower():
SSO_GUIDE_LINK = (
"https://docs.github.com/en/enterprise-cloud@latest/authentication/"
"authenticating-with-saml-single-sign-on/"
"authorizing-a-personal-access-token-for-use-with-saml-single-sign-on"
)
raise ConnectorValidationError(
f"Your GitHub token is missing authorization to access the "
f"`{self.repo_owner}` organization. Please follow the guide to "
f"authorize your token: {SSO_GUIDE_LINK}"
)
# If not an org, try as a user
user = self.github_client.get_user(self.repo_owner)
# Check if we can access any repos
total_count = user.get_repos().totalCount
if total_count == 0:
raise ConnectorValidationError(
f"Found no repos for user: {self.repo_owner}. "
"Does the credential have the right scopes?"
)
except RateLimitExceededException:
raise UnexpectedValidationError(
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
)
except GithubException as e:
if e.status == 401:
raise CredentialExpiredError(
"GitHub credential appears to be invalid or expired (HTTP 401)."
)
elif e.status == 403:
raise InsufficientPermissionsError(
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repositories:
if "," in self.repositories:
raise ConnectorValidationError(
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
)
def validate_checkpoint_json(
self, checkpoint_json: str
) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0
)
if __name__ == "__main__":
# Initialize the connector
connector = GithubConnector(
repo_owner="EvoAgentX",
repositories="EvoAgentX",
include_issues=True,
include_prs=False,
)
connector.load_credentials(
{"github_access_token": "<Your_GitHub_Access_Token>"}
)
if connector.github_client:
get_external_access_permission(
connector.get_github_repos(connector.github_client).pop(),
connector.github_client,
)
# Create a time range from epoch to now
end_time = datetime.now(timezone.utc)
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
time_range = (start_time, end_time)
# Initialize the runner with a batch size of 10
runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner(
connector, batch_size=10, include_permissions=False, time_range=time_range
)
# Get initial checkpoint
checkpoint = connector.build_dummy_checkpoint()
# Run the connector
while checkpoint.has_more:
for doc_batch, failure, next_checkpoint in runner.run(checkpoint):
if doc_batch:
print(f"Retrieved batch of {len(doc_batch)} documents")
for doc in doc_batch:
print(f"Document: {doc.semantic_identifier}")
if failure:
print(f"Failure: {failure.failure_message}")
if next_checkpoint:
checkpoint = next_checkpoint

View File

@ -0,0 +1,17 @@
from typing import Any
from github import Repository
from github.Requester import Requester
from pydantic import BaseModel
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)

View File

@ -0,0 +1,24 @@
import time
import logging
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from github import Github
def sleep_after_rate_limit_exception(github_client: Github) -> None:
"""
Sleep until the GitHub rate limit resets.
Args:
github_client: The GitHub client that hit the rate limit
"""
sleep_time = github_client.get_rate_limit().core.reset.replace(
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logging.info(
"Ran into Github rate-limit. Sleeping %s seconds.", sleep_time.seconds
)
time.sleep(sleep_time.total_seconds())

View File

@ -0,0 +1,44 @@
import logging
from github import Github
from github.Repository import Repository
from common.data_source.models import ExternalAccess
from .models import SerializedRepository
def get_external_access_permission(
repo: Repository, github_client: Github
) -> ExternalAccess:
"""
Get the external access permission for a repository.
This functionality requires Enterprise Edition.
"""
# RAGFlow doesn't implement the Onyx EE external-permissions system.
# Default to private/unknown permissions.
return ExternalAccess.empty()
def deserialize_repository(
cached_repo: SerializedRepository, github_client: Github
) -> Repository:
"""
Deserialize a SerializedRepository back into a Repository object.
"""
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try to get the requester using getattr to avoid linter errors
requester = getattr(github_client, "_requester", None)
if requester is None:
requester = getattr(github_client, "_Github__requester", None)
if requester is None:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
return cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logging.warning("Failed to deserialize repository: %s. Attempting to re-fetch.", e)
repo_id = cached_repo.id
return github_client.get_repo(repo_id)

View File

@ -0,0 +1,340 @@
import fnmatch
import itertools
from collections import deque
from collections.abc import Iterable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TypeVar
import gitlab
from gitlab.v4.objects import Project
from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.exceptions import ConnectorValidationError
from common.data_source.exceptions import CredentialExpiredError
from common.data_source.exceptions import InsufficientPermissionsError
from common.data_source.exceptions import UnexpectedValidationError
from common.data_source.interfaces import GenerateDocumentsOutput
from common.data_source.interfaces import LoadConnector
from common.data_source.interfaces import PollConnector
from common.data_source.interfaces import SecondsSinceUnixEpoch
from common.data_source.models import BasicExpertInfo
from common.data_source.models import Document
from common.data_source.utils import get_file_ext
T = TypeVar("T")
# List of directories/Files to exclude
exclude_patterns = [
"logs",
".github/",
".gitlab/",
".pre-commit-config.yaml",
]
def _batch_gitlab_objects(git_objs: Iterable[T], batch_size: int) -> Iterator[list[T]]:
it = iter(git_objs)
while True:
batch = list(itertools.islice(it, batch_size))
if not batch:
break
yield batch
def get_author(author: Any) -> BasicExpertInfo:
return BasicExpertInfo(
display_name=author.get("name"),
)
def _convert_merge_request_to_document(mr: Any) -> Document:
mr_text = mr.description or ""
doc = Document(
id=mr.web_url,
blob=mr_text,
source=DocumentSource.GITLAB,
semantic_identifier=mr.title,
extension=".md",
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=mr.updated_at.replace(tzinfo=timezone.utc),
size_bytes=len(mr_text.encode("utf-8")),
primary_owners=[get_author(mr.author)],
metadata={"state": mr.state, "type": "MergeRequest", "web_url": mr.web_url},
)
return doc
def _convert_issue_to_document(issue: Any) -> Document:
issue_text = issue.description or ""
doc = Document(
id=issue.web_url,
blob=issue_text,
source=DocumentSource.GITLAB,
semantic_identifier=issue.title,
extension=".md",
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
size_bytes=len(issue_text.encode("utf-8")),
primary_owners=[get_author(issue.author)],
metadata={
"state": issue.state,
"type": issue.type if issue.type else "Issue",
"web_url": issue.web_url,
},
)
return doc
def _convert_code_to_document(
project: Project, file: Any, url: str, projectName: str, projectOwner: str
) -> Document:
# Dynamically get the default branch from the project object
default_branch = project.default_branch
# Fetch the file content using the correct branch
file_content_obj = project.files.get(
file_path=file["path"], ref=default_branch # Use the default branch
)
# BoxConnector uses raw bytes for blob. Keep the same here.
file_content_bytes = file_content_obj.decode()
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/{default_branch}/{file['path']}"
# Try to use the last commit timestamp for incremental sync.
# Falls back to "now" if the commit lookup fails.
last_commit_at = None
try:
# Query commit history for this file on the default branch.
commits = project.commits.list(
ref_name=default_branch,
path=file["path"],
per_page=1,
)
if commits:
# committed_date is ISO string like "2024-01-01T00:00:00.000+00:00"
committed_date = commits[0].committed_date
if isinstance(committed_date, str):
last_commit_at = datetime.strptime(
committed_date, "%Y-%m-%dT%H:%M:%S.%f%z"
).astimezone(timezone.utc)
elif isinstance(committed_date, datetime):
last_commit_at = committed_date.astimezone(timezone.utc)
except Exception:
last_commit_at = None
# Create and return a Document object
doc = Document(
# Use a stable ID so reruns don't create duplicates.
id=file_url,
blob=file_content_bytes,
source=DocumentSource.GITLAB,
semantic_identifier=file.get("name"),
extension=get_file_ext(file.get("name")),
doc_updated_at=last_commit_at or datetime.now(tz=timezone.utc),
size_bytes=len(file_content_bytes) if file_content_bytes is not None else 0,
primary_owners=[], # Add owners if needed
metadata={
"type": "CodeFile",
"path": file.get("path"),
"ref": default_branch,
"project": f"{projectOwner}/{projectName}",
"web_url": file_url,
},
)
return doc
def _should_exclude(path: str) -> bool:
"""Check if a path matches any of the exclude patterns."""
return any(fnmatch.fnmatch(path, pattern) for pattern in exclude_patterns)
class GitlabConnector(LoadConnector, PollConnector):
def __init__(
self,
project_owner: str,
project_name: str,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_mrs: bool = True,
include_issues: bool = True,
include_code_files: bool = False,
) -> None:
self.project_owner = project_owner
self.project_name = project_name
self.batch_size = batch_size
self.state_filter = state_filter
self.include_mrs = include_mrs
self.include_issues = include_issues
self.include_code_files = include_code_files
self.gitlab_client: gitlab.Gitlab | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.gitlab_client = gitlab.Gitlab(
credentials["gitlab_url"], private_token=credentials["gitlab_access_token"]
)
return None
def validate_connector_settings(self) -> None:
if self.gitlab_client is None:
raise ConnectorMissingCredentialError("GitLab")
try:
self.gitlab_client.auth()
self.gitlab_client.projects.get(
f"{self.project_owner}/{self.project_name}",
lazy=True,
)
except gitlab.exceptions.GitlabAuthenticationError as e:
raise CredentialExpiredError(
"Invalid or expired GitLab credentials."
) from e
except gitlab.exceptions.GitlabAuthorizationError as e:
raise InsufficientPermissionsError(
"Insufficient permissions to access GitLab resources."
) from e
except gitlab.exceptions.GitlabGetError as e:
raise ConnectorValidationError(
"GitLab project not found or not accessible."
) from e
except Exception as e:
raise UnexpectedValidationError(
f"Unexpected error while validating GitLab settings: {e}"
) from e
def _fetch_from_gitlab(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.gitlab_client is None:
raise ConnectorMissingCredentialError("Gitlab")
project: Project = self.gitlab_client.projects.get(
f"{self.project_owner}/{self.project_name}"
)
start_utc = start.astimezone(timezone.utc) if start else None
end_utc = end.astimezone(timezone.utc) if end else None
# Fetch code files
if self.include_code_files:
# Fetching using BFS as project.report_tree with recursion causing slow load
queue = deque([""]) # Start with the root directory
while queue:
current_path = queue.popleft()
files = project.repository_tree(path=current_path, all=True)
for file_batch in _batch_gitlab_objects(files, self.batch_size):
code_doc_batch: list[Document] = []
for file in file_batch:
if _should_exclude(file["path"]):
continue
if file["type"] == "blob":
doc = _convert_code_to_document(
project,
file,
self.gitlab_client.url,
self.project_name,
self.project_owner,
)
# Apply incremental window filtering for code files too.
if start_utc is not None and doc.doc_updated_at <= start_utc:
continue
if end_utc is not None and doc.doc_updated_at > end_utc:
continue
code_doc_batch.append(doc)
elif file["type"] == "tree":
queue.append(file["path"])
if code_doc_batch:
yield code_doc_batch
if self.include_mrs:
merge_requests = project.mergerequests.list(
state=self.state_filter,
order_by="updated_at",
sort="desc",
iterator=True,
)
for mr_batch in _batch_gitlab_objects(merge_requests, self.batch_size):
mr_doc_batch: list[Document] = []
for mr in mr_batch:
mr.updated_at = datetime.strptime(
mr.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
)
if start_utc is not None and mr.updated_at <= start_utc:
yield mr_doc_batch
return
if end_utc is not None and mr.updated_at > end_utc:
continue
mr_doc_batch.append(_convert_merge_request_to_document(mr))
yield mr_doc_batch
if self.include_issues:
issues = project.issues.list(state=self.state_filter, iterator=True)
for issue_batch in _batch_gitlab_objects(issues, self.batch_size):
issue_doc_batch: list[Document] = []
for issue in issue_batch:
issue.updated_at = datetime.strptime(
issue.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
)
# Avoid re-syncing the last-seen item.
if start_utc is not None and issue.updated_at <= start_utc:
yield issue_doc_batch
return
if end_utc is not None and issue.updated_at > end_utc:
continue
issue_doc_batch.append(_convert_issue_to_document(issue))
yield issue_doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_gitlab()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
return self._fetch_from_gitlab(start_datetime, end_datetime)
if __name__ == "__main__":
import os
connector = GitlabConnector(
# gitlab_url="https://gitlab.com/api/v4",
project_owner=os.environ["PROJECT_OWNER"],
project_name=os.environ["PROJECT_NAME"],
batch_size=INDEX_BATCH_SIZE,
state_filter="all",
include_mrs=True,
include_issues=True,
include_code_files=True,
)
connector.load_credentials(
{
"gitlab_access_token": os.environ["GITLAB_ACCESS_TOKEN"],
"gitlab_url": os.environ["GITLAB_URL"],
}
)
document_batches = connector.load_from_state()
for f in document_batches:
print("Batch:", f)
print("Finished loading from state.")

View File

@ -191,7 +191,7 @@ def get_credentials_from_env(email: str, oauth: bool = False, source="drive") ->
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded", DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
} }
def sanitize_filename(name: str) -> str: def sanitize_filename(name: str, extension: str = "txt") -> str:
""" """
Soft sanitize for MinIO/S3: Soft sanitize for MinIO/S3:
- Replace only prohibited characters with a space. - Replace only prohibited characters with a space.
@ -199,7 +199,7 @@ def sanitize_filename(name: str) -> str:
- Collapse multiple spaces. - Collapse multiple spaces.
""" """
if name is None: if name is None:
return "file.txt" return f"file.{extension}"
name = str(name).strip() name = str(name).strip()
@ -222,9 +222,8 @@ def sanitize_filename(name: str) -> str:
base, ext = os.path.splitext(name) base, ext = os.path.splitext(name)
name = base[:180].rstrip() + ext name = base[:180].rstrip() + ext
# Ensure there is an extension (your original logic)
if not os.path.splitext(name)[1]: if not os.path.splitext(name)[1]:
name += ".txt" name += f".{extension}"
return name return name

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from enum import IntFlag, auto from enum import IntFlag, auto
from types import TracebackType from types import TracebackType
from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias
from collections.abc import Iterator
from anthropic import BaseModel from anthropic import BaseModel
from common.data_source.models import ( from common.data_source.models import (
@ -16,6 +16,7 @@ from common.data_source.models import (
SecondsSinceUnixEpoch, GenerateSlimDocumentOutput SecondsSinceUnixEpoch, GenerateSlimDocumentOutput
) )
GenerateDocumentsOutput = Iterator[list[Document]]
class LoadConnector(ABC): class LoadConnector(ABC):
"""Load connector interface""" """Load connector interface"""
@ -236,16 +237,13 @@ class BaseConnector(abc.ABC, Generic[CT]):
def validate_perm_sync(self) -> None: def validate_perm_sync(self) -> None:
""" """
Don't override this; add a function to perm_sync_valid.py in the ee package Permission-sync validation hook.
to do permission sync validation
RAGFlow doesn't ship the Onyx EE permission-sync validation package.
Connectors that support permission sync should override
`validate_connector_settings()` as needed.
""" """
""" return None
validate_connector_settings_fn = fetch_ee_implementation_or_noop(
"onyx.connectors.perm_sync_valid",
"validate_perm_sync",
noop_return_value=None,
)
validate_connector_settings_fn(self)"""
def set_allow_images(self, value: bool) -> None: def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading """Implement if the underlying connector wants to skip/allow image downloading
@ -344,6 +342,17 @@ class CheckpointOutputWrapper(Generic[CT]):
yield None, None, self.next_checkpoint yield None, None, self.next_checkpoint
class CheckpointedConnectorWithPermSyncGH(CheckpointedConnector[CT]):
@abc.abstractmethod
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CT,
) -> CheckpointOutput[CT]:
raise NotImplementedError
# Slim connectors retrieve just the ids of documents # Slim connectors retrieve just the ids of documents
class SlimConnector(BaseConnector): class SlimConnector(BaseConnector):
@abc.abstractmethod @abc.abstractmethod

View File

@ -94,8 +94,10 @@ class Document(BaseModel):
blob: bytes blob: bytes
doc_updated_at: datetime doc_updated_at: datetime
size_bytes: int size_bytes: int
externale_access: Optional[ExternalAccess] = None
primary_owners: Optional[list] = None primary_owners: Optional[list] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None
doc_metadata: Optional[dict[str, Any]] = None
class BasicExpertInfo(BaseModel): class BasicExpertInfo(BaseModel):

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import ast
import logging import logging
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
@ -49,8 +50,8 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
try: try:
if isinstance(input, list): if isinstance(input, list):
input = input[0] input = input[0]
input = float(input) input = ast.literal_eval(input)
value = float(value) value = ast.literal_eval(value)
except Exception: except Exception:
pass pass
if isinstance(input, str): if isinstance(input, str):
@ -58,28 +59,41 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
if isinstance(value, str): if isinstance(value, str):
value = value.lower() value = value.lower()
for conds in [ matched = False
(operator == "contains", input in value if not isinstance(input, list) else all([i in value for i in input])), try:
(operator == "not contains", input not in value if not isinstance(input, list) else all([i not in value for i in input])), if operator == "contains":
(operator == "in", input in value if not isinstance(input, list) else all([i in value for i in input])), matched = input in value if not isinstance(input, list) else all(i in value for i in input)
(operator == "not in", input not in value if not isinstance(input, list) else all([i not in value for i in input])), elif operator == "not contains":
(operator == "start with", str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())), matched = input not in value if not isinstance(input, list) else all(i not in value for i in input)
(operator == "end with", str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())), elif operator == "in":
(operator == "empty", not input), matched = input in value if not isinstance(input, list) else all(i in value for i in input)
(operator == "not empty", input), elif operator == "not in":
(operator == "=", input == value), matched = input not in value if not isinstance(input, list) else all(i not in value for i in input)
(operator == "", input != value), elif operator == "start with":
(operator == ">", input > value), matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())
(operator == "<", input < value), elif operator == "end with":
(operator == "", input >= value), matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())
(operator == "", input <= value), elif operator == "empty":
]: matched = not input
try: elif operator == "not empty":
if all(conds): matched = bool(input)
ids.extend(docids) elif operator == "=":
break matched = input == value
except Exception: elif operator == "":
pass matched = input != value
elif operator == ">":
matched = input > value
elif operator == "<":
matched = input < value
elif operator == "":
matched = input >= value
elif operator == "":
matched = input <= value
except Exception:
pass
if matched:
ids.extend(docids)
return ids return ids
for k, v2docs in metas.items(): for k, v2docs in metas.items():

View File

@ -334,6 +334,9 @@ def init_settings():
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4)) DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16)) EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
os.environ["DOTNET_SYSTEM_GLOBALIZATION_INVARIANT"] = "1"
def check_and_install_torch(): def check_and_install_torch():
global PARALLEL_DEVICES global PARALLEL_DEVICES
try: try:

View File

@ -78,14 +78,21 @@ class DoclingParser(RAGFlowPdfParser):
def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None): def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None):
self.page_from = page_from self.page_from = page_from
self.page_to = page_to self.page_to = page_to
bytes_io = None
try: try:
opener = pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) if not isinstance(fnm, (str, PathLike)):
bytes_io = BytesIO(fnm)
opener = pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(bytes_io)
with opener as pdf: with opener as pdf:
pages = pdf.pages[page_from:page_to] pages = pdf.pages[page_from:page_to]
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for p in pages] self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for p in pages]
except Exception as e: except Exception as e:
self.page_images = [] self.page_images = []
self.logger.exception(e) self.logger.exception(e)
finally:
if bytes_io:
bytes_io.close()
def _make_line_tag(self,bbox: _BBox) -> str: def _make_line_tag(self,bbox: _BBox) -> str:
if bbox is None: if bbox is None:

View File

@ -1061,8 +1061,8 @@ class RAGFlowPdfParser:
self.total_page = len(self.pdf.pages) self.total_page = len(self.pdf.pages)
except Exception: except Exception as e:
logging.exception("RAGFlowPdfParser __images__") logging.exception(f"RAGFlowPdfParser __images__, exception: {e}")
logging.info(f"__images__ dedupe_chars cost {timer() - start}s") logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
self.outlines = [] self.outlines = []

View File

@ -16,6 +16,7 @@
import logging import logging
import sys import sys
import ast
import six import six
import cv2 import cv2
import numpy as np import numpy as np
@ -108,7 +109,14 @@ class NormalizeImage:
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str): if isinstance(scale, str):
scale = eval(scale) try:
scale = float(scale)
except ValueError:
if '/' in scale:
parts = scale.split('/')
scale = ast.literal_eval(parts[0]) / ast.literal_eval(parts[1])
else:
scale = ast.literal_eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406] mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225] std = std if std is not None else [0.229, 0.224, 0.225]

View File

@ -1,3 +1,10 @@
# -----------------------------------------------------------------------------
# SECURITY WARNING: DO NOT DEPLOY WITH DEFAULT PASSWORDS
# For non-local deployments, please change all passwords (ELASTIC_PASSWORD,
# MYSQL_PASSWORD, MINIO_PASSWORD, etc.) to strong, unique values.
# You can generate a random string using: openssl rand -hex 32
# -----------------------------------------------------------------------------
# ------------------------------ # ------------------------------
# docker env var for specifying vector db type at startup # docker env var for specifying vector db type at startup
# (based on the vector db type, the corresponding docker # (based on the vector db type, the corresponding docker
@ -30,6 +37,7 @@ ES_HOST=es01
ES_PORT=1200 ES_PORT=1200
# The password for Elasticsearch. # The password for Elasticsearch.
# WARNING: Change this for production!
ELASTIC_PASSWORD=infini_rag_flow ELASTIC_PASSWORD=infini_rag_flow
# the hostname where OpenSearch service is exposed, set it not the same as elasticsearch # the hostname where OpenSearch service is exposed, set it not the same as elasticsearch
@ -85,6 +93,7 @@ OB_DATAFILE_SIZE=${OB_DATAFILE_SIZE:-20G}
OB_LOG_DISK_SIZE=${OB_LOG_DISK_SIZE:-20G} OB_LOG_DISK_SIZE=${OB_LOG_DISK_SIZE:-20G}
# The password for MySQL. # The password for MySQL.
# WARNING: Change this for production!
MYSQL_PASSWORD=infini_rag_flow MYSQL_PASSWORD=infini_rag_flow
# The hostname where the MySQL service is exposed # The hostname where the MySQL service is exposed
MYSQL_HOST=mysql MYSQL_HOST=mysql

View File

@ -34,7 +34,7 @@ Enabling TOC extraction requires significant memory, computational resources, an
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/auto_metadata_settings.png) ![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/auto_metadata_settings.png)
3. Click **+** to add new fields and enter the congiruation page. 3. Click **+** to add new fields and enter the configuration page.
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/metadata_field_settings.png) ![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/metadata_field_settings.png)

View File

@ -340,13 +340,13 @@ Application startup complete.
setting->model providers->search->vllm->add ,configure as follow: setting->model providers->search->vllm->add ,configure as follow:
![add vllm](https://github.com/user-attachments/assets/6f1d9f1a-3507-465b-87a3-4427254fff86) ![add vllm](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/ragflow_vllm.png)
select vllm chat model as default llm model as follow: select vllm chat model as default llm model as follow:
![chat](https://github.com/user-attachments/assets/05efbd4b-2c18-4c6b-8d1c-52bae712372d) ![chat](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/ragflow_vllm1.png)
### 5.3 chat with vllm chat model ### 5.3 chat with vllm chat model
create chat->create conversations-chat as follow: create chat->create conversations-chat as follow:
![chat](https://github.com/user-attachments/assets/dc1885f6-23a9-48f1-8850-d5f59b5e8f67) ![chat](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/ragflow_vllm2.png)

View File

@ -1603,7 +1603,7 @@ In streaming mode, not all responses include a reference, as this depends on the
##### question: `str` ##### question: `str`
The question to start an AI-powered conversation. Ifthe **Begin** component takes parameters, a question is not required. The question to start an AI-powered conversation. If the **Begin** component takes parameters, a question is not required.
##### stream: `bool` ##### stream: `bool`

View File

@ -9,7 +9,7 @@ Key features, improvements and bug fixes in the latest releases.
## v0.23.0 ## v0.23.0
Released on December 29, 2025. Released on December 27, 2025.
### New features ### New features
@ -32,7 +32,7 @@ Released on December 29, 2025.
### Improvements ### Improvements
- Bumps RAGFlow's document engine, [Infinity](https://github.com/infiniflow/infinity) to v0.6.13 (backward compatible). - Bumps RAGFlow's document engine, [Infinity](https://github.com/infiniflow/infinity) to v0.6.15 (backward compatible).
### Data sources ### Data sources

133
helm/README.md Normal file
View File

@ -0,0 +1,133 @@
# RAGFlow Helm Chart
A Helm chart to deploy RAGFlow and its dependencies on Kubernetes.
- Components: RAGFlow (web/api) and optional dependencies (Infinity/Elasticsearch/OpenSearch, MySQL, MinIO, Redis)
- Requirements: Kubernetes >= 1.24, Helm >= 3.10
## Install
```bash
helm upgrade --install ragflow ./ \
--namespace ragflow --create-namespace
```
Uninstall:
```bash
helm uninstall ragflow -n ragflow
```
## Global Settings
- `global.repo`: Prepend a global image registry prefix for all images.
- Behavior: Replaces the registry part and keeps the image path (e.g., `quay.io/minio/minio` -> `registry.example.com/myproj/minio/minio`).
- Example: `global.repo: "registry.example.com/myproj"`
- `global.imagePullSecrets`: List of image pull secrets applied to all Pods.
- Example:
```yaml
global:
imagePullSecrets:
- name: regcred
```
## External Services (MySQL / MinIO / Redis)
The chart can deploy in-cluster services or connect to external ones. Toggle with `*.enabled`. When disabled, provide host/port via `env.*`.
- MySQL
- `mysql.enabled`: default `true`
- If `false`, set:
- `env.MYSQL_HOST` (required), `env.MYSQL_PORT` (default `3306`)
- `env.MYSQL_DBNAME` (default `rag_flow`), `env.MYSQL_PASSWORD` (required)
- `env.MYSQL_USER` (default `root` if omitted)
- MinIO
- `minio.enabled`: default `true`
- Configure:
- `env.MINIO_HOST` (optional external host), `env.MINIO_PORT` (default `9000`)
- `env.MINIO_ROOT_USER` (default `rag_flow`), `env.MINIO_PASSWORD` (optional)
- Redis (Valkey)
- `redis.enabled`: default `true`
- If `false`, set:
- `env.REDIS_HOST` (required), `env.REDIS_PORT` (default `6379`)
- `env.REDIS_PASSWORD` (optional; empty disables auth if server allows)
Notes:
- When `*.enabled=true`, the chart renders in-cluster resources and injects corresponding `*_HOST`/`*_PORT` automatically.
- Sensitive variables like `MYSQL_PASSWORD` are required; `MINIO_PASSWORD` and `REDIS_PASSWORD` are optional. All secrets are stored in a Secret.
### Example: use external MySQL, MinIO, and Redis
```yaml
# values.override.yaml
mysql:
enabled: false # use external MySQL
minio:
enabled: false # use external MinIO (S3 compatible)
redis:
enabled: false # use external Redis/Valkey
env:
# MySQL
MYSQL_HOST: mydb.example.com
MYSQL_PORT: "3306"
MYSQL_USER: root
MYSQL_DBNAME: rag_flow
MYSQL_PASSWORD: "<your-mysql-password>"
# MinIO
MINIO_HOST: s3.example.com
MINIO_PORT: "9000"
MINIO_ROOT_USER: rag_flow
MINIO_PASSWORD: "<your-minio-secret>"
# Redis
REDIS_HOST: redis.example.com
REDIS_PORT: "6379"
REDIS_PASSWORD: "<your-redis-pass>"
```
Apply:
```bash
helm upgrade --install ragflow ./helm -n ragflow -f values.override.yaml
```
## Document Engine Selection
Choose one of `infinity` (default), `elasticsearch`, or `opensearch` via `env.DOC_ENGINE`. The chart renders only the selected engine and sets the appropriate host variables.
```yaml
env:
DOC_ENGINE: infinity # or: elasticsearch | opensearch
# For elasticsearch
ELASTIC_PASSWORD: "<es-pass>"
# For opensearch
OPENSEARCH_PASSWORD: "<os-pass>"
```
## Ingress
Expose the web UI via Ingress:
```yaml
ingress:
enabled: true
className: nginx
hosts:
- host: ragflow.example.com
paths:
- path: /
pathType: Prefix
```
## Validate the Chart
```bash
helm lint ./helm
helm template ragflow ./helm > rendered.yaml
```
## Notes
- By default, the chart uses `DOC_ENGINE: infinity` and deploys in-cluster MySQL, MinIO, and Redis.
- The chart injects derived `*_HOST`/`*_PORT` and required secrets into a single Secret (`<release>-ragflow-env-config`).
- `global.repo` and `global.imagePullSecrets` apply to all Pods; per-component `*.image.pullSecrets` still work and are merged with global settings.

View File

@ -42,6 +42,31 @@ app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
app.kubernetes.io/managed-by: {{ .Release.Service }} app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }} {{- end }}
{{/*
Resolve image repository with optional global repo prefix.
If .Values.global.repo is set, replace registry part and keep image path.
Detect existing registry by first segment containing '.' or ':' or being 'localhost'.
Usage: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.foo.image.repository) }}
*/}}
{{- define "ragflow.imageRepo" -}}
{{- $root := .root -}}
{{- $repo := .repo -}}
{{- $global := $root.Values.global -}}
{{- if and $global $global.repo }}
{{- $parts := splitList "/" $repo -}}
{{- $first := index $parts 0 -}}
{{- $hasRegistry := or (regexMatch "\\." $first) (regexMatch ":" $first) (eq $first "localhost") -}}
{{- if $hasRegistry -}}
{{- $path := join "/" (rest $parts) -}}
{{- printf "%s/%s" $global.repo $path -}}
{{- else -}}
{{- printf "%s/%s" $global.repo $repo -}}
{{- end -}}
{{- else -}}
{{- $repo -}}
{{- end -}}
{{- end }}
{{/* {{/*
Selector labels Selector labels
*/}} */}}

View File

@ -32,7 +32,7 @@ spec:
{{- include "ragflow.selectorLabels" . | nindent 6 }} {{- include "ragflow.selectorLabels" . | nindent 6 }}
app.kubernetes.io/component: elasticsearch app.kubernetes.io/component: elasticsearch
{{- with .Values.elasticsearch.deployment.strategy }} {{- with .Values.elasticsearch.deployment.strategy }}
strategy: updateStrategy:
{{- . | toYaml | nindent 4 }} {{- . | toYaml | nindent 4 }}
{{- end }} {{- end }}
template: template:
@ -44,9 +44,9 @@ spec:
checksum/config-es: {{ include (print $.Template.BasePath "/elasticsearch-config.yaml") . | sha256sum }} checksum/config-es: {{ include (print $.Template.BasePath "/elasticsearch-config.yaml") . | sha256sum }}
checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }} checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }}
spec: spec:
{{- if or .Values.imagePullSecrets .Values.elasticsearch.image.pullSecrets }} {{- if or .Values.global.imagePullSecrets .Values.elasticsearch.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{- with .Values.imagePullSecrets }} {{- with .Values.global.imagePullSecrets }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.elasticsearch.image.pullSecrets }} {{- with .Values.elasticsearch.image.pullSecrets }}
@ -55,7 +55,7 @@ spec:
{{- end }} {{- end }}
initContainers: initContainers:
- name: fix-data-volume-permissions - name: fix-data-volume-permissions
image: {{ .Values.elasticsearch.initContainers.alpine.repository }}:{{ .Values.elasticsearch.initContainers.alpine.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.elasticsearch.initContainers.alpine.repository) }}:{{ .Values.elasticsearch.initContainers.alpine.tag }}
{{- with .Values.elasticsearch.initContainers.alpine.pullPolicy }} {{- with .Values.elasticsearch.initContainers.alpine.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}
@ -67,7 +67,7 @@ spec:
- mountPath: /usr/share/elasticsearch/data - mountPath: /usr/share/elasticsearch/data
name: es-data name: es-data
- name: sysctl - name: sysctl
image: {{ .Values.elasticsearch.initContainers.busybox.repository }}:{{ .Values.elasticsearch.initContainers.busybox.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.elasticsearch.initContainers.busybox.repository) }}:{{ .Values.elasticsearch.initContainers.busybox.tag }}
{{- with .Values.elasticsearch.initContainers.busybox.pullPolicy }} {{- with .Values.elasticsearch.initContainers.busybox.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}
@ -77,7 +77,7 @@ spec:
command: ["sysctl", "-w", "vm.max_map_count=262144"] command: ["sysctl", "-w", "vm.max_map_count=262144"]
containers: containers:
- name: elasticsearch - name: elasticsearch
image: {{ .Values.elasticsearch.image.repository }}:{{ .Values.elasticsearch.image.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.elasticsearch.image.repository) }}:{{ .Values.elasticsearch.image.tag }}
{{- with .Values.elasticsearch.image.pullPolicy }} {{- with .Values.elasticsearch.image.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}

View File

@ -9,20 +9,39 @@ metadata:
type: Opaque type: Opaque
stringData: stringData:
{{- range $key, $val := .Values.env }} {{- range $key, $val := .Values.env }}
{{- if $val }} {{- if and $val (ne $key "MYSQL_HOST") (ne $key "MYSQL_PORT") (ne $key "MYSQL_USER") (ne $key "MINIO_HOST") (ne $key "MINIO_PORT") (ne $key "REDIS_HOST") (ne $key "REDIS_PORT") }}
{{ $key }}: {{ quote $val }} {{ $key }}: {{ quote $val }}
{{- end }} {{- end }}
{{- end }} {{- end }}
{{- /* {{- /*
Use host names derived from internal cluster DNS Use host names derived from internal cluster DNS
*/}} */}}
{{- if .Values.redis.enabled }}
REDIS_HOST: {{ printf "%s-redis.%s.svc" (include "ragflow.fullname" .) .Release.Namespace }} REDIS_HOST: {{ printf "%s-redis.%s.svc" (include "ragflow.fullname" .) .Release.Namespace }}
REDIS_PORT: "6379"
{{- else }}
REDIS_HOST: {{ required "env.REDIS_HOST is required when redis.enabled=false" .Values.env.REDIS_HOST | quote }}
REDIS_PORT: {{ default "6379" .Values.env.REDIS_PORT | quote }}
{{- end }}
{{- if .Values.mysql.enabled }}
MYSQL_HOST: {{ printf "%s-mysql.%s.svc" (include "ragflow.fullname" .) .Release.Namespace }} MYSQL_HOST: {{ printf "%s-mysql.%s.svc" (include "ragflow.fullname" .) .Release.Namespace }}
MYSQL_PORT: "3306"
{{- else }}
MYSQL_HOST: {{ required "env.MYSQL_HOST is required when mysql.enabled=false" .Values.env.MYSQL_HOST | quote }}
MYSQL_PORT: {{ default "3306" .Values.env.MYSQL_PORT | quote }}
MYSQL_USER: {{ default "root" .Values.env.MYSQL_USER | quote }}
{{- end }}
{{- if .Values.minio.enabled }}
MINIO_HOST: {{ printf "%s-minio.%s.svc" (include "ragflow.fullname" .) .Release.Namespace }} MINIO_HOST: {{ printf "%s-minio.%s.svc" (include "ragflow.fullname" .) .Release.Namespace }}
MINIO_PORT: "9000"
{{- else }}
MINIO_HOST: {{ default "" .Values.env.MINIO_HOST | quote }}
MINIO_PORT: {{ default "9000" .Values.env.MINIO_PORT | quote }}
{{- end }}
{{- /* {{- /*
Fail if passwords are not provided in release values Fail if passwords are not provided in release values
*/}} */}}
REDIS_PASSWORD: {{ .Values.env.REDIS_PASSWORD | required "REDIS_PASSWORD is required" }} REDIS_PASSWORD: {{ default "" .Values.env.REDIS_PASSWORD }}
{{- /* {{- /*
NOTE: MySQL uses MYSQL_ROOT_PASSWORD env var but Ragflow container expects NOTE: MySQL uses MYSQL_ROOT_PASSWORD env var but Ragflow container expects
MYSQL_PASSWORD so we need to define both as the same value here. MYSQL_PASSWORD so we need to define both as the same value here.
@ -31,10 +50,9 @@ stringData:
MYSQL_PASSWORD: {{ . }} MYSQL_PASSWORD: {{ . }}
MYSQL_ROOT_PASSWORD: {{ . }} MYSQL_ROOT_PASSWORD: {{ . }}
{{- end }} {{- end }}
{{- with .Values.env.MINIO_PASSWORD | required "MINIO_PASSWORD is required" }} {{- $minioPass := default "" .Values.env.MINIO_PASSWORD }}
MINIO_PASSWORD: {{ . }} MINIO_PASSWORD: {{ $minioPass }}
MINIO_ROOT_PASSWORD: {{ . }} MINIO_ROOT_PASSWORD: {{ $minioPass }}
{{- end }}
{{- /* {{- /*
Only provide env vars for enabled doc engine Only provide env vars for enabled doc engine
*/}} */}}

View File

@ -32,7 +32,7 @@ spec:
{{- include "ragflow.selectorLabels" . | nindent 6 }} {{- include "ragflow.selectorLabels" . | nindent 6 }}
app.kubernetes.io/component: infinity app.kubernetes.io/component: infinity
{{- with .Values.infinity.deployment.strategy }} {{- with .Values.infinity.deployment.strategy }}
strategy: updateStrategy:
{{- . | toYaml | nindent 4 }} {{- . | toYaml | nindent 4 }}
{{- end }} {{- end }}
template: template:
@ -43,9 +43,9 @@ spec:
annotations: annotations:
checksum/config: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }} checksum/config: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }}
spec: spec:
{{- if or .Values.imagePullSecrets .Values.infinity.image.pullSecrets }} {{- if or .Values.global.imagePullSecrets .Values.infinity.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{- with .Values.imagePullSecrets }} {{- with .Values.global.imagePullSecrets }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.infinity.image.pullSecrets }} {{- with .Values.infinity.image.pullSecrets }}
@ -54,7 +54,7 @@ spec:
{{- end }} {{- end }}
containers: containers:
- name: infinity - name: infinity
image: {{ .Values.infinity.image.repository }}:{{ .Values.infinity.image.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.infinity.image.repository) }}:{{ .Values.infinity.image.tag }}
{{- with .Values.infinity.image.pullPolicy }} {{- with .Values.infinity.image.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}

View File

@ -35,7 +35,7 @@ spec:
{{- end }} {{- end }}
backend: backend:
service: service:
name: {{ $.Release.Name }} name: {{ include "ragflow.fullname" $ }}
port: port:
name: http name: http
{{- end }} {{- end }}

View File

@ -1,3 +1,4 @@
{{- if .Values.minio.enabled }}
--- ---
apiVersion: v1 apiVersion: v1
kind: PersistentVolumeClaim kind: PersistentVolumeClaim
@ -43,9 +44,9 @@ spec:
{{- include "ragflow.labels" . | nindent 8 }} {{- include "ragflow.labels" . | nindent 8 }}
app.kubernetes.io/component: minio app.kubernetes.io/component: minio
spec: spec:
{{- if or .Values.imagePullSecrets .Values.minio.image.pullSecrets }} {{- if or .Values.global.imagePullSecrets .Values.minio.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{- with .Values.imagePullSecrets }} {{- with .Values.global.imagePullSecrets }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.minio.image.pullSecrets }} {{- with .Values.minio.image.pullSecrets }}
@ -54,7 +55,7 @@ spec:
{{- end }} {{- end }}
containers: containers:
- name: minio - name: minio
image: {{ .Values.minio.image.repository }}:{{ .Values.minio.image.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.minio.image.repository) }}:{{ .Values.minio.image.tag }}
{{- with .Values.minio.image.pullPolicy }} {{- with .Values.minio.image.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}
@ -103,3 +104,4 @@ spec:
port: 9001 port: 9001
targetPort: console targetPort: console
type: {{ .Values.minio.service.type }} type: {{ .Values.minio.service.type }}
{{- end }}

View File

@ -1,3 +1,4 @@
{{- if .Values.mysql.enabled }}
--- ---
apiVersion: v1 apiVersion: v1
kind: ConfigMap kind: ConfigMap
@ -7,3 +8,4 @@ data:
init.sql: |- init.sql: |-
CREATE DATABASE IF NOT EXISTS rag_flow; CREATE DATABASE IF NOT EXISTS rag_flow;
USE rag_flow; USE rag_flow;
{{- end }}

View File

@ -1,3 +1,4 @@
{{- if .Values.mysql.enabled }}
--- ---
apiVersion: v1 apiVersion: v1
kind: PersistentVolumeClaim kind: PersistentVolumeClaim
@ -32,7 +33,7 @@ spec:
{{- include "ragflow.selectorLabels" . | nindent 6 }} {{- include "ragflow.selectorLabels" . | nindent 6 }}
app.kubernetes.io/component: mysql app.kubernetes.io/component: mysql
{{- with .Values.mysql.deployment.strategy }} {{- with .Values.mysql.deployment.strategy }}
strategy: updateStrategy:
{{- . | toYaml | nindent 4 }} {{- . | toYaml | nindent 4 }}
{{- end }} {{- end }}
template: template:
@ -44,9 +45,9 @@ spec:
checksum/config-mysql: {{ include (print $.Template.BasePath "/mysql-config.yaml") . | sha256sum }} checksum/config-mysql: {{ include (print $.Template.BasePath "/mysql-config.yaml") . | sha256sum }}
checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }} checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }}
spec: spec:
{{- if or .Values.imagePullSecrets .Values.mysql.image.pullSecrets }} {{- if or .Values.global.imagePullSecrets .Values.mysql.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{- with .Values.imagePullSecrets }} {{- with .Values.global.imagePullSecrets }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.mysql.image.pullSecrets }} {{- with .Values.mysql.image.pullSecrets }}
@ -55,7 +56,7 @@ spec:
{{- end }} {{- end }}
containers: containers:
- name: mysql - name: mysql
image: {{ .Values.mysql.image.repository }}:{{ .Values.mysql.image.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.mysql.image.repository) }}:{{ .Values.mysql.image.tag }}
{{- with .Values.mysql.image.pullPolicy }} {{- with .Values.mysql.image.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}
@ -108,3 +109,4 @@ spec:
port: 3306 port: 3306
targetPort: mysql targetPort: mysql
type: {{ .Values.mysql.service.type }} type: {{ .Values.mysql.service.type }}
{{- end }}

View File

@ -32,7 +32,7 @@ spec:
{{- include "ragflow.selectorLabels" . | nindent 6 }} {{- include "ragflow.selectorLabels" . | nindent 6 }}
app.kubernetes.io/component: opensearch app.kubernetes.io/component: opensearch
{{- with .Values.opensearch.deployment.strategy }} {{- with .Values.opensearch.deployment.strategy }}
strategy: updateStrategy:
{{- . | toYaml | nindent 4 }} {{- . | toYaml | nindent 4 }}
{{- end }} {{- end }}
template: template:
@ -44,9 +44,9 @@ spec:
checksum/config-opensearch: {{ include (print $.Template.BasePath "/opensearch-config.yaml") . | sha256sum }} checksum/config-opensearch: {{ include (print $.Template.BasePath "/opensearch-config.yaml") . | sha256sum }}
checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }} checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }}
spec: spec:
{{- if or .Values.imagePullSecrets .Values.opensearch.image.pullSecrets }} {{- if or .Values.global.imagePullSecrets .Values.opensearch.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{- with .Values.imagePullSecrets }} {{- with .Values.global.imagePullSecrets }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.opensearch.image.pullSecrets }} {{- with .Values.opensearch.image.pullSecrets }}
@ -55,7 +55,7 @@ spec:
{{- end }} {{- end }}
initContainers: initContainers:
- name: fix-data-volume-permissions - name: fix-data-volume-permissions
image: {{ .Values.opensearch.initContainers.alpine.repository }}:{{ .Values.opensearch.initContainers.alpine.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.opensearch.initContainers.alpine.repository) }}:{{ .Values.opensearch.initContainers.alpine.tag }}
{{- with .Values.opensearch.initContainers.alpine.pullPolicy }} {{- with .Values.opensearch.initContainers.alpine.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}
@ -67,7 +67,7 @@ spec:
- mountPath: /usr/share/opensearch/data - mountPath: /usr/share/opensearch/data
name: opensearch-data name: opensearch-data
- name: sysctl - name: sysctl
image: {{ .Values.opensearch.initContainers.busybox.repository }}:{{ .Values.opensearch.initContainers.busybox.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.opensearch.initContainers.busybox.repository) }}:{{ .Values.opensearch.initContainers.busybox.tag }}
{{- with .Values.opensearch.initContainers.busybox.pullPolicy }} {{- with .Values.opensearch.initContainers.busybox.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}
@ -77,7 +77,7 @@ spec:
command: ["sysctl", "-w", "vm.max_map_count=262144"] command: ["sysctl", "-w", "vm.max_map_count=262144"]
containers: containers:
- name: opensearch - name: opensearch
image: {{ .Values.opensearch.image.repository }}:{{ .Values.opensearch.image.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.opensearch.image.repository) }}:{{ .Values.opensearch.image.tag }}
{{- with .Values.opensearch.image.pullPolicy }} {{- with .Values.opensearch.image.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}

View File

@ -25,9 +25,9 @@ spec:
checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }} checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }}
checksum/config-ragflow: {{ include (print $.Template.BasePath "/ragflow_config.yaml") . | sha256sum }} checksum/config-ragflow: {{ include (print $.Template.BasePath "/ragflow_config.yaml") . | sha256sum }}
spec: spec:
{{- if or .Values.imagePullSecrets .Values.ragflow.image.pullSecrets }} {{- if or .Values.global.imagePullSecrets .Values.ragflow.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{- with .Values.imagePullSecrets }} {{- with .Values.global.imagePullSecrets }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.ragflow.image.pullSecrets }} {{- with .Values.ragflow.image.pullSecrets }}
@ -36,7 +36,7 @@ spec:
{{- end }} {{- end }}
containers: containers:
- name: ragflow - name: ragflow
image: {{ .Values.ragflow.image.repository }}:{{ .Values.ragflow.image.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.ragflow.image.repository) }}:{{ .Values.ragflow.image.tag }}
{{- with .Values.ragflow.image.pullPolicy }} {{- with .Values.ragflow.image.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}

View File

@ -1,3 +1,4 @@
{{- if .Values.redis.enabled }}
--- ---
apiVersion: v1 apiVersion: v1
kind: Service kind: Service
@ -40,9 +41,9 @@ spec:
annotations: annotations:
checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }} checksum/config-env: {{ include (print $.Template.BasePath "/env.yaml") . | sha256sum }}
spec: spec:
{{- if or .Values.imagePullSecrets .Values.redis.image.pullSecrets }} {{- if or .Values.global.imagePullSecrets .Values.redis.image.pullSecrets }}
imagePullSecrets: imagePullSecrets:
{{- with .Values.imagePullSecrets }} {{- with .Values.global.imagePullSecrets }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.redis.image.pullSecrets }} {{- with .Values.redis.image.pullSecrets }}
@ -52,7 +53,7 @@ spec:
terminationGracePeriodSeconds: 60 terminationGracePeriodSeconds: 60
containers: containers:
- name: redis - name: redis
image: {{ .Values.redis.image.repository }}:{{ .Values.redis.image.tag }} image: {{ include "ragflow.imageRepo" (dict "root" . "repo" .Values.redis.image.repository) }}:{{ .Values.redis.image.tag }}
{{- with .Values.redis.image.pullPolicy }} {{- with .Values.redis.image.pullPolicy }}
imagePullPolicy: {{ . }} imagePullPolicy: {{ . }}
{{- end }} {{- end }}
@ -131,3 +132,4 @@ spec:
matchLabels: matchLabels:
{{- include "ragflow.selectorLabels" . | nindent 6 }} {{- include "ragflow.selectorLabels" . | nindent 6 }}
app.kubernetes.io/component: redis app.kubernetes.io/component: redis
{{- end }}

View File

@ -9,7 +9,7 @@ metadata:
spec: spec:
containers: containers:
- name: wget - name: wget
image: busybox image: {{ include "ragflow.imageRepo" (dict "root" . "repo" "busybox") }}
command: command:
- 'wget' - 'wget'
args: args:

View File

@ -1,7 +1,14 @@
# Based on docker compose .env file # Based on docker compose .env file
# Global image pull secrets configuration # Global image pull secrets configuration
imagePullSecrets: [] global:
# Global image repo prefix to render all images from a mirror/registry.
# Example: "registry.example.com/myproj"
# When set, template will replace the registry part of each image and keep the path.
# Leave empty to use per-image repositories as-is.
repo: ""
# Global image pull secrets for all pods
imagePullSecrets: []
env: env:
# The type of doc engine to use. # The type of doc engine to use.
@ -27,14 +34,28 @@ env:
MYSQL_PASSWORD: infini_rag_flow_helm MYSQL_PASSWORD: infini_rag_flow_helm
# The database of the MySQL service to use # The database of the MySQL service to use
MYSQL_DBNAME: rag_flow MYSQL_DBNAME: rag_flow
# External MySQL host (only required when mysql.enabled=false)
# MYSQL_HOST: ""
# External MySQL port (defaults to 3306 if not set)
# MYSQL_PORT: "3306"
# External MySQL user (only when mysql.enabled=false), default is root if omitted
# MYSQL_USER: "root"
# The username for MinIO. # The username for MinIO.
MINIO_ROOT_USER: rag_flow MINIO_ROOT_USER: rag_flow
# The password for MinIO # The password for MinIO
MINIO_PASSWORD: infini_rag_flow_helm MINIO_PASSWORD: infini_rag_flow_helm
# External MinIO host
# MINIO_HOST: ""
# External MinIO port (defaults to 9000 if not set)
# MINIO_PORT: "9000"
# The password for Redis # The password for Redis
REDIS_PASSWORD: infini_rag_flow_helm REDIS_PASSWORD: infini_rag_flow_helm
# External Redis host (only required when redis.enabled=false)
# REDIS_HOST: ""
# External Redis port (defaults to 6379 if not set)
# REDIS_PORT: "6379"
# The local time zone. # The local time zone.
TZ: "Asia/Shanghai" TZ: "Asia/Shanghai"
@ -163,6 +184,7 @@ opensearch:
type: ClusterIP type: ClusterIP
minio: minio:
enabled: true
image: image:
repository: quay.io/minio/minio repository: quay.io/minio/minio
tag: RELEASE.2023-12-20T01-00-02Z tag: RELEASE.2023-12-20T01-00-02Z
@ -178,6 +200,7 @@ minio:
type: ClusterIP type: ClusterIP
mysql: mysql:
enabled: true
image: image:
repository: mysql repository: mysql
tag: 8.0.39 tag: 8.0.39
@ -193,6 +216,7 @@ mysql:
type: ClusterIP type: ClusterIP
redis: redis:
enabled: true
image: image:
repository: valkey/valkey repository: valkey/valkey
tag: 8 tag: 8

View File

@ -71,7 +71,7 @@ class MessageService:
filter_dict["session_id"] = keywords filter_dict["session_id"] = keywords
order_by = OrderByExpr() order_by = OrderByExpr()
order_by.desc("valid_at") order_by.desc("valid_at")
res = settings.msgStoreConn.search( res, total_count = settings.msgStoreConn.search(
select_fields=[ select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status" "invalid_at", "forget_at", "status"
@ -82,13 +82,12 @@ class MessageService:
offset=(page-1)*page_size, limit=page_size, offset=(page-1)*page_size, limit=page_size,
index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False
) )
if not res: if not total_count:
return { return {
"message_list": [], "message_list": [],
"total_count": 0 "total_count": 0
} }
total_count = settings.msgStoreConn.get_total(res)
doc_mapping = settings.msgStoreConn.get_fields(res, [ doc_mapping = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status" "valid_at", "invalid_at", "forget_at", "status"
@ -107,7 +106,7 @@ class MessageService:
} }
order_by = OrderByExpr() order_by = OrderByExpr()
order_by.desc("valid_at") order_by.desc("valid_at")
res = settings.msgStoreConn.search( res, total_count = settings.msgStoreConn.search(
select_fields=[ select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status", "content" "invalid_at", "forget_at", "status", "content"
@ -118,7 +117,7 @@ class MessageService:
offset=0, limit=limit, offset=0, limit=limit,
index_names=index_names, memory_ids=memory_ids, agg_fields=[] index_names=index_names, memory_ids=memory_ids, agg_fields=[]
) )
if not res: if not total_count:
return [] return []
doc_mapping = settings.msgStoreConn.get_fields(res, [ doc_mapping = settings.msgStoreConn.get_fields(res, [
@ -136,7 +135,7 @@ class MessageService:
order_by = OrderByExpr() order_by = OrderByExpr()
order_by.desc("valid_at") order_by.desc("valid_at")
res = settings.msgStoreConn.search( res, total_count = settings.msgStoreConn.search(
select_fields=[ select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at", "valid_at",
@ -149,7 +148,7 @@ class MessageService:
offset=0, limit=top_n, offset=0, limit=top_n,
index_names=index_names, memory_ids=memory_ids, agg_fields=[] index_names=index_names, memory_ids=memory_ids, agg_fields=[]
) )
if not res: if not total_count:
return [] return []
docs = settings.msgStoreConn.get_fields(res, [ docs = settings.msgStoreConn.get_fields(res, [
@ -195,23 +194,22 @@ class MessageService:
select_fields = ["message_id", "content", "content_embed"] select_fields = ["message_id", "content", "content_embed"]
_index_name = index_name(uid) _index_name = index_name(uid)
res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id) res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id)
if not res:
return []
message_list = settings.msgStoreConn.get_fields(res, select_fields)
current_size = 0 current_size = 0
ids_to_remove = [] ids_to_remove = []
for message in message_list.values(): if res:
if current_size < size_to_delete: message_list = settings.msgStoreConn.get_fields(res, select_fields)
current_size += cls.calculate_message_size(message) for message in message_list.values():
ids_to_remove.append(message["message_id"]) if current_size < size_to_delete:
else: current_size += cls.calculate_message_size(message)
ids_to_remove.append(message["message_id"])
else:
return ids_to_remove, current_size
if current_size >= size_to_delete:
return ids_to_remove, current_size return ids_to_remove, current_size
if current_size >= size_to_delete:
return ids_to_remove, current_size
order_by = OrderByExpr() order_by = OrderByExpr()
order_by.asc("valid_at") order_by.asc("valid_at")
res = settings.msgStoreConn.search( res, total_count = settings.msgStoreConn.search(
select_fields=select_fields, select_fields=select_fields,
highlight_fields=[], highlight_fields=[],
condition={}, condition={},
@ -240,7 +238,7 @@ class MessageService:
order_by = OrderByExpr() order_by = OrderByExpr()
order_by.desc("message_id") order_by.desc("message_id")
index_names = [index_name(uid) for uid in uid_list] index_names = [index_name(uid) for uid in uid_list]
res = settings.msgStoreConn.search( res, total_count = settings.msgStoreConn.search(
select_fields=["message_id"], select_fields=["message_id"],
highlight_fields=[], highlight_fields=[],
condition={}, condition={},
@ -250,7 +248,7 @@ class MessageService:
index_names=index_names, memory_ids=memory_ids, index_names=index_names, memory_ids=memory_ids,
agg_fields=[], hide_forgotten=False agg_fields=[], hide_forgotten=False
) )
if not res: if not total_count:
return 1 return 1
docs = settings.msgStoreConn.get_fields(res, ["message_id"]) docs = settings.msgStoreConn.get_fields(res, ["message_id"])

View File

@ -130,7 +130,7 @@ class ESConnection(ESConnectionBase):
exist_index_list = [idx for idx in index_names if self.index_exist(idx)] exist_index_list = [idx for idx in index_names if self.index_exist(idx)]
if not exist_index_list: if not exist_index_list:
return None return None, 0
bool_query = Q("bool", must=[], must_not=[]) bool_query = Q("bool", must=[], must_not=[])
if hide_forgotten: if hide_forgotten:

View File

@ -149,6 +149,9 @@ dependencies = [
# "cryptography==46.0.3", # "cryptography==46.0.3",
# "jinja2>=3.1.0", # "jinja2>=3.1.0",
"pyairtable>=3.3.0", "pyairtable>=3.3.0",
"pygithub>=2.8.1",
"asana>=5.2.2",
"python-gitlab>=7.0.0",
] ]
[dependency-groups] [dependency-groups]

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import os import os
import re import re
import tempfile import tempfile
@ -28,13 +28,14 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
# is it English # is it English
eng = lang.lower() == "english" # is_english(sections) is_english = lang.lower() == "english" # is_english(sections)
try: try:
_, ext = os.path.splitext(filename) _, ext = os.path.splitext(filename)
if not ext: if not ext:
raise RuntimeError("No extension detected.") raise RuntimeError("No extension detected.")
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]: if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma",
".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
raise RuntimeError(f"Extension {ext} is not supported yet.") raise RuntimeError(f"Extension {ext} is not supported yet.")
tmp_path = "" tmp_path = ""
@ -48,7 +49,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
ans = seq2txt_mdl.transcription(tmp_path) ans = seq2txt_mdl.transcription(tmp_path)
callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32]) callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32])
tokenize(doc, ans, eng) tokenize(doc, ans, is_english)
return [doc] return [doc]
except Exception as e: except Exception as e:
callback(prog=-1, msg=str(e)) callback(prog=-1, msg=str(e))
@ -56,6 +57,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
if tmp_path and os.path.exists(tmp_path): if tmp_path and os.path.exists(tmp_path):
try: try:
os.unlink(tmp_path) os.unlink(tmp_path)
except Exception: except Exception as e:
logging.exception(f"Failed to remove temporary file: {tmp_path}, exception: {e}")
pass pass
return [] return []

View File

@ -22,7 +22,7 @@ from deepdoc.parser.utils import get_text
from rag.app import naive from rag.app import naive
from rag.app.naive import by_plaintext, PARSERS from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
from rag.nlp import bullets_category, is_english,remove_contents_table, \ from rag.nlp import bullets_category, is_english, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \ hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks, attach_media_context tokenize_chunks, attach_media_context
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
@ -91,9 +91,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
filename, binary=binary, from_page=from_page, to_page=to_page) filename, binary=binary, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english( remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200))) random_choices([t for t, _ in sections], k=200)))
tbls = vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs) tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs)
# tbls = [((None, lns), None) for lns in tbls] # tbls = [((None, lns), None) for lns in tbls]
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)] sections = [(item[0], item[1] if item[1] is not None else "") for item in sections if
not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
@ -109,14 +110,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser( sections, tables, pdf_parser = parser(
filename = filename, filename=filename,
binary = binary, binary=binary,
from_page = from_page, from_page=from_page,
to_page = to_page, to_page=to_page,
lang = lang, lang=lang,
callback = callback, callback=callback,
pdf_cls = Pdf, pdf_cls=Pdf,
layout_recognizer = layout_recognizer, layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name, mineru_llm_name=parser_model_name,
**kwargs **kwargs
) )
@ -175,7 +176,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
for ck in hierarchical_merge(bull, sections, 5)] for ck in hierarchical_merge(bull, sections, 5)]
else: else:
sections = [s.split("@") for s, _ in sections] sections = [s.split("@") for s, _ in sections]
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ] sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections]
chunks = naive_merge( chunks = naive_merge(
sections, sections,
parser_config.get("chunk_token_num", 256), parser_config.get("chunk_token_num", 256),
@ -199,6 +200,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)

View File

@ -26,13 +26,13 @@ import io
def chunk( def chunk(
filename, filename,
binary=None, binary=None,
from_page=0, from_page=0,
to_page=100000, to_page=100000,
lang="Chinese", lang="Chinese",
callback=None, callback=None,
**kwargs, **kwargs,
): ):
""" """
Only eml is supported Only eml is supported
@ -93,7 +93,8 @@ def chunk(
_add_content(msg, msg.get_content_type()) _add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [ sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line (line, "") for line in
HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
] ]
st = timer() st = timer()
@ -126,7 +127,9 @@ def chunk(
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -29,8 +29,6 @@ from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
class Docx(DocxParser): class Docx(DocxParser):
def __init__(self): def __init__(self):
pass pass
@ -58,37 +56,36 @@ class Docx(DocxParser):
return [line for line in lines if line] return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000): def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document( self.doc = Document(
filename) if not binary else Document(BytesIO(binary)) filename) if not binary else Document(BytesIO(binary))
pn = 0 pn = 0
lines = [] lines = []
level_set = set() level_set = set()
bull = bullets_category([p.text for p in self.doc.paragraphs]) bull = bullets_category([p.text for p in self.doc.paragraphs])
for p in self.doc.paragraphs: for p in self.doc.paragraphs:
if pn > to_page: if pn > to_page:
break break
question_level, p_text = docx_question_level(p, bull) question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"): if not p_text.strip("\n"):
continue
lines.append((question_level, p_text))
level_set.add(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue continue
lines.append((question_level, p_text)) if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
level_set.add(question_level) pn += 1
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
sorted_levels = sorted(level_set) sorted_levels = sorted(level_set)
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1 h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
root = Node(level=0, depth=h2_level, texts=[]) root = Node(level=0, depth=h2_level, texts=[])
root.build_tree(lines) root.build_tree(lines)
return [element for element in root.get_tree() if element]
return [element for element in root.get_tree() if element]
def __str__(self) -> str: def __str__(self) -> str:
return f''' return f'''
@ -121,8 +118,7 @@ class Pdf(PdfParser):
start = timer() start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start)) callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts:".format( logging.debug("layouts: {}".format((timer() - start)))
))
self._naive_vertical_merge() self._naive_vertical_merge()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start)) callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
@ -168,14 +164,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
raw_sections, tables, pdf_parser = parser( raw_sections, tables, pdf_parser = parser(
filename = filename, filename=filename,
binary = binary, binary=binary,
from_page = from_page, from_page=from_page,
to_page = to_page, to_page=to_page,
lang = lang, lang=lang,
callback = callback, callback=callback,
pdf_cls = Pdf, pdf_cls=Pdf,
layout_recognizer = layout_recognizer, layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name, mineru_llm_name=parser_model_name,
**kwargs **kwargs
) )
@ -226,7 +222,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
raise NotImplementedError( raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)") "file type not supported yet(doc, docx, pdf, txt supported)")
# Remove 'Contents' part # Remove 'Contents' part
remove_contents_table(sections, eng) remove_contents_table(sections, eng)
@ -234,7 +229,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
bull = bullets_category(sections) bull = bullets_category(sections)
res = tree_merge(bull, sections, 2) res = tree_merge(bull, sections, 2)
if not res: if not res:
callback(0.99, "No chunk parsed out.") callback(0.99, "No chunk parsed out.")
@ -243,9 +237,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
# chunks = hierarchical_merge(bull, sections, 5) # chunks = hierarchical_merge(bull, sections, 5)
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser) # return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -20,15 +20,17 @@ import re
from common.constants import ParserType from common.constants import ParserType
from io import BytesIO from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level, attach_media_context from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, \
docx_question_level, attach_media_context
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
from deepdoc.parser import PdfParser, DocxParser from deepdoc.parser import PdfParser, DocxParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper, vision_figure_parser_docx_wrapper
from docx import Document from docx import Document
from PIL import Image from PIL import Image
from rag.app.naive import by_plaintext, PARSERS from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
class Pdf(PdfParser): class Pdf(PdfParser):
def __init__(self): def __init__(self):
self.model_speciess = ParserType.MANUAL.value self.model_speciess = ParserType.MANUAL.value
@ -129,11 +131,11 @@ class Docx(DocxParser):
question_level, p_text = 0, '' question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip(): if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p) question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}' last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p) current_image = self.get_picture(self.doc, p)
last_image = self.concat_img(last_image, current_image) last_image = self.concat_img(last_image, current_image)
else: # is a question else: # is a question
if last_answer or last_image: if last_answer or last_image:
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
@ -159,14 +161,14 @@ class Docx(DocxParser):
tbls = [] tbls = []
for tb in self.doc.tables: for tb in self.doc.tables:
html= "<table>" html = "<table>"
for r in tb.rows: for r in tb.rows:
html += "<tr>" html += "<tr>"
i = 0 i = 0
while i < len(r.cells): while i < len(r.cells):
span = 1 span = 1
c = r.cells[i] c = r.cells[i]
for j in range(i+1, len(r.cells)): for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text: if c.text == r.cells[j].text:
span += 1 span += 1
i = j i = j
@ -211,16 +213,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
kwargs.pop("parse_method", None) kwargs.pop("parse_method", None)
kwargs.pop("mineru_llm_name", None) kwargs.pop("mineru_llm_name", None)
sections, tbls, pdf_parser = pdf_parser( sections, tbls, pdf_parser = pdf_parser(
filename = filename, filename=filename,
binary = binary, binary=binary,
from_page = from_page, from_page=from_page,
to_page = to_page, to_page=to_page,
lang = lang, lang=lang,
callback = callback, callback=callback,
pdf_cls = Pdf, pdf_cls=Pdf,
layout_recognizer = layout_recognizer, layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name, mineru_llm_name=parser_model_name,
parse_method = "manual", parse_method="manual",
**kwargs **kwargs
) )
@ -237,10 +239,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if isinstance(poss, str): if isinstance(poss, str):
poss = pdf_parser.extract_positions(poss) poss = pdf_parser.extract_positions(poss)
if poss: if poss:
first = poss[0] # tuple: ([pn], x1, x2, y1, y2) first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
pn = first[0] pn = first[0]
if isinstance(pn, list) and pn: if isinstance(pn, list) and pn:
pn = pn[0] # [pn] -> pn pn = pn[0] # [pn] -> pn
poss[0] = (pn, *first[1:]) poss[0] = (pn, *first[1:])
return (txt, layoutno, poss) return (txt, layoutno, poss)
@ -289,7 +291,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if not rows: if not rows:
continue continue
sections.append((rows if isinstance(rows, str) else rows[0], -1, sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
def tag(pn, left, right, top, bottom): def tag(pn, left, right, top, bottom):
if pn + left + right + top + bottom == 0: if pn + left + right + top + bottom == 0:
@ -312,7 +314,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
tk_cnt = num_tokens_from_string(txt) tk_cnt = num_tokens_from_string(txt)
if sec_id > -1: if sec_id > -1:
last_sid = sec_id last_sid = sec_id
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs) tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0)) table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
@ -325,7 +327,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
docx_parser = Docx() docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary, ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback) from_page=0, to_page=10000, callback=callback)
tbls = vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs) tbls = vision_figure_parser_docx_wrapper(sections=ti_list, tbls=tbls, callback=callback, **kwargs)
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
for text, image in ti_list: for text, image in ti_list:
d = copy.deepcopy(doc) d = copy.deepcopy(doc)

View File

@ -31,16 +31,20 @@ from common.token_utils import num_tokens_from_string
from common.constants import LLMType from common.constants import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, \
from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper PdfParser, TxtParser
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_docx_wrapper, \
vision_figure_parser_pdf_wrapper
from deepdoc.parser.pdf_parser import PlainParser, VisionParser from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from deepdoc.parser.docling_parser import DoclingParser from deepdoc.parser.docling_parser import DoclingParser
from deepdoc.parser.tcadp_parser import TCADPParser from deepdoc.parser.tcadp_parser import TCADPParser
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, \
tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs): def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
**kwargs):
callback = callback callback = callback
binary = binary binary = binary
pdf_parser = pdf_cls() if pdf_cls else Pdf() pdf_parser = pdf_cls() if pdf_cls else Pdf()
@ -58,17 +62,17 @@ def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
def by_mineru( def by_mineru(
filename, filename,
binary=None, binary=None,
from_page=0, from_page=0,
to_page=100000, to_page=100000,
lang="Chinese", lang="Chinese",
callback=None, callback=None,
pdf_cls=None, pdf_cls=None,
parse_method: str = "raw", parse_method: str = "raw",
mineru_llm_name: str | None = None, mineru_llm_name: str | None = None,
tenant_id: str | None = None, tenant_id: str | None = None,
**kwargs, **kwargs,
): ):
pdf_parser = None pdf_parser = None
if tenant_id: if tenant_id:
@ -106,7 +110,8 @@ def by_mineru(
return None, None, None return None, None, None
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs): def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None,
**kwargs):
pdf_parser = DoclingParser() pdf_parser = DoclingParser()
parse_method = kwargs.get("parse_method", "raw") parse_method = kwargs.get("parse_method", "raw")
@ -125,7 +130,7 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
return sections, tables, pdf_parser return sections, tables, pdf_parser
def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None, **kwargs): def by_tcadp(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls=None, **kwargs):
tcadp_parser = TCADPParser() tcadp_parser = TCADPParser()
if not tcadp_parser.check_installation(): if not tcadp_parser.check_installation():
@ -168,10 +173,10 @@ def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=No
PARSERS = { PARSERS = {
"deepdoc": by_deepdoc, "deepdoc": by_deepdoc,
"mineru": by_mineru, "mineru": by_mineru,
"docling": by_docling, "docling": by_docling,
"tcadp": by_tcadp, "tcadp": by_tcadp,
"plaintext": by_plaintext, # default "plaintext": by_plaintext, # default
} }
@ -205,8 +210,8 @@ class Docx(DocxParser):
except UnicodeDecodeError: except UnicodeDecodeError:
logging.info("The recognized image stream appears to be corrupted. Skipping image.") logging.info("The recognized image stream appears to be corrupted. Skipping image.")
continue continue
except Exception: except Exception as e:
logging.info("The recognized image stream appears to be corrupted. Skipping image.") logging.warning(f"The recognized image stream appears to be corrupted. Skipping image, exception: {e}")
continue continue
try: try:
image = Image.open(BytesIO(image_blob)).convert('RGB') image = Image.open(BytesIO(image_blob)).convert('RGB')
@ -214,7 +219,8 @@ class Docx(DocxParser):
res_img = image res_img = image
else: else:
res_img = concat_img(res_img, image) res_img = concat_img(res_img, image)
except Exception: except Exception as e:
logging.warning(f"Fail to open or concat images, exception: {e}")
continue continue
return res_img return res_img
@ -264,7 +270,7 @@ class Docx(DocxParser):
# Find the nearest heading paragraph in reverse order # Find the nearest heading paragraph in reverse order
nearest_title = None nearest_title = None
for i in range(len(blocks)-1, -1, -1): for i in range(len(blocks) - 1, -1, -1):
block_type, pos, block = blocks[i] block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table if pos >= target_table_pos: # Skip blocks after the table
continue continue
@ -293,7 +299,7 @@ class Docx(DocxParser):
# Find all parent headings, allowing cross-level search # Find all parent headings, allowing cross-level search
while current_level > 1: while current_level > 1:
found = False found = False
for i in range(len(blocks)-1, -1, -1): for i in range(len(blocks) - 1, -1, -1):
block_type, pos, block = blocks[i] block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table if pos >= target_table_pos: # Skip blocks after the table
continue continue
@ -426,7 +432,8 @@ class Docx(DocxParser):
try: try:
if inline_images: if inline_images:
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64)) result = mammoth.convert_to_html(docx_file,
convert_image=mammoth.images.img_element(_convert_image_to_base64))
else: else:
result = mammoth.convert_to_html(docx_file) result = mammoth.convert_to_html(docx_file)
@ -547,7 +554,8 @@ class Markdown(MarkdownParser):
if (src, line_no) not in seen: if (src, line_no) not in seen:
urls.append({"url": src, "line": line_no}) urls.append({"url": src, "line": line_no})
seen.add((src, line_no)) seen.add((src, line_no))
except Exception: except Exception as e:
logging.error("Failed to extract image urls: {}".format(e))
pass pass
return urls return urls
@ -621,6 +629,7 @@ class Markdown(MarkdownParser):
return sections, tbls, section_images return sections, tbls, section_images
return sections, tbls return sections, tbls
def load_from_xml_v2(baseURI, rels_item_xml): def load_from_xml_v2(baseURI, rels_item_xml):
""" """
Return |_SerializedRelationships| instance loaded with the Return |_SerializedRelationships| instance loaded with the
@ -636,6 +645,7 @@ def load_from_xml_v2(baseURI, rels_item_xml):
srels._srels.append(_SerializedRelationship(baseURI, rel_elm)) srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
return srels return srels
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, excel, txt. Supported file formats are docx, pdf, excel, txt.
@ -651,7 +661,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
"parser_config", { "parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True}) "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8') child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode(
'latin1').decode('utf-8')
cust_child_deli = re.findall(r"`([^`]+)`", child_deli) cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli)) child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
if cust_child_deli: if cust_child_deli:
@ -685,11 +696,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
# Recursively chunk each embedded file and collect results # Recursively chunk each embedded file and collect results
for embed_filename, embed_bytes in embeds: for embed_filename, embed_bytes in embeds:
try: try:
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False, **kwargs) or [] sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False,
**kwargs) or []
embed_res.extend(sub_res) embed_res.extend(sub_res)
except Exception as e: except Exception as e:
error_msg = f"Failed to chunk embed {embed_filename}: {e}"
logging.error(error_msg)
if callback: if callback:
callback(0.05, f"Failed to chunk embed {embed_filename}: {e}") callback(0.05, error_msg)
continue continue
if re.search(r"\.docx$", filename, re.IGNORECASE): if re.search(r"\.docx$", filename, re.IGNORECASE):
@ -704,7 +718,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs) sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
except Exception as e: except Exception as e:
logging.info(f"Failed to chunk url in registered file type {url}: {e}") logging.info(f"Failed to chunk url in registered file type {url}: {e}")
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs) sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False,
**kwargs)
url_res.extend(sub_url_res) url_res.extend(sub_url_res)
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246 # fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
@ -747,14 +762,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser( sections, tables, pdf_parser = parser(
filename = filename, filename=filename,
binary = binary, binary=binary,
from_page = from_page, from_page=from_page,
to_page = to_page, to_page=to_page,
lang = lang, lang=lang,
callback = callback, callback=callback,
layout_recognizer = layout_recognizer, layout_recognizer=layout_recognizer,
mineru_llm_name = parser_model_name, mineru_llm_name=parser_model_name,
**kwargs **kwargs
) )
@ -812,7 +827,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
parser_config.get("delimiter", "\n!?;。;!?")) parser_config.get("delimiter", "\n!?;。;!?"))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE): elif re.search(r"\.(md|markdown|mdx)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
markdown_parser = Markdown(int(parser_config.get("chunk_token_num", 128))) markdown_parser = Markdown(int(parser_config.get("chunk_token_num", 128)))
sections, tables, section_images = markdown_parser( sections, tables, section_images = markdown_parser(
@ -828,7 +843,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
try: try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT) vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.2, "Visual model detected. Attempting to enhance figure extraction...") callback(0.2, "Visual model detected. Attempting to enhance figure extraction...")
except Exception: except Exception as e:
logging.warning(f"Failed to detect figure extraction: {e}")
vision_model = None vision_model = None
if vision_model: if vision_model:
@ -846,9 +862,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
else: else:
section_images = [None] * len(sections) section_images = [None] * len(sections)
section_images[idx] = combined_image section_images[idx] = combined_image
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs) markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=[
((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
boosted_figures = markdown_vision_parser(callback=callback) boosted_figures = markdown_vision_parser(callback=callback)
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]), sections[idx][1]) sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]),
sections[idx][1])
else: else:
logging.warning("No visual model detected. Skipping figure parsing enhancement.") logging.warning("No visual model detected. Skipping figure parsing enhancement.")
@ -892,8 +910,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sections = [(_, "") for _ in sections if _] sections = [(_, "") for _ in sections if _]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: else:
callback(0.8, f"tika.parser got empty content from {filename}.") error_msg = f"tika.parser got empty content from {filename}."
logging.warning(f"tika.parser got empty content from {filename}.") callback(0.8, error_msg)
logging.warning(error_msg)
return [] return []
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -945,7 +964,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
has_images = merged_images and any(img is not None for img in merged_images) has_images = merged_images and any(img is not None for img in merged_images)
if has_images: if has_images:
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images, child_delimiters_pattern=child_deli)) res.extend(tokenize_chunks_with_images(chunks, doc, is_english, merged_images,
child_delimiters_pattern=child_deli))
else: else:
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser, child_delimiters_pattern=child_deli)) res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser, child_delimiters_pattern=child_deli))
else: else:
@ -955,10 +975,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if section_images: if section_images:
chunks, images = naive_merge_with_images(sections, section_images, chunks, images = naive_merge_with_images(sections, section_images,
int(parser_config.get( int(parser_config.get(
"chunk_token_num", 128)), parser_config.get( "chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?")) "delimiter", "\n!?。;!?"))
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli)) res.extend(
tokenize_chunks_with_images(chunks, doc, is_english, images, child_delimiters_pattern=child_deli))
else: else:
chunks = naive_merge( chunks = naive_merge(
sections, int(parser_config.get( sections, int(parser_config.get(
@ -993,7 +1014,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -26,6 +26,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from rag.app.naive import by_plaintext, PARSERS from rag.app.naive import by_plaintext, PARSERS
from common.parser_config_utils import normalize_layout_recognizer from common.parser_config_utils import normalize_layout_recognizer
class Pdf(PdfParser): class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0, def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None): to_page=100000, zoomin=3, callback=None):
@ -95,14 +96,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
sections, tbls, pdf_parser = parser( sections, tbls, pdf_parser = parser(
filename = filename, filename=filename,
binary = binary, binary=binary,
from_page = from_page, from_page=from_page,
to_page = to_page, to_page=to_page,
lang = lang, lang=lang,
callback = callback, callback=callback,
pdf_cls = Pdf, pdf_cls=Pdf,
layout_recognizer = layout_recognizer, layout_recognizer=layout_recognizer,
mineru_llm_name=parser_model_name, mineru_llm_name=parser_model_name,
**kwargs **kwargs
) )
@ -127,7 +128,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
excel_parser = ExcelParser() excel_parser = ExcelParser()
sections = excel_parser.html(binary, 1000000000) sections = excel_parser.html(binary, 1000000000)
elif re.search(r"\.(txt|md|markdown)$", filename, re.IGNORECASE): elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)
sections = txt.split("\n") sections = txt.split("\n")
@ -172,7 +173,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -20,7 +20,8 @@ import re
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
from common.constants import ParserType from common.constants import ParserType
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks, attach_media_context from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, \
tokenize_chunks, attach_media_context
from deepdoc.parser import PdfParser from deepdoc.parser import PdfParser
import numpy as np import numpy as np
from rag.app.naive import by_plaintext, PARSERS from rag.app.naive import by_plaintext, PARSERS
@ -66,7 +67,7 @@ class Pdf(PdfParser):
# clean mess # clean mess
if column_width < self.page_images[0].size[0] / zoomin / 2: if column_width < self.page_images[0].size[0] / zoomin / 2:
logging.debug("two_column................... {} {}".format(column_width, logging.debug("two_column................... {} {}".format(column_width,
self.page_images[0].size[0] / zoomin / 2)) self.page_images[0].size[0] / zoomin / 2))
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2) self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
for b in self.boxes: for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
@ -89,7 +90,7 @@ class Pdf(PdfParser):
title = "" title = ""
authors = [] authors = []
i = 0 i = 0
while i < min(32, len(self.boxes)-1): while i < min(32, len(self.boxes) - 1):
b = self.boxes[i] b = self.boxes[i]
i += 1 i += 1
if b.get("layoutno", "").find("title") >= 0: if b.get("layoutno", "").find("title") >= 0:
@ -190,8 +191,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"tables": tables "tables": tables
} }
tbls=paper["tables"] tbls = paper["tables"]
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs) tbls = vision_figure_parser_pdf_wrapper(tbls=tbls, callback=callback, **kwargs)
paper["tables"] = tbls paper["tables"] = tbls
else: else:
raise NotImplementedError("file type not supported yet(pdf supported)") raise NotImplementedError("file type not supported yet(pdf supported)")
@ -329,6 +330,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -51,7 +51,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
} }
) )
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang) cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)) ans = asyncio.run(
cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
callback(0.8, "CV LLM respond: %s ..." % ans[:32]) callback(0.8, "CV LLM respond: %s ..." % ans[:32])
ans += "\n" + ans ans += "\n" + ans
tokenize(doc, ans, eng) tokenize(doc, ans, eng)

View File

@ -249,7 +249,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(a, b): def dummy(a, b):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -102,9 +102,9 @@ class Pdf(PdfParser):
self._text_merge() self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start)) callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
tbls = self._extract_table_figure(True, zoomin, True, True) tbls = self._extract_table_figure(True, zoomin, True, True)
#self._naive_vertical_merge() # self._naive_vertical_merge()
# self._concat_downward() # self._concat_downward()
#self._filter_forpages() # self._filter_forpages()
logging.debug("layouts: {}".format(timer() - start)) logging.debug("layouts: {}".format(timer() - start))
sections = [b["text"] for b in self.boxes] sections = [b["text"] for b in self.boxes]
bull_x0_list = [] bull_x0_list = []
@ -114,12 +114,14 @@ class Pdf(PdfParser):
qai_list = [] qai_list = []
last_q, last_a, last_tag = '', '', '' last_q, last_a, last_tag = '', '', ''
last_index = -1 last_index = -1
last_box = {'text':''} last_box = {'text': ''}
last_bull = None last_bull = None
def sort_key(element): def sort_key(element):
tbls_pn = element[1][0][0] tbls_pn = element[1][0][0]
tbls_top = element[1][0][3] tbls_top = element[1][0][3]
return tbls_pn, tbls_top return tbls_pn, tbls_top
tbls.sort(key=sort_key) tbls.sort(key=sort_key)
tbl_index = 0 tbl_index = 0
last_pn, last_bottom = 0, 0 last_pn, last_bottom = 0, 0
@ -133,28 +135,32 @@ class Pdf(PdfParser):
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
if not has_bull: # No question bullet if not has_bull: # No question bullet
if not last_q: if not last_q:
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
tbl_index += 1 tbl_index += 1
continue continue
else: else:
sum_tag = line_tag sum_tag = line_tag
sum_section = section sum_section = section
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \ while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer and ((tbl_pn == line_pn and tbl_top <= line_top) or (
tbl_pn < line_pn)): # add image at the middle of current answer
sum_tag = f'{tbl_tag}{sum_tag}' sum_tag = f'{tbl_tag}{sum_tag}'
sum_section = f'{tbl_text}{sum_section}' sum_section = f'{tbl_text}{sum_section}'
tbl_index += 1 tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
tbl_index)
last_a = f'{last_a}{sum_section}' last_a = f'{last_a}{sum_section}'
last_tag = f'{last_tag}{sum_tag}' last_tag = f'{last_tag}{sum_tag}'
else: else:
if last_q: if last_q:
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \ while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer and ((tbl_pn == line_pn and tbl_top <= line_top) or (
tbl_pn < line_pn)): # add image at the end of last answer
last_tag = f'{last_tag}{tbl_tag}' last_tag = f'{last_tag}{tbl_tag}'
last_a = f'{last_a}{tbl_text}' last_a = f'{last_a}{tbl_text}'
tbl_index += 1 tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls,
tbl_index)
image, poss = self.crop(last_tag, need_position=True) image, poss = self.crop(last_tag, need_position=True)
qai_list.append((last_q, last_a, image, poss)) qai_list.append((last_q, last_a, image, poss))
last_q, last_a, last_tag = '', '', '' last_q, last_a, last_tag = '', '', ''
@ -171,7 +177,7 @@ class Pdf(PdfParser):
def get_tbls_info(self, tbls, tbl_index): def get_tbls_info(self, tbls, tbl_index):
if tbl_index >= len(tbls): if tbl_index >= len(tbls):
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', '' return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
tbl_pn = tbls[tbl_index][1][0][0]+1 tbl_pn = tbls[tbl_index][1][0][0] + 1
tbl_left = tbls[tbl_index][1][0][1] tbl_left = tbls[tbl_index][1][0][1]
tbl_right = tbls[tbl_index][1][0][2] tbl_right = tbls[tbl_index][1][0][2]
tbl_top = tbls[tbl_index][1][0][3] tbl_top = tbls[tbl_index][1][0][3]
@ -210,11 +216,11 @@ class Docx(DocxParser):
question_level, p_text = 0, '' question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip(): if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p) question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}' last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p) current_image = self.get_picture(self.doc, p)
last_image = concat_img(last_image, current_image) last_image = concat_img(last_image, current_image)
else: # is a question else: # is a question
if last_answer or last_image: if last_answer or last_image:
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
@ -240,14 +246,14 @@ class Docx(DocxParser):
tbls = [] tbls = []
for tb in self.doc.tables: for tb in self.doc.tables:
html= "<table>" html = "<table>"
for r in tb.rows: for r in tb.rows:
html += "<tr>" html += "<tr>"
i = 0 i = 0
while i < len(r.cells): while i < len(r.cells):
span = 1 span = 1
c = r.cells[i] c = r.cells[i]
for j in range(i+1, len(r.cells)): for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text: if c.text == r.cells[j].text:
span += 1 span += 1
i = j i = j
@ -356,7 +362,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if question: if question:
answer += "\n" + lines[i] answer += "\n" + lines[i]
else: else:
fails.append(str(i+1)) fails.append(str(i + 1))
elif len(arr) == 2: elif len(arr) == 2:
if question and answer: if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng, i)) res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
@ -415,7 +421,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss)) res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss))
return res return res
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE): elif re.search(r"\.(md|markdown|mdx)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)
lines = txt.split("\n") lines = txt.split("\n")
@ -429,13 +435,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if not code_block: if not code_block:
question_level, question = mdQuestionLevel(line) question_level, question = mdQuestionLevel(line)
if not question_level or question_level > 6: # not a question if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{line}' last_answer = f'{last_answer}\n{line}'
else: # is a question else: # is a question
if last_answer.strip(): if last_answer.strip():
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) res.append(beAdoc(deepcopy(doc), sum_question,
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = '' last_answer = ''
i = question_level i = question_level
@ -447,13 +454,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if last_answer.strip(): if last_answer.strip():
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) res.append(beAdoc(deepcopy(doc), sum_question,
markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res return res
elif re.search(r"\.docx$", filename, re.IGNORECASE): elif re.search(r"\.docx$", filename, re.IGNORECASE):
docx_parser = Docx() docx_parser = Docx()
qai_list, tbls = docx_parser(filename, binary, qai_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback) from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
for i, (q, a, image) in enumerate(qai_list): for i, (q, a, image) in enumerate(qai_list):
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i)) res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
@ -466,6 +474,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -64,7 +64,8 @@ def remote_call(filename, binary):
del resume[k] del resume[k]
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) "updated_at": datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume) resume = step_two.parse(resume)
return resume return resume
except Exception: except Exception:
@ -171,6 +172,9 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(a, b): def dummy(a, b):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -42,23 +42,24 @@ class Excel(ExcelParser):
else: else:
wb = Excel._load_excel_to_workbook(BytesIO(binary)) wb = Excel._load_excel_to_workbook(BytesIO(binary))
total = 0 total = 0
for sheetname in wb.sheetnames: for sheet_name in wb.sheetnames:
total += len(list(wb[sheetname].rows)) total += len(list(wb[sheet_name].rows))
res, fails, done = [], [], 0 res, fails, done = [], [], 0
rn = 0 rn = 0
flow_images = [] flow_images = []
pending_cell_images = [] pending_cell_images = []
tables = [] tables = []
for sheetname in wb.sheetnames: for sheet_name in wb.sheetnames:
ws = wb[sheetname] ws = wb[sheet_name]
images = Excel._extract_images_from_worksheet(ws,sheetname=sheetname) images = Excel._extract_images_from_worksheet(ws, sheetname=sheet_name)
if images: if images:
image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, **kwargs) image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback,
**kwargs)
if image_descriptions and len(image_descriptions) == len(images): if image_descriptions and len(image_descriptions) == len(images):
for i, bf in enumerate(image_descriptions): for i, bf in enumerate(image_descriptions):
images[i]["image_description"] = "\n".join(bf[0][1]) images[i]["image_description"] = "\n".join(bf[0][1])
for img in images: for img in images:
if (img["span_type"] == "single_cell"and img.get("image_description")): if img["span_type"] == "single_cell" and img.get("image_description"):
pending_cell_images.append(img) pending_cell_images.append(img)
else: else:
flow_images.append(img) flow_images.append(img)
@ -66,7 +67,7 @@ class Excel(ExcelParser):
try: try:
rows = list(ws.rows) rows = list(ws.rows)
except Exception as e: except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}") logging.warning(f"Skip sheet '{sheet_name}' due to rows access error: {e}")
continue continue
if not rows: if not rows:
continue continue
@ -113,16 +114,17 @@ class Excel(ExcelParser):
tables.append( tables.append(
( (
( (
img["image"], # Image.Image img["image"], # Image.Image
[img["image_description"]] # description list (must be list) [img["image_description"]] # description list (must be list)
), ),
[ [
(0, 0, 0, 0, 0) # dummy position (0, 0, 0, 0, 0) # dummy position
] ]
) )
) )
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
return res,tables f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res, tables
def _parse_headers(self, ws, rows): def _parse_headers(self, ws, rows):
if len(rows) == 0: if len(rows) == 0:
@ -301,7 +303,8 @@ class Excel(ExcelParser):
def trans_datatime(s): def trans_datatime(s):
try: try:
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S") return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
except Exception: except Exception as e:
logging.warning(f"Failed to parse date from {s}, error: {e}")
pass pass
@ -310,19 +313,21 @@ def trans_bool(s):
return "yes" return "yes"
if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE): if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
return "no" return "no"
return None
def column_data_type(arr): def column_data_type(arr):
arr = list(arr) arr = list(arr)
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
float_flag = False float_flag = False
for a in arr: for a in arr:
if a is None: if a is None:
continue continue
if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"): if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
counts["int"] += 1 counts["int"] += 1
if int(str(a)) > 2**63 - 1: if int(str(a)) > 2 ** 63 - 1:
float_flag = True float_flag = True
break break
elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"): elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
@ -343,8 +348,9 @@ def column_data_type(arr):
continue continue
try: try:
arr[i] = trans[ty](str(arr[i])) arr[i] = trans[ty](str(arr[i]))
except Exception: except Exception as e:
arr[i] = None arr[i] = None
logging.warning(f"Column {i}: {e}")
# if ty == "text": # if ty == "text":
# if len(arr) > 128 and uni / len(arr) < 0.1: # if len(arr) > 128 and uni / len(arr) < 0.1:
# ty = "keyword" # ty = "keyword"
@ -370,7 +376,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if re.search(r"\.xlsx?$", filename, re.IGNORECASE): if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
excel_parser = Excel() excel_parser = Excel()
dfs,tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs) dfs, tbls = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback, **kwargs)
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)
@ -389,7 +395,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
continue continue
rows.append(row) rows.append(row)
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)] dfs = [pd.DataFrame(np.array(rows), columns=headers)]
elif re.search(r"\.csv$", filename, re.IGNORECASE): elif re.search(r"\.csv$", filename, re.IGNORECASE):
@ -406,7 +413,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
fails = [] fails = []
rows = [] rows = []
for i, row in enumerate(all_rows[1 + from_page : 1 + to_page]): for i, row in enumerate(all_rows[1 + from_page: 1 + to_page]):
if len(row) != len(headers): if len(row) != len(headers):
fails.append(str(i + from_page)) fails.append(str(i + from_page))
continue continue
@ -415,7 +422,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
callback( callback(
0.3, 0.3,
(f"Extract records: {from_page}~{from_page + len(rows)}" + (f"Extract records: {from_page}~{from_page + len(rows)}" +
(f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else "")) (f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))
) )
dfs = [pd.DataFrame(rows, columns=headers)] dfs = [pd.DataFrame(rows, columns=headers)]
@ -445,7 +452,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
df[clmns[j]] = cln df[clmns[j]] = cln
if ty == "text": if ty == "text":
txts.extend([str(c) for c in cln if c]) txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))] clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
range(len(clmns))]
eng = lang.lower() == "english" # is_english(txts) eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows(): for ii, row in df.iterrows():
@ -477,7 +485,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -141,17 +141,20 @@ def label_question(question, kbs):
if not tag_kbs: if not tag_kbs:
return tags return tags
tags = settings.retriever.tag_query(question, tags = settings.retriever.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])), list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids, tag_kb_ids,
all_tags, all_tags,
kb.parser_config.get("topn_tags", 3) kb.parser_config.get("topn_tags", 3)
) )
return tags return tags
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -263,7 +263,7 @@ class SparkTTS(Base):
raise Exception(error) raise Exception(error)
def on_close(self, ws, close_status_code, close_msg): def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志 self.audio_queue.put(None) # None is terminator
def on_open(self, ws): def on_open(self, ws):
def run(*args): def run(*args):

View File

@ -273,7 +273,7 @@ def tokenize(d, txt, eng):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
def split_with_pattern(d, pattern:str, content:str, eng) -> list: def split_with_pattern(d, pattern: str, content: str, eng) -> list:
docs = [] docs = []
txts = [txt for txt in re.split(r"(%s)" % pattern, content, flags=re.DOTALL)] txts = [txt for txt in re.split(r"(%s)" % pattern, content, flags=re.DOTALL)]
for j in range(0, len(txts), 2): for j in range(0, len(txts), 2):
@ -281,7 +281,7 @@ def split_with_pattern(d, pattern:str, content:str, eng) -> list:
if not txt: if not txt:
continue continue
if j + 1 < len(txts): if j + 1 < len(txts):
txt += txts[j+1] txt += txts[j + 1]
dd = copy.deepcopy(d) dd = copy.deepcopy(d)
tokenize(dd, txt, eng) tokenize(dd, txt, eng)
docs.append(dd) docs.append(dd)
@ -304,7 +304,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None, child_delimiters_pattern=
except NotImplementedError: except NotImplementedError:
pass pass
else: else:
add_positions(d, [[ii]*5]) add_positions(d, [[ii] * 5])
if child_delimiters_pattern: if child_delimiters_pattern:
d["mom_with_weight"] = ck d["mom_with_weight"] = ck
@ -325,7 +325,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images, child_delimiters_patte
logging.debug("-- {}".format(ck)) logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
d["image"] = image d["image"] = image
add_positions(d, [[ii]*5]) add_positions(d, [[ii] * 5])
if child_delimiters_pattern: if child_delimiters_pattern:
d["mom_with_weight"] = ck d["mom_with_weight"] = ck
res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng)) res.extend(split_with_pattern(d, child_delimiters_pattern, ck, eng))
@ -658,7 +658,8 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
if "content_ltks" in ck: if "content_ltks" in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(combined) ck["content_ltks"] = rag_tokenizer.tokenize(combined)
if "content_sm_ltks" in ck: if "content_sm_ltks" in ck:
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck.get("content_ltks", rag_tokenizer.tokenize(combined))) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
ck.get("content_ltks", rag_tokenizer.tokenize(combined)))
if positioned_indices: if positioned_indices:
chunks[:] = [chunks[i] for i in ordered_indices] chunks[:] = [chunks[i] for i in ordered_indices]
@ -764,8 +765,8 @@ def not_title(txt):
return True return True
return re.search(r"[,;,。;!!]", txt) return re.search(r"[,;,。;!!]", txt)
def tree_merge(bull, sections, depth):
def tree_merge(bull, sections, depth):
if not sections or bull < 0: if not sections or bull < 0:
return sections return sections
if isinstance(sections[0], type("")): if isinstance(sections[0], type("")):
@ -777,16 +778,17 @@ def tree_merge(bull, sections, depth):
def get_level(bull, section): def get_level(bull, section):
text, layout = section text, layout = section
text = re.sub(r"\u3000", " ", text).strip() text = re.sub(r"\u3000", " ", text).strip()
for i, title in enumerate(BULLET_PATTERN[bull]): for i, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, text.strip()): if re.match(title, text.strip()):
return i+1, text return i + 1, text
else: else:
if re.search(r"(title|head)", layout) and not not_title(text): if re.search(r"(title|head)", layout) and not not_title(text):
return len(BULLET_PATTERN[bull])+1, text return len(BULLET_PATTERN[bull]) + 1, text
else: else:
return len(BULLET_PATTERN[bull])+2, text return len(BULLET_PATTERN[bull]) + 2, text
level_set = set() level_set = set()
lines = [] lines = []
for section in sections: for section in sections:
@ -812,8 +814,8 @@ def tree_merge(bull, sections, depth):
return [element for element in root.get_tree() if element] return [element for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0: if not sections or bull < 0:
return [] return []
if isinstance(sections[0], type("")): if isinstance(sections[0], type("")):
@ -922,10 +924,10 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
if tnum < 8: if tnum < 8:
pos = "" pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num # Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.: if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
if cks: if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1]) overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
if t.find(pos) < 0: if t.find(pos) < 0:
t += pos t += pos
cks.append(t) cks.append(t)
@ -957,7 +959,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
return cks return cks
for sec, pos in sections: for sec, pos in sections:
add_chunk("\n"+sec, pos) add_chunk("\n" + sec, pos)
return cks return cks
@ -978,10 +980,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if tnum < 8: if tnum < 8:
pos = "" pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num # Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.: if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.:
if cks: if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1]) overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t
if t.find(pos) < 0: if t.find(pos) < 0:
t += pos t += pos
cks.append(t) cks.append(t)
@ -1025,9 +1027,9 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if isinstance(text, tuple): if isinstance(text, tuple):
text_str = text[0] text_str = text[0]
text_pos = text[1] if len(text) > 1 else "" text_pos = text[1] if len(text) > 1 else ""
add_chunk("\n"+text_str, image, text_pos) add_chunk("\n" + text_str, image, text_pos)
else: else:
add_chunk("\n"+text, image) add_chunk("\n" + text, image)
return cks, result_images return cks, result_images
@ -1042,7 +1044,7 @@ def docx_question_level(p, bull=-1):
for j, title in enumerate(BULLET_PATTERN[bull]): for j, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, txt): if re.match(title, txt):
return j + 1, txt return j + 1, txt
return len(BULLET_PATTERN[bull])+1, txt return len(BULLET_PATTERN[bull]) + 1, txt
def concat_img(img1, img2): def concat_img(img1, img2):
@ -1211,7 +1213,7 @@ class Node:
child = node.get_children() child = node.get_children()
if level == 0 and texts: if level == 0 and texts:
tree_list.append("\n".join(titles+texts)) tree_list.append("\n".join(titles + texts))
# Titles within configured depth are accumulated into the current path # Titles within configured depth are accumulated into the current path
if 1 <= level <= self.depth: if 1 <= level <= self.depth:

View File

@ -205,11 +205,11 @@ class FulltextQueryer(QueryBase):
s = 1e-9 s = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
if k in dtwt: if k in dtwt:
s += v #* dtwt[k] s += v # * dtwt[k]
q = 1e-9 q = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
q += v #* v q += v # * v
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 ))) return s / q # math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30): def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str): if isinstance(content_tks, str):
@ -232,4 +232,5 @@ class FulltextQueryer(QueryBase):
keywords.append(f"{tk}^{w}") keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100, return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) / 10), "original_query": " ".join(origin_keywords)}) {"minimum_should_match": min(3, len(keywords) / 10),
"original_query": " ".join(origin_keywords)})

View File

@ -66,7 +66,8 @@ class Dealer:
if key in req and req[key] is not None: if key in req and req[key] is not None:
condition[field] = req[key] condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns. # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]: for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
"removed_kwd"]:
if key in req and req[key] is not None: if key in req and req[key] is not None:
condition[key] = req[key] condition[key] = req[key]
return condition return condition
@ -141,7 +142,8 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1) matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17 matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.get_total(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total)) logging.debug("Dealer.search 2 TOTAL: {}".format(total))
@ -218,8 +220,9 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_) ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)): for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]): if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0]) chunk_v[i] = [0.0] * len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i]))) logging.warning(
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0])) len(ans_v[0]), len(chunk_v[0]))
@ -273,7 +276,7 @@ class Dealer:
if not query_rfea: if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD])) q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids: for i in search_res.ids:
nor, denor = 0, 0 nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD): if not search_res.field[i].get(TAG_FLD):
@ -286,8 +289,8 @@ class Dealer:
if denor == 0: if denor == 0:
rank_fea.append(0) rank_fea.append(0)
else: else:
rank_fea.append(nor/np.sqrt(denor)/q_denor) rank_fea.append(nor / np.sqrt(denor) / q_denor)
return np.array(rank_fea)*10. + pageranks return np.array(rank_fea) * 10. + pageranks
def rerank(self, sres, query, tkweight=0.3, def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks", vtweight=0.7, cfield="content_ltks",
@ -358,21 +361,21 @@ class Dealer:
rag_tokenizer.tokenize(inst).split()) rag_tokenizer.tokenize(inst).split())
def retrieval( def retrieval(
self, self,
question, question,
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,
kb_ids, kb_ids,
page, page,
page_size, page_size,
similarity_threshold=0.2, similarity_threshold=0.2,
vector_similarity_weight=0.3, vector_similarity_weight=0.3,
top=1024, top=1024,
doc_ids=None, doc_ids=None,
aggs=True, aggs=True,
rerank_mdl=None, rerank_mdl=None,
highlight=False, highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10}, rank_feature: dict | None = {PAGERANK_FLD: 10},
): ):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}} ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: if not question:
@ -395,7 +398,8 @@ class Dealer:
if isinstance(tenant_ids, str): if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",") tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature) sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
rank_feature=rank_feature)
if rerank_mdl and sres.total > 0: if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model( sim, tsim, vsim = self.rerank_by_model(
@ -558,13 +562,14 @@ class Dealer:
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000): def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id) idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.get_aggregation(res, "tag_kwd") aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs: if not aggs:
return False return False
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags] key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0} doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True return True
@ -580,11 +585,11 @@ class Dealer:
if not aggs: if not aggs:
return {} return {}
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags] key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea} return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6): def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
if not chunks: if not chunks:
return [] return []
idx_nms = [index_name(tid) for tid in tenant_ids] idx_nms = [index_name(tid) for tid in tenant_ids]
@ -594,9 +599,10 @@ class Dealer:
ranks[ck["doc_id"]] = 0 ranks[ck["doc_id"]] = 0
ranks[ck["doc_id"]] += ck["similarity"] ranks[ck["doc_id"]] += ck["similarity"]
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"] doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0] doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
kb_ids = [doc_id2kb_id[doc_id]] kb_ids = [doc_id2kb_id[doc_id]]
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
OrderByExpr(), 0, 128, idx_nms,
kb_ids) kb_ids)
toc = [] toc = []
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"]) dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
@ -608,7 +614,7 @@ class Dealer:
if not toc: if not toc:
return chunks return chunks
ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)) ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2))
if not ids: if not ids:
return chunks return chunks
@ -644,9 +650,9 @@ class Dealer:
break break
chunks.append(d) chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn] return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]
def retrieval_by_children(self, chunks:list[dict], tenant_ids:list[str]): def retrieval_by_children(self, chunks: list[dict], tenant_ids: list[str]):
if not chunks: if not chunks:
return [] return []
idx_nms = [index_name(tid) for tid in tenant_ids] idx_nms = [index_name(tid) for tid in tenant_ids]
@ -692,4 +698,4 @@ class Dealer:
break break
chunks.append(d) chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1) return sorted(chunks, key=lambda x: x["similarity"] * -1)

View File

@ -14,129 +14,131 @@
# limitations under the License. # limitations under the License.
# #
m = set(["","","","", m = set(["", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","羿","","", "", "羿", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","宿","","怀", "", "宿", "", "怀",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","寿","", "", "", "寿", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"广","","","", "广", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","", "", "", "", "",
"","","","西","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","鹿","", "", "", "", "西", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
"万俟","司马","上官","欧阳", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
"夏侯","诸葛","闻人","东方", "", "", "", "", "", "", "", "鹿", "",
"赫连","皇甫","尉迟","公羊", "万俟", "司马", "上官", "欧阳",
"澹台","公冶","宗政","濮阳", "夏侯", "诸葛", "闻人", "东方",
"淳于","单于","太叔","申屠", "赫连", "皇甫", "尉迟", "公羊",
"公孙","仲孙","轩辕","令狐", "澹台", "公冶", "宗政", "濮阳",
"钟离","宇文","长孙","慕容", "淳于", "单于", "太叔", "申屠",
"鲜于","闾丘","司徒","司空", "公孙", "仲孙", "轩辕", "令狐",
"亓官","司寇","仉督","子车", "钟离", "宇文", "长孙", "慕容",
"颛孙","端木","巫马","公西", "鲜于", "闾丘", "司徒", "司空",
"漆雕","乐正","壤驷","公良", "亓官", "司寇", "仉督", "子车",
"拓跋","夹谷","宰父","榖梁", "颛孙", "端木", "巫马", "公西",
"","","","","","","","", "漆雕", "乐正", "壤驷", "公良",
"段干","百里","东郭","南门", "拓跋", "夹谷", "宰父", "榖梁",
"呼延","","","羊舌","","", "", "", "", "", "", "", "", "",
"","","","","","","","", "段干", "百里", "东郭", "南门",
"梁丘","左丘","东门","西门", "呼延", "", "", "羊舌", "", "",
"","","","","","","南宫", "", "", "", "", "", "", "", "",
"","","","","","","","", "梁丘", "左丘", "东门", "西门",
"第五","",""]) "", "", "", "", "", "", "南宫",
"", "", "", "", "", "", "", "",
"第五", "", ""])
def isit(n):return n.strip() in m
def isit(n): return n.strip() in m

View File

@ -1,4 +1,4 @@
# #
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -108,13 +108,14 @@ class Dealer:
if re.match(p, t): if re.match(p, t):
tk = "#" tk = "#"
break break
#tk = re.sub(r"([\+\\-])", r"\\\1", tk) # tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk: if tk != "#" and tk:
res.append(tk) res.append(tk)
return res return res
def token_merge(self, tks): def token_merge(self, tks):
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) def one_term(t):
return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0 res, i = [], 0
while i < len(tks): while i < len(tks):
@ -152,8 +153,8 @@ class Dealer:
tks = [] tks = []
for t in re.sub(r"[ \t]+", " ", txt).split(): for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \ if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r".*[a-zA-Z]$", t) and tks and \ re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func": self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t tks[-1] = tks[-1] + " " + t
else: else:
tks.append(t) tks.append(t)
@ -220,14 +221,15 @@ class Dealer:
return 3 return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5))) def idf(s, N):
return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = [] tw = []
if not preprocess: if not preprocess:
idf1 = np.array([idf(freq(t), 10000000) for t in tks]) idf1 = np.array([idf(freq(t), 10000000) for t in tks])
idf2 = np.array([idf(df(t), 1000000000) for t in tks]) idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \ wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tks]) np.array([ner(t) * postag(t) for t in tks])
wts = [s for s in wts] wts = [s for s in wts]
tw = list(zip(tks, wts)) tw = list(zip(tks, wts))
else: else:
@ -236,7 +238,7 @@ class Dealer:
idf1 = np.array([idf(freq(t), 10000000) for t in tt]) idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt]) idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \ wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tt]) np.array([ner(t) * postag(t) for t in tt])
wts = [s for s in wts] wts = [s for s in wts]
tw.extend(zip(tt, wts)) tw.extend(zip(tt, wts))

View File

@ -28,17 +28,16 @@ from rag.prompts.template import load_prompt
from common.constants import TAG_FLD from common.constants import TAG_FLD
from common.token_utils import encoder, num_tokens_from_string from common.token_utils import encoder, num_tokens_from_string
STOP_TOKEN = "<|STOP|>"
STOP_TOKEN="<|STOP|>" COMPLETE_TASK = "complete_task"
COMPLETE_TASK="complete_task"
INPUT_UTILIZATION = 0.5 INPUT_UTILIZATION = 0.5
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
def chunks_format(reference): def chunks_format(reference):
return [ return [
{ {
"id": get_value(chunk, "chunk_id", "id"), "id": get_value(chunk, "chunk_id", "id"),
@ -126,7 +125,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500)) cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name")) cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
cnt += draw_node("URL", ck['url']) if "url" in ck else "" cnt += draw_node("URL", ck['url']) if "url" in ck else ""
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items(): for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
cnt += draw_node(k, v) cnt += draw_node(k, v)
cnt += "\n└── Content:\n" cnt += "\n└── Content:\n"
@ -173,7 +172,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True) PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
def citation_prompt(user_defined_prompts: dict={}) -> str: def citation_prompt(user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE)) template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
return template.render() return template.render()
@ -258,9 +257,11 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,
languages=languages)
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}],
{"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
return query return query
@ -332,7 +333,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
"description": "When you have the final answer and are ready to complete the task, call this function with your answer", "description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}}, "properties": {
"answer": {"type": "string", "description": "The final answer to the user's question"}},
"required": ["answer"] "required": ["answer"]
} }
} }
@ -341,7 +343,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
name = tool["function"]["name"] name = tool["function"]["name"]
desc[name] = tool desc[name] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())]) return "\n\n".join([f"## {i + 1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in
enumerate(desc.items())])
def form_history(history, limit=-6): def form_history(history, limit=-6):
@ -350,14 +353,14 @@ def form_history(history, limit=-6):
if h["role"] == "system": if h["role"] == "system":
continue continue
role = "USER" role = "USER"
if h["role"].upper()!= role: if h["role"].upper() != role:
role = "AGENT" role = "AGENT"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}" context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content']) > 2048 else '')}"
return context return context
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict],
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}): user_defined_prompts: dict = {}):
tools_desc = tool_schema(tools_description) tools_desc = tool_schema(tools_description)
context = "" context = ""
@ -375,7 +378,8 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd return kwd
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): async def next_step_async(chat_mdl, history: list, tools_description: list[dict], task_desc,
user_defined_prompts: dict = {}):
if not tools_description: if not tools_description:
return "", 0 return "", 0
desc = tool_schema(tools_description) desc = tool_schema(tools_description)
@ -396,7 +400,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
return json_str, tk_cnt return json_str, tk_cnt
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict = {}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"] goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT)) template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@ -419,7 +423,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
def form_message(system_prompt, user_prompt): def form_message(system_prompt, user_prompt):
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}] return [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
def structured_output_prompt(schema=None) -> str: def structured_output_prompt(schema=None) -> str:
@ -427,27 +431,29 @@ def structured_output_prompt(schema=None) -> str:
return template.render(schema=schema) return template.render(schema=schema)
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY) template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name, system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2), params=json.dumps(params, ensure_ascii=False, indent=2),
result=result) result=result)
user_prompt = "→ Summary: " user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:]) ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}): async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summaries: list[str],
user_defined_prompts: dict = {}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY) template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) system_prompt = template.render(goal=goal, sub_goal=sub_goal,
results=[{"i": i, "content": s} for i, s in enumerate(tool_call_summaries)])
user_prompt = " → rank: " user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>") ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict:
meta_data_structure = {} meta_data_structure = {}
for key, values in meta_data.items(): for key, values in meta_data.items():
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
@ -471,13 +477,13 @@ async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
return {"conditions": []} return {"conditions": []}
async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf=None):
from graphrag.utils import get_llm_cache, set_llm_cache from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf) cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached: if cached:
return json_repair.loads(cached) return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL) ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try: try:
res = json_repair.loads(ans) res = json_repair.loads(ans)
@ -488,10 +494,13 @@ async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None
TOC_DETECTION = load_prompt("toc_detection") TOC_DETECTION = load_prompt("toc_detection")
async def detect_table_of_contents(page_1024:list[str], chat_mdl):
async def detect_table_of_contents(page_1024: list[str], chat_mdl):
toc_secs = [] toc_secs = []
for i, sec in enumerate(page_1024[:22]): for i, sec in enumerate(page_1024[:22]):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.",
chat_mdl)
if toc_secs and not ans["exists"]: if toc_secs and not ans["exists"]:
break break
toc_secs.append(sec) toc_secs.append(sec)
@ -500,14 +509,17 @@ async def detect_table_of_contents(page_1024:list[str], chat_mdl):
TOC_EXTRACTION = load_prompt("toc_extraction") TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue") TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
async def extract_table_of_contents(toc_pages, chat_mdl): async def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages: if not toc_pages:
return [] return []
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)),
"Only JSON please.", chat_mdl)
async def toc_index_extractor(toc:list[dict], content:str, chat_mdl): async def toc_index_extractor(toc: list[dict], content: str, chat_mdl):
tob_extractor_prompt = """ tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format. You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
@ -529,18 +541,21 @@ async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
If the title of the section are not in the provided pages, do not add the physical_index to it. If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else.""" Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False,
indent=2) + '\nDocument pages:\n' + content
return await gen_json(prompt, "Only JSON please.", chat_mdl) return await gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index") TOC_INDEX = load_prompt("toc_index")
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections: if not toc_arr or not sections:
return [] return []
toc_map = {} toc_map = {}
for i, it in enumerate(toc_arr): for i, it in enumerate(toc_arr):
k1 = (it["structure"]+it["title"]).replace(" ", "") k1 = (it["structure"] + it["title"]).replace(" ", "")
k2 = it["title"].strip() k2 = it["title"].strip()
if k1 not in toc_map: if k1 not in toc_map:
toc_map[k1] = [] toc_map[k1] = []
@ -558,6 +573,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
toc_arr[j]["indices"].append(i) toc_arr[j]["indices"].append(i)
all_pathes = [] all_pathes = []
def dfs(start, path): def dfs(start, path):
nonlocal all_pathes nonlocal all_pathes
if start >= len(toc_arr): if start >= len(toc_arr):
@ -565,7 +581,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
all_pathes.append(path) all_pathes.append(path)
return return
if not toc_arr[start]["indices"]: if not toc_arr[start]["indices"]:
dfs(start+1, path) dfs(start + 1, path)
return return
added = False added = False
for j in toc_arr[start]["indices"]: for j in toc_arr[start]["indices"]:
@ -574,12 +590,12 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
_path = deepcopy(path) _path = deepcopy(path)
_path.append((j, start)) _path.append((j, start))
added = True added = True
dfs(start+1, _path) dfs(start + 1, _path)
if not added and path: if not added and path:
all_pathes.append(path) all_pathes.append(path)
dfs(0, []) dfs(0, [])
path = max(all_pathes, key=lambda x:len(x)) path = max(all_pathes, key=lambda x: len(x))
for it in toc_arr: for it in toc_arr:
it["indices"] = [] it["indices"] = []
for j, i in path: for j, i in path:
@ -588,24 +604,24 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
i = 0 i = 0
while i < len(toc_arr): while i < len(toc_arr):
it = toc_arr[i] it = toc_arr[i]
if it["indices"]: if it["indices"]:
i += 1 i += 1
continue continue
if i>0 and toc_arr[i-1]["indices"]: if i > 0 and toc_arr[i - 1]["indices"]:
st_i = toc_arr[i-1]["indices"][-1] st_i = toc_arr[i - 1]["indices"][-1]
else: else:
st_i = 0 st_i = 0
e = i + 1 e = i + 1
while e <len(toc_arr) and not toc_arr[e]["indices"]: while e < len(toc_arr) and not toc_arr[e]["indices"]:
e += 1 e += 1
if e >= len(toc_arr): if e >= len(toc_arr):
e = len(sections) e = len(sections)
else: else:
e = toc_arr[e]["indices"][0] e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))): for j in range(st_i, min(e + 1, len(sections))):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render( ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"], structure=it["structure"],
title=it["title"], title=it["title"],
@ -656,11 +672,15 @@ async def toc_transformer(toc_pages, chat_mdl):
toc_content = "\n".join(toc_pages) toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content prompt = init_prompt + '\n Given table of contents\n:' + toc_content
def clean_toc(arr): def clean_toc(arr):
for a in arr: for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"]) a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl) last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False, indent=2),
chat_mdl)
clean_toc(last_complete) clean_toc(last_complete)
if if_complete == "yes": if if_complete == "yes":
return last_complete return last_complete
@ -682,13 +702,17 @@ async def toc_transformer(toc_pages, chat_mdl):
break break
clean_toc(new_complete) clean_toc(new_complete)
last_complete.extend(new_complete) last_complete.extend(new_complete)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False,
indent=2), chat_mdl)
return last_complete return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels") TOC_LEVELS = load_prompt("assign_toc_levels")
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}):
if not toc_secs: if not toc_secs:
return [] return []
return await gen_json( return await gen_json(
@ -701,12 +725,15 @@ async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2})
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system") TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user") TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms # Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None): async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try: try:
ans = await gen_json( ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(
text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl, chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9} gen_conf={"temperature": 0.0, "top_p": 0.9}
) )
@ -743,7 +770,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
) )
input_budget = 1024 if input_budget > 1024 else input_budget input_budget = 1024 if input_budget > 1024 else input_budget
chunk_sections = split_chunks(chunks, input_budget) chunk_sections = split_chunks(chunks, input_budget)
titles = [] titles = []
@ -798,7 +825,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
if sorted_list: if sorted_list:
max_lvl = sorted_list[-1] max_lvl = sorted_list[-1]
merged = [] merged = []
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)): for _, (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
if prune and toc_item.get("level", "0") >= max_lvl: if prune and toc_item.get("level", "0") >= max_lvl:
continue continue
merged.append({ merged.append({
@ -812,12 +839,15 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system") TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user") TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
import numpy as np import numpy as np
try: try:
ans = await gen_json( ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n" % "\n".join(
[json.dumps({"level": d["level"], "title": d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl, chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9} gen_conf={"temperature": 0.0, "top_p": 0.9}
) )
@ -828,17 +858,19 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
for id in ti.get("ids", []): for id in ti.get("ids", []):
if id not in id2score: if id not in id2score:
id2score[id] = [] id2score[id] = []
id2score[id].append(sc["score"]/5.) id2score[id].append(sc["score"] / 5.)
for id in id2score.keys(): for id in id2score.keys():
id2score[id] = np.mean(id2score[id]) id2score[id] = np.mean(id2score[id])
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn] return [(id, sc) for id, sc in list(id2score.items()) if sc >= 0.3][:topn]
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
return [] return []
META_DATA = load_prompt("meta_data") META_DATA = load_prompt("meta_data")
async def gen_metadata(chat_mdl, schema:dict, content:str):
async def gen_metadata(chat_mdl, schema: dict, content: str):
template = PROMPT_JINJA_ENV.from_string(META_DATA) template = PROMPT_JINJA_ENV.from_string(META_DATA)
for k, desc in schema["properties"].items(): for k, desc in schema["properties"].items():
if "enum" in desc and not desc.get("enum"): if "enum" in desc and not desc.get("enum"):

View File

@ -1,6 +1,5 @@
import os import os
PROMPT_DIR = os.path.dirname(__file__) PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {} _loaded_prompts = {}

View File

@ -48,9 +48,11 @@ def main():
REDIS_CONN.transaction(key, file_bin, 12 * 60) REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc)) logging.info("CACHE: {}".format(loc))
except Exception as e: except Exception as e:
traceback.print_stack(e) logging.error(f"Error to get data from REDIS: {e}")
traceback.print_stack()
except Exception as e: except Exception as e:
traceback.print_stack(e) logging.error(f"Error to check REDIS connection: {e}")
traceback.print_stack()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -19,16 +19,15 @@ import requests
import base64 import base64
import asyncio import asyncio
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
JSON_DATA = { JSON_DATA = {
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation "conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key "Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
"word": "" # User question, don't need to initialize "word": "" # User question, don't need to initialize
} }
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application
intents = discord.Intents.default() intents = discord.Intents.default()
intents.message_content = True intents.message_content = True
@ -50,7 +49,7 @@ async def on_message(message):
if len(message.content.split('> ')) == 1: if len(message.content.split('> ')) == 1:
await message.channel.send("Hi~ How can I help you? ") await message.channel.send("Hi~ How can I help you? ")
else: else:
JSON_DATA['word']=message.content.split('> ')[1] JSON_DATA['word'] = message.content.split('> ')[1]
response = requests.post(URL, json=JSON_DATA) response = requests.post(URL, json=JSON_DATA)
response_data = response.json().get('data', []) response_data = response.json().get('data', [])
image_bool = False image_bool = False
@ -61,9 +60,9 @@ async def on_message(message):
if i['type'] == 3: if i['type'] == 3:
image_bool = True image_bool = True
image_data = base64.b64decode(i['url']) image_data = base64.b64decode(i['url'])
with open('tmp_image.png','wb') as file: with open('tmp_image.png', 'wb') as file:
file.write(image_data) file.write(image_data)
image= discord.File('tmp_image.png') image = discord.File('tmp_image.png')
await message.channel.send(f"{message.author.mention}{res}") await message.channel.send(f"{message.author.mention}{res}")

View File

@ -38,12 +38,25 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings from common import settings
from common.config_utils import show_configs from common.config_utils import show_configs
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector from common.data_source import (
BlobStorageConnector,
NotionConnector,
DiscordConnector,
GoogleDriveConnector,
MoodleConnector,
JiraConnector,
DropboxConnector,
WebDAVConnector,
AirtableConnector,
AsanaConnector,
)
from common.constants import FileSource, TaskStatus from common.constants import FileSource, TaskStatus
from common.data_source.config import INDEX_BATCH_SIZE from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.confluence_connector import ConfluenceConnector from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.gmail_connector import GmailConnector from common.data_source.gmail_connector import GmailConnector
from common.data_source.box_connector import BoxConnector from common.data_source.box_connector import BoxConnector
from common.data_source.github.connector import GithubConnector
from common.data_source.gitlab_connector import GitlabConnector
from common.data_source.interfaces import CheckpointOutputWrapper from common.data_source.interfaces import CheckpointOutputWrapper
from common.log_utils import init_root_logger from common.log_utils import init_root_logger
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
@ -96,7 +109,7 @@ class SyncBase:
if task["poll_range_start"]: if task["poll_range_start"]:
next_update = task["poll_range_start"] next_update = task["poll_range_start"]
for document_batch in document_batch_generator: async for document_batch in document_batch_generator:
if not document_batch: if not document_batch:
continue continue
@ -161,6 +174,7 @@ class SyncBase:
def _get_source_prefix(self): def _get_source_prefix(self):
return "" return ""
class _BlobLikeBase(SyncBase): class _BlobLikeBase(SyncBase):
DEFAULT_BUCKET_TYPE: str = "s3" DEFAULT_BUCKET_TYPE: str = "s3"
@ -199,22 +213,27 @@ class _BlobLikeBase(SyncBase):
) )
return document_batch_generator return document_batch_generator
class S3(_BlobLikeBase): class S3(_BlobLikeBase):
SOURCE_NAME: str = FileSource.S3 SOURCE_NAME: str = FileSource.S3
DEFAULT_BUCKET_TYPE: str = "s3" DEFAULT_BUCKET_TYPE: str = "s3"
class R2(_BlobLikeBase): class R2(_BlobLikeBase):
SOURCE_NAME: str = FileSource.R2 SOURCE_NAME: str = FileSource.R2
DEFAULT_BUCKET_TYPE: str = "r2" DEFAULT_BUCKET_TYPE: str = "r2"
class OCI_STORAGE(_BlobLikeBase): class OCI_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.OCI_STORAGE SOURCE_NAME: str = FileSource.OCI_STORAGE
DEFAULT_BUCKET_TYPE: str = "oci_storage" DEFAULT_BUCKET_TYPE: str = "oci_storage"
class GOOGLE_CLOUD_STORAGE(_BlobLikeBase): class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage" DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
class Confluence(SyncBase): class Confluence(SyncBase):
SOURCE_NAME: str = FileSource.CONFLUENCE SOURCE_NAME: str = FileSource.CONFLUENCE
@ -248,7 +267,9 @@ class Confluence(SyncBase):
index_recursively=index_recursively, index_recursively=index_recursively,
) )
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"]) credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json=self.conf["credentials"])
self.connector.set_credentials_provider(credentials_provider) self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start # Determine the time range for synchronization based on reindex or poll_range_start
@ -280,7 +301,8 @@ class Confluence(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator: for document, failure, next_checkpoint in doc_generator:
if failure is not None: if failure is not None:
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure)) logging.warning("Confluence connector failure: %s",
getattr(failure, "failure_message", failure))
continue continue
if document is not None: if document is not None:
pending_docs.append(document) pending_docs.append(document)
@ -314,10 +336,12 @@ class Notion(SyncBase):
document_generator = ( document_generator = (
self.connector.load_from_state() self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"] if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
) )
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info)) logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
return document_generator return document_generator
@ -340,10 +364,12 @@ class Discord(SyncBase):
document_generator = ( document_generator = (
self.connector.load_from_state() self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"] if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
) )
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info)) logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
return document_generator return document_generator
@ -485,7 +511,8 @@ class GoogleDrive(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator: for document, failure, next_checkpoint in doc_generator:
if failure is not None: if failure is not None:
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure)) logging.warning("Google Drive connector failure: %s",
getattr(failure, "failure_message", failure))
continue continue
if document is not None: if document is not None:
pending_docs.append(document) pending_docs.append(document)
@ -649,7 +676,7 @@ class WebDAV(SyncBase):
logging.info(f"Task info: reindex={task['reindex']}, poll_range_start={task['poll_range_start']}") logging.info(f"Task info: reindex={task['reindex']}, poll_range_start={task['poll_range_start']}")
if task["reindex"]=="1" or not task["poll_range_start"]: if task["reindex"] == "1" or not task["poll_range_start"]:
logging.info("Using load_from_state (full sync)") logging.info("Using load_from_state (full sync)")
document_batch_generator = self.connector.load_from_state() document_batch_generator = self.connector.load_from_state()
begin_info = "totally" begin_info = "totally"
@ -667,6 +694,7 @@ class WebDAV(SyncBase):
)) ))
return document_batch_generator return document_batch_generator
class Moodle(SyncBase): class Moodle(SyncBase):
SOURCE_NAME: str = FileSource.MOODLE SOURCE_NAME: str = FileSource.MOODLE
@ -679,20 +707,17 @@ class Moodle(SyncBase):
self.connector.load_credentials(self.conf["credentials"]) self.connector.load_credentials(self.conf["credentials"])
# Determine the time range for synchronization based on reindex or poll_range_start # Determine the time range for synchronization based on reindex or poll_range_start
if task["reindex"] == "1" or not task.get("poll_range_start"): poll_start = task.get("poll_range_start")
if task["reindex"] == "1" or poll_start is None:
document_generator = self.connector.load_from_state() document_generator = self.connector.load_from_state()
begin_info = "totally" begin_info = "totally"
else: else:
poll_start = task["poll_range_start"] document_generator = self.connector.poll_source(
if poll_start is None: poll_start.timestamp(),
document_generator = self.connector.load_from_state() datetime.now(timezone.utc).timestamp(),
begin_info = "totally" )
else: begin_info = f"from {poll_start}"
document_generator = self.connector.poll_source(
poll_start.timestamp(),
datetime.now(timezone.utc).timestamp()
)
begin_info = "from {}".format(poll_start)
logging.info("Connect to Moodle: {} {}".format(self.conf["moodle_url"], begin_info)) logging.info("Connect to Moodle: {} {}".format(self.conf["moodle_url"], begin_info))
return document_generator return document_generator
@ -722,23 +747,21 @@ class BOX(SyncBase):
auth.token_storage.store(token) auth.token_storage.store(token)
self.connector.load_credentials(auth) self.connector.load_credentials(auth)
if task["reindex"] == "1" or not task["poll_range_start"]: poll_start = task["poll_range_start"]
if task["reindex"] == "1" or poll_start is None:
document_generator = self.connector.load_from_state() document_generator = self.connector.load_from_state()
begin_info = "totally" begin_info = "totally"
else: else:
poll_start = task["poll_range_start"] document_generator = self.connector.poll_source(
if poll_start is None: poll_start.timestamp(),
document_generator = self.connector.load_from_state() datetime.now(timezone.utc).timestamp(),
begin_info = "totally" )
else: begin_info = f"from {poll_start}"
document_generator = self.connector.poll_source(
poll_start.timestamp(),
datetime.now(timezone.utc).timestamp()
)
begin_info = "from {}".format(poll_start)
logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info)) logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info))
return document_generator return document_generator
class Airtable(SyncBase): class Airtable(SyncBase):
SOURCE_NAME: str = FileSource.AIRTABLE SOURCE_NAME: str = FileSource.AIRTABLE
@ -760,6 +783,44 @@ class Airtable(SyncBase):
{"airtable_access_token": credentials["airtable_access_token"]} {"airtable_access_token": credentials["airtable_access_token"]}
) )
poll_start = task.get("poll_range_start")
if task.get("reindex") == "1" or poll_start is None:
document_generator = self.connector.load_from_state()
begin_info = "totally"
else:
document_generator = self.connector.poll_source(
poll_start.timestamp(),
datetime.now(timezone.utc).timestamp(),
)
begin_info = f"from {poll_start}"
logging.info(
"Connect to Airtable: base_id(%s), table(%s) %s",
self.conf.get("base_id"),
self.conf.get("table_name_or_id"),
begin_info,
)
return document_generator
class Asana(SyncBase):
SOURCE_NAME: str = FileSource.ASANA
async def _generate(self, task: dict):
self.connector = AsanaConnector(
self.conf.get("asana_workspace_id"),
self.conf.get("asana_project_ids"),
self.conf.get("asana_team_id"),
)
credentials = self.conf.get("credentials", {})
if "asana_api_token_secret" not in credentials:
raise ValueError("Missing asana_api_token_secret in credentials")
self.connector.load_credentials(
{"asana_api_token_secret": credentials["asana_api_token_secret"]}
)
if task.get("reindex") == "1" or not task.get("poll_range_start"): if task.get("reindex") == "1" or not task.get("poll_range_start"):
document_generator = self.connector.load_from_state() document_generator = self.connector.load_from_state()
begin_info = "totally" begin_info = "totally"
@ -776,14 +837,126 @@ class Airtable(SyncBase):
begin_info = f"from {poll_start}" begin_info = f"from {poll_start}"
logging.info( logging.info(
"Connect to Airtable: base_id(%s), table(%s) %s", "Connect to Asana: workspace_id(%s), project_ids(%s), team_id(%s) %s",
self.conf.get("base_id"), self.conf.get("asana_workspace_id"),
self.conf.get("table_name_or_id"), self.conf.get("asana_project_ids"),
self.conf.get("asana_team_id"),
begin_info, begin_info,
) )
return document_generator return document_generator
class Github(SyncBase):
SOURCE_NAME: str = FileSource.GITHUB
async def _generate(self, task: dict):
"""
Sync files from Github repositories.
"""
from common.data_source.connector_runner import ConnectorRunner
self.connector = GithubConnector(
repo_owner=self.conf.get("repository_owner"),
repositories=self.conf.get("repository_name"),
include_prs=self.conf.get("include_pull_requests", False),
include_issues=self.conf.get("include_issues", False),
)
credentials = self.conf.get("credentials", {})
if "github_access_token" not in credentials:
raise ValueError("Missing github_access_token in credentials")
self.connector.load_credentials(
{"github_access_token": credentials["github_access_token"]}
)
if task.get("reindex") == "1" or not task.get("poll_range_start"):
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
begin_info = "totally"
else:
start_time = task.get("poll_range_start")
begin_info = f"from {start_time}"
end_time = datetime.now(timezone.utc)
runner = ConnectorRunner(
connector=self.connector,
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE),
include_permissions=False,
time_range=(start_time, end_time)
)
def document_batches():
checkpoint = self.connector.build_dummy_checkpoint()
while checkpoint.has_more:
for doc_batch, failure, next_checkpoint in runner.run(checkpoint):
if failure is not None:
logging.warning(
"Github connector failure: %s",
getattr(failure, "failure_message", failure),
)
continue
if doc_batch is not None:
yield doc_batch
if next_checkpoint is not None:
checkpoint = next_checkpoint
async def async_wrapper():
for batch in document_batches():
yield batch
logging.info(
"Connect to Github: org_name(%s), repo_names(%s) for %s",
self.conf.get("repository_owner"),
self.conf.get("repository_name"),
begin_info,
)
return async_wrapper()
class Gitlab(SyncBase):
SOURCE_NAME: str = FileSource.GITLAB
async def _generate(self, task: dict):
"""
Sync files from GitLab attachments.
"""
self.connector = GitlabConnector(
project_owner= self.conf.get("project_owner"),
project_name= self.conf.get("project_name"),
include_mrs = self.conf.get("include_mrs", False),
include_issues = self.conf.get("include_issues", False),
include_code_files= self.conf.get("include_code_files", False),
)
self.connector.load_credentials(
{
"gitlab_access_token": self.conf.get("credentials", {}).get("gitlab_access_token"),
"gitlab_url": self.conf.get("credentials", {}).get("gitlab_url"),
}
)
if task["reindex"] == "1" or not task["poll_range_start"]:
document_generator = self.connector.load_from_state()
begin_info = "totally"
else:
poll_start = task["poll_range_start"]
if poll_start is None:
document_generator = self.connector.load_from_state()
begin_info = "totally"
else:
document_generator = self.connector.poll_source(
poll_start.timestamp(),
datetime.now(timezone.utc).timestamp()
)
begin_info = "from {}".format(poll_start)
logging.info("Connect to Gitlab: ({}) {}".format(self.conf["project_name"], begin_info))
return document_generator
func_factory = { func_factory = {
FileSource.S3: S3, FileSource.S3: S3,
FileSource.R2: R2, FileSource.R2: R2,
@ -803,6 +976,9 @@ func_factory = {
FileSource.WEBDAV: WebDAV, FileSource.WEBDAV: WebDAV,
FileSource.BOX: BOX, FileSource.BOX: BOX,
FileSource.AIRTABLE: Airtable, FileSource.AIRTABLE: Airtable,
FileSource.ASANA: Asana,
FileSource.GITHUB: Github,
FileSource.GITLAB: Gitlab,
} }

View File

@ -26,6 +26,7 @@ import time
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.joint_services.memory_message_service import handle_save_to_memory_task
from common.connection_utils import timeout from common.connection_utils import timeout
from common.metadata_utils import update_metadata_to, metadata_schema from common.metadata_utils import update_metadata_to, metadata_schema
from rag.utils.base64_image import image2id from rag.utils.base64_image import image2id
@ -92,10 +93,11 @@ FACTORY = {
} }
TASK_TYPE_TO_PIPELINE_TASK_TYPE = { TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE, "dataflow": PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR, "raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG, "graphrag": PipelineTaskType.GRAPH_RAG,
"mindmap": PipelineTaskType.MINDMAP, "mindmap": PipelineTaskType.MINDMAP,
"memory": PipelineTaskType.MEMORY,
} }
UNACKED_ITERATOR = None UNACKED_ITERATOR = None
@ -157,8 +159,8 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
except DoesNotExist: except DoesNotExist:
logging.warning(f"set_progress({task_id}) got exception DoesNotExist") logging.warning(f"set_progress({task_id}) got exception DoesNotExist")
except Exception: except Exception as e:
logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception") logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception: {e}")
async def collect(): async def collect():
@ -166,6 +168,7 @@ async def collect():
global UNACKED_ITERATOR global UNACKED_ITERATOR
svr_queue_names = settings.get_svr_queue_names() svr_queue_names = settings.get_svr_queue_names()
redis_msg = None
try: try:
if not UNACKED_ITERATOR: if not UNACKED_ITERATOR:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
@ -176,8 +179,8 @@ async def collect():
redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
if redis_msg: if redis_msg:
break break
except Exception: except Exception as e:
logging.exception("collect got exception") logging.exception(f"collect got exception: {e}")
return None, None return None, None
if not redis_msg: if not redis_msg:
@ -196,6 +199,9 @@ async def collect():
if task: if task:
task["doc_id"] = msg["doc_id"] task["doc_id"] = msg["doc_id"]
task["doc_ids"] = msg.get("doc_ids", []) or [] task["doc_ids"] = msg.get("doc_ids", []) or []
elif msg.get("task_type") == PipelineTaskType.MEMORY.lower():
_, task_obj = TaskService.get_by_id(msg["id"])
task = task_obj.to_dict()
else: else:
task = TaskService.get_task(msg["id"]) task = TaskService.get_task(msg["id"])
@ -214,6 +220,10 @@ async def collect():
task["tenant_id"] = msg["tenant_id"] task["tenant_id"] = msg["tenant_id"]
task["dataflow_id"] = msg["dataflow_id"] task["dataflow_id"] = msg["dataflow_id"]
task["kb_id"] = msg.get("kb_id", "") task["kb_id"] = msg.get("kb_id", "")
if task_type[:6] == "memory":
task["memory_id"] = msg["memory_id"]
task["source_id"] = msg["source_id"]
task["message_dict"] = msg["message_dict"]
return redis_msg, task return redis_msg, task
@ -221,7 +231,7 @@ async def get_storage_binary(bucket, name):
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name) return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60*80, 1) @timeout(60 * 80, 1)
async def build_chunks(task, progress_callback): async def build_chunks(task, progress_callback):
if task["size"] > settings.DOC_MAXIMUM_SIZE: if task["size"] > settings.DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
@ -283,7 +293,8 @@ async def build_chunks(task, progress_callback):
try: try:
d = copy.deepcopy(document) d = copy.deepcopy(document)
d.update(chunk) d.update(chunk)
d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() d["id"] = xxhash.xxh64(
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp() d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"): if not d.get("image"):
@ -328,9 +339,11 @@ async def build_chunks(task, progress_callback):
d["important_kwd"] = cached.split(",") d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return return
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))) tasks.append(
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try: try:
await asyncio.gather(*tasks, return_exceptions=False) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e: except Exception as e:
@ -355,9 +368,11 @@ async def build_chunks(task, progress_callback):
if cached: if cached:
d["question_kwd"] = cached.split("\n") d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))) tasks.append(
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try: try:
await asyncio.gather(*tasks, return_exceptions=False) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e: except Exception as e:
@ -374,15 +389,18 @@ async def build_chunks(task, progress_callback):
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
async def gen_metadata_task(chat_mdl, d): async def gen_metadata_task(chat_mdl, d):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", task["parser_config"]["metadata"]) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
task["parser_config"]["metadata"])
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await gen_metadata(chat_mdl, cached = await gen_metadata(chat_mdl,
metadata_schema(task["parser_config"]["metadata"]), metadata_schema(task["parser_config"]["metadata"]),
d["content_with_weight"]) d["content_with_weight"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", task["parser_config"]["metadata"]) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
task["parser_config"]["metadata"])
if cached: if cached:
d["metadata_obj"] = cached d["metadata_obj"] = cached
tasks = [] tasks = []
for d in docs: for d in docs:
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d))) tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
@ -430,7 +448,8 @@ async def build_chunks(task, progress_callback):
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return None return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else: else:
docs_to_tag.append(d) docs_to_tag.append(d)
@ -438,7 +457,7 @@ async def build_chunks(task, progress_callback):
async def doc_content_tagging(chat_mdl, d, topn_tags): async def doc_content_tagging(chat_mdl, d, topn_tags):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached: if not cached:
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
if not picked_examples: if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter: async with chat_limiter:
@ -454,6 +473,7 @@ async def build_chunks(task, progress_callback):
if cached: if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached) d[TAG_FLD] = json.loads(cached)
tasks = [] tasks = []
for d in docs_to_tag: for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags))) tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
@ -473,21 +493,22 @@ async def build_chunks(task, progress_callback):
def build_TOC(task, docs, progress_callback): def build_TOC(task, docs, progress_callback):
progress_callback(msg="Start to generate table of content ...") progress_callback(msg="Start to generate table of content ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
docs = sorted(docs, key=lambda d:( docs = sorted(docs, key=lambda d: (
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
)) ))
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)) toc: list[dict] = asyncio.run(
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0 ii = 0
while ii < len(toc): while ii < len(toc):
try: try:
idx = int(toc[ii]["chunk_id"]) idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"] del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]] toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1: if ii == len(toc) - 1:
break break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1): for jj in range(idx + 1, int(toc[ii + 1]["chunk_id"]) + 1):
toc[ii]["ids"].append(docs[jj]["id"]) toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
@ -499,7 +520,8 @@ def build_TOC(task, docs, progress_callback):
d["toc_kwd"] = "toc" d["toc_kwd"] = "toc"
d["available_int"] = 0 d["available_int"] = 0
d["page_num_int"] = [100000000] d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() d["id"] = xxhash.xxh64(
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d return d
return None return None
@ -532,12 +554,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
@timeout(60) @timeout(60)
def batch_encode(txts): def batch_encode(txts):
nonlocal mdl nonlocal mdl
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts]) return mdl.encode([truncate(c, mdl.max_length - 10) for c in txts])
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE]) vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0: if len(cnts_) == 0:
cnts_ = vts cnts_ = vts
else: else:
@ -545,7 +567,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count += c tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="") callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
cnts = cnts_ cnts = cnts_
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
if not filename_embd_weight: if not filename_embd_weight:
filename_embd_weight = 0.1 filename_embd_weight = 0.1
title_w = float(filename_embd_weight) title_w = float(filename_embd_weight)
@ -588,7 +610,8 @@ async def run_dataflow(task: dict):
return return
if not chunks: if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0) embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
@ -610,25 +633,27 @@ async def run_dataflow(task: dict):
e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id) embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60) @timeout(60)
def batch_encode(txts): def batch_encode(txts):
nonlocal embedding_model nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts]) return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([]) vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks] texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1) delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
prog = 0.8 prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE]) vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0: if len(vects) == 0:
vects = vts vects = vts
else: else:
vects = np.concatenate((vects, vts), axis=0) vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c embedding_token_consumption += c
prog += delta prog += delta
if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1: if i % (len(texts) // settings.EMBEDDING_BATCH_SIZE / 100 + 1) == 1:
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}") set_progress(task_id, prog=prog, msg=f"{i + 1} / {len(texts) // settings.EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks) assert len(vects) == len(chunks)
for i, ck in enumerate(chunks): for i, ck in enumerate(chunks):
@ -636,10 +661,10 @@ async def run_dataflow(task: dict):
ck["q_%d_vec" % len(v)] = v ck["q_%d_vec" % len(v)] = v
except Exception as e: except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}") set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
metadata = {} metadata = {}
for ck in chunks: for ck in chunks:
ck["doc_id"] = doc_id ck["doc_id"] = doc_id
@ -686,15 +711,19 @@ async def run_dataflow(task: dict):
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...") set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e: if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
time_cost = timer() - start_ts time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost) DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) task_time_cost)
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
dsl=str(pipeline))
@timeout(3600) @timeout(3600)
@ -702,7 +731,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {}) raptor_config = kb_parser_config.get("raptor", {})
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec" % vector_size
res = [] res = []
tk_count = 0 tk_count = 0
@ -747,17 +776,17 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for x, doc_id in enumerate(doc_ids): for x, doc_id in enumerate(doc_ids):
chunks = [] chunks = []
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm], fields=["content_with_weight", vctr_nm],
sort_by_position=True): sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, doc_id) await generate(chunks, doc_id)
callback(prog=(x+1.)/len(doc_ids)) callback(prog=(x + 1.) / len(doc_ids))
else: else:
chunks = [] chunks = []
for doc_id in doc_ids: for doc_id in doc_ids:
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm], fields=["content_with_weight", vctr_nm],
sort_by_position=True): sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, fake_doc_id) await generate(chunks, fake_doc_id)
@ -792,19 +821,22 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mom_ck["available_int"] = 0 mom_ck["available_int"] = 0
flds = list(mom_ck.keys()) flds = list(mom_ck.keys())
for fld in flds: for fld in flds:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int"]: if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
"position_int"]:
del mom_ck[fld] del mom_ck[fld]
mothers.append(mom_ck) mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE): for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return False return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE): for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
@ -821,7 +853,8 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str) TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,) doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
search.index_name(task_tenant_id), task_dataset_id, )
tasks = [] tasks = []
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id))) tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
@ -838,10 +871,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
return True return True
@timeout(60*60*3, 1) @timeout(60 * 60 * 3, 1)
async def do_handle_task(task): async def do_handle_task(task):
task_type = task.get("task_type", "") task_type = task.get("task_type", "")
if task_type == "memory":
await handle_save_to_memory_task(task)
return
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID: if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task) await run_dataflow(task)
return return
@ -914,7 +951,7 @@ async def do_handle_task(task):
}, },
} }
) )
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}): if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration") progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return return
@ -943,7 +980,7 @@ async def do_handle_task(task):
doc_ids=task.get("doc_ids", []), doc_ids=task.get("doc_ids", []),
) )
if fake_doc_ids := task.get("doc_ids", []): if fake_doc_ids := task.get("doc_ids", []):
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
# Either using graphrag or Standard chunking methods # Either using graphrag or Standard chunking methods
elif task_type == "graphrag": elif task_type == "graphrag":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id) ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
@ -968,11 +1005,10 @@ async def do_handle_task(task):
} }
} }
) )
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}): if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration") progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return return
graphrag_conf = kb_parser_config.get("graphrag", {}) graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer() start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
@ -1028,8 +1064,8 @@ async def do_handle_task(task):
async def _maybe_insert_es(_chunks): async def _maybe_insert_es(_chunks):
if has_canceled(task_id): if has_canceled(task_id):
return True return True
e = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback) insert_result = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(e) return bool(insert_result)
try: try:
if not await _maybe_insert_es(chunks): if not await _maybe_insert_es(chunks):
@ -1079,13 +1115,12 @@ async def do_handle_task(task):
search.index_name(task_tenant_id), search.index_name(task_tenant_id),
task_dataset_id, task_dataset_id,
) )
except Exception: except Exception as e:
logging.exception( logging.exception(
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled." f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}")
)
async def handle_task(): async def handle_task():
global DONE_TASKS, FAILED_TASKS global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect() redis_msg, task = await collect()
if not task: if not task:
@ -1093,25 +1128,27 @@ async def handle_task():
return return
task_type = task["task_type"] task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
task_id = task["id"]
try: try:
logging.info(f"handle_task begin for task {json.dumps(task)}") logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task) CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
await do_handle_task(task) await do_handle_task(task)
DONE_TASKS += 1 DONE_TASKS += 1
CURRENT_TASKS.pop(task["id"], None) CURRENT_TASKS.pop(task_id, None)
logging.info(f"handle_task done for task {json.dumps(task)}") logging.info(f"handle_task done for task {json.dumps(task)}")
except Exception as e: except Exception as e:
FAILED_TASKS += 1 FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None) CURRENT_TASKS.pop(task_id, None)
try: try:
err_msg = str(e) err_msg = str(e)
while isinstance(e, exceptiongroup.ExceptionGroup): while isinstance(e, exceptiongroup.ExceptionGroup):
e = e.exceptions[0] e = e.exceptions[0]
err_msg += ' -- ' + str(e) err_msg += ' -- ' + str(e)
set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}") set_progress(task_id, prog=-1, msg=f"[Exception]: {err_msg}")
except Exception: except Exception as e:
logging.exception(f"[Exception]: {str(e)}")
pass pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}") logging.exception(f"handle_task got exception for task {json.dumps(task)}")
finally: finally:
@ -1119,7 +1156,9 @@ async def handle_task():
if task_type in ["graphrag", "raptor", "mindmap"]: if task_type in ["graphrag", "raptor", "mindmap"]:
task_document_ids = task["doc_ids"] task_document_ids = task["doc_ids"]
if not task.get("dataflow_id", ""): if not task.get("dataflow_id", ""):
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids) PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
task_type=pipeline_task_type,
fake_document_ids=task_document_ids)
redis_msg.ack() redis_msg.ack()
@ -1182,8 +1221,8 @@ async def report_status():
logging.info(f"{consumer_name} expired, removed") logging.info(f"{consumer_name} expired, removed")
REDIS_CONN.srem("TASKEXE", consumer_name) REDIS_CONN.srem("TASKEXE", consumer_name)
REDIS_CONN.delete(consumer_name) REDIS_CONN.delete(consumer_name)
except Exception: except Exception as e:
logging.exception("report_status got exception") logging.exception(f"report_status got exception: {e}")
finally: finally:
redis_lock.release() redis_lock.release()
await asyncio.sleep(30) await asyncio.sleep(30)
@ -1249,6 +1288,7 @@ async def main():
await asyncio.gather(report_task, return_exceptions=True) await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!") logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__": if __name__ == "__main__":
faulthandler.enable() faulthandler.enable()
init_root_logger(CONSUMER_NAME) init_root_logger(CONSUMER_NAME)

View File

@ -42,8 +42,10 @@ class RAGFlowAzureSpnBlob:
pass pass
try: try:
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA) credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id,
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials) client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name,
credential=credentials)
except Exception: except Exception:
logging.exception("Fail to connect %s" % self.account_url) logging.exception("Fail to connect %s" % self.account_url)

View File

@ -25,7 +25,8 @@ from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64) test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str = "imagetemps"):
import logging import logging
from io import BytesIO from io import BytesIO
from rag.svr.task_executor import minio_limiter from rag.svr.task_executor import minio_limiter
@ -74,7 +75,7 @@ async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="
del d["image"] del d["image"]
def id2image(image_id:str|None, storage_get_func: partial): def id2image(image_id: str | None, storage_get_func: partial):
if not image_id: if not image_id:
return return
arr = image_id.split("-") arr = image_id.split("-")

View File

@ -16,6 +16,8 @@
import logging import logging
from common.crypto_utils import CryptoUtil from common.crypto_utils import CryptoUtil
# from common.decorator import singleton # from common.decorator import singleton
class EncryptedStorageWrapper: class EncryptedStorageWrapper:
@ -240,6 +242,7 @@ class EncryptedStorageWrapper:
self.encryption_enabled = False self.encryption_enabled = False
logging.info("Encryption disabled") logging.info("Encryption disabled")
# Create singleton wrapper function # Create singleton wrapper function
def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True): def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True):
""" """

View File

@ -32,7 +32,6 @@ ATTEMPT_TIME = 2
@singleton @singleton
class ESConnection(ESConnectionBase): class ESConnection(ESConnectionBase):
""" """
CRUD operations CRUD operations
""" """
@ -82,8 +81,9 @@ class ESConnection(ESConnectionBase):
vector_similarity_weight = 0.5 vector_similarity_weight = 0.5
for m in match_expressions: for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1], assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
MatchDenseExpr) and isinstance( match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr) match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"] weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1]) vector_similarity_weight = get_float(weights.split(",")[1])
@ -93,9 +93,9 @@ class ESConnection(ESConnectionBase):
if isinstance(minimum_should_match, float): if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%" minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=m.fields, bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text, type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match, minimum_should_match=minimum_should_match,
boost=1)) boost=1))
bool_query.boost = 1.0 - vector_similarity_weight bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr): elif isinstance(m, MatchDenseExpr):
@ -146,7 +146,7 @@ class ESConnection(ESConnectionBase):
for i in range(ATTEMPT_TIME): for i in range(ATTEMPT_TIME):
try: try:
#print(json.dumps(q, ensure_ascii=False)) # print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names, res = self.es.search(index=index_names,
body=q, body=q,
timeout="600s", timeout="600s",
@ -220,13 +220,15 @@ class ESConnection(ESConnectionBase):
try: try:
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");") self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception: except Exception:
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try: try:
self.es.update(index=index_name, id=chunk_id, doc=doc) self.es.update(index=index_name, id=chunk_id, doc=doc)
return True return True
except Exception as e: except Exception as e:
self.logger.exception( self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(
e))
break break
return False return False

View File

@ -25,18 +25,23 @@ import PyPDF2
from docx import Document from docx import Document
import olefile import olefile
def _is_zip(h: bytes) -> bool: def _is_zip(h: bytes) -> bool:
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08") return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
def _is_pdf(h: bytes) -> bool: def _is_pdf(h: bytes) -> bool:
return h.startswith(b"%PDF-") return h.startswith(b"%PDF-")
def _is_ole(h: bytes) -> bool: def _is_ole(h: bytes) -> bool:
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1") return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
def _sha10(b: bytes) -> str: def _sha10(b: bytes) -> str:
return hashlib.sha256(b).hexdigest()[:10] return hashlib.sha256(b).hexdigest()[:10]
def _guess_ext(b: bytes) -> str: def _guess_ext(b: bytes) -> str:
h = b[:8] h = b[:8]
if _is_zip(h): if _is_zip(h):
@ -58,13 +63,14 @@ def _guess_ext(b: bytes) -> str:
return ".doc" return ".doc"
return ".bin" return ".bin"
# Try to extract the real embedded payload from OLE's Ole10Native # Try to extract the real embedded payload from OLE's Ole10Native
def _extract_ole10native_payload(data: bytes) -> bytes: def _extract_ole10native_payload(data: bytes) -> bytes:
try: try:
pos = 0 pos = 0
if len(data) < 4: if len(data) < 4:
return data return data
_ = int.from_bytes(data[pos:pos+4], "little") _ = int.from_bytes(data[pos:pos + 4], "little")
pos += 4 pos += 4
# filename/src/tmp (NUL-terminated ANSI) # filename/src/tmp (NUL-terminated ANSI)
for _ in range(3): for _ in range(3):
@ -74,14 +80,15 @@ def _extract_ole10native_payload(data: bytes) -> bytes:
pos += 4 pos += 4
if pos + 4 > len(data): if pos + 4 > len(data):
return data return data
size = int.from_bytes(data[pos:pos+4], "little") size = int.from_bytes(data[pos:pos + 4], "little")
pos += 4 pos += 4
if pos + size <= len(data): if pos + size <= len(data):
return data[pos:pos+size] return data[pos:pos + size]
except Exception: except Exception:
pass pass
return data return data
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]: def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
""" """
Only extract the 'first layer' of embedding, returning raw (filename, bytes). Only extract the 'first layer' of embedding, returning raw (filename, bytes).
@ -163,7 +170,7 @@ def extract_links_from_docx(docx_bytes: bytes):
# Each relationship may represent a hyperlink, image, footer, etc. # Each relationship may represent a hyperlink, image, footer, etc.
for rel in document.part.rels.values(): for rel in document.part.rels.values():
if rel.reltype == ( if rel.reltype == (
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink" "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
): ):
links.add(rel.target_ref) links.add(rel.target_ref)
@ -198,6 +205,8 @@ def extract_links_from_pdf(pdf_bytes: bytes):
_GLOBAL_SESSION: Optional[requests.Session] = None _GLOBAL_SESSION: Optional[requests.Session] = None
def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session: def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
"""Get or create a global reusable session.""" """Get or create a global reusable session."""
global _GLOBAL_SESSION global _GLOBAL_SESSION
@ -216,10 +225,10 @@ def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
def extract_html( def extract_html(
url: str, url: str,
timeout: float = 60.0, timeout: float = 60.0,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
max_retries: int = 2, max_retries: int = 2,
) -> Tuple[Optional[bytes], Dict[str, str]]: ) -> Tuple[Optional[bytes], Dict[str, str]]:
""" """
Extract the full HTML page as raw bytes from a given URL. Extract the full HTML page as raw bytes from a given URL.

View File

@ -28,7 +28,6 @@ from common.doc_store.infinity_conn_base import InfinityConnectionBase
@singleton @singleton
class InfinityConnection(InfinityConnectionBase): class InfinityConnection(InfinityConnectionBase):
""" """
Dataframe and fields convert Dataframe and fields convert
""" """
@ -83,24 +82,23 @@ class InfinityConnection(InfinityConnectionBase):
tokens[0] = field tokens[0] = field
return "^".join(tokens) return "^".join(tokens)
""" """
CRUD operations CRUD operations
""" """
def search( def search(
self, self,
select_fields: list[str], select_fields: list[str],
highlight_fields: list[str], highlight_fields: list[str],
condition: dict, condition: dict,
match_expressions: list[MatchExpr], match_expressions: list[MatchExpr],
order_by: OrderByExpr, order_by: OrderByExpr,
offset: int, offset: int,
limit: int, limit: int,
index_names: str | list[str], index_names: str | list[str],
knowledgebase_ids: list[str], knowledgebase_ids: list[str],
agg_fields: list[str] | None = None, agg_fields: list[str] | None = None,
rank_feature: dict | None = None, rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]: ) -> tuple[pd.DataFrame, int]:
""" """
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field. BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
@ -159,7 +157,8 @@ class InfinityConnection(InfinityConnectionBase):
if table_found: if table_found:
break break
if not table_found: if not table_found:
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}") self.logger.error(
f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
return pd.DataFrame(), 0 return pd.DataFrame(), 0
for matchExpr in match_expressions: for matchExpr in match_expressions:
@ -280,7 +279,8 @@ class InfinityConnection(InfinityConnectionBase):
try: try:
table_instance = db_instance.get_table(table_name) table_instance = db_instance.get_table(table_name)
except Exception: except Exception:
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.") self.logger.warning(
f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
continue continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df() kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
@ -288,7 +288,9 @@ class InfinityConnection(InfinityConnectionBase):
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"]) res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist()) fields = set(res.columns.tolist())
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]: for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd",
"question_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks",
"authors_sm_tks"]:
fields.add(field) fields.add(field)
res_fields = self.get_fields(res, list(fields)) res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunk_id, None) return res_fields.get(chunk_id, None)
@ -379,7 +381,9 @@ class InfinityConnection(InfinityConnectionBase):
d[k] = "_".join(f"{num:08x}" for num in v) d[k] = "_".join(f"{num:08x}" for num in v)
else: else:
d[k] = v d[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]: for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd",
"question_tks"]:
if k in d: if k in d:
del d[k] del d[k]
@ -478,7 +482,8 @@ class InfinityConnection(InfinityConnectionBase):
del new_value[k] del new_value[k]
else: else:
new_value[k] = v new_value[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]: for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight",
"content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in new_value: if k in new_value:
del new_value[k] del new_value[k]
@ -502,7 +507,8 @@ class InfinityConnection(InfinityConnectionBase):
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.") self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
for update_kv, ids in remove_opt.items(): for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv) k, v = json.loads(update_kv)
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)}) table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])),
{k: "###".join(v)})
table_instance.update(filter, new_value) table_instance.update(filter, new_value)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
@ -561,7 +567,7 @@ class InfinityConnection(InfinityConnectionBase):
def to_position_int(v): def to_position_int(v):
if v: if v:
arr = [int(hex_val, 16) for hex_val in v.split("_")] arr = [int(hex_val, 16) for hex_val in v.split("_")]
v = [arr[i : i + 5] for i in range(0, len(arr), 5)] v = [arr[i: i + 5] for i in range(0, len(arr), 5)]
else: else:
v = [] v = []
return v return v

View File

@ -46,6 +46,7 @@ class RAGFlowMinio:
# pass original identifier forward for use by other decorators # pass original identifier forward for use by other decorators
kwargs['_orig_bucket'] = original_bucket kwargs['_orig_bucket'] = original_bucket
return method(self, actual_bucket, *args, **kwargs) return method(self, actual_bucket, *args, **kwargs)
return wrapper return wrapper
@staticmethod @staticmethod
@ -71,6 +72,7 @@ class RAGFlowMinio:
fnm = f"{orig_bucket}/{fnm}" fnm = f"{orig_bucket}/{fnm}"
return method(self, bucket, fnm, *args, **kwargs) return method(self, bucket, fnm, *args, **kwargs)
return wrapper return wrapper
def __open__(self): def __open__(self):

View File

@ -37,7 +37,8 @@ from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton from common.decorator import singleton
from common.float_utils import get_float from common.float_utils import get_float
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
MatchDenseExpr
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
ATTEMPT_TIME = 2 ATTEMPT_TIME = 2
@ -719,19 +720,19 @@ class OBConnection(DocStoreConnection):
""" """
def search( def search(
self, self,
selectFields: list[str], selectFields: list[str],
highlightFields: list[str], highlightFields: list[str],
condition: dict, condition: dict,
matchExprs: list[MatchExpr], matchExprs: list[MatchExpr],
orderBy: OrderByExpr, orderBy: OrderByExpr,
offset: int, offset: int,
limit: int, limit: int,
indexNames: str | list[str], indexNames: str | list[str],
knowledgebaseIds: list[str], knowledgebaseIds: list[str],
aggFields: list[str] = [], aggFields: list[str] = [],
rank_feature: dict | None = None, rank_feature: dict | None = None,
**kwargs, **kwargs,
): ):
if isinstance(indexNames, str): if isinstance(indexNames, str):
indexNames = indexNames.split(",") indexNames = indexNames.split(",")
@ -1546,7 +1547,7 @@ class OBConnection(DocStoreConnection):
flags=re.IGNORECASE | re.MULTILINE, flags=re.IGNORECASE | re.MULTILINE,
) )
if len(re.findall(r'</em><em>', highlighted_txt)) > 0 or len( if len(re.findall(r'</em><em>', highlighted_txt)) > 0 or len(
re.findall(r'</em>\s*<em>', highlighted_txt)) > 0: re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
return highlighted_txt return highlighted_txt
else: else:
return None return None
@ -1565,9 +1566,9 @@ class OBConnection(DocStoreConnection):
if token_pos != -1: if token_pos != -1:
if token in keywords: if token in keywords:
highlighted_txt = ( highlighted_txt = (
highlighted_txt[:token_pos] + highlighted_txt[:token_pos] +
f'<em>{token}</em>' + f'<em>{token}</em>' +
highlighted_txt[token_pos + len(token):] highlighted_txt[token_pos + len(token):]
) )
last_pos = token_pos last_pos = token_pos
return re.sub(r'</em><em>', '', highlighted_txt) return re.sub(r'</em><em>', '', highlighted_txt)

View File

@ -6,7 +6,6 @@ from urllib.parse import quote_plus
from common.config_utils import get_base_config from common.config_utils import get_base_config
from common.decorator import singleton from common.decorator import singleton
CREATE_TABLE_SQL = """ CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS `{}` ( CREATE TABLE IF NOT EXISTS `{}` (
`key` VARCHAR(255) PRIMARY KEY, `key` VARCHAR(255) PRIMARY KEY,
@ -36,7 +35,8 @@ def get_opendal_config():
"table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"), "table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"),
"max_allowed_packet": str(max_packet) "max_allowed_packet": str(max_packet)
} }
kwargs["connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}" kwargs[
"connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}"
else: else:
scheme = opendal_config.get("scheme") scheme = opendal_config.get("scheme")
config_data = opendal_config.get("config", {}) config_data = opendal_config.get("config", {})
@ -99,7 +99,6 @@ class OpenDALStorage:
def obj_exist(self, bucket, fnm, tenant_id=None): def obj_exist(self, bucket, fnm, tenant_id=None):
return self._operator.exists(f"{bucket}/{fnm}") return self._operator.exists(f"{bucket}/{fnm}")
def init_db_config(self): def init_db_config(self):
try: try:
conn = pymysql.connect( conn = pymysql.connect(

View File

@ -26,7 +26,8 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout from opensearchpy import ConnectionTimeout
from common.decorator import singleton from common.decorator import singleton
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer from rag.nlp import is_english, rag_tokenizer
from common.constants import PAGERANK_FLD, TAG_FLD from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings from common import settings
@ -216,7 +217,7 @@ class OSConnection(DocStoreConnection):
if bqry: if bqry:
s = s.query(bqry) s = s.query(bqry)
for field in highlightFields: for field in highlightFields:
s = s.highlight(field,force_source=True,no_match_size=30,require_field_match=False) s = s.highlight(field, force_source=True, no_match_size=30, require_field_match=False)
if orderBy: if orderBy:
orders = list() orders = list()
@ -242,7 +243,7 @@ class OSConnection(DocStoreConnection):
if use_knn: if use_knn:
del q["query"] del q["query"]
q["query"] = {"knn" : knn_query} q["query"] = {"knn": knn_query}
for i in range(ATTEMPT_TIME): for i in range(ATTEMPT_TIME):
try: try:
@ -328,7 +329,7 @@ class OSConnection(DocStoreConnection):
chunkId = condition["id"] chunkId = condition["id"]
for i in range(ATTEMPT_TIME): for i in range(ATTEMPT_TIME):
try: try:
self.os.update(index=indexName, id=chunkId, body={"doc":doc}) self.os.update(index=indexName, id=chunkId, body={"doc": doc})
return True return True
except Exception as e: except Exception as e:
logger.exception( logger.exception(
@ -435,7 +436,7 @@ class OSConnection(DocStoreConnection):
logger.debug("OSConnection.delete query: " + json.dumps(qry.to_dict())) logger.debug("OSConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME): for _ in range(ATTEMPT_TIME):
try: try:
#print(Search().query(qry).to_dict(), flush=True) # print(Search().query(qry).to_dict(), flush=True)
res = self.os.delete_by_query( res = self.os.delete_by_query(
index=indexName, index=indexName,
body=Search().query(qry).to_dict(), body=Search().query(qry).to_dict(),

View File

@ -42,6 +42,7 @@ class RAGFlowOSS:
# If there is a default bucket, use the default bucket # If there is a default bucket, use the default bucket
actual_bucket = self.bucket if self.bucket else bucket actual_bucket = self.bucket if self.bucket else bucket
return method(self, actual_bucket, *args, **kwargs) return method(self, actual_bucket, *args, **kwargs)
return wrapper return wrapper
@staticmethod @staticmethod
@ -50,6 +51,7 @@ class RAGFlowOSS:
# If the prefix path is set, use the prefix path # If the prefix path is set, use the prefix path
fnm = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm fnm = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm
return method(self, bucket, fnm, *args, **kwargs) return method(self, bucket, fnm, *args, **kwargs)
return wrapper return wrapper
def __open__(self): def __open__(self):
@ -171,4 +173,3 @@ class RAGFlowOSS:
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return None return None

Some files were not shown because too many files have changed in this diff Show More