Compare commits

...

46 Commits

Author SHA1 Message Date
0db00f70b2 Fix: add describe_image_with_prompt for ZHIPU AI (#11317)
### What problem does this PR solve?

Fix: add describe_image_with_prompt for ZHIPU AI  #11289 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 13:09:39 +08:00
701761d119 Feat: Fixed the issue where form data assigned by variables was not updated in real time. #10427 (#11333)
### What problem does this PR solve?

Feat: Fixed the issue where form data assigned by variables was not
updated in real time. #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 13:07:52 +08:00
2993fc666b Feat: update version to 0.22.1 (#11331)
### What problem does this PR solve?

Update version to 0.22.1

### Type of change

- [x] Documentation Update
2025-11-18 10:49:36 +08:00
8a6d205df0 fix: entrypoint.sh typo for disable datasync command (#11326)
### What problem does this PR solve?

There's a typo in `entrypoint.sh` on line 74: the case statement uses
`--disable-datasyn)` (missing the 'c'), while the usage function and
documentation correctly show `--disable-datasync` (with the 'c'). This
mismatch causes the `--disable-datasync` flag to be unrecognized,
triggering the usage message and causing containers to restart in a loop
when this flag is used.

**Background:**
- Users following the documentation use `--disable-datasync` in their
docker-compose.yml
- The entrypoint script doesn't recognize this flag due to the typo
- The script calls `usage()` and exits, causing Docker containers to
restart continuously
- This makes it impossible to disable the data sync service as intended

**Example scenario:**
When a user adds `--disable-datasync` to their docker-compose command
(as shown in examples), the container fails to start properly because
the argument isn't recognized.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
### Proposed Solution

Fix the typo on line 74 of `entrypoint.sh` by changing:
```bash
    --disable-datasyn)
```
to:
```bash
    --disable-datasync)
```

This matches the spelling used in the usage function (line 9 and 13) and
allows the flag to work as documented.

### Changes Made

- Fixed typo in `entrypoint.sh` line 74: changed `--disable-datasyn)` to
`--disable-datasync)`
- This ensures the argument matches the documented flag name and usage
function

---

**Code change:**

```bash
# Line 74 in entrypoint.sh
# Before:
    --disable-datasyn)
      ENABLE_DATASYNC=0
      shift
      ;;

# After:
    --disable-datasync)
      ENABLE_DATASYNC=0
      shift
      ;;
```

This is a simple one-character fix that resolves the argument parsing
issue.
2025-11-18 10:28:00 +08:00
912b6b023e fix: update check_embedding failed info (#11321)
### What problem does this PR solve?
change:
update check_embedding failed info

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 09:39:45 +08:00
89e8818dda Feat: add s3-compatible storage boxes (#11313)
### What problem does this PR solve?

PR for implementing s3 compatible storage units #11240 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 09:39:25 +08:00
1dba6b5bf9 Fix: Fixed an issue where adding session variables multiple times would overwrite them. (#11308)
### What problem does this PR solve?

Fix: Fixed an issue where adding session variables multiple times would
overwrite them.
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 09:39:02 +08:00
3fcf2ee54c feat: add new LLM provider Jiekou.AI (#11300)
### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jason <ggbbddjm@gmail.com>
2025-11-17 19:47:46 +08:00
d8f413a885 Feat: Construct a dynamic variable assignment form #10427 (#11316)
### What problem does this PR solve?

Feat: Construct a dynamic variable assignment form #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-17 19:45:58 +08:00
7264fb6978 Fix: concat images in word document. (#11310)
### What problem does this PR solve?

Fix: concat images in word document. Partially solved issues in #11063 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 19:38:26 +08:00
bd4bc57009 Refactor: move mcp connection utilities to common (#11304)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-17 15:34:17 +08:00
0569b50fed Fix: create dataset return type inconsistent (#11272)
### What problem does this PR solve?

Fix: create dataset return type inconsistent #11167 
 
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 15:27:19 +08:00
6b64641042 Fix: default model base url extraction logic (#11263)
### What problem does this PR solve?

Fixes an issue where default models which used the same factory but
different base URLs would all be initialised with the default chat
model's base URL and would ignore e.g. the embedding model's base URL
config.

For example, with the following service config, the embedding and
reranker models would end up using the base URL for the default chat
model (i.e. `llm1.example.com`):

```yaml
ragflow:
  service_conf:
    user_default_llm:
      factory: OpenAI-API-Compatible
      api_key: not-used
      default_models:
        chat_model:
          name: llm1
          base_url: https://llm1.example.com/v1
        embedding_model:
          name: llm2
          base_url: https://llm2.example.com/v1
        rerank_model:
          name: llm3
          base_url: https://llm3.example.com/v1/rerank

  llm_factories:
    factory_llm_infos:
    - name: OpenAI-API-Compatible
      logo: ""
      tags: "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION"
      status: "1"
      llm:
        - llm_name: llm1
          base_url: 'https://llm1.example.com/v1'
          api_key: not-used
          tags: "LLM,CHAT,IMAGE2TEXT"
          max_tokens: 100000
          model_type: chat
          is_tools: false

        - llm_name: llm2
          base_url: https://llm2.example.com/v1
          api_key: not-used
          tags: "TEXT EMBEDDING"
          max_tokens: 10000
          model_type: embedding

        - llm_name: llm3
          base_url: https://llm3.example.com/v1/rerank
          api_key: not-used
          tags: "RERANK,1k"
          max_tokens: 10000
          model_type: rerank
```

### Type of change

- [X] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 14:21:27 +08:00
9cef3a2625 Fix: Fixed the issue of not being able to select the time zone in the user center. (#11298)
… user center.

### What problem does this PR solve?

Fix: Fixed the issue of not being able to select the time zone in the
user center.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 11:16:55 +08:00
e7e89d3ecb Doc: style fix (#11295)
### What problem does this PR solve?

Style fix based on  #11283
### Type of change

- [x] Documentation Update
2025-11-17 11:16:34 +08:00
13e212c856 Feat: add Jira connector (#11285)
### What problem does this PR solve?

Add Jira connector.

<img width="978" height="925" alt="image"
src="https://github.com/user-attachments/assets/78bb5c77-2710-4569-a76e-9087ca23b227"
/>

---

<img width="1903" height="489" alt="image"
src="https://github.com/user-attachments/assets/193bc5c5-f751-4bd5-883a-2173282c2b96"
/>

---

<img width="1035" height="925" alt="image"
src="https://github.com/user-attachments/assets/1a0aec19-30eb-4ada-9283-61d1c915f59d"
/>

---

<img width="1905" height="601" alt="image"
src="https://github.com/user-attachments/assets/3dde1062-3f27-4717-8e09-fd5fd5e64171"
/>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-17 09:38:04 +08:00
61cf430dbb Minor tweats (#11271)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-16 19:29:20 +08:00
e841b09d63 Remove unused code and fix performance issue (#11284)
### What problem does this PR solve?

1. remove redundant code
2. fix miner performance issue

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-14 20:39:54 +08:00
b1a1eedf53 Doc: add default username & pwd (#11283)
### What problem does this PR solve?
Doc: add default username & pwd

### Type of change

- [x] Documentation Update

---------

Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com>
2025-11-14 19:52:58 +08:00
68e3b33ae4 Feat: extract message output to file (#11251)
### What problem does this PR solve?

Feat: extract message output to file

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-14 19:52:11 +08:00
cd55f6c1b8 Fix:ListOperations does not support sorting arrays of objects. (#11278)
### What problem does this PR solve?

pr:
#11276
change:
ListOperations does not support sorting arrays of objects.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 19:50:29 +08:00
996b5fe14e Fix: Added the ability to download files in the agent message reply function. (#11281)
### What problem does this PR solve?

Fix: Added the ability to download files in the agent message reply
function.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 19:50:01 +08:00
db4fd19c82 Feat:new component list operations (#11276)
### What problem does this PR solve?
issue:
https://github.com/infiniflow/ragflow/issues/10427
change:
new component list operations

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-14 16:33:20 +08:00
12db62b9c7 Refactor: improve mineru_parser get property logic (#11268)
### What problem does this PR solve?

improve mineru_parser get property logic

### Type of change

- [x] Refactoring
2025-11-14 16:32:35 +08:00
b5f2cf16bc Fix: check task executor alive and display status (#11270)
### What problem does this PR solve?

Correctly check task executor alive and display status.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 15:52:28 +08:00
e27ff8d3d4 Fix: rerank algorithm (#11266)
### What problem does this PR solve?

Fix: rerank algorithm #11234

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 13:59:54 +08:00
5f59418aba Remove leftover account and password from the code (#11248)
Remove legacy accounts and passwords.

### What problem does this PR solve?

Remove leftover account and password in
agent/templates/sql_assistant.json

### Type of change

- [x] Other (please describe):
2025-11-14 13:59:03 +08:00
87e69868c0 Fixes: Added session variable types and modified configuration (#11269)
### What problem does this PR solve?

Fixes: Added session variable types and modified configuration

- Added more types of session variables
- Modified the embedding model switching logic in the knowledge base
configuration

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 13:56:56 +08:00
72c20022f6 Refactor service config fetching in admin server (#11267)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
2025-11-14 12:32:08 +08:00
3f2472f1b9 Skip checking python comments 2025-11-14 11:59:15 +08:00
1d4d67daf8 Fix check_comment_ascii.py 2025-11-14 11:45:32 +08:00
7538e218a5 Fix check_comment_ascii.py 2025-11-14 11:32:55 +08:00
6b52f7df5a CI check comments of cheanged Python files 2025-11-14 10:54:07 +08:00
63131ec9b2 Docs: default admin credentials (#11260)
### What problem does this PR solve?

### Type of change

- [x] Documentation Update
2025-11-14 09:35:56 +08:00
e8f1a245a6 Feat:update check_embedding api (#11254)
### What problem does this PR solve?
pr: 
#10854
change:
update check_embedding api

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:25 +08:00
908450509f Feat: add fault-tolerant mechanism to RAPTOR (#11206)
### What problem does this PR solve?

Add fault-tolerant mechanism to RAPTOR.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:07 +08:00
70a0f081f6 Minor tweaks (#11249)
### What problem does this PR solve?

Fix some IDE warnings

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-13 16:11:07 +08:00
93422fa8cc Fix: Law parser (#11246)
### What problem does this PR solve?

Fix: Law parser
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-13 15:19:02 +08:00
bfc84ba95b Test: handle duplicate names by appending "(1)" (#11244)
### What problem does this PR solve?

- Updated tests to reflect new behavior of handling duplicate dataset
names
- Instead of returning an error, the system now appends "(1)" to
duplicate names
- This problem was introduced by PR #10960

### Type of change

- [x] Testcase update
2025-11-13 15:18:32 +08:00
871055b0fc Feat:support API for generating knowledge graph and raptor (#11229)
### What problem does this PR solve?
issue:
[#11195](https://github.com/infiniflow/ragflow/issues/11195)
change:
support API for generating knowledge graph and raptor

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
2025-11-13 15:17:52 +08:00
ba71160b14 Refa: rm useless code. (#11238)
### Type of change

- [x] Refactoring
2025-11-13 09:59:55 +08:00
bd5dda6b10 Feature/doc upload api add parent path 20251112 (#11231)
### What problem does this PR solve?

Add the specified parent_path to the document upload api interface
(#11230)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: virgilwong <hyhvirgil@gmail.com>
2025-11-13 09:59:39 +08:00
774563970b Fix: update readme (#11212)
### What problem does this PR solve?

Continue update readme #11167 

### Type of change

- [x] Documentation Update
2025-11-13 09:50:47 +08:00
83d84e90ed Fix: Profile picture cropping supported #10703 (#11221)
### What problem does this PR solve?

Fix: Profile picture cropping supported

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-13 09:50:10 +08:00
8ef2f79d0a Fix:reset the agent component’s output (#11222)
### What problem does this PR solve?

change:
“After each dialogue turn, the agent component’s output is not reset.”

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-13 09:49:12 +08:00
296476ab89 Refactor function name (#11210)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-12 19:00:15 +08:00
166 changed files with 6379 additions and 1997 deletions

View File

@ -95,6 +95,38 @@ jobs:
version: ">=0.11.x"
args: "check"
- name: Check comments of changed Python files
if: ${{ false }}
run: |
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
| grep -E '\.(py)$' || true)
if [ -n "$CHANGED_FILES" ]; then
echo "Check comments of changed Python files with check_comment_ascii.py"
readarray -t files <<< "$CHANGED_FILES"
HAS_ERROR=0
for file in "${files[@]}"; do
if [ -f "$file" ]; then
if python3 check_comment_ascii.py "$file"; then
echo "✅ $file"
else
echo "❌ $file"
HAS_ERROR=1
fi
fi
done
if [ $HAS_ERROR -ne 0 ]; then
exit 1
fi
else
echo "No Python files changed"
fi
fi
- name: Build ragflow:nightly
run: |
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}

View File

@ -51,7 +51,9 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
apt install -y ghostscript && \
apt install -y pandoc && \
apt install -y texlive
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \

View File

@ -192,9 +192,10 @@ releases! 🌟
```bash
$ cd ragflow/docker
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -192,6 +192,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
$ cd ragflow/docker
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -172,6 +172,7 @@
$ cd ragflow/docker
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -174,6 +174,7 @@
$ cd ragflow/docker
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -192,6 +192,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
$ cd ragflow/docker
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -191,6 +191,7 @@
$ cd ragflow/docker
# 可選使用穩定版標籤查看發佈https://github.com/infiniflow/ragflow/releasesgit checkout v0.22.0
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -192,6 +192,7 @@
$ cd ragflow/docker
# 可选使用稳定版本标签查看发布https://github.com/infiniflow/ragflow/releases例如git checkout v0.22.0
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -4,7 +4,7 @@
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.

View File

@ -378,7 +378,7 @@ class AdminCLI(Cmd):
self.session.headers.update({
'Content-Type': 'application/json',
'Authorization': response.headers['Authorization'],
'User-Agent': 'RAGFlow-CLI/0.22.0'
'User-Agent': 'RAGFlow-CLI/0.22.1'
})
print("Authentication successful.")
return True
@ -393,7 +393,9 @@ class AdminCLI(Cmd):
print(f"Can't access {self.host}, port: {self.port}")
def _format_service_detail_table(self, data):
if not any([isinstance(v, list) for v in data.values()]):
if isinstance(data, list):
return data
if not all([isinstance(v, list) for v in data.values()]):
# normal table
return data
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
@ -404,7 +406,7 @@ class AdminCLI(Cmd):
task_executor_list.append({
"task_executor_name": k,
**heartbeats[0],
})
} if heartbeats else {"task_executor_name": k})
return task_executor_list
def _print_table_simple(self, data):
@ -415,7 +417,8 @@ class AdminCLI(Cmd):
# handle single row data
data = [data]
columns = list(data[0].keys())
columns = list(set().union(*(d.keys() for d in data)))
columns.sort()
col_widths = {}
def get_string_width(text):

View File

@ -1,6 +1,6 @@
[project]
name = "ragflow-cli"
version = "0.22.0"
version = "0.22.1"
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
license = { text = "Apache License, Version 2.0" }

View File

@ -169,7 +169,7 @@ def login_verify(f):
username = auth.parameters['username']
password = auth.parameters['password']
try:
if check_admin(username, password) is False:
if not check_admin(username, password):
return jsonify({
"code": 500,
"message": "Access denied",

View File

@ -25,8 +25,21 @@ from common.config_utils import read_config
from urllib.parse import urlparse
class BaseConfig(BaseModel):
id: int
name: str
host: str
port: int
service_type: str
detail_func_name: str
def to_dict(self) -> dict[str, Any]:
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
'service_type': self.service_type}
class ServiceConfigs:
configs = dict
configs = list[BaseConfig]
def __init__(self):
self.configs = []
@ -45,19 +58,6 @@ class ServiceType(Enum):
FILE_STORE = "file_store"
class BaseConfig(BaseModel):
id: int
name: str
host: str
port: int
service_type: str
detail_func_name: str
def to_dict(self) -> dict[str, Any]:
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
'service_type': self.service_type}
class MetaConfig(BaseConfig):
meta_type: str
@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
ragflow_count = 0
id_count = 0
for k, v in raw_configs.items():
match (k):
match k:
case "ragflow":
name: str = f'ragflow_{ragflow_count}'
host: str = v['host']

View File

@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from werkzeug.security import check_password_hash
from common.constants import ActiveEnum
@ -190,7 +189,8 @@ class ServiceMgr:
config_dict['status'] = service_detail['status']
else:
config_dict['status'] = 'timeout'
except Exception:
except Exception as e:
logging.warning(f"Can't get service details, error: {e}")
config_dict['status'] = 'timeout'
if not config_dict['host']:
config_dict['host'] = '-'
@ -205,17 +205,13 @@ class ServiceMgr:
@staticmethod
def get_service_details(service_id: int):
service_id = int(service_id)
service_idx = int(service_id)
configs = SERVICE_CONFIGS.configs
service_config_mapping = {
c.id: {
'name': c.name,
'detail_func_name': c.detail_func_name
} for c in configs
}
service_info = service_config_mapping.get(service_id, {})
if not service_info:
raise AdminException(f"invalid service_id: {service_id}")
if service_idx < 0 or service_idx >= len(configs):
raise AdminException(f"invalid service_index: {service_idx}")
service_config = configs[service_idx]
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
res = detail_func()

View File

@ -298,8 +298,6 @@ class Canvas(Graph):
for kk, vv in kwargs["webhook_payload"].items():
self.components[k]["obj"].set_output(kk, vv)
self.components[k]["obj"].reset(True)
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files":
@ -408,6 +406,10 @@ class Canvas(Graph):
else:
yield decorate("message", {"content": cpn_obj.output("content")})
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
if isinstance(cpn_obj.output("attachment"), tuple):
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
while partials:

View File

@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
@ -368,11 +368,19 @@ Respond immediately with your final comprehensive answer.
return "Error occurred."
def reset(self, temp=False):
def reset(self, only_output=False):
"""
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
"""
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k, cpn in self.tools.items():
if hasattr(cpn, "reset") and callable(cpn.reset):
cpn.reset()
if only_output:
return
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}

View File

@ -463,12 +463,15 @@ class ComponentBase(ABC):
return self._param.outputs.get("_ERROR", {}).get("value")
def reset(self, only_output=False):
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
outputs: dict = self._param.outputs # for better performance
for k in outputs.keys():
outputs[k]["value"] = None
if only_output:
return
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
inputs: dict = self._param.inputs # for better performance
for k in inputs.keys():
inputs[k]["value"] = None
self._param.debug_inputs = {}
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:

View File

@ -0,0 +1,166 @@
from abc import ABC
import os
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
class ListOperationsParam(ComponentParamBase):
"""
Define the List Operations component parameters.
"""
def __init__(self):
super().__init__()
self.query = ""
self.operations = "topN"
self.n=0
self.sort_method = "asc"
self.filter = {
"operator": "=",
"value": ""
}
self.outputs = {
"result": {
"value": [],
"type": "Array of ?"
},
"first": {
"value": "",
"type": "?"
},
"last": {
"value": "",
"type": "?"
}
}
def check(self):
self.check_empty(self.query, "query")
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
def get_input_form(self) -> dict[str, dict]:
return {}
class ListOperations(ComponentBase,ABC):
component_name = "ListOperations"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
self.input_objects=[]
inputs = getattr(self._param, "query", None)
self.inputs=self._canvas.get_variable_value(inputs)
self.set_input_value(inputs, self.inputs)
if self._param.operations == "topN":
self._topN()
elif self._param.operations == "head":
self._head()
elif self._param.operations == "tail":
self._tail()
elif self._param.operations == "filter":
self._filter()
elif self._param.operations == "sort":
self._sort()
elif self._param.operations == "drop_duplicates":
self._drop_duplicates()
def _coerce_n(self):
try:
return int(getattr(self._param, "n", 0))
except Exception:
return 0
def _set_outputs(self, outputs):
self._param.outputs["result"]["value"] = outputs
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
def _topN(self):
n = self._coerce_n()
if n < 1:
outputs = []
else:
n = min(n, len(self.inputs))
outputs = self.inputs[:n]
self._set_outputs(outputs)
def _head(self):
n = self._coerce_n()
if 1 <= n <= len(self.inputs):
outputs = [self.inputs[n - 1]]
else:
outputs = []
self._set_outputs(outputs)
def _tail(self):
n = self._coerce_n()
if 1 <= n <= len(self.inputs):
outputs = [self.inputs[-n]]
else:
outputs = []
self._set_outputs(outputs)
def _filter(self):
self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])])
def _norm(self,v):
s = "" if v is None else str(v)
return s
def _eval(self, v, operator, value):
if operator == "=":
return v == value
elif operator == "":
return v != value
elif operator == "contains":
return value in v
elif operator == "start with":
return v.startswith(value)
elif operator == "end with":
return v.endswith(value)
else:
return False
def _sort(self):
items = self.inputs or []
method = getattr(self._param, "sort_method", "asc") or "asc"
reverse = method == "desc"
if not items:
self._set_outputs([])
return
first = items[0]
if isinstance(first, dict):
outputs = sorted(
items,
key=lambda x: self._hashable(x),
reverse=reverse,
)
else:
outputs = sorted(items, reverse=reverse)
self._set_outputs(outputs)
def _drop_duplicates(self):
seen = set()
outs = []
for item in self.inputs:
k = self._hashable(item)
if k in seen:
continue
seen.add(k)
outs.append(item)
self._set_outputs(outs)
def _hashable(self,x):
if isinstance(x, dict):
return tuple(sorted((k, self._hashable(v)) for k, v in x.items()))
if isinstance(x, (list, tuple)):
return tuple(self._hashable(v) for v in x)
if isinstance(x, set):
return tuple(sorted(self._hashable(v) for v in x))
return x
def thoughts(self) -> str:
return "ListOperation in progress"

View File

@ -222,7 +222,7 @@ class LLM(ComponentBase):
output_structure = self._param.outputs['structured']
except Exception:
pass
if output_structure:
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1):

View File

@ -17,6 +17,9 @@ import json
import os
import random
import re
import pypandoc
import logging
import tempfile
from functools import partial
from typing import Any
@ -24,7 +27,8 @@ from agent.component.base import ComponentBase, ComponentParamBase
from jinja2 import Template as Jinja2Template
from common.connection_utils import timeout
from common.misc_utils import get_uuid
from common import settings
class MessageParam(ComponentParamBase):
"""
@ -34,6 +38,7 @@ class MessageParam(ComponentParamBase):
super().__init__()
self.content = []
self.stream = True
self.output_format = None # default output format
self.outputs = {
"content": {
"type": "str"
@ -133,6 +138,7 @@ class Message(ComponentBase):
yield rand_cnt[s: ]
self.set_output("content", all_content)
self._convert_content(all_content)
def _is_jinjia2(self, content:str) -> bool:
patt = [
@ -164,6 +170,68 @@ class Message(ComponentBase):
content = re.sub(n, v, content)
self.set_output("content", content)
self._convert_content(content)
def thoughts(self) -> str:
return ""
def _convert_content(self, content):
doc_id = get_uuid()
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
self._param.output_format = "markdown"
try:
if self._param.output_format in {"markdown", "html"}:
if isinstance(content, str):
converted = pypandoc.convert_text(
content,
to=self._param.output_format,
format="markdown",
)
else:
converted = pypandoc.convert_file(
content,
to=self._param.output_format,
format="markdown",
)
binary_content = converted.encode("utf-8")
else: # pdf, docx
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
tmp_name = tmp.name
try:
if isinstance(content, str):
pypandoc.convert_text(
content,
to=self._param.output_format,
format="markdown",
outputfile=tmp_name,
)
else:
pypandoc.convert_file(
content,
to=self._param.output_format,
format="markdown",
outputfile=tmp_name,
)
with open(tmp_name, "rb") as f:
binary_content = f.read()
finally:
if os.path.exists(tmp_name):
os.remove(tmp_name)
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
self.set_output("attachment", {
"doc_id":doc_id,
"format":self._param.output_format,
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
except Exception as e:
logging.error(f"Error converting content to {self._param.output_format}: {e}")

View File

@ -83,10 +83,10 @@
"value": []
}
},
"password": "20010812Yy!",
"password": "",
"port": 3306,
"sql": "{Agent:WickedGoatsDivide@content}",
"username": "13637682833@163.com"
"username": ""
}
},
"upstream": [
@ -527,10 +527,10 @@
"value": []
}
},
"password": "20010812Yy!",
"password": "",
"port": 3306,
"sql": "{Agent:WickedGoatsDivide@content}",
"username": "13637682833@163.com"
"username": ""
},
"label": "ExeSQL",
"name": "ExeSQL"

View File

@ -21,9 +21,8 @@ from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
from common.misc_utils import hash_str2int
from rag.llm.chat_model import ToolCallSession
from rag.prompts.generator import kb_prompt
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer

View File

@ -96,12 +96,12 @@ login_manager.init_app(app)
commands.register_commands(app)
def search_pages_path(pages_dir):
def search_pages_path(page_path):
app_path_list = [
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
]
api_path_list = [
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
]
app_path_list.extend(api_path_list)
return app_path_list
@ -138,7 +138,7 @@ pages_dir = [
]
client_urls_prefix = [
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
]
@ -177,5 +177,7 @@ def load_user(web_request):
@app.teardown_request
def _db_close(exc):
def _db_close(exception):
if exception:
logging.exception(f"Request failed: {exception}")
close_connection()

View File

@ -13,41 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
from datetime import datetime, timedelta
from flask import request, Response
from api.db.services.llm_service import LLMBundle
from flask import request
from flask_login import login_required, current_user
from api.db import VALID_FILE_TYPES, FileType
from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
from api.utils.file_utils import filename_type, thumbnail
from rag.app.tag import label_question
from rag.prompts.generator import keyword_extraction
from common.time_utils import current_timestamp, datetime_format
from api.db.services.canvas_service import UserCanvasService
from agent.canvas import Canvas
from functools import partial
from pathlib import Path
from common import settings
@manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required
@ -138,758 +113,3 @@ def stats():
except Exception as e:
return server_error_response(e)
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
def set_conversation():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
if objs[0].source == "agent":
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
if not e:
return server_error_response("canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
else:
e, dia = DialogService.get_by_id(objs[0].dialog_id)
if not e:
return get_data_error_result(message="Dialog not found")
conv = {
"id": get_uuid(),
"dialog_id": dia.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/completion', methods=['POST']) # noqa: F821
@validate_request("conversation_id", "messages")
def completion():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req:
req["quote"] = False
msg = []
for m in req["messages"]:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
def rename_field(ans):
reference = ans['reference']
if not isinstance(reference, dict):
return
for chunk_i in reference.get('chunks', []):
if 'docnm_kwd' in chunk_i:
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
try:
if conv.source == "agent":
stream = req.get("stream", True)
conv.message.append(msg[-1])
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
if not e:
return server_error_response("canvas not found.")
del req["conversation_id"]
del req["messages"]
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "content": ""}
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=stream)
assert answer is not None, "Nothing. Is it over?"
if stream:
assert isinstance(answer, partial), "Nothing. Is it over?"
def sse():
nonlocal answer, cvs, conv
try:
for ans in answer():
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
canvas.history.append(("assistant", final_ans["content"]))
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
fillin_conv(result)
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_json_result(data=result)
# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
del req["conversation_id"]
del req["messages"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
answer = None
for ans in chat(dia, msg, **req):
answer = ans
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
break
rename_field(answer)
return get_json_result(data=answer)
except Exception as e:
return server_error_response(e)
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
# @login_required
def get_conversation(conversation_id):
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
e, conv = API4ConversationService.get_by_id(conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR)
for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0:
continue
for chunk_i in referenct_i['chunks']:
if 'docnm_kwd' in chunk_i.keys():
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/document/upload', methods=['POST']) # noqa: F821
@validate_request("kb_name")
def upload():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id
try:
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
kb_id = kb.id
except Exception as e:
return server_error_response(e)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, tenant_id)
kb_root_folder = FileService.get_kb_folder(tenant_id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
try:
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
return get_data_error_result(
message="Exceed the maximum file number of a free user!")
filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb_id)
filetype = filename_type(filename)
if not filetype:
return get_data_error_result(
message="This type of file has not been supported yet!")
location = filename
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
location += "_"
blob = request.files['file'].read()
settings.STORAGE_IMPL.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": kb.tenant_id,
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob),
"suffix": Path(filename).suffix.lstrip("."),
}
form_data = request.form
if "parser_id" in form_data.keys():
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
doc["parser_id"] = request.form.get("parser_id").strip()
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if doc["type"] == FileType.AURAL:
doc["parser_id"] = ParserType.AUDIO.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
if re.search(r"\.(eml)$", filename):
doc["parser_id"] = ParserType.EMAIL.value
doc_result = DocumentService.insert(doc)
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e:
return server_error_response(e)
if "run" in form_data.keys():
if request.form.get("run").strip() == "1":
try:
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(doc["id"], info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
tenant_id = DocumentService.get_tenant_id(doc["id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
# e, doc = DocumentService.get_by_id(doc["id"])
TaskService.filter_delete([Task.doc_id == doc["id"]])
e, doc = DocumentService.get_by_id(doc["id"])
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0)
except Exception as e:
return server_error_response(e)
return get_json_result(data=doc_result.to_json())
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
@validate_request("conversation_id")
def upload_parse():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
# @login_required
def list_chunks():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if "doc_name" in req.keys():
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
elif "doc_id" in req.keys():
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
doc_id = req['doc_id']
else:
return get_json_result(
data=False, message="Can't find doc_name or doc_id"
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
"doc_name": res_item["docnm_kwd"],
"image_id": res_item["img_id"]
} for res_item in res
]
except Exception as e:
return server_error_response(e)
return get_json_result(data=res)
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
# @login_required
def get_chunk(chunk_id):
from rag.nlp import search
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
tenant_id = objs[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response(Exception("Chunk not found"))
k = []
for n in chunk.keys():
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
k.append(n)
for n in k:
del chunk[n]
return get_json_result(data=chunk)
except Exception as e:
return server_error_response(e)
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
# @login_required
def list_kb_docs():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
tenant_id = objs[0].tenant_id
kb_name = req.get("kb_name", "").strip()
try:
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
kb_id = kb.id
except Exception as e:
return server_error_response(e)
page_number = int(req.get("page", 1))
items_per_page = int(req.get("page_size", 15))
orderby = req.get("orderby", "create_time")
desc = req.get("desc", True)
keywords = req.get("keywords", "")
status = req.get("status", [])
if status:
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
)
types = req.get("types", [])
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
return get_data_error_result(
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
)
try:
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
return get_json_result(data={"total": tol, "docs": docs})
except Exception as e:
return server_error_response(e)
@manager.route('/document/infos', methods=['POST']) # noqa: F821
@validate_request("doc_ids")
def docinfos():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts()))
@manager.route('/document', methods=['DELETE']) # noqa: F821
# @login_required
def document_rm():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
req = request.json
try:
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
for doc_id in req.get("doc_ids", []):
if doc_id not in doc_ids:
doc_ids.append(doc_id)
if not doc_ids:
return get_json_result(
data=False, message="Can't find doc_names or doc_ids"
)
except Exception as e:
return server_error_response(e)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, tenant_id)
errors = ""
docs = DocumentService.get_by_ids(doc_ids)
doc_dic = {}
for doc in docs:
doc_dic[doc.id] = doc
for doc_id in doc_ids:
try:
if doc_id not in doc_dic:
return get_data_error_result(message="Document not found!")
doc = doc_dic[doc_id]
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc_id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc_id)
settings.STORAGE_IMPL.rm(b, n)
except Exception as e:
errors += str(e)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=True)
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
@validate_request("Authorization", "conversation_id", "word")
def completion_faq():
import base64
req = request.json
token = req["Authorization"]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req:
req["quote"] = True
msg = [{"role": "user", "content": req["word"]}]
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
try:
if conv.source == "agent":
conv.message.append(msg[-1])
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
if not e:
return server_error_response("canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "doc_aggs": []}
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=False)
assert answer is not None, "Nothing. Is it over?"
data_type_picture = {
"type": 3,
"url": "base64 content"
}
data = [
{
"type": 1,
"content": ""
}
]
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
except Exception as e:
return server_error_response(e)
response = {"code": 200, "msg": "success", "data": data}
return response
# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
del req["conversation_id"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
data_type_picture = {
"type": 3,
"url": "base64 content"
}
data = [
{
"type": 1,
"content": ""
}
]
ans = ""
for a in chat(dia, msg, stream=False, **req):
ans = a
break
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
except Exception as e:
return server_error_response(e)
response = {"code": 200, "msg": "success", "data": data}
return response
except Exception as e:
return server_error_response(e)
@manager.route('/retrieval', methods=['POST']) # noqa: F821
@validate_request("kb_id", "question")
def retrieval():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
kb_ids = req.get("kb_id", [])
doc_ids = req.get("doc_ids", [])
question = req.get("question")
page = int(req.get("page", 1))
size = int(req.get("page_size", 30))
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
highlight = bool(req.get("highlight", False))
try:
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."',
code=RetCode.AUTHENTICATION_ERROR)
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs))
for c in ranks["chunks"]:
c.pop("vector", None)
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
return server_error_response(e)

View File

@ -426,7 +426,6 @@ def test_db_connect():
try:
import trino
import os
from trino.auth import BasicAuthentication
except Exception as e:
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
@ -438,7 +437,7 @@ def test_db_connect():
auth = None
if http_scheme == "https" and req.get("password"):
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
conn = trino.dbapi.connect(
host=req["host"],
@ -471,8 +470,8 @@ def test_db_connect():
@login_required
def getlistversion(canvas_id):
try:
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=list)
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=versions)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")

View File

@ -55,7 +55,6 @@ def set_connector():
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
"status": TaskStatus.SCHEDULE,
}
conn["status"] = TaskStatus.SCHEDULE
ConnectorService.save(**conn)
time.sleep(1)

View File

@ -85,7 +85,6 @@ def get():
if not e:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
avatar = None
for tenant in tenants:
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog) > 0:

View File

@ -154,15 +154,15 @@ def get_kb_names(kb_ids):
@login_required
def list_dialogs():
try:
diags = DialogService.query(
conversations = DialogService.query(
tenant_id=current_user.id,
status=StatusEnum.VALID.value,
reverse=True,
order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
return get_json_result(data=diags)
conversations = [d.to_dict() for d in conversations]
for conversation in conversations:
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
return get_json_result(data=conversations)
except Exception as e:
return server_error_response(e)

View File

@ -308,7 +308,7 @@ def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
def docinfos():
def doc_infos():
req = request.json
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
@ -508,6 +508,7 @@ def get(doc_id):
ext = ext.group(1) if ext else None
if ext:
if doc.type == FileType.VISUAL.value:
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
else:
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
@ -517,6 +518,22 @@ def get(doc_id):
return server_error_response(e)
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
@login_required
def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
response = flask.make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
return response
except Exception as e:
return server_error_response(e)
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id")
@ -544,6 +561,7 @@ def change_parser():
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None
try:
if "pipeline_id" in req and req["pipeline_id"] != "":

View File

@ -246,8 +246,8 @@ def rm():
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:

View File

@ -16,6 +16,7 @@
import json
import logging
import random
import re
from flask import request
from flask_login import login_required, current_user
@ -54,6 +55,10 @@ def create():
**req
)
code = req.get("code")
if code:
return get_data_error_result(code=code, message=req.get("message"))
try:
if not KnowledgebaseService.save(**req):
return get_data_error_result()
@ -731,6 +736,8 @@ def delete_kb_task():
def cancel_task(task_id):
REDIS_CONN.set(f"{task_id}-cancel", "x")
kb_task_id_field: str = ""
kb_task_finish_at: str = ""
match pipeline_task_type:
case PipelineTaskType.GRAPH_RAG:
kb_task_id_field = "graphrag_task_id"
@ -807,7 +814,7 @@ def check_embedding():
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
total = docStoreConn.getTotal(res0)
total = docStoreConn.get_total(res0)
if total <= 0:
return []
@ -824,7 +831,7 @@ def check_embedding():
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
ids = docStoreConn.getChunkIds(res1)
ids = docStoreConn.get_chunk_ids(res1)
if not ids:
continue
@ -845,8 +852,13 @@ def check_embedding():
"position_int": full_doc.get("position_int"),
"top_int": full_doc.get("top_int"),
"content_with_weight": full_doc.get("content_with_weight") or "",
"question_kwd": full_doc.get("question_kwd") or []
})
return out
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = request.json
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
@ -859,8 +871,10 @@ def check_embedding():
results, eff_sims = [], []
for ck in samples:
txt = (ck.get("content_with_weight") or "").strip()
if not txt:
title = ck.get("doc_name") or "Title"
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
txt_in = _clean(txt_in)
if not txt_in:
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
continue
@ -869,8 +883,16 @@ def check_embedding():
continue
try:
qv, _ = emb_mdl.encode_queries(txt)
sim = _cos_sim(qv, ck["vector"])
v, _ = emb_mdl.encode([title, txt_in])
sim_content = _cos_sim(v[1], ck["vector"])
title_w = 0.1
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
sim_mix = _cos_sim(qv_mix, ck["vector"])
sim = sim_content
mode = "content_only"
if sim_mix > sim:
sim = sim_mix
mode = "title+content"
except Exception:
return get_error_data_result(message="embedding failure")
@ -892,9 +914,10 @@ def check_embedding():
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
"match_mode": mode,
}
if summary["avg_cos_sim"] > 0.99:
if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})

View File

@ -25,7 +25,7 @@ from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821

View File

@ -41,12 +41,12 @@ def list_agents(tenant_id):
return get_error_data_result("The agent doesn't exist.")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
orderby = request.args.get("orderby", "update_time")
order_by = request.args.get("orderby", "update_time")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
return get_result(data=canvas)

View File

@ -21,10 +21,11 @@ import json
from flask import request
from peewee import OperationalError
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
from api.db.services.user_service import TenantService
from common.constants import RetCode, FileSource, StatusEnum
from api.utils.api_utils import (
@ -118,7 +119,6 @@ def create(tenant_id):
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
if err is not None:
return get_error_argument_result(err)
req = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = tenant_id,
@ -144,7 +144,6 @@ def create(tenant_id):
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
except Exception as e:
@ -532,3 +531,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
@token_required
def run_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
return get_result(data={"graphrag_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
@token_required
def trace_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_result(data={})
return get_result(data=task.to_dict())
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
@token_required
def run_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
return get_result(data={"raptor_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
@token_required
def trace_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
return get_result(data=task.to_dict())

View File

@ -93,6 +93,10 @@ def upload(dataset_id, tenant_id):
type: file
required: true
description: Document files to upload.
- in: formData
name: parent_path
type: string
description: Optional nested path under the parent folder. Uses '/' separators.
responses:
200:
description: Successfully uploaded documents.
@ -151,7 +155,7 @@ def upload(dataset_id, tenant_id):
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id)
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=request.form.get("parent_path"))
if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
# rename key's name

View File

@ -305,6 +305,7 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
time.sleep(self.retry_delay * (2 ** attempt))
else:
raise
return None
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
@ -772,7 +773,7 @@ class Document(DataBaseModel):
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
@ -876,7 +877,7 @@ class Dialog(DataBaseModel):
class Conversation(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
dialog_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
name = CharField(max_length=255, null=True, help_text="conversation name", index=True)
message = JSONField(null=True)
reference = JSONField(null=True, default=[])
user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)

View File

@ -70,7 +70,7 @@ class ConnectorService(CommonService):
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
e, conn = cls.get_by_id(connector_id)
if not e:
return
return None
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id)
err = FileService.delete_docs([d.id for d in docs], tenant_id)
@ -125,11 +125,11 @@ class SyncLogsService(CommonService):
)
query = query.distinct().order_by(cls.model.update_time.desc())
totbal = query.count()
total = query.count()
if page_number:
query = query.paginate(page_number, items_per_page)
return list(query.dicts()), totbal
return list(query.dicts()), total
@classmethod
def start(cls, id, connector_id):
@ -242,7 +242,7 @@ class Connector2KbService(CommonService):
"id": get_uuid(),
"connector_id": conn_id,
"kb_id": kb_id,
"auto_parse": conn.get("auto_parse", "1")
"auto_parse": conn.get("auto_parse", "1")
})
SyncLogsService.schedule(conn_id, kb_id, reindex=True)

View File

@ -342,7 +342,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream):
yield ans
return
return None
chat_start_ts = timer()
@ -386,7 +386,7 @@ def chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
yield ans
return
return None
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -617,6 +617,8 @@ def chat(dialog, messages, stream=True, **kwargs):
res["audio_binary"] = tts(tts_mdl, answer)
yield res
return None
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """
@ -745,7 +747,7 @@ Please write the SQL, only SQL, without any other explanations or text.
def tts(tts_mdl, text):
if not tts_mdl or not text:
return
return None
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk

View File

@ -113,7 +113,7 @@ class DocumentService(CommonService):
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
raise RuntimeError("Exceed the maximum file number of a free user!")
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
raise RuntimeError("Exceed the maximum length of file name!")
@ -309,7 +309,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -322,7 +322,7 @@ class DocumentService(CommonService):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields(
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
@ -464,7 +464,7 @@ class DocumentService(CommonService):
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["tenant_id"]
@classmethod
@ -473,7 +473,7 @@ class DocumentService(CommonService):
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["kb_id"]
@classmethod
@ -486,7 +486,7 @@ class DocumentService(CommonService):
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["tenant_id"]
@classmethod
@ -533,7 +533,7 @@ class DocumentService(CommonService):
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["embd_id"]
@classmethod
@ -569,7 +569,7 @@ class DocumentService(CommonService):
.where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return
return None
return doc_id[0]["id"]
@classmethod
@ -715,7 +715,7 @@ class DocumentService(CommonService):
prg = 1
status = TaskStatus.DONE.value
# only for special task and parsed docs and unfinised
# only for special task and parsed docs and unfinished
freeze_progress = special_task_running and doc_progress >= 1 and not finished
msg = "\n".join(sorted(msg))
info = {
@ -974,13 +974,13 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
vectors = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
vectors.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
token_counts[doc_id] += c
return vects
return vectors
idxnm = search.index_name(kb.tenant_id)
try_create_idx = True
@ -1011,15 +1011,15 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
except Exception:
logging.exception("Mind map generation error")
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vects)
vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vectors)
for i, d in enumerate(cks):
v = vects[i]
v = vectors[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -31,7 +31,7 @@ from common.misc_utils import get_uuid
from common.constants import TaskStatus, FileSource, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
from rag.llm.cv_model import GptV4
from common import settings
@ -329,7 +329,7 @@ class FileService(CommonService):
current_id = start_id
while current_id:
e, file = cls.get_by_id(current_id)
if file.parent_id != file.id and e:
if e and file.parent_id != file.id:
parent_folders.append(file)
current_id = file.parent_id
else:
@ -423,13 +423,15 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def upload_document(self, kb, file_objs, user_id, src="local"):
def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None):
root_folder = self.get_root_folder(user_id)
pf_id = root_folder["id"]
self.init_knowledgebase_docs(pf_id, user_id)
kb_root_folder = self.get_kb_folder(user_id)
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
safe_parent_path = sanitize_path(parent_path)
err, files = [], []
for file in file_objs:
try:
@ -439,7 +441,7 @@ class FileService(CommonService):
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")
location = filename
location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}"
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
location += "_"

View File

@ -24,9 +24,9 @@ from common.time_utils import current_timestamp, datetime_format
from api.db.services import duplicate_name
from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid
from common.constants import StatusEnum
from common.constants import StatusEnum, RetCode
from api.constants import DATASET_NAME_LIMIT
from api.utils.api_utils import get_parser_config, get_data_error_result
from api.utils.api_utils import get_parser_config
class KnowledgebaseService(CommonService):
"""Service class for managing knowledge base operations.
@ -391,12 +391,12 @@ class KnowledgebaseService(CommonService):
"""
# Validate name
if not isinstance(name, str):
return get_data_error_result(message="Dataset name must be string.")
return {"code": RetCode.DATA_ERROR, "message": "Dataset name must be string."}
dataset_name = name.strip()
if dataset_name == "":
return get_data_error_result(message="Dataset name can't be empty.")
if len(dataset_name) == 0:
return {"code": RetCode.DATA_ERROR, "message": "Dataset name can't be empty."}
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
return {"code": RetCode.DATA_ERROR, "message": f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}"}
# Deduplicate name within tenant
dataset_name = duplicate_name(
@ -409,7 +409,7 @@ class KnowledgebaseService(CommonService):
# Verify tenant exists
ok, _t = TenantService.get_by_id(tenant_id)
if not ok:
return False, "Tenant not found."
return {"code": RetCode.DATA_ERROR, "message": "Tenant does not exist."}
# Build payload
kb_id = get_uuid()
@ -419,11 +419,12 @@ class KnowledgebaseService(CommonService):
"tenant_id": tenant_id,
"created_by": tenant_id,
"parser_id": (parser_id or "naive"),
**kwargs
**kwargs # Includes optional fields such as description, language, permission, avatar, parser_config, etc.
}
# Default parser_config (align with kb_app.create) — do not accept external overrides
# Update parser_config (always override with validated default/merged config)
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
return payload

View File

@ -19,6 +19,7 @@ import re
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
@ -32,6 +33,14 @@ def get_init_tenant_llm(user_id):
from common import settings
tenant_llm = []
model_configs = {
LLMType.CHAT: settings.CHAT_CFG,
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
LLMType.SPEECH2TEXT: settings.ASR_CFG,
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
LLMType.RERANK: settings.RERANK_CFG,
}
seen = set()
factory_configs = []
for factory_config in [
@ -54,8 +63,8 @@ def get_init_tenant_llm(user_id):
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
@ -80,8 +89,8 @@ class LLMBundle(LLM4Tenant):
def encode(self, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
safe_texts = []
for text in texts:
token_size = num_tokens_from_string(text)
@ -90,7 +99,7 @@ class LLMBundle(LLM4Tenant):
safe_texts.append(text[:target_len])
else:
safe_texts.append(text)
embeddings, used_tokens = self.mdl.encode(safe_texts)
llm_name = getattr(self, "llm_name", None)

View File

@ -41,7 +41,7 @@ from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data
from common.versions import get_ragflow_version
from common.config_utils import show_configs
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event()

View File

@ -37,7 +37,7 @@ from peewee import OperationalError
from common.constants import ActiveEnum
from api.db.db_models import APIToken
from api.utils.json_encode import CustomJSONEncoder
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout
from common.constants import RetCode

View File

@ -1,3 +1,19 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reusable HTML email templates and registry.
"""

View File

@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob):
return repaired
return blob
def sanitize_path(raw_path: str | None) -> str:
"""Normalize and sanitize a user-provided path segment.
- Converts backslashes to forward slashes
- Strips leading/trailing slashes
- Removes '.' and '..' segments
- Restricts characters to A-Za-z0-9, underscore, dash, and '/'
"""
if not raw_path:
return ""
backslash_re = re.compile(r"[\\]+")
unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]")
normalized = backslash_re.sub("/", raw_path)
normalized = normalized.strip("/")
parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")]
sanitized = "/".join(parts)
sanitized = unsafe_re.sub("", sanitized)
return sanitized

View File

@ -173,7 +173,8 @@ def check_task_executor_alive():
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
task_executor_heartbeats[task_executor_id] = heartbeats
if task_executor_heartbeats:
return {"status": "alive", "message": task_executor_heartbeats}
status = "alive" if any(task_executor_heartbeats.values()) else "timeout"
return {"status": status, "message": task_executor_heartbeats}
else:
return {"status": "timeout", "message": "Not found any task executor."}
except Exception as e:

View File

@ -1,3 +1,19 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
from enum import Enum, IntEnum

48
check_comment_ascii.py Normal file
View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""
Check whether given python files contain non-ASCII comments.
How to check the whole git repo:
```
$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py
```
"""
import sys
import tokenize
import ast
import pathlib
import re
ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline
def check(src: str, name: str) -> int:
"""
docstring line 1
docstring line 2
"""
ok = 1
# A common comment begins with `#`
with tokenize.open(src) as fp:
for tk in tokenize.generate_tokens(fp.readline):
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
ok = 0
# A docstring begins and ends with `'''`
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
ok = 0
return ok
if __name__ == "__main__":
status = 0
for file in sys.argv[1:]:
if not check(file, file):
status = 1
sys.exit(status)

View File

@ -11,7 +11,7 @@ from .confluence_connector import ConfluenceConnector
from .discord_connector import DiscordConnector
from .dropbox_connector import DropboxConnector
from .google_drive.connector import GoogleDriveConnector
from .jira_connector import JiraConnector
from .jira.connector import JiraConnector
from .sharepoint_connector import SharePointConnector
from .teams_connector import TeamsConnector
from .config import BlobType, DocumentSource

View File

@ -87,6 +87,13 @@ class BlobStorageConnector(LoadConnector, PollConnector):
):
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
elif self.bucket_type == BlobType.S3_COMPATIBLE:
if not all(
credentials.get(key)
for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("S3 Compatible Storage")
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")

View File

@ -13,6 +13,7 @@ def get_current_tz_offset() -> int:
return round(time_diff.total_seconds() / 3600)
ONE_MINUTE = 60
ONE_HOUR = 3600
ONE_DAY = ONE_HOUR * 24
@ -31,6 +32,7 @@ class BlobType(str, Enum):
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
S3_COMPATIBLE = "s3_compatible"
class DocumentSource(str, Enum):
@ -42,9 +44,11 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
SLACK = "slack"
CONFLUENCE = "confluence"
JIRA = "jira"
GOOGLE_DRIVE = "google_drive"
GMAIL = "gmail"
DISCORD = "discord"
S3_COMPATIBLE = "s3_compatible"
class FileOrigin(str, Enum):
@ -178,6 +182,21 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
JIRA_SYNC_TIME_BUFFER_SECONDS = int(
os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE)
)
JIRA_TIMEZONE_OFFSET = float(
os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset())
)
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(

View File

@ -1788,6 +1788,7 @@ class ConfluenceConnector(
cql_url = self.confluence_client.build_cql_url(
page_query, expand=",".join(_PAGE_EXPANSION_FIELDS)
)
logging.info(f"[Confluence Connector] Building CQL URL {cql_url}")
return update_param_in_path(cql_url, "limit", str(limit))
@override

View File

@ -3,15 +3,9 @@ import os
import threading
from typing import Any, Callable
import requests
from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
DEFAULT_DEVICE_INTERVAL = 5
def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var."""
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
return result.get("value")
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
if "client_id" in credentials:
return credentials["client_id"], credentials.get("client_secret")
for key in ("installed", "web"):
if key in credentials and isinstance(credentials[key], dict):
nested = credentials[key]
if "client_id" not in nested:
break
return nested["client_id"], nested.get("client_secret")
raise ValueError("Provided Google OAuth credentials are missing client_id.")
def start_device_authorization_flow(
credentials: dict[str, Any],
source: DocumentSource,
) -> tuple[dict[str, Any], dict[str, Any]]:
client_id, client_secret = _extract_client_info(credentials)
data = {
"client_id": client_id,
"scope": " ".join(_get_requested_scopes(source)),
}
if client_secret:
data["client_secret"] = client_secret
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
resp.raise_for_status()
payload = resp.json()
state = {
"client_id": client_id,
"client_secret": client_secret,
"device_code": payload.get("device_code"),
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
}
response_data = {
"user_code": payload.get("user_code"),
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
"verification_url_complete": payload.get("verification_url_complete")
or payload.get("verification_uri_complete"),
"expires_in": payload.get("expires_in"),
"interval": state["interval"],
}
return state, response_data
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
data = {
"client_id": state["client_id"],
"device_code": state["device_code"],
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
if state.get("client_secret"):
data["client_secret"] = state["client_secret"]
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
resp.raise_for_status()
return resp.json()
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs()
timeout_message = (
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.")
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
instructions = [
"Google rejected one or more of the requested OAuth scopes.",
"Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
]
raise RuntimeError("\n".join(instructions)) from warning
raise
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
client_config = {"web": credentials["web"]}
if client_config is None:
raise ValueError(
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
return _run_local_server_flow(client_config, source)

View File

@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC):
class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)"""
"""Checkpoint connector interface (with permission sync)"""
@abstractmethod
def load_from_checkpoint(
@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
"""If dynamic, the credentials may change during usage ... meaning the client
needs to use the locking features of the credentials provider to operate
correctly.

View File

View File

@ -0,0 +1,973 @@
"""Checkpointed Jira connector that emits markdown blobs for each issue."""
from __future__ import annotations
import argparse
import copy
import logging
import os
import re
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from datetime import datetime, timedelta, timezone
from typing import Any
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from jira import JIRA
from jira.resources import Issue
from pydantic import Field
from common.data_source.config import (
INDEX_BATCH_SIZE,
JIRA_CONNECTOR_LABELS_TO_SKIP,
JIRA_CONNECTOR_MAX_TICKET_SIZE,
JIRA_TIMEZONE_OFFSET,
ONE_HOUR,
DocumentSource,
)
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
InsufficientPermissionsError,
UnexpectedValidationError,
)
from common.data_source.interfaces import (
CheckpointedConnectorWithPermSync,
CheckpointOutputWrapper,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync,
)
from common.data_source.jira.utils import (
JIRA_CLOUD_API_VERSION,
JIRA_SERVER_API_VERSION,
build_issue_url,
extract_body_text,
extract_named_value,
extract_user,
format_attachments,
format_comments,
parse_jira_datetime,
should_skip_issue,
)
from common.data_source.models import (
ConnectorCheckpoint,
ConnectorFailure,
Document,
DocumentFailure,
SlimDocument,
)
from common.data_source.utils import is_atlassian_cloud_url, is_atlassian_date_error, scoped_url
logger = logging.getLogger(__name__)
_DEFAULT_FIELDS = "summary,description,updated,created,status,priority,assignee,reporter,labels,issuetype,project,comment,attachment"
_SLIM_FIELDS = "key,project"
_MAX_RESULTS_FETCH_IDS = 5000
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
_DEFAULT_ATTACHMENT_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB
class JiraCheckpoint(ConnectorCheckpoint):
"""Checkpoint that tracks which slice of the current JQL result set was emitted."""
start_at: int = 0
cursor: str | None = None
ids_done: bool = False
all_issue_ids: list[list[str]] = Field(default_factory=list)
_TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Retrieve Jira issues and emit them as markdown documents."""
def __init__(
self,
jira_base_url: str,
project_key: str | None = None,
jql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
include_comments: bool = True,
include_attachments: bool = False,
labels_to_skip: Sequence[str] | None = None,
comment_email_blacklist: Sequence[str] | None = None,
scoped_token: bool = False,
attachment_size_limit: int | None = None,
timezone_offset: float | None = None,
) -> None:
if not jira_base_url:
raise ConnectorValidationError("Jira base URL must be provided.")
self.jira_base_url = jira_base_url.rstrip("/")
self.project_key = project_key
self.jql_query = jql_query
self.batch_size = batch_size
self.include_comments = include_comments
self.include_attachments = include_attachments
configured_labels = labels_to_skip or JIRA_CONNECTOR_LABELS_TO_SKIP
self.labels_to_skip = {label.lower() for label in configured_labels}
self.comment_email_blacklist = {email.lower() for email in comment_email_blacklist or []}
self.scoped_token = scoped_token
self.jira_client: JIRA | None = None
self.max_ticket_size = JIRA_CONNECTOR_MAX_TICKET_SIZE
self.attachment_size_limit = attachment_size_limit if attachment_size_limit and attachment_size_limit > 0 else _DEFAULT_ATTACHMENT_SIZE_LIMIT
self._fields_param = _DEFAULT_FIELDS
self._slim_fields = _SLIM_FIELDS
tz_offset_value = float(timezone_offset) if timezone_offset is not None else float(JIRA_TIMEZONE_OFFSET)
self.timezone_offset = tz_offset_value
self.timezone = timezone(offset=timedelta(hours=tz_offset_value))
self._timezone_overridden = timezone_offset is not None
# -------------------------------------------------------------------------
# Connector lifecycle helpers
# -------------------------------------------------------------------------
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Instantiate the Jira client using either an API token or username/password."""
jira_url_for_client = self.jira_base_url
if self.scoped_token:
if is_atlassian_cloud_url(self.jira_base_url):
try:
jira_url_for_client = scoped_url(self.jira_base_url, "jira")
except ValueError as exc:
raise ConnectorValidationError(str(exc)) from exc
else:
logger.warning(f"[Jira] Scoped token requested but Jira base URL {self.jira_base_url} does not appear to be an Atlassian Cloud domain; scoped token ignored.")
user_email = credentials.get("jira_user_email") or credentials.get("username")
api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token")
password = credentials.get("jira_password") or credentials.get("password")
rest_api_version = credentials.get("rest_api_version")
if not rest_api_version:
rest_api_version = JIRA_CLOUD_API_VERSION if api_token else JIRA_SERVER_API_VERSION
options: dict[str, Any] = {"rest_api_version": rest_api_version}
try:
if user_email and api_token:
self.jira_client = JIRA(
server=jira_url_for_client,
basic_auth=(user_email, api_token),
options=options,
)
elif api_token:
self.jira_client = JIRA(
server=jira_url_for_client,
token_auth=api_token,
options=options,
)
elif user_email and password:
self.jira_client = JIRA(
server=jira_url_for_client,
basic_auth=(user_email, password),
options=options,
)
else:
raise ConnectorMissingCredentialError("Jira credentials must include either an API token or username/password.")
except Exception as exc: # pragma: no cover - jira lib raises many types
raise ConnectorMissingCredentialError(f"Jira: {exc}") from exc
self._sync_timezone_from_server()
return None
def validate_connector_settings(self) -> None:
"""Validate connectivity by fetching basic Jira info."""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
try:
if self.jql_query:
dummy_checkpoint = self.build_dummy_checkpoint()
checkpoint_callback = self._make_checkpoint_callback(dummy_checkpoint)
iterator = self._perform_jql_search(
jql=self.jql_query,
start=0,
max_results=1,
fields="key",
all_issue_ids=dummy_checkpoint.all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=dummy_checkpoint.cursor,
ids_done=dummy_checkpoint.ids_done,
)
next(iter(iterator), None)
elif self.project_key:
self.jira_client.project(self.project_key)
else:
self.jira_client.projects()
except Exception as exc: # pragma: no cover - dependent on Jira responses
self._handle_validation_error(exc)
# -------------------------------------------------------------------------
# Checkpointed connector implementation
# -------------------------------------------------------------------------
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraCheckpoint,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
"""Load Jira issues, emitting a Document per issue."""
try:
return (yield from self._load_with_retry(start, end, checkpoint))
except Exception as exc:
logger.exception(f"[Jira] Jira query ultimately failed: {exc}")
yield ConnectorFailure(
failure_message=f"Failed to query Jira: {exc}",
exception=exc,
)
return JiraCheckpoint(has_more=False, start_at=checkpoint.start_at)
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraCheckpoint,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
"""Permissions are not synced separately, so reuse the standard loader."""
return (yield from self.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
def _load_with_retry(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraCheckpoint,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
attempt_start = start
retried_with_buffer = False
attempt = 0
while True:
attempt += 1
jql = self._build_jql(attempt_start, end)
logger.info(f"[Jira] Executing Jira JQL attempt {attempt} (start={attempt_start}, end={end}, buffered_retry={retried_with_buffer}): {jql}")
try:
return (yield from self._load_from_checkpoint_internal(jql, checkpoint, start_filter=start))
except Exception as exc:
if attempt_start is not None and not retried_with_buffer and is_atlassian_date_error(exc):
attempt_start = attempt_start - ONE_HOUR
retried_with_buffer = True
logger.info(f"[Jira] Atlassian date error detected; retrying with start={attempt_start}.")
continue
raise
def _handle_validation_error(self, exc: Exception) -> None:
status_code = getattr(exc, "status_code", None)
if status_code == 401:
raise InsufficientPermissionsError("Jira credential appears to be invalid or expired (HTTP 401).") from exc
if status_code == 403:
raise InsufficientPermissionsError("Jira token does not have permission to access the requested resources (HTTP 403).") from exc
if status_code == 404:
raise ConnectorValidationError("Jira resource not found (HTTP 404).") from exc
if status_code == 429:
raise ConnectorValidationError("Jira rate limit exceeded during validation (HTTP 429).") from exc
message = getattr(exc, "text", str(exc))
if not message:
raise UnexpectedValidationError("Unexpected Jira validation error.") from exc
raise ConnectorValidationError(f"Jira validation failed: {message}") from exc
def _load_from_checkpoint_internal(
self,
jql: str,
checkpoint: JiraCheckpoint,
start_filter: SecondsSinceUnixEpoch | None = None,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
assert self.jira_client, "load_credentials must be called before loading issues."
page_size = self._full_page_size()
new_checkpoint = copy.deepcopy(checkpoint)
starting_offset = new_checkpoint.start_at or 0
current_offset = starting_offset
checkpoint_callback = self._make_checkpoint_callback(new_checkpoint)
issue_iter = self._perform_jql_search(
jql=jql,
start=current_offset,
max_results=page_size,
fields=self._fields_param,
all_issue_ids=new_checkpoint.all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=new_checkpoint.cursor,
ids_done=new_checkpoint.ids_done,
)
start_cutoff = float(start_filter) if start_filter is not None else None
for issue in issue_iter:
current_offset += 1
issue_key = getattr(issue, "key", "unknown")
if should_skip_issue(issue, self.labels_to_skip):
continue
issue_updated = parse_jira_datetime(issue.raw.get("fields", {}).get("updated"))
if start_cutoff is not None and issue_updated is not None and issue_updated.timestamp() <= start_cutoff:
# Jira JQL only supports minute precision, so we discard already-processed
# issues here based on the original second-level cutoff.
continue
try:
document = self._issue_to_document(issue)
except Exception as exc: # pragma: no cover - defensive
logger.exception(f"[Jira] Failed to convert Jira issue {issue_key}: {exc}")
yield ConnectorFailure(
failure_message=f"Failed to convert Jira issue {issue_key}: {exc}",
failed_document=DocumentFailure(
document_id=issue_key,
document_link=build_issue_url(self.jira_base_url, issue_key),
),
exception=exc,
)
continue
if document is not None:
yield document
if self.include_attachments:
for attachment_document in self._attachment_documents(issue):
if attachment_document is not None:
yield attachment_document
self._update_checkpoint_for_next_run(
checkpoint=new_checkpoint,
current_offset=current_offset,
starting_offset=starting_offset,
page_size=page_size,
)
new_checkpoint.start_at = current_offset
return new_checkpoint
def build_dummy_checkpoint(self) -> JiraCheckpoint:
"""Create an empty checkpoint used to kick off ingestion."""
return JiraCheckpoint(has_more=True, start_at=0)
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraCheckpoint:
"""Validate a serialized checkpoint."""
return JiraCheckpoint.model_validate_json(checkpoint_json)
# -------------------------------------------------------------------------
# Slim connector implementation
# -------------------------------------------------------------------------
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None, # noqa: ARG002 - maintained for interface compatibility
) -> Generator[list[SlimDocument], None, None]:
"""Return lightweight references to Jira issues (used for permission syncing)."""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
start_ts = start if start is not None else 0
end_ts = end if end is not None else datetime.now(timezone.utc).timestamp()
jql = self._build_jql(start_ts, end_ts)
checkpoint = self.build_dummy_checkpoint()
checkpoint_callback = self._make_checkpoint_callback(checkpoint)
prev_offset = 0
current_offset = 0
slim_batch: list[SlimDocument] = []
while checkpoint.has_more:
for issue in self._perform_jql_search(
jql=jql,
start=current_offset,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields=self._slim_fields,
all_issue_ids=checkpoint.all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=checkpoint.cursor,
ids_done=checkpoint.ids_done,
):
current_offset += 1
if should_skip_issue(issue, self.labels_to_skip):
continue
doc_id = build_issue_url(self.jira_base_url, issue.key)
slim_batch.append(SlimDocument(id=doc_id))
if len(slim_batch) >= _JIRA_SLIM_PAGE_SIZE:
yield slim_batch
slim_batch = []
self._update_checkpoint_for_next_run(
checkpoint=checkpoint,
current_offset=current_offset,
starting_offset=prev_offset,
page_size=_JIRA_SLIM_PAGE_SIZE,
)
prev_offset = current_offset
if slim_batch:
yield slim_batch
# -------------------------------------------------------------------------
# Internal helpers
# -------------------------------------------------------------------------
def _build_jql(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> str:
clauses: list[str] = []
if self.jql_query:
clauses.append(f"({self.jql_query})")
elif self.project_key:
clauses.append(f'project = "{self.project_key}"')
else:
raise ConnectorValidationError("Either project_key or jql_query must be provided for Jira connector.")
if self.labels_to_skip:
labels = ", ".join(f'"{label}"' for label in self.labels_to_skip)
clauses.append(f"labels NOT IN ({labels})")
if start is not None:
clauses.append(f'updated >= "{self._format_jql_time(start)}"')
if end is not None:
clauses.append(f'updated <= "{self._format_jql_time(end)}"')
if not clauses:
raise ConnectorValidationError("Unable to build Jira JQL query.")
jql = " AND ".join(clauses)
if "order by" not in jql.lower():
jql = f"{jql} ORDER BY updated ASC"
return jql
def _format_jql_time(self, timestamp: SecondsSinceUnixEpoch) -> str:
dt_utc = datetime.fromtimestamp(float(timestamp), tz=timezone.utc)
dt_local = dt_utc.astimezone(self.timezone)
# Jira only accepts minute-precision timestamps in JQL, so we format accordingly
# and rely on a post-query second-level filter to avoid duplicates.
return dt_local.strftime("%Y-%m-%d %H:%M")
def _issue_to_document(self, issue: Issue) -> Document | None:
fields = issue.raw.get("fields", {})
summary = fields.get("summary") or ""
description_text = extract_body_text(fields.get("description"))
comments_text = (
format_comments(
fields.get("comment"),
blacklist=self.comment_email_blacklist,
)
if self.include_comments
else ""
)
attachments_text = format_attachments(fields.get("attachment"))
reporter_name, reporter_email = extract_user(fields.get("reporter"))
assignee_name, assignee_email = extract_user(fields.get("assignee"))
status = extract_named_value(fields.get("status"))
priority = extract_named_value(fields.get("priority"))
issue_type = extract_named_value(fields.get("issuetype"))
project = fields.get("project") or {}
issue_url = build_issue_url(self.jira_base_url, issue.key)
metadata_lines = [
f"key: {issue.key}",
f"url: {issue_url}",
f"summary: {summary}",
f"status: {status or 'Unknown'}",
f"priority: {priority or 'Unspecified'}",
f"issue_type: {issue_type or 'Unknown'}",
f"project: {project.get('name') or ''}",
f"project_key: {project.get('key') or self.project_key or ''}",
]
if reporter_name:
metadata_lines.append(f"reporter: {reporter_name}")
if reporter_email:
metadata_lines.append(f"reporter_email: {reporter_email}")
if assignee_name:
metadata_lines.append(f"assignee: {assignee_name}")
if assignee_email:
metadata_lines.append(f"assignee_email: {assignee_email}")
if fields.get("labels"):
metadata_lines.append(f"labels: {', '.join(fields.get('labels'))}")
created_dt = parse_jira_datetime(fields.get("created"))
updated_dt = parse_jira_datetime(fields.get("updated")) or created_dt or datetime.now(timezone.utc)
metadata_lines.append(f"created: {created_dt.isoformat() if created_dt else ''}")
metadata_lines.append(f"updated: {updated_dt.isoformat() if updated_dt else ''}")
sections: list[str] = [
"---",
"\n".join(filter(None, metadata_lines)),
"---",
"",
"## Description",
description_text or "No description provided.",
]
if comments_text:
sections.extend(["", "## Comments", comments_text])
if attachments_text:
sections.extend(["", "## Attachments", attachments_text])
blob_text = "\n".join(sections).strip() + "\n"
blob = blob_text.encode("utf-8")
if len(blob) > self.max_ticket_size:
logger.info(f"[Jira] Skipping {issue.key} because it exceeds the maximum size of {self.max_ticket_size} bytes.")
return None
semantic_identifier = f"{issue.key}: {summary}" if summary else issue.key
return Document(
id=issue_url,
source=DocumentSource.JIRA,
semantic_identifier=semantic_identifier,
extension=".md",
blob=blob,
doc_updated_at=updated_dt,
size_bytes=len(blob),
)
def _attachment_documents(self, issue: Issue) -> Iterable[Document]:
attachments = issue.raw.get("fields", {}).get("attachment") or []
for attachment in attachments:
try:
document = self._attachment_to_document(issue, attachment)
if document is not None:
yield document
except Exception as exc: # pragma: no cover - defensive
failed_id = attachment.get("id") or attachment.get("filename")
issue_key = getattr(issue, "key", "unknown")
logger.warning(f"[Jira] Failed to process attachment {failed_id} for issue {issue_key}: {exc}")
def _attachment_to_document(self, issue: Issue, attachment: dict[str, Any]) -> Document | None:
if not self.include_attachments:
return None
filename = attachment.get("filename")
content_url = attachment.get("content")
if not filename or not content_url:
return None
try:
attachment_size = int(attachment.get("size", 0))
except (TypeError, ValueError):
attachment_size = 0
if attachment_size and attachment_size > self.attachment_size_limit:
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because reported size exceeds limit ({self.attachment_size_limit} bytes).")
return None
blob = self._download_attachment(content_url)
if blob is None:
return None
if len(blob) > self.attachment_size_limit:
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because it exceeds the size limit ({self.attachment_size_limit} bytes).")
return None
attachment_time = parse_jira_datetime(attachment.get("created")) or parse_jira_datetime(attachment.get("updated"))
updated_dt = attachment_time or parse_jira_datetime(issue.raw.get("fields", {}).get("updated")) or datetime.now(timezone.utc)
extension = os.path.splitext(filename)[1] or ""
document_id = f"{issue.key}::attachment::{attachment.get('id') or filename}"
semantic_identifier = f"{issue.key} attachment: {filename}"
return Document(
id=document_id,
source=DocumentSource.JIRA,
semantic_identifier=semantic_identifier,
extension=extension,
blob=blob,
doc_updated_at=updated_dt,
size_bytes=len(blob),
)
def _download_attachment(self, url: str) -> bytes | None:
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
response = self.jira_client._session.get(url)
response.raise_for_status()
return response.content
def _sync_timezone_from_server(self) -> None:
if self._timezone_overridden or not self.jira_client:
return
try:
server_info = self.jira_client.server_info()
except Exception as exc: # pragma: no cover - defensive
logger.info(f"[Jira] Unable to determine timezone from server info; continuing with offset {self.timezone_offset}. Error: {exc}")
return
detected_offset = self._extract_timezone_offset(server_info)
if detected_offset is None or detected_offset == self.timezone_offset:
return
self.timezone_offset = detected_offset
self.timezone = timezone(offset=timedelta(hours=detected_offset))
logger.info(f"[Jira] Timezone offset adjusted to {detected_offset} hours using Jira server info.")
def _extract_timezone_offset(self, server_info: dict[str, Any]) -> float | None:
server_time_raw = server_info.get("serverTime")
if isinstance(server_time_raw, str):
offset = self._parse_offset_from_datetime_string(server_time_raw)
if offset is not None:
return offset
tz_name = server_info.get("timeZone")
if isinstance(tz_name, str):
offset = self._offset_from_zone_name(tz_name)
if offset is not None:
return offset
return None
@staticmethod
def _parse_offset_from_datetime_string(value: str) -> float | None:
normalized = JiraConnector._normalize_datetime_string(value)
try:
dt = datetime.fromisoformat(normalized)
except ValueError:
return None
if dt.tzinfo is None:
return 0.0
offset = dt.tzinfo.utcoffset(dt)
if offset is None:
return None
return offset.total_seconds() / 3600.0
@staticmethod
def _normalize_datetime_string(value: str) -> str:
trimmed = (value or "").strip()
if trimmed.endswith("Z"):
return f"{trimmed[:-1]}+00:00"
match = _TZ_OFFSET_PATTERN.search(trimmed)
if match and match.group(3) != ":":
sign, hours, _, minutes = match.groups()
trimmed = f"{trimmed[: match.start()]}{sign}{hours}:{minutes}"
return trimmed
@staticmethod
def _offset_from_zone_name(name: str) -> float | None:
try:
tz = ZoneInfo(name)
except (ZoneInfoNotFoundError, ValueError):
return None
reference = datetime.now(tz)
offset = reference.utcoffset()
if offset is None:
return None
return offset.total_seconds() / 3600.0
def _is_cloud_client(self) -> bool:
if not self.jira_client:
return False
rest_version = str(self.jira_client._options.get("rest_api_version", "")).strip()
return rest_version == str(JIRA_CLOUD_API_VERSION)
def _full_page_size(self) -> int:
return max(1, min(self.batch_size, _JIRA_FULL_PAGE_SIZE))
def _perform_jql_search(
self,
*,
jql: str,
start: int,
max_results: int,
fields: str | None = None,
all_issue_ids: list[list[str]] | None = None,
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
next_page_token: str | None = None,
ids_done: bool = False,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
is_cloud = self._is_cloud_client()
if is_cloud:
if all_issue_ids is None:
raise ValueError("all_issue_ids is required for Jira Cloud searches.")
yield from self._perform_jql_search_v3(
jql=jql,
max_results=max_results,
fields=fields,
all_issue_ids=all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=next_page_token,
ids_done=ids_done,
)
else:
yield from self._perform_jql_search_v2(
jql=jql,
start=start,
max_results=max_results,
fields=fields,
)
def _perform_jql_search_v3(
self,
*,
jql: str,
max_results: int,
all_issue_ids: list[list[str]],
fields: str | None = None,
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
next_page_token: str | None = None,
ids_done: bool = False,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
if not ids_done:
new_ids, page_token = self._enhanced_search_ids(jql, next_page_token)
if checkpoint_callback is not None and new_ids:
checkpoint_callback(
self._chunk_issue_ids(new_ids, max_results),
page_token,
)
elif checkpoint_callback is not None:
checkpoint_callback([], page_token)
if all_issue_ids:
issue_ids = all_issue_ids.pop()
if issue_ids:
yield from self._bulk_fetch_issues(issue_ids, fields)
def _perform_jql_search_v2(
self,
*,
jql: str,
start: int,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
issues = self.jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields or self._fields_param,
expand="renderedFields",
)
for issue in issues:
yield issue
def _enhanced_search_ids(
self,
jql: str,
next_page_token: str | None,
) -> tuple[list[str], str | None]:
assert self.jira_client, "Jira client not initialized."
enhanced_search_path = self.jira_client._get_url("search/jql")
params: dict[str, str | int | None] = {
"jql": jql,
"maxResults": _MAX_RESULTS_FETCH_IDS,
"nextPageToken": next_page_token,
"fields": "id",
}
response = self.jira_client._session.get(enhanced_search_path, params=params)
response.raise_for_status()
data = response.json()
return [str(issue["id"]) for issue in data.get("issues", [])], data.get("nextPageToken")
def _bulk_fetch_issues(
self,
issue_ids: list[str],
fields: str | None,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
if not issue_ids:
return []
bulk_fetch_path = self.jira_client._get_url("issue/bulkfetch")
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
payload["fields"] = fields.split(",") if fields else ["*all"]
response = self.jira_client._session.post(bulk_fetch_path, json=payload)
response.raise_for_status()
data = response.json()
return [Issue(self.jira_client._options, self.jira_client._session, raw=issue) for issue in data.get("issues", [])]
@staticmethod
def _chunk_issue_ids(issue_ids: list[str], chunk_size: int) -> Iterable[list[str]]:
if chunk_size <= 0:
chunk_size = _JIRA_FULL_PAGE_SIZE
for idx in range(0, len(issue_ids), chunk_size):
yield issue_ids[idx : idx + chunk_size]
def _make_checkpoint_callback(self, checkpoint: JiraCheckpoint) -> Callable[[Iterable[list[str]], str | None], None]:
def checkpoint_callback(
issue_ids: Iterable[list[str]] | list[list[str]],
page_token: str | None,
) -> None:
for id_batch in issue_ids:
checkpoint.all_issue_ids.append(list(id_batch))
checkpoint.cursor = page_token
checkpoint.ids_done = page_token is None
return checkpoint_callback
def _update_checkpoint_for_next_run(
self,
*,
checkpoint: JiraCheckpoint,
current_offset: int,
starting_offset: int,
page_size: int,
) -> None:
if self._is_cloud_client():
checkpoint.has_more = bool(checkpoint.all_issue_ids) or not checkpoint.ids_done
else:
checkpoint.has_more = current_offset - starting_offset == page_size
checkpoint.cursor = None
checkpoint.ids_done = True
checkpoint.all_issue_ids = []
def iterate_jira_documents(
connector: "JiraConnector",
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
iteration_limit: int = 100_000,
) -> Iterator[Document]:
"""Yield documents without materializing the entire result set."""
checkpoint = connector.build_dummy_checkpoint()
iterations = 0
while checkpoint.has_more:
wrapper = CheckpointOutputWrapper[JiraCheckpoint]()
generator = wrapper(connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
for document, failure, next_checkpoint in generator:
if failure is not None:
failure_message = getattr(failure, "failure_message", str(failure))
raise RuntimeError(f"Failed to load Jira documents: {failure_message}")
if document is not None:
yield document
if next_checkpoint is not None:
checkpoint = next_checkpoint
iterations += 1
if iterations > iteration_limit:
raise RuntimeError("Too many iterations while loading Jira documents.")
def test_jira(
*,
base_url: str,
project_key: str | None = None,
jql_query: str | None = None,
credentials: dict[str, Any],
batch_size: int = INDEX_BATCH_SIZE,
start_ts: float | None = None,
end_ts: float | None = None,
connector_options: dict[str, Any] | None = None,
) -> list[Document]:
"""Programmatic entry point that mirrors the CLI workflow."""
connector_kwargs = connector_options.copy() if connector_options else {}
connector = JiraConnector(
jira_base_url=base_url,
project_key=project_key,
jql_query=jql_query,
batch_size=batch_size,
**connector_kwargs,
)
connector.load_credentials(credentials)
connector.validate_connector_settings()
now_ts = datetime.now(timezone.utc).timestamp()
start = start_ts if start_ts is not None else 0.0
end = end_ts if end_ts is not None else now_ts
documents = list(iterate_jira_documents(connector, start=start, end=end))
logger.info(f"[Jira] Fetched {len(documents)} Jira documents.")
for doc in documents[:5]:
logger.info(f"[Jira] Document {doc.semantic_identifier} ({doc.id}) size={doc.size_bytes} bytes")
return documents
def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Fetch Jira issues and print summary statistics.")
parser.add_argument("--base-url", dest="base_url", default=os.environ.get("JIRA_BASE_URL"))
parser.add_argument("--project", dest="project_key", default=os.environ.get("JIRA_PROJECT_KEY"))
parser.add_argument("--jql", dest="jql_query", default=os.environ.get("JIRA_JQL"))
parser.add_argument("--email", dest="user_email", default=os.environ.get("JIRA_USER_EMAIL"))
parser.add_argument("--token", dest="api_token", default=os.environ.get("JIRA_API_TOKEN"))
parser.add_argument("--password", dest="password", default=os.environ.get("JIRA_PASSWORD"))
parser.add_argument("--batch-size", dest="batch_size", type=int, default=int(os.environ.get("JIRA_BATCH_SIZE", INDEX_BATCH_SIZE)))
parser.add_argument("--include_comments", dest="include_comments", type=bool, default=True)
parser.add_argument("--include_attachments", dest="include_attachments", type=bool, default=True)
parser.add_argument("--attachment_size_limit", dest="attachment_size_limit", type=float, default=_DEFAULT_ATTACHMENT_SIZE_LIMIT)
parser.add_argument("--start-ts", dest="start_ts", type=float, default=None, help="Epoch seconds inclusive lower bound for updated issues.")
parser.add_argument("--end-ts", dest="end_ts", type=float, default=9999999999, help="Epoch seconds inclusive upper bound for updated issues.")
return parser
def main(config: dict[str, Any] | None = None) -> None:
if config is None:
args = _build_arg_parser().parse_args()
config = {
"base_url": args.base_url,
"project_key": args.project_key,
"jql_query": args.jql_query,
"batch_size": args.batch_size,
"start_ts": args.start_ts,
"end_ts": args.end_ts,
"include_comments": args.include_comments,
"include_attachments": args.include_attachments,
"attachment_size_limit": args.attachment_size_limit,
"credentials": {
"jira_user_email": args.user_email,
"jira_api_token": args.api_token,
"jira_password": args.password,
},
}
base_url = config.get("base_url")
credentials = config.get("credentials", {})
print(f"[Jira] {config=}", flush=True)
print(f"[Jira] {credentials=}", flush=True)
if not base_url:
raise RuntimeError("Jira base URL must be provided via config or CLI arguments.")
if not (credentials.get("jira_api_token") or (credentials.get("jira_user_email") and credentials.get("jira_password"))):
raise RuntimeError("Provide either an API token or both email/password for Jira authentication.")
connector_options = {
key: value
for key, value in (
("include_comments", config.get("include_comments")),
("include_attachments", config.get("include_attachments")),
("attachment_size_limit", config.get("attachment_size_limit")),
("labels_to_skip", config.get("labels_to_skip")),
("comment_email_blacklist", config.get("comment_email_blacklist")),
("scoped_token", config.get("scoped_token")),
("timezone_offset", config.get("timezone_offset")),
)
if value is not None
}
documents = test_jira(
base_url=base_url,
project_key=config.get("project_key"),
jql_query=config.get("jql_query"),
credentials=credentials,
batch_size=config.get("batch_size", INDEX_BATCH_SIZE),
start_ts=config.get("start_ts"),
end_ts=config.get("end_ts"),
connector_options=connector_options,
)
preview_count = min(len(documents), 5)
for idx in range(preview_count):
doc = documents[idx]
print(f"[Jira] [Sample {idx + 1}] {doc.semantic_identifier} | id={doc.id} | size={doc.size_bytes} bytes")
print(f"[Jira] Jira connector test completed. Documents fetched: {len(documents)}")
if __name__ == "__main__": # pragma: no cover - manual execution path
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(name)s %(message)s")
main()

View File

@ -0,0 +1,149 @@
"""Helper utilities for the Jira connector."""
from __future__ import annotations
import os
from collections.abc import Collection
from datetime import datetime, timezone
from typing import Any, Iterable
from jira.resources import Issue
from common.data_source.utils import datetime_from_string
JIRA_SERVER_API_VERSION = os.environ.get("JIRA_SERVER_API_VERSION", "2")
JIRA_CLOUD_API_VERSION = os.environ.get("JIRA_CLOUD_API_VERSION", "3")
def build_issue_url(base_url: str, issue_key: str) -> str:
"""Return the canonical UI URL for a Jira issue."""
return f"{base_url.rstrip('/')}/browse/{issue_key}"
def parse_jira_datetime(value: Any) -> datetime | None:
"""Best-effort parse of Jira datetime values to aware UTC datetimes."""
if value is None:
return None
if isinstance(value, datetime):
return value.astimezone(timezone.utc) if value.tzinfo else value.replace(tzinfo=timezone.utc)
if isinstance(value, str):
return datetime_from_string(value)
return None
def extract_named_value(value: Any) -> str | None:
"""Extract a readable string out of Jira's typed objects."""
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, dict):
return value.get("name") or value.get("value")
return getattr(value, "name", None)
def extract_user(value: Any) -> tuple[str | None, str | None]:
"""Return display name + email tuple for a Jira user blob."""
if value is None:
return None, None
if isinstance(value, dict):
return value.get("displayName"), value.get("emailAddress")
display = getattr(value, "displayName", None)
email = getattr(value, "emailAddress", None)
return display, email
def extract_text_from_adf(adf: Any) -> str:
"""Flatten Atlassian Document Format (ADF) structures to text."""
texts: list[str] = []
def _walk(node: Any) -> None:
if node is None:
return
if isinstance(node, dict):
node_type = node.get("type")
if node_type == "text":
texts.append(node.get("text", ""))
for child in node.get("content", []):
_walk(child)
elif isinstance(node, list):
for child in node:
_walk(child)
_walk(adf)
return "\n".join(part for part in texts if part)
def extract_body_text(value: Any) -> str:
"""Normalize Jira description/comments (raw/adf/str) into plain text."""
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, dict):
return extract_text_from_adf(value).strip()
return str(value).strip()
def format_comments(
comment_block: Any,
*,
blacklist: Collection[str],
) -> str:
"""Convert Jira comments into a markdown-ish bullet list."""
if not isinstance(comment_block, dict):
return ""
comments = comment_block.get("comments") or []
lines: list[str] = []
normalized_blacklist = {email.lower() for email in blacklist if email}
for comment in comments:
author = comment.get("author") or {}
author_email = (author.get("emailAddress") or "").lower()
if author_email and author_email in normalized_blacklist:
continue
author_name = author.get("displayName") or author.get("name") or author_email or "Unknown"
created = parse_jira_datetime(comment.get("created"))
created_str = created.isoformat() if created else "Unknown time"
body = extract_body_text(comment.get("body"))
if not body:
continue
lines.append(f"- {author_name} ({created_str}):\n{body}")
return "\n\n".join(lines)
def format_attachments(attachments: Any) -> str:
"""List Jira attachments as bullet points."""
if not isinstance(attachments, list):
return ""
attachment_lines: list[str] = []
for attachment in attachments:
filename = attachment.get("filename")
if not filename:
continue
size = attachment.get("size")
size_text = f" ({size} bytes)" if isinstance(size, int) else ""
content_url = attachment.get("content") or ""
url_suffix = f" -> {content_url}" if content_url else ""
attachment_lines.append(f"- {filename}{size_text}{url_suffix}")
return "\n".join(attachment_lines)
def should_skip_issue(issue: Issue, labels_to_skip: set[str]) -> bool:
"""Return True if the issue contains any label from the skip list."""
if not labels_to_skip:
return False
fields = getattr(issue, "raw", {}).get("fields", {})
labels: Iterable[str] = fields.get("labels") or []
for label in labels:
if (label or "").lower() in labels_to_skip:
return True
return False

View File

@ -1,112 +0,0 @@
"""Jira connector"""
from typing import Any
from jira import JIRA
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.exceptions import (
ConnectorValidationError,
InsufficientPermissionsError,
UnexpectedValidationError, ConnectorMissingCredentialError
)
from common.data_source.interfaces import (
CheckpointedConnectorWithPermSync,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.models import (
ConnectorCheckpoint
)
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Jira connector for accessing Jira issues and projects"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.jira_client: JIRA | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load Jira credentials"""
try:
url = credentials.get("url")
username = credentials.get("username")
password = credentials.get("password")
token = credentials.get("token")
if not url:
raise ConnectorMissingCredentialError("Jira URL is required")
if token:
# API token authentication
self.jira_client = JIRA(server=url, token_auth=token)
elif username and password:
# Basic authentication
self.jira_client = JIRA(server=url, basic_auth=(username, password))
else:
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
return None
except Exception as e:
raise ConnectorMissingCredentialError(f"Jira: {e}")
def validate_connector_settings(self) -> None:
"""Validate Jira connector settings"""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
try:
# Test connection by getting server info
self.jira_client.server_info()
except Exception as e:
if "401" in str(e) or "403" in str(e):
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
elif "404" in str(e):
raise ConnectorValidationError("Jira instance not found")
else:
raise UnexpectedValidationError(f"Jira validation error: {e}")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Jira for recent issues"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint"""
# Simplified implementation
return []
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint with permission sync"""
# Simplified implementation
return []
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
return ConnectorCheckpoint()
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
# Simplified implementation
return ConnectorCheckpoint()
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
return []

View File

@ -48,17 +48,35 @@ from common.data_source.exceptions import RateLimitTriedTooManyTimesError
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
from common.data_source.models import BasicExpertInfo, Document
_TZ_SUFFIX_PATTERN = re.compile(r"([+-])([\d:]+)$")
def datetime_from_string(datetime_string: str) -> datetime:
datetime_string = datetime_string.strip()
match_jira_format = _TZ_SUFFIX_PATTERN.search(datetime_string)
if match_jira_format:
sign, tz_field = match_jira_format.groups()
digits = tz_field.replace(":", "")
if digits.isdigit() and 1 <= len(digits) <= 4:
if len(digits) >= 3:
hours = digits[:-2].rjust(2, "0")
minutes = digits[-2:]
else:
hours = digits.rjust(2, "0")
minutes = "00"
normalized = f"{sign}{hours}:{minutes}"
datetime_string = f"{datetime_string[: match_jira_format.start()]}{normalized}"
# Handle the case where the datetime string ends with 'Z' (Zulu time)
if datetime_string.endswith('Z'):
datetime_string = datetime_string[:-1] + '+00:00'
if datetime_string.endswith("Z"):
datetime_string = datetime_string[:-1] + "+00:00"
# Handle timezone format "+0000" -> "+00:00"
if datetime_string.endswith('+0000'):
datetime_string = datetime_string[:-5] + '+00:00'
if datetime_string.endswith("+0000"):
datetime_string = datetime_string[:-5] + "+00:00"
datetime_object = datetime.fromisoformat(datetime_string)
@ -293,6 +311,13 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
aws_secret_access_key=credentials["secret_access_key"],
region_name=credentials["region"],
)
elif bucket_type == BlobType.S3_COMPATIBLE:
return boto3.client(
"s3",
endpoint_url=credentials["endpoint_url"],
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
else:
raise ValueError(f"Unsupported bucket type: {bucket_type}")
@ -480,7 +505,7 @@ def get_file_ext(file_name: str) -> str:
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"}
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
@ -902,6 +927,18 @@ def load_all_docs_from_checkpoint_connector(
)
_ATLASSIAN_CLOUD_DOMAINS = (".atlassian.net", ".jira.com", ".jira-dev.com")
def is_atlassian_cloud_url(url: str) -> bool:
try:
host = urlparse(url).hostname or ""
except ValueError:
return False
host = host.lower()
return any(host.endswith(domain) for domain in _ATLASSIAN_CLOUD_DOMAINS)
def get_cloudId(base_url: str) -> str:
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
response = requests.get(tenant_info_url, timeout=10)

View File

@ -80,4 +80,4 @@ def log_exception(e, *args):
raise Exception(a.text)
else:
logging.error(str(a))
raise e
raise e

View File

@ -21,7 +21,7 @@ import weakref
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from string import Template
from typing import Any, Literal
from typing import Any, Literal, Protocol
from typing_extensions import override
@ -30,12 +30,15 @@ from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
from rag.llm.chat_model import ToolCallSession
MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class MCPToolCallSession(ToolCallSession):
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession):
await self._process_mcp_tasks(None, msg)
else:
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
await self._process_mcp_tasks(None,
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
while not self._close:
@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession):
raise
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
timeout=timeout)
if result.isError:
return f"MCP server error: {result.content}"
@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) ->
except Exception:
logging.exception("Exception during MCP session cleanup thread management")
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
logging.info(
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
def shutdown_all_mcp_sessions():
@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions():
logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
if isinstance(mcp_tool, dict):
return {
"type": "function",

View File

@ -4839,6 +4839,639 @@
"is_tools": false
}
]
},
{
"name": "JieKou.AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK",
"status": "1",
"llm": [
{
"llm_name": "Sao10K/L3-8B-Stheno-v3.2",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baichuan/baichuan-m2-32b",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baidu/ernie-4.5-300b-a47b-paddle",
"tags": "LLM,CHAT,123K",
"max_tokens": 123000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baidu/ernie-4.5-vl-424b-a47b",
"tags": "LLM,CHAT,123K",
"max_tokens": 123000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-5-haiku-20241022",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-5-sonnet-20241022",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-7-sonnet-20250219",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-haiku-20240307",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-haiku-4-5-20251001",
"tags": "LLM,CHAT,20K,IMAGE2TEXT",
"max_tokens": 20000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "claude-opus-4-1-20250805",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-opus-4-20250514",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-sonnet-4-20250514",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-sonnet-4-5-20250929",
"tags": "LLM,CHAT,200K,IMAGE2TEXT",
"max_tokens": 200000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "deepseek/deepseek-r1-0528",
"tags": "LLM,CHAT,163K",
"max_tokens": 163840,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "deepseek/deepseek-v3-0324",
"tags": "LLM,CHAT,163K",
"max_tokens": 163840,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "deepseek/deepseek-v3.1",
"tags": "LLM,CHAT,163K",
"max_tokens": 163840,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "doubao-1-5-pro-32k-250115",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "doubao-1.5-pro-32k-character-250715",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.0-flash-20250609",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.0-flash-lite",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-lite",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-lite-preview-06-17",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-lite-preview-09-2025",
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
"max_tokens": 1048576,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-preview-05-20",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-pro",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-pro-preview-06-05",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "google/gemma-3-12b-it",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "google/gemma-3-27b-it",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4.1",
"tags": "LLM,CHAT,1M",
"max_tokens": 1047576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4.1-mini",
"tags": "LLM,CHAT,1M",
"max_tokens": 1047576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4.1-nano",
"tags": "LLM,CHAT,1M",
"max_tokens": 1047576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4o",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4o-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-chat-latest",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-codex",
"tags": "LLM,CHAT,400K,IMAGE2TEXT",
"max_tokens": 400000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "gpt-5-mini",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-nano",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-pro",
"tags": "LLM,CHAT,400K,IMAGE2TEXT",
"max_tokens": 400000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "gpt-5.1",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5.1-chat-latest",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5.1-codex",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-3",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-3-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-4-0709",
"tags": "LLM,CHAT,256K",
"max_tokens": 256000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-4-fast-non-reasoning",
"tags": "LLM,CHAT,2M,IMAGE2TEXT",
"max_tokens": 2000000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "grok-4-fast-reasoning",
"tags": "LLM,CHAT,2M,IMAGE2TEXT",
"max_tokens": 2000000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "grok-code-fast-1",
"tags": "LLM,CHAT,256K",
"max_tokens": 256000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gryphe/mythomax-l2-13b",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "meta-llama/llama-3.1-8b-instruct",
"tags": "LLM,CHAT,16K",
"max_tokens": 16384,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "meta-llama/llama-3.2-3b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "meta-llama/llama-3.3-70b-instruct",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "meta-llama/llama-4-scout-17b-16e-instruct",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "minimaxai/minimax-m1-80k",
"tags": "LLM,CHAT,1M",
"max_tokens": 1000000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "mistralai/mistral-7b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "mistralai/mistral-nemo",
"tags": "LLM,CHAT,60K",
"max_tokens": 60288,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "moonshotai/kimi-k2-0905",
"tags": "LLM,CHAT,262K",
"max_tokens": 262144,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "moonshotai/kimi-k2-instruct",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o1",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o1-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o3",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o3-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "openai/gpt-oss-120b",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "openai/gpt-oss-20b",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen-2.5-72b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen-mt-plus",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen2.5-7b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen2.5-vl-72b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-235b-a22b-fp8",
"tags": "LLM,CHAT,40K",
"max_tokens": 40960,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-235b-a22b-instruct-2507",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-235b-a22b-thinking-2507",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-30b-a3b-fp8",
"tags": "LLM,CHAT,40K",
"max_tokens": 40960,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-32b-fp8",
"tags": "LLM,CHAT,40K",
"max_tokens": 40960,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-8b-fp8",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-coder-480b-a35b-instruct",
"tags": "LLM,CHAT,262K",
"max_tokens": 262144,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-next-80b-a3b-instruct",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-next-80b-a3b-thinking",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "sao10k/l3-70b-euryale-v2.1",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "sao10k/l3-8b-lunaris",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "sao10k/l31-70b-euryale-v2.2",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "thudm/glm-4.1v-9b-thinking",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "zai-org/glm-4.5",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "zai-org/glm-4.5v",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baai/bge-m3",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8192,
"model_type": "embedding"
},
{
"llm_name": "qwen/qwen3-embedding-0.6b",
"tags": "TEXT EMBEDDING,32K",
"max_tokens": 32768,
"model_type": "embedding"
},
{
"llm_name": "qwen/qwen3-embedding-8b",
"tags": "TEXT EMBEDDING,32K",
"max_tokens": 32768,
"model_type": "embedding"
},
{
"llm_name": "baai/bge-reranker-v2-m3",
"tags": "RE-RANK,8K",
"max_tokens": 8000,
"model_type": "reranker"
},
{
"llm_name": "qwen/qwen3-reranker-8b",
"tags": "RE-RANK,32K",
"max_tokens": 32768,
"model_type": "reranker"
}
]
}
]
}

View File

@ -186,9 +186,6 @@ class DoclingParser(RAGFlowPdfParser):
yield (DoclingContentType.EQUATION.value, text, bbox)
def _transfer_to_sections(self, doc) -> list[tuple[str, str]]:
"""
和 MinerUParser 保持一致:返回 [(section_text, line_tag), ...]
"""
sections: list[tuple[str, str]] = []
for typ, payload, bbox in self._iter_doc_items(doc):
if typ == DoclingContentType.TEXT.value:

View File

@ -34,6 +34,7 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if isinstance(figure_data[1], Image.Image)
]
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
@ -50,7 +51,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
@ -72,6 +74,7 @@ def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
shared_executor = ThreadPoolExecutor(max_workers=10)

View File

@ -434,7 +434,7 @@ class MinerUParser(RAGFlowPdfParser):
if not section.strip():
section = "FAILED TO PARSE TABLE"
case MinerUContentType.IMAGE:
section = "".join(output["image_caption"]) + "\n" + "".join(output["image_footnote"])
section = "".join(output.get("image_caption", [])) + "\n" + "".join(output.get("image_footnote", []))
case MinerUContentType.EQUATION:
section = output["text"]
case MinerUContentType.CODE:

View File

@ -117,7 +117,6 @@ def load_model(model_dir, nm, device_id: int | None = None):
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
else:
sess = ort.InferenceSession(

View File

@ -71,7 +71,7 @@ for arg in "$@"; do
ENABLE_TASKEXECUTOR=0
shift
;;
--disable-datasyn)
--disable-datasync)
ENABLE_DATASYNC=0
shift
;;

View File

@ -12,6 +12,10 @@ The RAGFlow Admin UI is a web-based interface that provides comprehensive system
To access the RAGFlow admin UI, append `/admin` to the web UI's address, e.g. `http://[RAGFLOW_WEB_UI_ADDR]/admin`, replace `[RAGFLOW_WEB_UI_ADDR]` with real RAGFlow web UI address.
### Default Credentials
| Username | Password |
|----------|----------|
| `admin@ragflow.io` | `admin` |
## Admin UI Overview

View File

@ -0,0 +1,8 @@
{
"label": "Add data source",
"position": 18,
"link": {
"type": "generated-index",
"description": "Add various data sources"
}
}

View File

@ -0,0 +1,137 @@
---
sidebar_position: 3
slug: /add_google_drive
---
# Add Google Drive
## 1. Create a Google Cloud Project
You can either create a dedicated project for RAGFlow or use an existing
Google Cloud external project.
**Steps:**
1. Open the project creation page\
`https://console.cloud.google.com/projectcreate`
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image1.jpeg?raw=true)
2. Select **External** as the Audience
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image2.png?raw=true)
3. Click **Create**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image3.jpeg?raw=true)
------------------------------------------------------------------------
## 2. Configure OAuth Consent Screen
1. Go to **APIs & Services → OAuth consent screen**
2. Ensure **User Type = External**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image4.jpeg?raw=true)
3. Add your test users under **Test Users** by entering email addresses
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image5.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image6.jpeg?raw=true)
------------------------------------------------------------------------
## 3. Create OAuth Client Credentials
1. Navigate to:\
`https://console.cloud.google.com/auth/clients`
2. Create a **Web Application**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image7.png?raw=true)
3. Enter a name for the client
4. Add the following **Authorized Redirect URIs**:
```
http://localhost:9380/v1/connector/google-drive/oauth/web/callback
```
### If using Docker deployment:
**Authorized JavaScript origin:**
```
http://localhost:80
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image8.png?raw=true)
### If running from source:
**Authorized JavaScript origin:**
```
http://localhost:9222
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image9.png?raw=true)
5. After saving, click **Download JSON**. This file will later be
uploaded into RAGFlow.
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image10.png?raw=true)
------------------------------------------------------------------------
## 4. Add Scopes
1. Open **Data Access → Add or remove scopes**
2. Paste and add the following entries:
```
https://www.googleapis.com/auth/drive.readonly
https://www.googleapis.com/auth/drive.metadata.readonly
https://www.googleapis.com/auth/admin.directory.group.readonly
https://www.googleapis.com/auth/admin.directory.user.readonly
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image11.jpeg?raw=true)
3. Update and Save changes
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image12.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image13.jpeg?raw=true)
------------------------------------------------------------------------
## 5. Enable Required APIs
Navigate to the Google API Library:\
`https://console.cloud.google.com/apis/library`
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image14.png?raw=true)
Enable the following APIs:
- Google Drive API
- Admin SDK API
- Google Sheets API
- Google Docs API
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image15.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image16.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image17.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image18.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image19.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image21.png?raw=true)
------------------------------------------------------------------------
## 6. Add Google Drive As a Data Source in RAGFlow
1. Go to **Data Sources** inside RAGFlow
2. Select **Google Drive**
3. Upload the previously downloaded JSON credentials
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image22.jpeg?raw=true)
4. Enter the shared Google Drive folder link (https://drive.google.com/drive), such as:
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image23.png?raw=true)
5. Click **Authorize with Google**
A browser window will appear.
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image25.jpeg?raw=true)
Click: - **Continue** - **Select All → Continue** - Authorization should
succeed - Select **OK** to add the data source
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image26.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image27.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image28.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image29.png?raw=true)

View File

@ -1,6 +1,6 @@
{
"label": "Best practices",
"position": 11,
"position": 19,
"link": {
"type": "generated-index",
"description": "Best practices on configuring a dataset."

View File

@ -64,7 +64,10 @@ The Admin CLI and Admin Service form a client-server architectural suite for RAG
- -p: RAGFlow admin server port
## Default administrative account
- Username: admin@ragflow.io
- Password: admin
## Supported Commands

View File

@ -974,6 +974,237 @@ Failure:
---
### Construct knowledge graph
**POST** `/api/v1/datasets/{dataset_id}/run_graphrag`
Constructs a knowledge graph from a specified dataset.
#### Request
- Method: POST
- URL: `/api/v1/datasets/{dataset_id}/run_graphrag`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request POST \
--url http://{address}/api/v1/datasets/{dataset_id}/run_graphrag \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"graphrag_task_id":"e498de54bfbb11f0ba028f704583b57b"
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Get knowledge graph construction status
**GET** `/api/v1/datasets/{dataset_id}/trace_graphrag`
Retrieves the knowledge graph construction status for a specified dataset.
#### Request
- Method: GET
- URL: `/api/v1/datasets/{dataset_id}/trace_graphrag`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request GET \
--url http://{address}/api/v1/datasets/{dataset_id}/trace_graphrag \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"begin_at":"Wed, 12 Nov 2025 19:36:56 GMT",
"chunk_ids":"",
"create_date":"Wed, 12 Nov 2025 19:36:56 GMT",
"create_time":1762947416350,
"digest":"39e43572e3dcd84f",
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
"from_page":100000000,
"id":"e498de54bfbb11f0ba028f704583b57b",
"priority":0,
"process_duration":2.45419,
"progress":1.0,
"progress_msg":"19:36:56 created task graphrag\n19:36:57 Task has been received.\n19:36:58 [GraphRAG] doc:083661febe2411f0bc79456921e5745f has no available chunks, skip generation.\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 start (chunks=1, timeout=10000000000s)\n19:36:58 Graph already contains 44661c10bde211f0bc93c164a47ffc40\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 empty\n19:36:58 [GraphRAG] kb:33137ed0bde211f0bc93c164a47ffc40 no subgraphs generated successfully, end.\n19:36:58 Knowledge Graph done (0.72s)","retry_count":1,
"task_type":"graphrag",
"to_page":100000000,
"update_date":"Wed, 12 Nov 2025 19:36:58 GMT",
"update_time":1762947418454
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Construct RAPTOR
**POST** `/api/v1/datasets/{dataset_id}/run_raptor`
Construct a RAPTOR from a specified dataset.
#### Request
- Method: POST
- URL: `/api/v1/datasets/{dataset_id}/run_raptor`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request POST \
--url http://{address}/api/v1/datasets/{dataset_id}/run_raptor \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"raptor_task_id":"50d3c31cbfbd11f0ba028f704583b57b"
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Get RAPTOR construction status
**GET** `/api/v1/datasets/{dataset_id}/trace_raptor`
Retrieves the RAPTOR construction status for a specified dataset.
#### Request
- Method: GET
- URL: `/api/v1/datasets/{dataset_id}/trace_raptor`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request GET \
--url http://{address}/api/v1/datasets/{dataset_id}/trace_raptor \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"begin_at":"Wed, 12 Nov 2025 19:47:07 GMT",
"chunk_ids":"",
"create_date":"Wed, 12 Nov 2025 19:47:07 GMT",
"create_time":1762948027427,
"digest":"8b279a6248cb8fc6",
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
"from_page":100000000,
"id":"50d3c31cbfbd11f0ba028f704583b57b",
"priority":0,
"process_duration":0.948244,
"progress":1.0,
"progress_msg":"19:47:07 created task raptor\n19:47:07 Task has been received.\n19:47:07 Processing...\n19:47:07 Processing...\n19:47:07 Indexing done (0.01s).\n19:47:07 Task done (0.29s)",
"retry_count":1,
"task_type":"raptor",
"to_page":100000000,
"update_date":"Wed, 12 Nov 2025 19:47:07 GMT",
"update_time":1762948027948
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
## FILE MANAGEMENT WITHIN DATASET
---

View File

@ -67,6 +67,7 @@ A complete list of models supported by RAGFlow, which will continue to expand.
| 302.AI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| CometAPI | :heavy_check_mark: | :heavy_check_mark: | | | | |
| DeerAPI | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | :heavy_check_mark: |
| Jiekou.AI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | |
```mdx-code-block
</APITable>

View File

@ -693,7 +693,7 @@ Released on August 26, 2024.
- Incorporates monitoring for the task executor.
- Introduces Agent tools **GitHub**, **DeepL**, **BaiduFanyi**, **QWeather**, and **GoogleScholar**.
- Supports chunking of EML files.
- Supports more LLMs or model services: **GPT-4o-mini**, **PerfXCloud**, **TogetherAI**, **Upstage**, **Novita AI**, **01.AI**, **SiliconFlow**, **PPIO**, **XunFei Spark**, **Baidu Yiyan**, and **Tencent Hunyuan**.
- Supports more LLMs or model services: **GPT-4o-mini**, **PerfXCloud**, **TogetherAI**, **Upstage**, **Novita AI**, **01.AI**, **SiliconFlow**, **PPIO**, **XunFei Spark**, **Jiekou.AI**, **Baidu Yiyan**, and **Tencent Hunyuan**.
## v0.9.0

View File

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = 3
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency)

View File

@ -69,7 +69,7 @@ class KGSearch(Dealer):
def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {}
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
es_res = self.dataStore.getFields(es_res, flds)
es_res = self.dataStore.get_fields(es_res, flds)
for _, ent in es_res.items():
for f in flds:
if f in ent and ent[f] is None:
@ -88,7 +88,7 @@ class KGSearch(Dealer):
def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"])
for _, ent in es_res.items():
if get_float(ent["_score"]) < sim_thr:
@ -300,7 +300,7 @@ class KGSearch(Dealer):
fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [],
OrderByExpr(), 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, fields)
comm_res_fields = self.dataStore.get_fields(comm_res, fields)
txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["content_with_weight"])

View File

@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
)
# tot = settings.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds)
# tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds)
if len(es_res) == 0:
break

View File

@ -145,6 +145,7 @@ dependencies = [
"markdownify>=1.2.0",
"captcha>=0.7.1",
"pip>=25.2",
"pypandoc>=1.16",
]
[dependency-groups]

View File

@ -49,6 +49,7 @@ class SupportedLiteLLMProvider(StrEnum):
Lingyi_AI = "01.AI"
GiteeAI = "GiteeAI"
AI_302 = "302.AI"
JiekouAI = "Jiekou.AI"
FACTORY_DEFAULT_BASE_URL = {
@ -69,6 +70,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/",
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
}
@ -99,6 +101,7 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.Lingyi_AI: "openai/",
SupportedLiteLLMProvider.GiteeAI: "openai/",
SupportedLiteLLMProvider.AI_302: "openai/",
SupportedLiteLLMProvider.JiekouAI: "openai/",
}
ChatModel = globals().get("ChatModel", {})

View File

@ -22,7 +22,6 @@ import re
import time
from abc import ABC
from copy import deepcopy
from typing import Any, Protocol
from urllib.parse import urljoin
import json_repair
@ -65,10 +64,6 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
@ -1402,6 +1397,7 @@ class LiteLLMBase(ABC):
"01.AI",
"GiteeAI",
"302.AI",
"Jiekou.AI",
]
def __init__(self, key, model_name, base_url=None, **kwargs):

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import re
import base64
import json
import os
@ -32,7 +33,6 @@ from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from common.token_utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
@ -208,6 +208,7 @@ class GptV4(Base):
model=self.model_name,
messages=self.prompt(b64),
extra_body=self.extra_body,
unused = None,
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
@ -324,6 +325,122 @@ class Zhipu4V(GptV4):
Base.__init__(self, **kwargs)
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
gen_conf = self._clean_conf_plealty(gen_conf)
return gen_conf
def _clean_conf_plealty(self, gen_conf):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
return gen_conf
def _request(self, msg, stream, gen_conf={}):
response = requests.post(
self.base_url,
json={
"model": self.model_name,
"messages": msg,
"stream": stream,
**gen_conf
},
headers= {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
)
return response.json()
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf)
content = response.choices[0].message.content.strip()
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
return cleaned, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
from rag.nlp import is_chinese
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
try:
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans = delta
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = total_token_count_from_response(resp)
if resp.choices[0].finish_reason == "stop":
tk_count = total_token_count_from_response(resp)
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def describe(self, image):
return self.describe_with_prompt(image)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
if prompt is None:
prompt = "Describe this image."
# Chat messages
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": { "url": b64 }
},
{
"type": "text",
"text": prompt
}
]
}
]
resp = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
stream=False
)
content = resp.choices[0].message.content.strip()
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
return cleaned, num_tokens_from_string(cleaned)
class StepFunCV(GptV4):
_FACTORY_NAME = "StepFun"

View File

@ -931,3 +931,12 @@ class DeerAPIEmbed(OpenAIEmbed):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url)
class JiekouAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "Jiekou.AI"
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/embeddings"):
if not base_url:
base_url = "https://api.jiekou.ai/openai/v1/embeddings"
super().__init__(key, model_name, base_url)

View File

@ -489,3 +489,12 @@ class Ai302Rerank(Base):
if not base_url:
base_url = "https://api.302.ai/v1/rerank"
super().__init__(key, model_name, base_url)
class JiekouAIRerank(JinaRerank):
_FACTORY_NAME = "Jiekou.AI"
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/rerank"):
if not base_url:
base_url = "https://api.jiekou.ai/openai/v1/rerank"
super().__init__(key, model_name, base_url)

View File

@ -155,13 +155,13 @@ def qbullets_category(sections):
if re.match(pro, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
maximum = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
if h <= maximum:
continue
res = i
maxium = h
maximum = h
return res, QUESTION_PATTERN[res]
@ -222,13 +222,13 @@ def bullets_category(sections):
if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
maximum = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
if h <= maximum:
continue
res = i
maxium = h
maximum = h
return res
@ -482,7 +482,7 @@ def tree_merge(bull, sections, depth):
root = Node(level=0, depth=target_level, texts=[])
root.build_tree(lines)
return [("\n").join(element) for element in root.get_tree() if element]
return [element for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):
@ -723,47 +723,40 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。"):
if not sections:
return [], []
cks = [""]
images = [None]
tk_nums = [0]
cks = []
images = []
tk_nums = []
def add_chunk(t, image, pos=""):
nonlocal cks, tk_nums, delimiter
nonlocal cks, images, tk_nums
tnum = num_tokens_from_string(t)
if tnum < 8:
pos = ""
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
if t.find(pos) < 0:
if not cks or tk_nums[-1] > chunk_token_num:
# new chunk
if pos and t.find(pos) < 0:
t += pos
cks.append(t)
images.append(image)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
# add to last chunk
if pos and cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
images[-1] = concat_img(images[-1], image)
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
line = ""
for sec, image in sections:
if not image:
line += sec + "\n"
continue
split_sec = re.split(r"(%s)" % dels, line + sec)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk("\n"+sub_sec, image,"")
line = ""
pattern = r"(%s)" % dels
if line:
split_sec = re.split(r"(%s)" % dels, line)
for sec, image in sections:
split_sec = re.split(pattern, sec)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
if not sub_sec or re.match(f"^{dels}$", sub_sec):
continue
add_chunk("\n"+sub_sec, image,"")
add_chunk("\n" + sub_sec, image, "")
return cks, images

View File

@ -38,11 +38,11 @@ class FulltextQueryer:
]
@staticmethod
def subSpecialChar(line):
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
@ -92,7 +92,7 @@ class FulltextQueryer:
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt):
if not self.is_chinese(txt):
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
@ -163,7 +163,7 @@ class FulltextQueryer:
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
@ -171,7 +171,7 @@ class FulltextQueryer:
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
@ -180,7 +180,7 @@ class FulltextQueryer:
if len(keywords) >= 32:
break
tk = FulltextQueryer.subSpecialChar(tk)
tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
@ -198,7 +198,7 @@ class FulltextQueryer:
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
for s in syns
]
)
@ -217,17 +217,17 @@ class FulltextQueryer:
return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
sims = cosine_similarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
def to_dict(tks):
if isinstance(tks, str):
tks = tks.split()
d = defaultdict(int)
@ -236,8 +236,8 @@ class FulltextQueryer:
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
atks = to_dict(atks)
btkss = [to_dict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
@ -262,10 +262,10 @@ class FulltextQueryer:
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:

View File

@ -35,7 +35,7 @@ class RagTokenizer:
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
def _load_dict(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding='utf-8')
@ -85,18 +85,18 @@ class RagTokenizer:
self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file
self.loadDict_(self.DIR_ + ".txt")
self._load_dict(self.DIR_ + ".txt")
def loadUserDict(self, fnm):
def load_user_dict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
self._load_dict(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def add_user_dict(self, fnm):
self._load_dict(fnm)
def _strQ2B(self, ustring):
"""Convert full-width characters to half-width characters"""
@ -221,7 +221,7 @@ class RagTokenizer:
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
def _sort_tokens(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
@ -246,7 +246,7 @@ class RagTokenizer:
return " ".join(res)
def maxForward_(self, line):
def _max_forward(self, line):
res = []
s = 0
while s < len(line):
@ -270,7 +270,7 @@ class RagTokenizer:
return self.score_(res)
def maxBackward_(self, line):
def _max_backward(self, line):
res = []
s = len(line) - 1
while s >= 0:
@ -336,8 +336,8 @@ class RagTokenizer:
continue
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
tks, s = self._max_forward(L)
tks1, s1 = self._max_backward(L)
if self.DEBUG:
logging.debug("[FW] {} {}".format(tks, s))
logging.debug("[BW] {} {}".format(tks1, s1))
@ -369,7 +369,7 @@ class RagTokenizer:
# backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
@ -385,7 +385,7 @@ class RagTokenizer:
assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
res = " ".join(res)
logging.debug("[TKS] {}".format(self.merge_(res)))
@ -413,7 +413,7 @@ class RagTokenizer:
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
stk = self._sort_tokens(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
@ -447,14 +447,13 @@ def is_number(s):
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
def naive_qie(txt):
tks = []
for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize
tag = tokenizer.tag
freq = tokenizer.freq
loadUserDict = tokenizer.loadUserDict
addUserDict = tokenizer.addUserDict
load_user_dict = tokenizer.load_user_dict
add_user_dict = tokenizer.add_user_dict
tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B
if __name__ == '__main__':
tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks))
@ -506,7 +505,7 @@ if __name__ == '__main__':
if len(sys.argv) < 2:
sys.exit()
tknzr.DEBUG = False
tknzr.loadUserDict(sys.argv[1])
tknzr.load_user_dict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()

View File

@ -102,7 +102,7 @@ class Dealer:
orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"]
@ -115,7 +115,7 @@ class Dealer:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
@ -127,20 +127,20 @@ class Dealer:
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
@ -153,17 +153,17 @@ class Dealer:
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res)
ids = self.dataStore.get_chunk_ids(res)
keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
return self.SearchResult(
total=total,
ids=ids,
query_vector=q_vec,
aggregation=aggs,
highlight=highlight,
field=self.dataStore.getFields(res, src + ["_score"]),
field=self.dataStore.get_fields(res, src + ["_score"]),
keywords=keywords
)
@ -347,7 +347,7 @@ class Dealer:
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
return tkweight * np.array(tksim) + vtweight * vtsim + rank_fea, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
@ -488,7 +488,7 @@ class Dealer:
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
dict_chunks = self.dataStore.get_fields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
@ -501,11 +501,11 @@ class Dealer:
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
return self.dataStore.get_aggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
res = self.dataStore.get_aggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
@ -513,7 +513,7 @@ class Dealer:
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
@ -529,7 +529,7 @@ class Dealer:
idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
@ -552,7 +552,7 @@ class Dealer:
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
for _, doc in dict_chunks.items():
try:
toc.extend(json.loads(doc["content_with_weight"]))

View File

@ -113,20 +113,20 @@ class Dealer:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
def token_merge(self, tks):
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
if i == 0 and one_term(tks[i]) and len(
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
@ -232,7 +232,7 @@ class Dealer:
tw = list(zip(tks, wts))
else:
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
tt = self.token_merge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \

View File

@ -15,27 +15,35 @@
#
import logging
import re
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import (
get_llm_cache,
chat_limiter,
get_embed_cache,
get_llm_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
):
self._max_cluster = max_cluster
self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold
self._prompt = prompt
self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20)
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync(
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
)
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
if cached:
return cached
if response:
return response
response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
)
return response
last_exc = None
for attempt in range(3):
try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20)
async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
if response is not None:
return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters)
bics = []
for n in n_clusters:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
@timeout(60*20)
@timeout(60 * 20)
async def summarize(ck_idx: list[int]):
nonlocal chunks
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int(
(self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter:
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = []
while end - start > 1:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2:
await summarize([start, start + 1])
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
labels.extend([0, 0])
layers.append((end, len(chunks)))
start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)
layers.append((end, len(chunks)))
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
start = end
end = len(chunks)

View File

@ -28,7 +28,7 @@ def collect():
logging.debug(doc_locations)
if len(doc_locations) == 0:
time.sleep(1)
return
return None
return doc_locations

View File

@ -20,33 +20,40 @@
import copy
import faulthandler
import logging
import os
import signal
import sys
import threading
import time
import traceback
from datetime import datetime, timezone
from typing import Any
import trio
from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common.log_utils import init_root_logger
from common.config_utils import show_configs
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector
import logging
import os
from datetime import datetime, timezone
import signal
import trio
import faulthandler
from common.constants import FileSource, TaskStatus
from common import settings
from common.versions import get_ragflow_version
from common.config_utils import show_configs
from common.constants import FileSource, TaskStatus
from common.data_source import (
BlobStorageConnector,
DiscordConnector,
GoogleDriveConnector,
JiraConnector,
NotionConnector,
)
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.interfaces import CheckpointOutputWrapper
from common.data_source.utils import load_all_docs_from_checkpoint_connector
from common.data_source.config import INDEX_BATCH_SIZE
from common.log_utils import init_root_logger
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.versions import get_ragflow_version
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -72,31 +79,32 @@ class SyncBase:
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
} for doc in document_batch]
docs = [
{
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
for doc in document_batch
]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized till {}".format(doc_num, next_update))
prefix = "[Jira] " if self.SOURCE_NAME == FileSource.JIRA else ""
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except Exception as ex:
msg = '\n'.join([
''.join(traceback.format_exception_only(None, ex)).strip(),
''.join(traceback.format_exception(None, ex, ex.__traceback__)).strip()
])
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
@ -109,21 +117,16 @@ class S3(SyncBase):
SOURCE_NAME: str = FileSource.S3
async def _generate(self, task: dict):
self.connector = BlobStorageConnector(
bucket_type=self.conf.get("bucket_type", "s3"),
bucket_name=self.conf["bucket_name"],
prefix=self.conf.get("prefix", "")
)
self.connector = BlobStorageConnector(bucket_type=self.conf.get("bucket_type", "s3"), bucket_name=self.conf["bucket_name"], prefix=self.conf.get("prefix", ""))
self.connector.load_credentials(self.conf["credentials"])
document_batch_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
document_batch_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
)
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"),
self.conf["bucket_name"],
self.conf.get("prefix", ""),
begin_info
))
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"), self.conf["bucket_name"], self.conf.get("prefix", ""), begin_info))
return document_batch_generator
@ -131,8 +134,8 @@ class Confluence(SyncBase):
SOURCE_NAME: str = FileSource.CONFLUENCE
async def _generate(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource
from common.data_source.interfaces import StaticCredentialsProvider
self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"],
@ -141,11 +144,7 @@ class Confluence(SyncBase):
# page_id=self.conf.get("page_id", ""),
)
credentials_provider = StaticCredentialsProvider(
tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json=self.conf["credentials"]
)
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
@ -174,10 +173,13 @@ class Notion(SyncBase):
async def _generate(self, task: dict):
self.connector = NotionConnector(root_page_id=self.conf["root_page_id"])
self.connector.load_credentials(self.conf["credentials"])
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
)
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
return document_generator
@ -194,13 +196,16 @@ class Discord(SyncBase):
server_ids=server_ids.split(",") if server_ids else [],
channel_names=channel_names.split(",") if channel_names else [],
start_date=datetime(1970, 1, 1, tzinfo=timezone.utc).strftime("%Y-%m-%d"),
batch_size=self.conf.get("batch_size", 1024)
batch_size=self.conf.get("batch_size", 1024),
)
self.connector.load_credentials(self.conf["credentials"])
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
)
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
return document_generator
@ -285,7 +290,7 @@ class GoogleDrive(SyncBase):
admin_email = self.connector.primary_admin_email
except RuntimeError:
admin_email = "unknown"
logging.info("Connect to Google Drive as %s %s", admin_email, begin_info)
logging.info(f"Connect to Google Drive as {admin_email} {begin_info}")
return document_batches()
def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None:
@ -303,7 +308,93 @@ class Jira(SyncBase):
SOURCE_NAME: str = FileSource.JIRA
async def _generate(self, task: dict):
pass
connector_kwargs = {
"jira_base_url": self.conf["base_url"],
"project_key": self.conf.get("project_key"),
"jql_query": self.conf.get("jql_query"),
"batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
"include_comments": self.conf.get("include_comments", True),
"include_attachments": self.conf.get("include_attachments", False),
"labels_to_skip": self._normalize_list(self.conf.get("labels_to_skip")),
"comment_email_blacklist": self._normalize_list(self.conf.get("comment_email_blacklist")),
"scoped_token": self.conf.get("scoped_token", False),
"attachment_size_limit": self.conf.get("attachment_size_limit"),
"timezone_offset": self.conf.get("timezone_offset"),
}
self.connector = JiraConnector(**connector_kwargs)
credentials = self.conf.get("credentials")
if not credentials:
raise ValueError("Jira connector is missing credentials.")
self.connector.load_credentials(credentials)
self.connector.validate_connector_settings()
if task["reindex"] == "1" or not task["poll_range_start"]:
start_time = 0.0
begin_info = "totally"
else:
start_time = task["poll_range_start"].timestamp()
begin_info = f"from {task['poll_range_start']}"
end_time = datetime.now(timezone.utc).timestamp()
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
try:
batch_size = int(raw_batch_size)
except (TypeError, ValueError):
batch_size = INDEX_BATCH_SIZE
if batch_size <= 0:
batch_size = INDEX_BATCH_SIZE
def document_batches():
checkpoint = self.connector.build_dummy_checkpoint()
pending_docs = []
iterations = 0
iteration_limit = 100_000
while checkpoint.has_more:
wrapper = CheckpointOutputWrapper()
generator = wrapper(
self.connector.load_from_checkpoint(
start_time,
end_time,
checkpoint,
)
)
for document, failure, next_checkpoint in generator:
if failure is not None:
logging.warning(
f"[Jira] Jira connector failure: {getattr(failure, 'failure_message', failure)}"
)
continue
if document is not None:
pending_docs.append(document)
if len(pending_docs) >= batch_size:
yield pending_docs
pending_docs = []
if next_checkpoint is not None:
checkpoint = next_checkpoint
iterations += 1
if iterations > iteration_limit:
logging.error(f"[Jira] Task {task.get('id')} exceeded iteration limit ({iteration_limit}).")
raise RuntimeError("Too many iterations while loading Jira documents.")
if pending_docs:
yield pending_docs
logging.info(f"[Jira] Connect to Jira {connector_kwargs['jira_base_url']} {begin_info}")
return document_batches()
@staticmethod
def _normalize_list(values: Any) -> list[str] | None:
if values is None:
return None
if isinstance(values, str):
values = [item.strip() for item in values.split(",")]
return [str(value).strip() for value in values if value is not None and str(value).strip()]
class SharePoint(SyncBase):
@ -337,9 +428,10 @@ func_factory = {
FileSource.JIRA: Jira,
FileSource.SHAREPOINT: SharePoint,
FileSource.SLACK: Slack,
FileSource.TEAMS: Teams
FileSource.TEAMS: Teams,
}
async def dispatch_tasks():
async with trio.open_nursery() as nursery:
while True:
@ -385,7 +477,7 @@ async def main():
__/ |
|___/
""")
logging.info(f'RAGFlow version: {get_ragflow_version()}')
logging.info(f"RAGFlow version: {get_ragflow_version()}")
show_configs()
settings.init_settings()
if sys.platform != "win32":

View File

@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback):
task_canceled = has_canceled(task["id"])
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback):
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
def init_kb(row, vector_size: int):
@ -441,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@timeout(60)
@ -464,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
vects = title_w * tts + (1 - title_w) * cnts
else:
vects = cnts
assert len(vects) == len(docs)
vector_size = 0
@ -648,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = []
tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did):
nonlocal tk_count, res
raptor = Raptor(
@ -657,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"],
raptor_config["max_token"],
raptor_config["threshold"],
max_errors=max_errors,
)
original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
@ -719,7 +725,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
return False
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@ -737,7 +743,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return False
return True

View File

@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"Fail put {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return None
return None
def rm(self, bucket, fnm):
try:
@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
def obj_exist(self, bucket, fnm):
try:
@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None

View File

@ -241,23 +241,23 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def getTotal(self, res):
def get_total(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
def get_chunk_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""

View File

@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res):
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getChunkIds(self, res):
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection):
rr.append(d["_source"])
return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection):
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection):
return ans
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()

View File

@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection):
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res_fields = self.getFields(res, res.columns.tolist())
res_fields = self.get_fields(res, res.columns.tolist())
return res_fields.get(chunkId, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection):
col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.getFields(row_to_opt, col_to_remove)
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items():
if remove_v in old_v[k]:
@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection):
return res2.set_index("id").to_dict(orient="index")
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection):
ans[id] = txt
return ans
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""

View File

@ -92,7 +92,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get {bucket}/{filename}")
self.__open__()
time.sleep(1)
return
return None
def obj_exist(self, bucket, filename, tenant_id=None):
try:
@ -130,7 +130,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return
return None
def remove_bucket(self, bucket):
try:

Some files were not shown because too many files have changed in this diff Show More