Compare commits

...

76 Commits

Author SHA1 Message Date
cfdccebb17 Feat: Fixed an issue where modifying fields in the agent operator caused the loss of structured data. #10427 (#11388)
### What problem does this PR solve?

Feat: Fixed an issue where modifying fields in the agent operator caused
the loss of structured data. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-19 20:11:53 +08:00
980a883033 Docs: minor (#11385)
### What problem does this PR solve?

### Type of change

- [x] Documentation Update
2025-11-19 19:41:21 +08:00
02d429f0ca Doc: Optimize read me (#11386)
### What problem does this PR solve?

Users currently can’t view `git checkout v0.22.1` directly. They need to
scroll the code block all the way to the right to see it.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-19 19:40:55 +08:00
9c24d5d44a Fix some multilingual issues (#11382)
### What problem does this PR solve?

Fix some multilingual issues

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-19 19:14:43 +08:00
0cc5d7a8a6 Feat: If a query variable in a data manipulation operator is deleted, a warning message should be displayed to the user. #10427 #11255 (#11384)
### What problem does this PR solve?

Feat: If a query variable in a data manipulation operator is deleted, a
warning message should be displayed to the user. #10427 #11255

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-19 19:10:57 +08:00
c43bf1dcf5 Fix: refine error msg. (#11380)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-19 19:10:45 +08:00
f76b8279dd Doc: Added v0.22.1 release notes (#11383)
### What problem does this PR solve?


### Type of change


- [x] Documentation Update
2025-11-19 18:40:06 +08:00
db5ec89dc5 Feat: The key for the begin operator can only contain alphanumeric characters and underscores. #10427 (#11377)
### What problem does this PR solve?

Feat: The key for the begin operator can only contain alphanumeric
characters and underscores. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-19 16:16:57 +08:00
1c201c4d54 Fix: circle imports issue. (#11374)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-19 16:13:21 +08:00
ba78d0f0c2 Feat: Structured data will still be stored in outputs for compatibility with older versions. #10427 (#11368)
### What problem does this PR solve?

Feat: Structured data will still be stored in outputs for compatibility
with older versions. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-19 15:15:51 +08:00
add8c63458 Add release notes (#11372)
### What problem does this PR solve?

As title.

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-19 14:48:41 +08:00
83661efdaf Update README for supporting Gemini 3 Pro (#11369)
### What problem does this PR solve?

As title

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-19 14:16:03 +08:00
971197d595 Feat: Set the outputs type of list operation. #10427 (#11366)
### What problem does this PR solve?

Feat: Set the outputs type of list operation. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-19 13:59:43 +08:00
0884e9a4d9 Fix: bbox not included in mineru output (#11365)
### What problem does this PR solve?

Fix: bbox not included in mineru output #11315

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-19 13:59:32 +08:00
2de42f00b8 Fix: component list operation issue. (#11364)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-19 13:19:44 +08:00
e8fe580d7a Feat: add Gemini 3 Pro preview (#11361)
### What problem does this PR solve?

Add Gemini 3 Pro preview.

Change `GenerativeModel` to `genai`.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-19 13:17:22 +08:00
62505164d5 chore(template): introducing variable aggregator to customer service template (#11352)
### What problem does this PR solve?
Update customer service template

### Type of change
- [x] Other (please describe):
2025-11-19 12:28:06 +08:00
d1dcf3b43c Refactor /stats API (#11363)
### What problem does this PR solve?

One loop to get better performance

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-19 12:27:45 +08:00
f84662d2ee Fix: Fixed an issue where variable aggregation operators could not be connected to other operators. #10427 (#11358)
### What problem does this PR solve?

Fix: Fixed an issue where variable aggregation operators could not be
connected to other operators. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-19 10:29:26 +08:00
1cb6b7f5dd Update version info to v0.22.1 (#11346)
### What problem does this PR solve?

As title

### Type of change

- [x] Other (please describe): Update version info

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-19 09:50:23 +08:00
023f509501 Fix: variable assigner issue. (#11351)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-19 09:49:40 +08:00
50bc53a1f5 Fix: Modify the personal center style #10703 (#11347)
### What problem does this PR solve?

Fix: Modify the personal center style

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 20:07:17 +08:00
8cd4882596 Feat: Display variables in the variable assignment node. #10427 (#11349)
### What problem does this PR solve?

Feat: Display variables in the variable assignment node. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 20:07:04 +08:00
35e5fade93 Feat: new component variable assigner (#11050)
### What problem does this PR solve?
issue:
https://github.com/infiniflow/ragflow/issues/10427
change:
new component variable assigner
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 19:14:38 +08:00
4942a23290 Feat: Add a switch to control the display of structured output to the agent form. #10427 (#11344)
### What problem does this PR solve?

Feat: Add a switch to control the display of structured output to the
agent form. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 18:58:36 +08:00
d1716d865a Feat: Alter flask to Quart for async API serving. (#11275)
### What problem does this PR solve?

#11277

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 17:05:16 +08:00
c2b7c305fa Fix: crop index may out of range (#11341)
### What problem does this PR solve?

Crop index may out of range. #11323


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 17:01:54 +08:00
341e5904c8 Fix: No results can be found through the API /api/v1/dify/retrieval (#11338)
### What problem does this PR solve?

No results can be found through the API /api/v1/dify/retrieval. #11307 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 15:42:31 +08:00
ded9bf80c5 Fix:limit random sampling range in check_embedding (#11337)
### What problem does this PR solve?
issue:
[#11319](https://github.com/infiniflow/ragflow/issues/11319)
change:
limit random sampling range in check_embedding

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 15:24:27 +08:00
fea157ba08 Fix: manual parser with mineru (#11336)
### What problem does this PR solve?

Fix: manual parser with mineru #11320
Fix: missing parameter in mineru #11334
Fix: add outlines parameter for pdf parsers

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 15:22:52 +08:00
0db00f70b2 Fix: add describe_image_with_prompt for ZHIPU AI (#11317)
### What problem does this PR solve?

Fix: add describe_image_with_prompt for ZHIPU AI  #11289 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 13:09:39 +08:00
701761d119 Feat: Fixed the issue where form data assigned by variables was not updated in real time. #10427 (#11333)
### What problem does this PR solve?

Feat: Fixed the issue where form data assigned by variables was not
updated in real time. #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 13:07:52 +08:00
2993fc666b Feat: update version to 0.22.1 (#11331)
### What problem does this PR solve?

Update version to 0.22.1

### Type of change

- [x] Documentation Update
2025-11-18 10:49:36 +08:00
8a6d205df0 fix: entrypoint.sh typo for disable datasync command (#11326)
### What problem does this PR solve?

There's a typo in `entrypoint.sh` on line 74: the case statement uses
`--disable-datasyn)` (missing the 'c'), while the usage function and
documentation correctly show `--disable-datasync` (with the 'c'). This
mismatch causes the `--disable-datasync` flag to be unrecognized,
triggering the usage message and causing containers to restart in a loop
when this flag is used.

**Background:**
- Users following the documentation use `--disable-datasync` in their
docker-compose.yml
- The entrypoint script doesn't recognize this flag due to the typo
- The script calls `usage()` and exits, causing Docker containers to
restart continuously
- This makes it impossible to disable the data sync service as intended

**Example scenario:**
When a user adds `--disable-datasync` to their docker-compose command
(as shown in examples), the container fails to start properly because
the argument isn't recognized.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
### Proposed Solution

Fix the typo on line 74 of `entrypoint.sh` by changing:
```bash
    --disable-datasyn)
```
to:
```bash
    --disable-datasync)
```

This matches the spelling used in the usage function (line 9 and 13) and
allows the flag to work as documented.

### Changes Made

- Fixed typo in `entrypoint.sh` line 74: changed `--disable-datasyn)` to
`--disable-datasync)`
- This ensures the argument matches the documented flag name and usage
function

---

**Code change:**

```bash
# Line 74 in entrypoint.sh
# Before:
    --disable-datasyn)
      ENABLE_DATASYNC=0
      shift
      ;;

# After:
    --disable-datasync)
      ENABLE_DATASYNC=0
      shift
      ;;
```

This is a simple one-character fix that resolves the argument parsing
issue.
2025-11-18 10:28:00 +08:00
912b6b023e fix: update check_embedding failed info (#11321)
### What problem does this PR solve?
change:
update check_embedding failed info

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 09:39:45 +08:00
89e8818dda Feat: add s3-compatible storage boxes (#11313)
### What problem does this PR solve?

PR for implementing s3 compatible storage units #11240 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-18 09:39:25 +08:00
1dba6b5bf9 Fix: Fixed an issue where adding session variables multiple times would overwrite them. (#11308)
### What problem does this PR solve?

Fix: Fixed an issue where adding session variables multiple times would
overwrite them.
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-18 09:39:02 +08:00
3fcf2ee54c feat: add new LLM provider Jiekou.AI (#11300)
### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jason <ggbbddjm@gmail.com>
2025-11-17 19:47:46 +08:00
d8f413a885 Feat: Construct a dynamic variable assignment form #10427 (#11316)
### What problem does this PR solve?

Feat: Construct a dynamic variable assignment form #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-17 19:45:58 +08:00
7264fb6978 Fix: concat images in word document. (#11310)
### What problem does this PR solve?

Fix: concat images in word document. Partially solved issues in #11063 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 19:38:26 +08:00
bd4bc57009 Refactor: move mcp connection utilities to common (#11304)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-17 15:34:17 +08:00
0569b50fed Fix: create dataset return type inconsistent (#11272)
### What problem does this PR solve?

Fix: create dataset return type inconsistent #11167 
 
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 15:27:19 +08:00
6b64641042 Fix: default model base url extraction logic (#11263)
### What problem does this PR solve?

Fixes an issue where default models which used the same factory but
different base URLs would all be initialised with the default chat
model's base URL and would ignore e.g. the embedding model's base URL
config.

For example, with the following service config, the embedding and
reranker models would end up using the base URL for the default chat
model (i.e. `llm1.example.com`):

```yaml
ragflow:
  service_conf:
    user_default_llm:
      factory: OpenAI-API-Compatible
      api_key: not-used
      default_models:
        chat_model:
          name: llm1
          base_url: https://llm1.example.com/v1
        embedding_model:
          name: llm2
          base_url: https://llm2.example.com/v1
        rerank_model:
          name: llm3
          base_url: https://llm3.example.com/v1/rerank

  llm_factories:
    factory_llm_infos:
    - name: OpenAI-API-Compatible
      logo: ""
      tags: "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION"
      status: "1"
      llm:
        - llm_name: llm1
          base_url: 'https://llm1.example.com/v1'
          api_key: not-used
          tags: "LLM,CHAT,IMAGE2TEXT"
          max_tokens: 100000
          model_type: chat
          is_tools: false

        - llm_name: llm2
          base_url: https://llm2.example.com/v1
          api_key: not-used
          tags: "TEXT EMBEDDING"
          max_tokens: 10000
          model_type: embedding

        - llm_name: llm3
          base_url: https://llm3.example.com/v1/rerank
          api_key: not-used
          tags: "RERANK,1k"
          max_tokens: 10000
          model_type: rerank
```

### Type of change

- [X] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 14:21:27 +08:00
9cef3a2625 Fix: Fixed the issue of not being able to select the time zone in the user center. (#11298)
… user center.

### What problem does this PR solve?

Fix: Fixed the issue of not being able to select the time zone in the
user center.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-17 11:16:55 +08:00
e7e89d3ecb Doc: style fix (#11295)
### What problem does this PR solve?

Style fix based on  #11283
### Type of change

- [x] Documentation Update
2025-11-17 11:16:34 +08:00
13e212c856 Feat: add Jira connector (#11285)
### What problem does this PR solve?

Add Jira connector.

<img width="978" height="925" alt="image"
src="https://github.com/user-attachments/assets/78bb5c77-2710-4569-a76e-9087ca23b227"
/>

---

<img width="1903" height="489" alt="image"
src="https://github.com/user-attachments/assets/193bc5c5-f751-4bd5-883a-2173282c2b96"
/>

---

<img width="1035" height="925" alt="image"
src="https://github.com/user-attachments/assets/1a0aec19-30eb-4ada-9283-61d1c915f59d"
/>

---

<img width="1905" height="601" alt="image"
src="https://github.com/user-attachments/assets/3dde1062-3f27-4717-8e09-fd5fd5e64171"
/>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-17 09:38:04 +08:00
61cf430dbb Minor tweats (#11271)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-16 19:29:20 +08:00
e841b09d63 Remove unused code and fix performance issue (#11284)
### What problem does this PR solve?

1. remove redundant code
2. fix miner performance issue

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-14 20:39:54 +08:00
b1a1eedf53 Doc: add default username & pwd (#11283)
### What problem does this PR solve?
Doc: add default username & pwd

### Type of change

- [x] Documentation Update

---------

Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com>
2025-11-14 19:52:58 +08:00
68e3b33ae4 Feat: extract message output to file (#11251)
### What problem does this PR solve?

Feat: extract message output to file

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-14 19:52:11 +08:00
cd55f6c1b8 Fix:ListOperations does not support sorting arrays of objects. (#11278)
### What problem does this PR solve?

pr:
#11276
change:
ListOperations does not support sorting arrays of objects.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 19:50:29 +08:00
996b5fe14e Fix: Added the ability to download files in the agent message reply function. (#11281)
### What problem does this PR solve?

Fix: Added the ability to download files in the agent message reply
function.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 19:50:01 +08:00
db4fd19c82 Feat:new component list operations (#11276)
### What problem does this PR solve?
issue:
https://github.com/infiniflow/ragflow/issues/10427
change:
new component list operations

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-14 16:33:20 +08:00
12db62b9c7 Refactor: improve mineru_parser get property logic (#11268)
### What problem does this PR solve?

improve mineru_parser get property logic

### Type of change

- [x] Refactoring
2025-11-14 16:32:35 +08:00
b5f2cf16bc Fix: check task executor alive and display status (#11270)
### What problem does this PR solve?

Correctly check task executor alive and display status.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 15:52:28 +08:00
e27ff8d3d4 Fix: rerank algorithm (#11266)
### What problem does this PR solve?

Fix: rerank algorithm #11234

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 13:59:54 +08:00
5f59418aba Remove leftover account and password from the code (#11248)
Remove legacy accounts and passwords.

### What problem does this PR solve?

Remove leftover account and password in
agent/templates/sql_assistant.json

### Type of change

- [x] Other (please describe):
2025-11-14 13:59:03 +08:00
87e69868c0 Fixes: Added session variable types and modified configuration (#11269)
### What problem does this PR solve?

Fixes: Added session variable types and modified configuration

- Added more types of session variables
- Modified the embedding model switching logic in the knowledge base
configuration

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-14 13:56:56 +08:00
72c20022f6 Refactor service config fetching in admin server (#11267)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
2025-11-14 12:32:08 +08:00
3f2472f1b9 Skip checking python comments 2025-11-14 11:59:15 +08:00
1d4d67daf8 Fix check_comment_ascii.py 2025-11-14 11:45:32 +08:00
7538e218a5 Fix check_comment_ascii.py 2025-11-14 11:32:55 +08:00
6b52f7df5a CI check comments of cheanged Python files 2025-11-14 10:54:07 +08:00
63131ec9b2 Docs: default admin credentials (#11260)
### What problem does this PR solve?

### Type of change

- [x] Documentation Update
2025-11-14 09:35:56 +08:00
e8f1a245a6 Feat:update check_embedding api (#11254)
### What problem does this PR solve?
pr: 
#10854
change:
update check_embedding api

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:25 +08:00
908450509f Feat: add fault-tolerant mechanism to RAPTOR (#11206)
### What problem does this PR solve?

Add fault-tolerant mechanism to RAPTOR.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-13 18:48:07 +08:00
70a0f081f6 Minor tweaks (#11249)
### What problem does this PR solve?

Fix some IDE warnings

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-13 16:11:07 +08:00
93422fa8cc Fix: Law parser (#11246)
### What problem does this PR solve?

Fix: Law parser
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-13 15:19:02 +08:00
bfc84ba95b Test: handle duplicate names by appending "(1)" (#11244)
### What problem does this PR solve?

- Updated tests to reflect new behavior of handling duplicate dataset
names
- Instead of returning an error, the system now appends "(1)" to
duplicate names
- This problem was introduced by PR #10960

### Type of change

- [x] Testcase update
2025-11-13 15:18:32 +08:00
871055b0fc Feat:support API for generating knowledge graph and raptor (#11229)
### What problem does this PR solve?
issue:
[#11195](https://github.com/infiniflow/ragflow/issues/11195)
change:
support API for generating knowledge graph and raptor

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
2025-11-13 15:17:52 +08:00
ba71160b14 Refa: rm useless code. (#11238)
### Type of change

- [x] Refactoring
2025-11-13 09:59:55 +08:00
bd5dda6b10 Feature/doc upload api add parent path 20251112 (#11231)
### What problem does this PR solve?

Add the specified parent_path to the document upload api interface
(#11230)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: virgilwong <hyhvirgil@gmail.com>
2025-11-13 09:59:39 +08:00
774563970b Fix: update readme (#11212)
### What problem does this PR solve?

Continue update readme #11167 

### Type of change

- [x] Documentation Update
2025-11-13 09:50:47 +08:00
83d84e90ed Fix: Profile picture cropping supported #10703 (#11221)
### What problem does this PR solve?

Fix: Profile picture cropping supported

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-13 09:50:10 +08:00
8ef2f79d0a Fix:reset the agent component’s output (#11222)
### What problem does this PR solve?

change:
“After each dialogue turn, the agent component’s output is not reset.”

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-13 09:49:12 +08:00
296476ab89 Refactor function name (#11210)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-12 19:00:15 +08:00
237 changed files with 12560 additions and 6940 deletions

View File

@ -95,6 +95,38 @@ jobs:
version: ">=0.11.x"
args: "check"
- name: Check comments of changed Python files
if: ${{ false }}
run: |
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
| grep -E '\.(py)$' || true)
if [ -n "$CHANGED_FILES" ]; then
echo "Check comments of changed Python files with check_comment_ascii.py"
readarray -t files <<< "$CHANGED_FILES"
HAS_ERROR=0
for file in "${files[@]}"; do
if [ -f "$file" ]; then
if python3 check_comment_ascii.py "$file"; then
echo "✅ $file"
else
echo "❌ $file"
HAS_ERROR=1
fi
fi
done
if [ $HAS_ERROR -ne 0 ]; then
exit 1
fi
else
echo "No Python files changed"
fi
fi
- name: Build ragflow:nightly
run: |
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}

View File

@ -51,7 +51,9 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
apt install -y ghostscript && \
apt install -y pandoc && \
apt install -y texlive
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -85,6 +85,7 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
## 🔥 Latest Updates
- 2025-11-19 Supports Gemini 3 Pro.
- 2025-11-12 Supports data synchronization from Confluence, AWS S3, Discord, Google Drive.
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
- 2025-10-15 Supports orchestrable ingestion pipeline.
@ -93,8 +94,6 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
- 2025-05-23 Adds a Python/JavaScript code executor component to Agent.
- 2025-05-05 Supports cross-language query.
- 2025-03-19 Supports using a multi-modal model to make sense of images within PDF or DOCX files.
- 2024-12-18 Upgrades Document Layout Analysis model in DeepDoc.
- 2024-08-22 Support text to SQL statements through RAG.
## 🎉 Stay Tuned
@ -188,13 +187,15 @@ releases! 🌟
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
> The command below downloads the `v0.22.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
> The command below downloads the `v0.22.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
```bash
$ cd ragflow/docker
# git checkout v0.22.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -22,7 +22,7 @@
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
@ -85,6 +85,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
## 🔥 Pembaruan Terbaru
- 2025-11-19 Mendukung Gemini 3 Pro.
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, AWS S3, Discord, Google Drive.
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
- 2025-10-15 Dukungan untuk jalur data yang terorkestrasi.
@ -186,12 +187,14 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
> Perintah di bawah ini mengunduh edisi v0.22.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
> Perintah di bawah ini mengunduh edisi v0.22.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
```bash
$ cd ragflow/docker
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
# git checkout v0.22.1
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -66,6 +66,7 @@
## 🔥 最新情報
- 2025-11-19 Gemini 3 Proをサポートしています
- 2025-11-12 Confluence、AWS S3、Discord、Google Drive からのデータ同期をサポートします。
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
@ -166,12 +167,14 @@
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
```bash
$ cd ragflow/docker
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
# git checkout v0.22.1
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -67,6 +67,7 @@
## 🔥 업데이트
- 2025-11-19 Gemini 3 Pro를 지원합니다.
- 2025-11-12 Confluence, AWS S3, Discord, Google Drive에서 데이터 동기화를 지원합니다.
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
- 2025-10-15 조정된 데이터 파이프라인 지원.
@ -168,12 +169,14 @@
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
```bash
$ cd ragflow/docker
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
# git checkout v0.22.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -22,7 +22,7 @@
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
@ -86,6 +86,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
## 🔥 Últimas Atualizações
- 19-11-2025 Suporta Gemini 3 Pro.
- 12-11-2025 Suporta a sincronização de dados do Confluence, AWS S3, Discord e Google Drive.
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
- 15-10-2025 Suporte para pipelines de dados orquestrados.
@ -186,12 +187,14 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
> O comando abaixo baixa a edição`v0.22.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
> O comando abaixo baixa a edição`v0.22.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
```bash
$ cd ragflow/docker
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
# git checkout v0.22.1
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -85,6 +85,7 @@
## 🔥 近期更新
- 2025-11-19 支援 Gemini 3 Pro.
- 2025-11-12 支援從 Confluence、AWS S3、Discord、Google Drive 進行資料同步。
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
- 2025-10-15 支援可編排的資料管道。
@ -185,12 +186,14 @@
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
```bash
$ cd ragflow/docker
# 可選使用穩定版標籤查看發佈https://github.com/infiniflow/ragflow/releasesgit checkout v0.22.0
# git checkout v0.22.1
# 可選使用穩定版標籤查看發佈https://github.com/infiniflow/ragflow/releases
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -85,6 +85,7 @@
## 🔥 近期更新
- 2025-11-19 支持 Gemini 3 Pro.
- 2025-11-12 支持从 Confluence、AWS S3、Discord、Google Drive 进行数据同步。
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
- 2025-10-15 支持可编排的数据管道。
@ -186,12 +187,14 @@
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
```bash
$ cd ragflow/docker
# 可选使用稳定版本标签查看发布https://github.com/infiniflow/ragflow/releases例如git checkout v0.22.0
# git checkout v0.22.1
# 可选使用稳定版本标签查看发布https://github.com/infiniflow/ragflow/releases
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -4,7 +4,7 @@
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
1. Ensure the Admin Service is running.
2. Install ragflow-cli.
```bash
pip install ragflow-cli==0.22.0
pip install ragflow-cli==0.22.1
```
3. Launch the CLI client:
```bash

View File

@ -378,7 +378,7 @@ class AdminCLI(Cmd):
self.session.headers.update({
'Content-Type': 'application/json',
'Authorization': response.headers['Authorization'],
'User-Agent': 'RAGFlow-CLI/0.22.0'
'User-Agent': 'RAGFlow-CLI/0.22.1'
})
print("Authentication successful.")
return True
@ -393,7 +393,9 @@ class AdminCLI(Cmd):
print(f"Can't access {self.host}, port: {self.port}")
def _format_service_detail_table(self, data):
if not any([isinstance(v, list) for v in data.values()]):
if isinstance(data, list):
return data
if not all([isinstance(v, list) for v in data.values()]):
# normal table
return data
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
@ -404,7 +406,7 @@ class AdminCLI(Cmd):
task_executor_list.append({
"task_executor_name": k,
**heartbeats[0],
})
} if heartbeats else {"task_executor_name": k})
return task_executor_list
def _print_table_simple(self, data):
@ -415,7 +417,8 @@ class AdminCLI(Cmd):
# handle single row data
data = [data]
columns = list(data[0].keys())
columns = list(set().union(*(d.keys() for d in data)))
columns.sort()
col_widths = {}
def get_string_width(text):

View File

@ -1,6 +1,6 @@
[project]
name = "ragflow-cli"
version = "0.22.0"
version = "0.22.1"
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
license = { text = "Apache License, Version 2.0" }

View File

@ -20,8 +20,10 @@ import logging
import time
import threading
import traceback
from werkzeug.serving import run_simple
from flask import Flask
from flask_login import LoginManager
from werkzeug.serving import run_simple
from routes import admin_bp
from common.log_utils import init_root_logger
from common.constants import SERVICE_CONF
@ -30,7 +32,6 @@ from common import settings
from config import load_configurations, SERVICE_CONFIGS
from auth import init_default_admin, setup_auth
from flask_session import Session
from flask_login import LoginManager
from common.versions import get_ragflow_version
stop_event = threading.Event()

View File

@ -19,7 +19,8 @@ import logging
import uuid
from functools import wraps
from datetime import datetime
from flask import request, jsonify
from flask import jsonify, request
from flask_login import current_user, login_user
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
@ -30,7 +31,7 @@ from common.constants import ActiveEnum, StatusEnum
from api.utils.crypt import decrypt
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, datetime_format, get_format_time
from common.connection_utils import construct_response
from common.connection_utils import sync_construct_response
from common import settings
@ -129,7 +130,7 @@ def login_admin(email: str, password: str):
user.last_login_time = get_format_time()
user.save()
msg = "Welcome back!"
return construct_response(data=resp, auth=user.get_id(), message=msg)
return sync_construct_response(data=resp, auth=user.get_id(), message=msg)
def check_admin(username: str, password: str):
@ -169,7 +170,7 @@ def login_verify(f):
username = auth.parameters['username']
password = auth.parameters['password']
try:
if check_admin(username, password) is False:
if not check_admin(username, password):
return jsonify({
"code": 500,
"message": "Access denied",

View File

@ -25,8 +25,21 @@ from common.config_utils import read_config
from urllib.parse import urlparse
class BaseConfig(BaseModel):
id: int
name: str
host: str
port: int
service_type: str
detail_func_name: str
def to_dict(self) -> dict[str, Any]:
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
'service_type': self.service_type}
class ServiceConfigs:
configs = dict
configs = list[BaseConfig]
def __init__(self):
self.configs = []
@ -45,19 +58,6 @@ class ServiceType(Enum):
FILE_STORE = "file_store"
class BaseConfig(BaseModel):
id: int
name: str
host: str
port: int
service_type: str
detail_func_name: str
def to_dict(self) -> dict[str, Any]:
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
'service_type': self.service_type}
class MetaConfig(BaseConfig):
meta_type: str
@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
ragflow_count = 0
id_count = 0
for k, v in raw_configs.items():
match (k):
match k:
case "ragflow":
name: str = f'ragflow_{ragflow_count}'
host: str = v['host']

View File

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import jsonify

View File

@ -17,7 +17,7 @@
import secrets
from flask import Blueprint, request
from flask_login import current_user, logout_user, login_required
from flask_login import current_user, login_required, logout_user
from auth import login_verify, login_admin, check_admin_auth
from responses import success_response, error_response

View File

@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from werkzeug.security import check_password_hash
from common.constants import ActiveEnum
@ -190,7 +189,8 @@ class ServiceMgr:
config_dict['status'] = service_detail['status']
else:
config_dict['status'] = 'timeout'
except Exception:
except Exception as e:
logging.warning(f"Can't get service details, error: {e}")
config_dict['status'] = 'timeout'
if not config_dict['host']:
config_dict['host'] = '-'
@ -205,17 +205,13 @@ class ServiceMgr:
@staticmethod
def get_service_details(service_id: int):
service_id = int(service_id)
service_idx = int(service_id)
configs = SERVICE_CONFIGS.configs
service_config_mapping = {
c.id: {
'name': c.name,
'detail_func_name': c.detail_func_name
} for c in configs
}
service_info = service_config_mapping.get(service_id, {})
if not service_info:
raise AdminException(f"invalid service_id: {service_id}")
if service_idx < 0 or service_idx >= len(configs):
raise AdminException(f"invalid service_index: {service_idx}")
service_config = configs[service_idx]
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
res = detail_func()

View File

@ -25,7 +25,6 @@ from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.db.services.task_service import has_canceled
from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
@ -217,6 +216,38 @@ class Graph:
else:
cur = getattr(cur, key, None)
return cur
def set_variable_value(self, exp: str,value):
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
if exp.find("@") < 0:
self.globals[exp] = value
return
cpn_id, var_nm = exp.split("@")
cpn = self.get_component(cpn_id)
if not cpn:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
parts = var_nm.split(".", 1)
root_key = parts[0]
rest = parts[1] if len(parts) > 1 else ""
if not rest:
cpn["obj"].set_output(root_key, value)
return
root_val = cpn["obj"].output(root_key)
if not root_val:
root_val = {}
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value))
def set_variable_param_value(self, obj: Any, path: str, value) -> Any:
cur = obj
keys = path.split('.')
if not path:
return value
for key in keys:
if key not in cur or not isinstance(cur[key], dict):
cur[key] = {}
cur = cur[key]
cur[keys[-1]] = value
return obj
def is_canceled(self) -> bool:
return has_canceled(self.task_id)
@ -270,7 +301,7 @@ class Canvas(Graph):
self.retrieval = []
self.memory = []
for k in self.globals.keys():
if k.startswith("sys."):
if k.startswith("sys.") or k.startswith("env."):
if isinstance(self.globals[k], str):
self.globals[k] = ""
elif isinstance(self.globals[k], int):
@ -284,7 +315,7 @@ class Canvas(Graph):
else:
self.globals[k] = None
def run(self, **kwargs):
async def run(self, **kwargs):
st = time.perf_counter()
self.message_id = get_uuid()
created_at = int(time.time())
@ -298,8 +329,6 @@ class Canvas(Graph):
for kk, vv in kwargs["webhook_payload"].items():
self.components[k]["obj"].set_output(kk, vv)
self.components[k]["obj"].reset(True)
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files":
@ -408,6 +437,10 @@ class Canvas(Graph):
else:
yield decorate("message", {"content": cpn_obj.output("content")})
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
if isinstance(cpn_obj.output("attachment"), tuple):
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
while partials:
@ -547,6 +580,7 @@ class Canvas(Graph):
return self.components[cpnnm]["obj"].get_input_elements()
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
from api.db.services.file_service import FileService
if not files:
return []
def image_to_base64(file):

View File

@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
@ -163,12 +163,7 @@ class Agent(LLM, ToolBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
output_structure=None
try:
output_structure=self._param.outputs['structured']
except Exception:
pass
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
return
@ -368,11 +363,19 @@ Respond immediately with your final comprehensive answer.
return "Error occurred."
def reset(self, temp=False):
def reset(self, only_output=False):
"""
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
"""
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k, cpn in self.tools.items():
if hasattr(cpn, "reset") and callable(cpn.reset):
cpn.reset()
if only_output:
return
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}

View File

@ -463,12 +463,15 @@ class ComponentBase(ABC):
return self._param.outputs.get("_ERROR", {}).get("value")
def reset(self, only_output=False):
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
outputs: dict = self._param.outputs # for better performance
for k in outputs.keys():
outputs[k]["value"] = None
if only_output:
return
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
inputs: dict = self._param.inputs # for better performance
for k in inputs.keys():
inputs[k]["value"] = None
self._param.debug_inputs = {}
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:

View File

@ -1,3 +1,18 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
import ast
import os

View File

@ -32,6 +32,7 @@ class IterationParam(ComponentParamBase):
def __init__(self):
super().__init__()
self.items_ref = ""
self.veriable={}
def get_input_form(self) -> dict[str, dict]:
return {

View File

@ -0,0 +1,168 @@
from abc import ABC
import os
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
class ListOperationsParam(ComponentParamBase):
"""
Define the List Operations component parameters.
"""
def __init__(self):
super().__init__()
self.query = ""
self.operations = "topN"
self.n=0
self.sort_method = "asc"
self.filter = {
"operator": "=",
"value": ""
}
self.outputs = {
"result": {
"value": [],
"type": "Array of ?"
},
"first": {
"value": "",
"type": "?"
},
"last": {
"value": "",
"type": "?"
}
}
def check(self):
self.check_empty(self.query, "query")
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
def get_input_form(self) -> dict[str, dict]:
return {}
class ListOperations(ComponentBase,ABC):
component_name = "ListOperations"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
self.input_objects=[]
inputs = getattr(self._param, "query", None)
self.inputs = self._canvas.get_variable_value(inputs)
if not isinstance(self.inputs, list):
raise TypeError("The input of List Operations should be an array.")
self.set_input_value(inputs, self.inputs)
if self._param.operations == "topN":
self._topN()
elif self._param.operations == "head":
self._head()
elif self._param.operations == "tail":
self._tail()
elif self._param.operations == "filter":
self._filter()
elif self._param.operations == "sort":
self._sort()
elif self._param.operations == "drop_duplicates":
self._drop_duplicates()
def _coerce_n(self):
try:
return int(getattr(self._param, "n", 0))
except Exception:
return 0
def _set_outputs(self, outputs):
self._param.outputs["result"]["value"] = outputs
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
def _topN(self):
n = self._coerce_n()
if n < 1:
outputs = []
else:
n = min(n, len(self.inputs))
outputs = self.inputs[:n]
self._set_outputs(outputs)
def _head(self):
n = self._coerce_n()
if 1 <= n <= len(self.inputs):
outputs = [self.inputs[n - 1]]
else:
outputs = []
self._set_outputs(outputs)
def _tail(self):
n = self._coerce_n()
if 1 <= n <= len(self.inputs):
outputs = [self.inputs[-n]]
else:
outputs = []
self._set_outputs(outputs)
def _filter(self):
self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])])
def _norm(self,v):
s = "" if v is None else str(v)
return s
def _eval(self, v, operator, value):
if operator == "=":
return v == value
elif operator == "":
return v != value
elif operator == "contains":
return value in v
elif operator == "start with":
return v.startswith(value)
elif operator == "end with":
return v.endswith(value)
else:
return False
def _sort(self):
items = self.inputs or []
method = getattr(self._param, "sort_method", "asc") or "asc"
reverse = method == "desc"
if not items:
self._set_outputs([])
return
first = items[0]
if isinstance(first, dict):
outputs = sorted(
items,
key=lambda x: self._hashable(x),
reverse=reverse,
)
else:
outputs = sorted(items, reverse=reverse)
self._set_outputs(outputs)
def _drop_duplicates(self):
seen = set()
outs = []
for item in self.inputs:
k = self._hashable(item)
if k in seen:
continue
seen.add(k)
outs.append(item)
self._set_outputs(outs)
def _hashable(self,x):
if isinstance(x, dict):
return tuple(sorted((k, self._hashable(v)) for k, v in x.items()))
if isinstance(x, (list, tuple)):
return tuple(self._hashable(v) for v in x)
if isinstance(x, set):
return tuple(sorted(self._hashable(v) for v in x))
return x
def thoughts(self) -> str:
return "ListOperation in progress"

View File

@ -222,7 +222,7 @@ class LLM(ComponentBase):
output_structure = self._param.outputs['structured']
except Exception:
pass
if output_structure:
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1):
@ -249,7 +249,7 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output, prompt, msg))
return

View File

@ -17,6 +17,8 @@ import json
import os
import random
import re
import logging
import tempfile
from functools import partial
from typing import Any
@ -24,6 +26,8 @@ from agent.component.base import ComponentBase, ComponentParamBase
from jinja2 import Template as Jinja2Template
from common.connection_utils import timeout
from common.misc_utils import get_uuid
from common import settings
class MessageParam(ComponentParamBase):
@ -34,6 +38,7 @@ class MessageParam(ComponentParamBase):
super().__init__()
self.content = []
self.stream = True
self.output_format = None # default output format
self.outputs = {
"content": {
"type": "str"
@ -133,6 +138,7 @@ class Message(ComponentBase):
yield rand_cnt[s: ]
self.set_output("content", all_content)
self._convert_content(all_content)
def _is_jinjia2(self, content:str) -> bool:
patt = [
@ -164,6 +170,72 @@ class Message(ComponentBase):
content = re.sub(n, v, content)
self.set_output("content", content)
self._convert_content(content)
def thoughts(self) -> str:
return ""
def _convert_content(self, content):
if not self._param.output_format:
return
import pypandoc
doc_id = get_uuid()
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
self._param.output_format = "markdown"
try:
if self._param.output_format in {"markdown", "html"}:
if isinstance(content, str):
converted = pypandoc.convert_text(
content,
to=self._param.output_format,
format="markdown",
)
else:
converted = pypandoc.convert_file(
content,
to=self._param.output_format,
format="markdown",
)
binary_content = converted.encode("utf-8")
else: # pdf, docx
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
tmp_name = tmp.name
try:
if isinstance(content, str):
pypandoc.convert_text(
content,
to=self._param.output_format,
format="markdown",
outputfile=tmp_name,
)
else:
pypandoc.convert_file(
content,
to=self._param.output_format,
format="markdown",
outputfile=tmp_name,
)
with open(tmp_name, "rb") as f:
binary_content = f.read()
finally:
if os.path.exists(tmp_name):
os.remove(tmp_name)
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
self.set_output("attachment", {
"doc_id":doc_id,
"format":self._param.output_format,
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
except Exception as e:
logging.error(f"Error converting content to {self._param.output_format}: {e}")

View File

@ -0,0 +1,192 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
import os
import numbers
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
class VariableAssignerParam(ComponentParamBase):
"""
Define the Variable Assigner component parameters.
"""
def __init__(self):
super().__init__()
self.variables=[]
def check(self):
return True
def get_input_form(self) -> dict[str, dict]:
return {
"items": {
"type": "json",
"name": "Items"
}
}
class VariableAssigner(ComponentBase,ABC):
component_name = "VariableAssigner"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if not isinstance(self._param.variables,list):
return
else:
for item in self._param.variables:
if any([not item.get("variable"), not item.get("operator"), not item.get("parameter")]):
assert "Variable is not complete."
variable=item["variable"]
operator=item["operator"]
parameter=item["parameter"]
variable_value=self._canvas.get_variable_value(variable)
new_variable=self._operate(variable_value,operator,parameter)
self._canvas.set_variable_value(variable, new_variable)
def _operate(self,variable,operator,parameter):
if operator == "overwrite":
return self._overwrite(parameter)
elif operator == "clear":
return self._clear(variable)
elif operator == "set":
return self._set(variable,parameter)
elif operator == "append":
return self._append(variable,parameter)
elif operator == "extend":
return self._extend(variable,parameter)
elif operator == "remove_first":
return self._remove_first(variable)
elif operator == "remove_last":
return self._remove_last(variable)
elif operator == "+=":
return self._add(variable,parameter)
elif operator == "-=":
return self._subtract(variable,parameter)
elif operator == "*=":
return self._multiply(variable,parameter)
elif operator == "/=":
return self._divide(variable,parameter)
else:
return
def _overwrite(self,parameter):
return self._canvas.get_variable_value(parameter)
def _clear(self,variable):
if isinstance(variable,list):
return []
elif isinstance(variable,str):
return ""
elif isinstance(variable,dict):
return {}
elif isinstance(variable,int):
return 0
elif isinstance(variable,float):
return 0.0
elif isinstance(variable,bool):
return False
else:
return None
def _set(self,variable,parameter):
if variable is None:
return self._canvas.get_value_with_variable(parameter)
elif isinstance(variable,str):
return self._canvas.get_value_with_variable(parameter)
elif isinstance(variable,bool):
return parameter
elif isinstance(variable,int):
return parameter
elif isinstance(variable,float):
return parameter
else:
return parameter
def _append(self,variable,parameter):
parameter=self._canvas.get_variable_value(parameter)
if variable is None:
variable=[]
if not isinstance(variable,list):
return "ERROR:VARIABLE_NOT_LIST"
elif len(variable)!=0 and not isinstance(parameter,type(variable[0])):
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
else:
variable.append(parameter)
return variable
def _extend(self,variable,parameter):
parameter=self._canvas.get_variable_value(parameter)
if variable is None:
variable=[]
if not isinstance(variable,list):
return "ERROR:VARIABLE_NOT_LIST"
elif not isinstance(parameter,list):
return "ERROR:PARAMETER_NOT_LIST"
elif len(variable)!=0 and len(parameter)!=0 and not isinstance(parameter[0],type(variable[0])):
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
else:
return variable + parameter
def _remove_first(self,variable):
if len(variable)==0:
return variable
if not isinstance(variable,list):
return "ERROR:VARIABLE_NOT_LIST"
else:
return variable[1:]
def _remove_last(self,variable):
if len(variable)==0:
return variable
if not isinstance(variable,list):
return "ERROR:VARIABLE_NOT_LIST"
else:
return variable[:-1]
def is_number(self, value):
if isinstance(value, bool):
return False
return isinstance(value, numbers.Number)
def _add(self,variable,parameter):
if self.is_number(variable) and self.is_number(parameter):
return variable + parameter
else:
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
def _subtract(self,variable,parameter):
if self.is_number(variable) and self.is_number(parameter):
return variable - parameter
else:
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
def _multiply(self,variable,parameter):
if self.is_number(variable) and self.is_number(parameter):
return variable * parameter
else:
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
def _divide(self,variable,parameter):
if self.is_number(variable) and self.is_number(parameter):
if parameter==0:
return "ERROR:DIVIDE_BY_ZERO"
else:
return variable/parameter
else:
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
def thoughts(self) -> str:
return "Assign variables from canvas."

File diff suppressed because one or more lines are too long

View File

@ -83,10 +83,10 @@
"value": []
}
},
"password": "20010812Yy!",
"password": "",
"port": 3306,
"sql": "{Agent:WickedGoatsDivide@content}",
"username": "13637682833@163.com"
"username": ""
}
},
"upstream": [
@ -527,10 +527,10 @@
"value": []
}
},
"password": "20010812Yy!",
"password": "",
"port": 3306,
"sql": "{Agent:WickedGoatsDivide@content}",
"username": "13637682833@163.com"
"username": ""
},
"label": "ExeSQL",
"name": "ExeSQL"

View File

@ -21,9 +21,8 @@ from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
from common.misc_utils import hash_str2int
from rag.llm.chat_model import ToolCallSession
from rag.prompts.generator import kb_prompt
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer

View File

@ -18,12 +18,11 @@ import sys
import logging
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from flask import Blueprint, Flask
from quart import Blueprint, Quart, request, g, current_app, session
from werkzeug.wrappers.request import Request
from flask_cors import CORS
from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
from common.constants import StatusEnum
from api.db.db_models import close_connection
from api.db.services import UserService
@ -31,17 +30,20 @@ from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands
from flask_mail import Mail
from flask_session import Session
from flask_login import LoginManager
from quart_auth import Unauthorized
from common import settings
from api.utils.api_utils import server_error_response
from api.constants import API_VERSION
from common.misc_utils import get_uuid
settings.init_settings()
__all__ = ["app"]
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Flask(__name__)
app = Quart(__name__)
app = cors(app, allow_origin="*")
smtp_mail_server = Mail()
# Add this at the beginning of your file to configure Swagger UI
@ -76,7 +78,6 @@ swagger = Swagger(
},
)
CORS(app, supports_credentials=True, max_age=2592000)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
@ -84,24 +85,150 @@ app.errorhandler(Exception)(server_error_response)
## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
app.config["SESSION_TYPE"] = "redis"
app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis")
app.config["MAX_CONTENT_LENGTH"] = int(
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
)
Session(app)
login_manager = LoginManager()
login_manager.init_app(app)
app.config['SECRET_KEY'] = settings.SECRET_KEY
app.secret_key = settings.SECRET_KEY
commands.register_commands(app)
from functools import wraps
from typing import ParamSpec, TypeVar
from collections.abc import Awaitable, Callable
from werkzeug.local import LocalProxy
def search_pages_path(pages_dir):
T = TypeVar("T")
P = ParamSpec("P")
def _load_user():
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = request.headers.get("Authorization")
g.user = None
if not authorization:
return
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return None
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return None
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
g.user = user[0]
return user[0]
except Exception as e:
logging.warning(f"load_user got exception {e}")
current_user = LocalProxy(_load_user)
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
"""A decorator to restrict route access to authenticated users.
This should be used to wrap a route handler (or view function) to
enforce that only authenticated requests can access it. Note that
it is important that this decorator be wrapped by the route
decorator and not vice, versa, as below.
.. code-block:: python
@app.route('/')
@login_required
async def index():
...
If the request is not authenticated a
`quart.exceptions.Unauthorized` exception will be raised.
"""
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not current_user:# or not session.get("_user_id"):
raise Unauthorized()
else:
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper
def login_user(user, remember=False, duration=None, force=False, fresh=True):
"""
Logs a user in. You should pass the actual user object to this. If the
user's `is_active` property is ``False``, they will not be logged in
unless `force` is ``True``.
This will return ``True`` if the log in attempt succeeds, and ``False`` if
it fails (i.e. because the user is inactive).
:param user: The user object to log in.
:type user: object
:param remember: Whether to remember the user after their session expires.
Defaults to ``False``.
:type remember: bool
:param duration: The amount of time before the remember cookie expires. If
``None`` the value set in the settings is used. Defaults to ``None``.
:type duration: :class:`datetime.timedelta`
:param force: If the user is inactive, setting this to ``True`` will log
them in regardless. Defaults to ``False``.
:type force: bool
:param fresh: setting this to ``False`` will log in the user with a session
marked as not "fresh". Defaults to ``True``.
:type fresh: bool
"""
if not force and not user.is_active:
return False
session["_user_id"] = user.id
session["_fresh"] = fresh
session["_id"] = get_uuid()
return True
def logout_user():
"""
Logs a user out. (You do not need to pass the actual user.) This will
also clean up the remember me cookie if it exists.
"""
if "_user_id" in session:
session.pop("_user_id")
if "_fresh" in session:
session.pop("_fresh")
if "_id" in session:
session.pop("_id")
COOKIE_NAME = "remember_token"
cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
if cookie_name in request.cookies:
session["_remember"] = "clear"
if "_remember_seconds" in session:
session.pop("_remember_seconds")
return True
def search_pages_path(page_path):
app_path_list = [
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
]
api_path_list = [
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
]
app_path_list.extend(api_path_list)
return app_path_list
@ -138,44 +265,12 @@ pages_dir = [
]
client_urls_prefix = [
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
]
@login_manager.request_loader
def load_user(web_request):
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = web_request.headers.get("Authorization")
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return None
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return None
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
return user[0]
else:
return None
except Exception as e:
logging.warning(f"load_user got exception {e}")
return None
else:
return None
@app.teardown_request
def _db_close(exc):
def _db_close(exception):
if exception:
logging.exception(f"Request failed: {exception}")
close_connection()

View File

@ -13,46 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
from datetime import datetime, timedelta
from flask import request, Response
from api.db.services.llm_service import LLMBundle
from flask_login import login_required, current_user
from api.db import VALID_FILE_TYPES, FileType
from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name
from quart import request
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
from api.utils.file_utils import filename_type, thumbnail
from rag.app.tag import label_question
from rag.prompts.generator import keyword_extraction
from common.time_utils import current_timestamp, datetime_format
from api.db.services.canvas_service import UserCanvasService
from agent.canvas import Canvas
from functools import partial
from pathlib import Path
from common import settings
from api.apps import login_required, current_user
@manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required
def new_token():
req = request.json
async def new_token():
req = await request.json
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
@ -97,8 +72,8 @@ def token_list():
@manager.route('/rm', methods=['POST']) # noqa: F821
@validate_request("tokens", "tenant_id")
@login_required
def rm():
req = request.json
async def rm():
req = await request.json
try:
for token in req["tokens"]:
APITokenService.filter_delete(
@ -126,770 +101,19 @@ def stats():
"to_date",
datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
"agent" if "canvas_id" in request.args else None)
res = {
"pv": [(o["dt"], o["pv"]) for o in objs],
"uv": [(o["dt"], o["uv"]) for o in objs],
"speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
"tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
"round": [(o["dt"], o["round"]) for o in objs],
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
}
res = {"pv": [], "uv": [], "speed": [], "tokens": [], "round": [], "thumb_up": []}
for obj in objs:
dt = obj["dt"]
res["pv"].append((dt, obj["pv"]))
res["uv"].append((dt, obj["uv"]))
res["speed"].append((dt, float(obj["tokens"]) / (float(obj["duration"]) + 0.1))) # +0.1 to avoid division by zero
res["tokens"].append((dt, float(obj["tokens"]) / 1000.0)) # convert to thousands
res["round"].append((dt, obj["round"]))
res["thumb_up"].append((dt, obj["thumb_up"]))
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
def set_conversation():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
if objs[0].source == "agent":
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
if not e:
return server_error_response("canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
else:
e, dia = DialogService.get_by_id(objs[0].dialog_id)
if not e:
return get_data_error_result(message="Dialog not found")
conv = {
"id": get_uuid(),
"dialog_id": dia.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/completion', methods=['POST']) # noqa: F821
@validate_request("conversation_id", "messages")
def completion():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req:
req["quote"] = False
msg = []
for m in req["messages"]:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
def rename_field(ans):
reference = ans['reference']
if not isinstance(reference, dict):
return
for chunk_i in reference.get('chunks', []):
if 'docnm_kwd' in chunk_i:
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
try:
if conv.source == "agent":
stream = req.get("stream", True)
conv.message.append(msg[-1])
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
if not e:
return server_error_response("canvas not found.")
del req["conversation_id"]
del req["messages"]
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "content": ""}
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=stream)
assert answer is not None, "Nothing. Is it over?"
if stream:
assert isinstance(answer, partial), "Nothing. Is it over?"
def sse():
nonlocal answer, cvs, conv
try:
for ans in answer():
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
canvas.history.append(("assistant", final_ans["content"]))
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
fillin_conv(result)
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_json_result(data=result)
# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
del req["conversation_id"]
del req["messages"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
answer = None
for ans in chat(dia, msg, **req):
answer = ans
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
break
rename_field(answer)
return get_json_result(data=answer)
except Exception as e:
return server_error_response(e)
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
# @login_required
def get_conversation(conversation_id):
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
e, conv = API4ConversationService.get_by_id(conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR)
for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0:
continue
for chunk_i in referenct_i['chunks']:
if 'docnm_kwd' in chunk_i.keys():
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/document/upload', methods=['POST']) # noqa: F821
@validate_request("kb_name")
def upload():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id
try:
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
kb_id = kb.id
except Exception as e:
return server_error_response(e)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, tenant_id)
kb_root_folder = FileService.get_kb_folder(tenant_id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
try:
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
return get_data_error_result(
message="Exceed the maximum file number of a free user!")
filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb_id)
filetype = filename_type(filename)
if not filetype:
return get_data_error_result(
message="This type of file has not been supported yet!")
location = filename
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
location += "_"
blob = request.files['file'].read()
settings.STORAGE_IMPL.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": kb.tenant_id,
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob),
"suffix": Path(filename).suffix.lstrip("."),
}
form_data = request.form
if "parser_id" in form_data.keys():
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
doc["parser_id"] = request.form.get("parser_id").strip()
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if doc["type"] == FileType.AURAL:
doc["parser_id"] = ParserType.AUDIO.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
if re.search(r"\.(eml)$", filename):
doc["parser_id"] = ParserType.EMAIL.value
doc_result = DocumentService.insert(doc)
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e:
return server_error_response(e)
if "run" in form_data.keys():
if request.form.get("run").strip() == "1":
try:
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(doc["id"], info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
tenant_id = DocumentService.get_tenant_id(doc["id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
# e, doc = DocumentService.get_by_id(doc["id"])
TaskService.filter_delete([Task.doc_id == doc["id"]])
e, doc = DocumentService.get_by_id(doc["id"])
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0)
except Exception as e:
return server_error_response(e)
return get_json_result(data=doc_result.to_json())
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
@validate_request("conversation_id")
def upload_parse():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
# @login_required
def list_chunks():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if "doc_name" in req.keys():
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
elif "doc_id" in req.keys():
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
doc_id = req['doc_id']
else:
return get_json_result(
data=False, message="Can't find doc_name or doc_id"
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
"doc_name": res_item["docnm_kwd"],
"image_id": res_item["img_id"]
} for res_item in res
]
except Exception as e:
return server_error_response(e)
return get_json_result(data=res)
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
# @login_required
def get_chunk(chunk_id):
from rag.nlp import search
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
tenant_id = objs[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response(Exception("Chunk not found"))
k = []
for n in chunk.keys():
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
k.append(n)
for n in k:
del chunk[n]
return get_json_result(data=chunk)
except Exception as e:
return server_error_response(e)
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
# @login_required
def list_kb_docs():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
tenant_id = objs[0].tenant_id
kb_name = req.get("kb_name", "").strip()
try:
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
kb_id = kb.id
except Exception as e:
return server_error_response(e)
page_number = int(req.get("page", 1))
items_per_page = int(req.get("page_size", 15))
orderby = req.get("orderby", "create_time")
desc = req.get("desc", True)
keywords = req.get("keywords", "")
status = req.get("status", [])
if status:
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
)
types = req.get("types", [])
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
return get_data_error_result(
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
)
try:
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
return get_json_result(data={"total": tol, "docs": docs})
except Exception as e:
return server_error_response(e)
@manager.route('/document/infos', methods=['POST']) # noqa: F821
@validate_request("doc_ids")
def docinfos():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts()))
@manager.route('/document', methods=['DELETE']) # noqa: F821
# @login_required
def document_rm():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
req = request.json
try:
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
for doc_id in req.get("doc_ids", []):
if doc_id not in doc_ids:
doc_ids.append(doc_id)
if not doc_ids:
return get_json_result(
data=False, message="Can't find doc_names or doc_ids"
)
except Exception as e:
return server_error_response(e)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, tenant_id)
errors = ""
docs = DocumentService.get_by_ids(doc_ids)
doc_dic = {}
for doc in docs:
doc_dic[doc.id] = doc
for doc_id in doc_ids:
try:
if doc_id not in doc_dic:
return get_data_error_result(message="Document not found!")
doc = doc_dic[doc_id]
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc_id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc_id)
settings.STORAGE_IMPL.rm(b, n)
except Exception as e:
errors += str(e)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=True)
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
@validate_request("Authorization", "conversation_id", "word")
def completion_faq():
import base64
req = request.json
token = req["Authorization"]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req:
req["quote"] = True
msg = [{"role": "user", "content": req["word"]}]
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
try:
if conv.source == "agent":
conv.message.append(msg[-1])
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
if not e:
return server_error_response("canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "doc_aggs": []}
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=False)
assert answer is not None, "Nothing. Is it over?"
data_type_picture = {
"type": 3,
"url": "base64 content"
}
data = [
{
"type": 1,
"content": ""
}
]
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
except Exception as e:
return server_error_response(e)
response = {"code": 200, "msg": "success", "data": data}
return response
# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
del req["conversation_id"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
data_type_picture = {
"type": 3,
"url": "base64 content"
}
data = [
{
"type": 1,
"content": ""
}
]
ans = ""
for a in chat(dia, msg, stream=False, **req):
ans = a
break
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
except Exception as e:
return server_error_response(e)
response = {"code": 200, "msg": "success", "data": data}
return response
except Exception as e:
return server_error_response(e)
@manager.route('/retrieval', methods=['POST']) # noqa: F821
@validate_request("kb_id", "question")
def retrieval():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
kb_ids = req.get("kb_id", [])
doc_ids = req.get("doc_ids", [])
question = req.get("question")
page = int(req.get("page", 1))
size = int(req.get("page_size", 30))
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
highlight = bool(req.get("highlight", False))
try:
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."',
code=RetCode.AUTHENTICATION_ERROR)
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs))
for c in ranks["chunks"]:
c.pop("vector", None)
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
return server_error_response(e)

View File

@ -18,12 +18,8 @@ import logging
import re
import sys
from functools import partial
import flask
import trio
from flask import request, Response
from flask_login import login_required, current_user
from quart import request, Response, make_response
from agent.component import LLM
from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
@ -35,7 +31,8 @@ from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
request_json
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
@ -46,6 +43,7 @@ from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from api.apps import login_required, current_user
@manager.route('/templates', methods=['GET']) # noqa: F821
@ -57,8 +55,9 @@ def templates():
@manager.route('/rm', methods=['POST']) # noqa: F821
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
async def rm():
req = await request_json()
for i in req["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -70,8 +69,8 @@ def rm():
@manager.route('/set', methods=['POST']) # noqa: F821
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
async def save():
req = await request_json()
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
@ -129,8 +128,8 @@ def getsse(canvas_id):
@manager.route('/completion', methods=['POST']) # noqa: F821
@validate_request("id")
@login_required
def run():
req = request.json
async def run():
req = await request_json()
query = req.get("query", "")
files = req.get("files", [])
inputs = req.get("inputs", {})
@ -160,10 +159,10 @@ def run():
except Exception as e:
return server_error_response(e)
def sse():
async def sse():
nonlocal canvas, user_id
try:
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
cvs.dsl = json.loads(str(canvas))
@ -179,15 +178,15 @@ def run():
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
resp.call_on_close(lambda: canvas.cancel_task())
#resp.call_on_close(lambda: canvas.cancel_task())
return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
async def rerun():
req = await request_json()
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
@ -224,8 +223,8 @@ def cancel(task_id):
@manager.route('/reset', methods=['POST']) # noqa: F821
@validate_request("id")
@login_required
def reset():
req = request.json
async def reset():
req = await request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -245,7 +244,7 @@ def reset():
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
async def upload(canvas_id):
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
@ -311,7 +310,8 @@ def upload(canvas_id):
except Exception as e:
return server_error_response(e)
file = request.files['file']
files = await request.files
file = files['file']
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
@ -342,8 +342,8 @@ def input_form():
@manager.route('/debug', methods=['POST']) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
async def debug():
req = await request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -374,8 +374,8 @@ def debug():
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
def test_db_connect():
req = request.json
async def test_db_connect():
req = await request_json()
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
@ -426,7 +426,6 @@ def test_db_connect():
try:
import trino
import os
from trino.auth import BasicAuthentication
except Exception as e:
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
@ -438,7 +437,7 @@ def test_db_connect():
auth = None
if http_scheme == "https" and req.get("password"):
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
conn = trino.dbapi.connect(
host=req["host"],
@ -471,8 +470,8 @@ def test_db_connect():
@login_required
def getlistversion(canvas_id):
try:
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=list)
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=versions)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")
@ -520,8 +519,8 @@ def list_canvas():
@manager.route('/setting', methods=['POST']) # noqa: F821
@validate_request("id", "title", "permission")
@login_required
def setting():
req = request.json
async def setting():
req = await request_json()
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):
@ -602,8 +601,8 @@ def prompts():
@manager.route('/download', methods=['GET']) # noqa: F821
def download():
async def download():
id = request.args.get("id")
created_by = request.args.get("created_by")
blob = FileService.get_blob(created_by, id)
return flask.make_response(blob)
return await make_response(blob)

View File

@ -18,8 +18,7 @@ import json
import re
import xxhash
from flask import request
from flask_login import current_user, login_required
from quart import request
from api.db.services.dialog_service import meta_filter
from api.db.services.document_service import DocumentService
@ -27,7 +26,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -35,13 +35,14 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
from common import settings
from api.apps import login_required, current_user
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id")
def list_chunk():
req = request.json
async def list_chunk():
req = await request_json()
doc_id = req["doc_id"]
page = int(req.get("page", 1))
size = int(req.get("size", 30))
@ -121,8 +122,8 @@ def get():
@manager.route('/set', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id", "chunk_id", "content_with_weight")
def set():
req = request.json
async def set():
req = await request_json()
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
@ -178,8 +179,8 @@ def set():
@manager.route('/switch', methods=['POST']) # noqa: F821
@login_required
@validate_request("chunk_ids", "available_int", "doc_id")
def switch():
req = request.json
async def switch():
req = await request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@ -198,8 +199,8 @@ def switch():
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("chunk_ids", "doc_id")
def rm():
req = request.json
async def rm():
req = await request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@ -222,8 +223,8 @@ def rm():
@manager.route('/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id", "content_with_weight")
def create():
req = request.json
async def create():
req = await request_json()
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
@ -280,8 +281,8 @@ def create():
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
@login_required
@validate_request("kb_id", "question")
def retrieval_test():
req = request.json
async def retrieval_test():
req = await request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import time
@ -20,8 +21,7 @@ import uuid
from html import escape
from typing import Any
from flask import make_response, request
from flask_login import current_user, login_required
from quart import request, make_response
from google_auth_oauthlib.flow import Flow
from api.db import InputType
@ -32,12 +32,13 @@ from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, Docum
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
from common.misc_utils import get_uuid
from rag.utils.redis_conn import REDIS_CONN
from api.apps import login_required, current_user
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
def set_connector():
req = request.json
async def set_connector():
req = await request.json
if req.get("id"):
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
ConnectorService.update_by_id(req["id"], conn)
@ -55,10 +56,9 @@ def set_connector():
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
"status": TaskStatus.SCHEDULE,
}
conn["status"] = TaskStatus.SCHEDULE
ConnectorService.save(**conn)
time.sleep(1)
await asyncio.sleep(1)
e, conn = ConnectorService.get_by_id(req["id"])
return get_json_result(data=conn.to_dict())
@ -89,8 +89,8 @@ def list_logs(connector_id):
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
@login_required
def resume(connector_id):
req = request.json
async def resume(connector_id):
req = await request.json
if req.get("resume"):
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
else:
@ -101,8 +101,8 @@ def resume(connector_id):
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("kb_id")
def rebuild(connector_id):
req = request.json
async def rebuild(connector_id):
req = await request.json
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
if err:
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
@ -146,7 +146,7 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
return {"web": web_section}
def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
status = "success" if success else "error"
auto_close = "window.close();" if success else ""
escaped_message = escape(message)
@ -164,7 +164,7 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
payload_json=payload_json,
auto_close=auto_close,
)
response = make_response(html, 200)
response = await make_response(html, 200)
response.headers["Content-Type"] = "text/html; charset=utf-8"
return response
@ -172,14 +172,14 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
@login_required
@validate_request("credentials")
def start_google_drive_web_oauth():
async def start_google_drive_web_oauth():
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
return get_json_result(
code=RetCode.SERVER_ERROR,
message="Google Drive OAuth redirect URI is not configured on the server.",
)
req = request.json or {}
req = await request.json or {}
raw_credentials = req.get("credentials", "")
try:
credentials = _load_credentials(raw_credentials)
@ -231,31 +231,31 @@ def start_google_drive_web_oauth():
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
def google_drive_web_oauth_callback():
async def google_drive_web_oauth_callback():
state_id = request.args.get("state")
error = request.args.get("error")
error_description = request.args.get("error_description") or error
if not state_id:
return _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
if not state_cache:
return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
state_obj = json.loads(state_cache)
client_config = state_obj.get("client_config")
if not client_config:
REDIS_CONN.delete(_web_state_cache_key(state_id))
return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
if error:
REDIS_CONN.delete(_web_state_cache_key(state_id))
return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
code = request.args.get("code")
if not code:
return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
try:
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
@ -264,7 +264,7 @@ def google_drive_web_oauth_callback():
except Exception as exc: # pragma: no cover - defensive
logging.exception("Failed to exchange Google OAuth code: %s", exc)
REDIS_CONN.delete(_web_state_cache_key(state_id))
return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
creds_json = flow.credentials.to_json()
result_payload = {
@ -274,14 +274,14 @@ def google_drive_web_oauth_callback():
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
REDIS_CONN.delete(_web_state_cache_key(state_id))
return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
@login_required
@validate_request("flow_id")
def poll_google_drive_web_result():
req = request.json or {}
async def poll_google_drive_web_result():
req = await request.json or {}
flow_id = req.get("flow_id")
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
if not cache_raw:

View File

@ -17,8 +17,8 @@ import json
import re
import logging
from copy import deepcopy
from flask import Response, request
from flask_login import current_user, login_required
from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
@ -34,8 +34,8 @@ from common.constants import RetCode, LLMType
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
def set_conversation():
req = request.json
async def set_conversation():
req = await request.json
conv_id = req.get("conversation_id")
is_new = req.get("is_new")
name = req.get("name", "New conversation")
@ -85,7 +85,6 @@ def get():
if not e:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
avatar = None
for tenant in tenants:
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog) > 0:
@ -129,8 +128,9 @@ def getsse(dialog_id):
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
def rm():
conv_ids = request.json["conversation_ids"]
async def rm():
req = await request.json
conv_ids = req["conversation_ids"]
try:
for cid in conv_ids:
exist, conv = ConversationService.get_by_id(cid)
@ -166,8 +166,8 @@ def list_conversation():
@manager.route("/completion", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "messages")
def completion():
req = request.json
async def completion():
req = await request.json
msg = []
for m in req["messages"]:
if m["role"] == "system":
@ -251,8 +251,8 @@ def completion():
@manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required
def tts():
req = request.json
async def tts():
req = await request.json
text = req["text"]
tenants = TenantService.get_info_by(current_user.id)
@ -284,8 +284,8 @@ def tts():
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def delete_msg():
req = request.json
async def delete_msg():
req = await request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@ -307,8 +307,8 @@ def delete_msg():
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def thumbup():
req = request.json
async def thumbup():
req = await request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@ -334,8 +334,8 @@ def thumbup():
@manager.route("/ask", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def ask_about():
req = request.json
async def ask_about():
req = await request.json
uid = current_user.id
search_id = req.get("search_id", "")
@ -366,8 +366,8 @@ def ask_about():
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def mindmap():
req = request.json
async def mindmap():
req = await request.json
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {}
@ -384,8 +384,8 @@ def mindmap():
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
def related_questions():
req = request.json
async def related_questions():
req = await request.json
search_id = req.get("search_id", "")
search_config = {}

View File

@ -14,8 +14,7 @@
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from quart import request
from api.db.services import duplicate_name
from api.db.services.dialog_service import DialogService
from common.constants import StatusEnum
@ -26,13 +25,14 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.utils.api_utils import get_json_result
from api.apps import login_required, current_user
@manager.route('/set', methods=['POST']) # noqa: F821
@validate_request("prompt_config")
@login_required
def set_dialog():
req = request.json
async def set_dialog():
req = await request.json
dialog_id = req.get("dialog_id", "")
is_create = not dialog_id
name = req.get("name", "New Dialog")
@ -154,33 +154,34 @@ def get_kb_names(kb_ids):
@login_required
def list_dialogs():
try:
diags = DialogService.query(
conversations = DialogService.query(
tenant_id=current_user.id,
status=StatusEnum.VALID.value,
reverse=True,
order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
return get_json_result(data=diags)
conversations = [d.to_dict() for d in conversations]
for conversation in conversations:
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
return get_json_result(data=conversations)
except Exception as e:
return server_error_response(e)
@manager.route('/next', methods=['POST']) # noqa: F821
@login_required
def list_dialogs_next():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
parser_id = request.args.get("parser_id")
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
async def list_dialogs_next():
args = request.args
keywords = args.get("keywords", "")
page_number = int(args.get("page", 0))
items_per_page = int(args.get("page_size", 0))
parser_id = args.get("parser_id")
orderby = args.get("orderby", "create_time")
if args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
req = request.get_json()
req = await request.get_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -207,8 +208,8 @@ def list_dialogs_next():
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("dialog_ids")
def rm():
req = request.json
async def rm():
req = await request.json
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try:

View File

@ -18,11 +18,8 @@ import os.path
import pathlib
import re
from pathlib import Path
import flask
from flask import request
from flask_login import current_user, login_required
from quart import request, make_response
from api.apps import current_user, login_required
from api.common.check_team_permission import check_kb_team_permission
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.db import VALID_FILE_TYPES, FileType
@ -39,7 +36,7 @@ from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
validate_request, request_json,
)
from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory
@ -53,14 +50,16 @@ from common import settings
@manager.route("/upload", methods=["POST"]) # noqa: F821
@login_required
@validate_request("kb_id")
def upload():
kb_id = request.form.get("kb_id")
async def upload():
form = await request.form
kb_id = form.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
if "file" not in request.files:
files = await request.files
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
@ -87,12 +86,13 @@ def upload():
@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
@login_required
@validate_request("kb_id", "name", "url")
def web_crawl():
kb_id = request.form.get("kb_id")
async def web_crawl():
form = await request.form
kb_id = form.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
name = request.form.get("name")
url = request.form.get("url")
name = form.get("name")
url = form.get("url")
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
@ -152,8 +152,8 @@ def web_crawl():
@manager.route("/create", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "kb_id")
def create():
req = request.json
async def create():
req = await request_json()
kb_id = req["kb_id"]
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -208,7 +208,7 @@ def create():
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_docs():
async def list_docs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -230,7 +230,7 @@ def list_docs():
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
req = request.get_json()
req = await request.get_json()
run_status = req.get("run_status", [])
if run_status:
@ -270,8 +270,8 @@ def list_docs():
@manager.route("/filter", methods=["POST"]) # noqa: F821
@login_required
def get_filter():
req = request.get_json()
async def get_filter():
req = await request.get_json()
kb_id = req.get("kb_id")
if not kb_id:
@ -308,8 +308,8 @@ def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
def docinfos():
req = request.json
async def doc_infos():
req = await request_json()
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
@ -340,8 +340,8 @@ def thumbnails():
@manager.route("/change_status", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_ids", "status")
def change_status():
req = request.get_json()
async def change_status():
req = await request.get_json()
doc_ids = req.get("doc_ids", [])
status = str(req.get("status", ""))
@ -380,8 +380,8 @@ def change_status():
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id")
def rm():
req = request.json
async def rm():
req = await request_json()
doc_ids = req["doc_id"]
if isinstance(doc_ids, str):
doc_ids = [doc_ids]
@ -401,8 +401,8 @@ def rm():
@manager.route("/run", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_ids", "run")
def run():
req = request.json
async def run():
req = await request_json()
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -448,8 +448,8 @@ def run():
@manager.route("/rename", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "name")
def rename():
req = request.json
async def rename():
req = await request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
@ -495,19 +495,20 @@ def rename():
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
# @login_required
def get(doc_id):
async def get(doc_id):
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
response = await make_response(settings.STORAGE_IMPL.get(b, n))
ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None
if ext:
if doc.type == FileType.VISUAL.value:
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
else:
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
@ -517,12 +518,28 @@ def get(doc_id):
return server_error_response(e)
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
@login_required
async def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
return response
except Exception as e:
return server_error_response(e)
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id")
def change_parser():
async def change_parser():
req = request.json
req = await request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -544,6 +561,7 @@ def change_parser():
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None
try:
if "pipeline_id" in req and req["pipeline_id"] != "":
@ -572,13 +590,13 @@ def change_parser():
@manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
# @login_required
def get_image(image_id):
async def get_image(image_id):
try:
arr = image_id.split("-")
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
response.headers.set("Content-Type", "image/JPEG")
return response
except Exception as e:
@ -588,24 +606,25 @@ def get_image(image_id):
@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id")
def upload_and_parse():
if "file" not in request.files:
async def upload_and_parse():
files = await request.file
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
form = await request.form
doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
return get_json_result(data=doc_ids)
@manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required
def parse():
url = request.json.get("url") if request.json else ""
async def parse():
url = await request.json.get("url") if await request.json else ""
if url:
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
@ -646,10 +665,11 @@ def parse():
txt = FileService.parse_docs([f], current_user.id)
return get_json_result(data=txt)
if "file" not in request.files:
files = await request.files
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
txt = FileService.parse_docs(file_objs, current_user.id)
return get_json_result(data=txt)
@ -658,8 +678,8 @@ def parse():
@manager.route("/set_meta", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "meta")
def set_meta():
req = request.json
async def set_meta():
req = await request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:

View File

@ -19,8 +19,8 @@ from pathlib import Path
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from flask import request
from flask_login import login_required, current_user
from quart import request
from api.apps import login_required, current_user
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from common.misc_utils import get_uuid
@ -33,8 +33,8 @@ from api.utils.api_utils import get_json_result
@manager.route('/convert', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids", "kb_ids")
def convert():
req = request.json
async def convert():
req = await request.json
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
@ -103,8 +103,8 @@ def convert():
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids")
def rm():
req = request.json
async def rm():
req = await request.json
file_ids = req["file_ids"]
if not file_ids:
return get_json_result(

View File

@ -17,10 +17,8 @@ import logging
import os
import pathlib
import re
import flask
from flask import request
from flask_login import login_required, current_user
from quart import request, make_response
from api.apps import login_required, current_user
from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService
@ -40,17 +38,19 @@ from common import settings
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required
# @validate_request("parent_id")
def upload():
pf_id = request.form.get("parent_id")
async def upload():
form = await request.form
pf_id = form.get("parent_id")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
if 'file' not in request.files:
files = await request.files
if 'file' not in files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
file_objs = files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
@ -123,10 +123,10 @@ def upload():
@manager.route('/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.json
pf_id = request.json.get("parent_id")
input_file_type = request.json.get("type")
async def create():
req = await request.json
pf_id = await request.json.get("parent_id")
input_file_type = await request.json.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@ -238,16 +238,16 @@ def get_all_parent_folders():
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("file_ids")
def rm():
req = request.json
async def rm():
req = await request.json
file_ids = req["file_ids"]
def _delete_single_file(file):
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
@ -299,8 +299,8 @@ def rm():
@manager.route('/rename', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_id", "name")
def rename():
req = request.json
async def rename():
req = await request.json
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@ -338,7 +338,7 @@ def rename():
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
@login_required
def get(file_id):
async def get(file_id):
try:
e, file = FileService.get_by_id(file_id)
if not e:
@ -351,7 +351,7 @@ def get(file_id):
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = settings.STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
ext = ext.group(1) if ext else None
if ext:
@ -368,8 +368,8 @@ def get(file_id):
@manager.route("/mv", methods=["POST"]) # noqa: F821
@login_required
@validate_request("src_file_ids", "dest_file_id")
def move():
req = request.json
async def move():
req = await request.json
try:
file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"]

View File

@ -16,12 +16,11 @@
import json
import logging
import random
import re
from flask import request
from flask_login import login_required, current_user
from quart import request
import numpy as np
from api.db.services.connector_service import Connector2KbService
from api.db.services.llm_service import LLMBundle
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
@ -30,7 +29,8 @@ from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
request_json
from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
@ -41,23 +41,28 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
from api.apps import login_required, current_user
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.json
req = KnowledgebaseService.create_with_name(
async def create():
req = await request_json()
e, res = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = current_user.id,
parser_id = req.pop("parser_id", None),
**req
)
if not e:
return res
try:
if not KnowledgebaseService.save(**req):
if not KnowledgebaseService.save(**res):
return get_data_error_result()
return get_json_result(data={"kb_id":req["id"]})
return get_json_result(data={"kb_id":res["id"]})
except Exception as e:
return server_error_response(e)
@ -66,8 +71,8 @@ def create():
@login_required
@validate_request("kb_id", "name", "description", "parser_id")
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
def update():
req = request.json
async def update():
req = await request_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "":
@ -165,18 +170,19 @@ def detail():
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
def list_kbs():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
parser_id = request.args.get("parser_id")
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
async def list_kbs():
args = request.args
keywords = args.get("keywords", "")
page_number = int(args.get("page", 0))
items_per_page = int(args.get("page_size", 0))
parser_id = args.get("parser_id")
orderby = args.get("orderby", "create_time")
if args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
req = request.get_json()
req = await request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -198,11 +204,12 @@ def list_kbs():
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['post']) # noqa: F821
@login_required
@validate_request("kb_id")
def rm():
req = request.json
async def rm():
req = await request_json()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
@ -278,8 +285,8 @@ def list_tags_from_kbs():
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required
def rm_tags(kb_id):
req = request.json
async def rm_tags(kb_id):
req = await request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@ -298,8 +305,8 @@ def rm_tags(kb_id):
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required
def rename_tags(kb_id):
req = request.json
async def rename_tags(kb_id):
req = await request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@ -402,7 +409,7 @@ def get_basic_info():
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_logs():
async def list_pipeline_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -421,7 +428,7 @@ def list_pipeline_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json()
req = await request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@ -446,7 +453,7 @@ def list_pipeline_logs():
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_dataset_logs():
async def list_pipeline_dataset_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -463,7 +470,7 @@ def list_pipeline_dataset_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json()
req = await request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@ -480,12 +487,12 @@ def list_pipeline_dataset_logs():
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
@login_required
def delete_pipeline_logs():
async def delete_pipeline_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
req = request.get_json()
req = await request_json()
log_ids = req.get("log_ids", [])
PipelineOperationLogService.delete_by_ids(log_ids)
@ -509,8 +516,8 @@ def pipeline_log_detail():
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
def run_graphrag():
req = request.json
async def run_graphrag():
req = await request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -578,8 +585,8 @@ def trace_graphrag():
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
@login_required
def run_raptor():
req = request.json
async def run_raptor():
req = await request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -647,8 +654,8 @@ def trace_raptor():
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
@login_required
def run_mindmap():
req = request.json
async def run_mindmap():
req = await request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -731,6 +738,8 @@ def delete_kb_task():
def cancel_task(task_id):
REDIS_CONN.set(f"{task_id}-cancel", "x")
kb_task_id_field: str = ""
kb_task_finish_at: str = ""
match pipeline_task_type:
case PipelineTaskType.GRAPH_RAG:
kb_task_id_field = "graphrag_task_id"
@ -761,7 +770,7 @@ def delete_kb_task():
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
@login_required
def check_embedding():
async def check_embedding():
def _guess_vec_field(src: dict) -> str | None:
for k in src or {}:
@ -807,12 +816,12 @@ def check_embedding():
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
total = docStoreConn.getTotal(res0)
total = docStoreConn.get_total(res0)
if total <= 0:
return []
n = min(n, total)
offsets = sorted(random.sample(range(total), n))
offsets = sorted(random.sample(range(min(total,1000)), n))
out = []
for off in offsets:
@ -824,7 +833,7 @@ def check_embedding():
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
ids = docStoreConn.getChunkIds(res1)
ids = docStoreConn.get_chunk_ids(res1)
if not ids:
continue
@ -845,9 +854,14 @@ def check_embedding():
"position_int": full_doc.get("position_int"),
"top_int": full_doc.get("top_int"),
"content_with_weight": full_doc.get("content_with_weight") or "",
"question_kwd": full_doc.get("question_kwd") or []
})
return out
req = request.json
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = await request_json()
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
n = int(req.get("check_num", 5))
@ -859,8 +873,10 @@ def check_embedding():
results, eff_sims = [], []
for ck in samples:
txt = (ck.get("content_with_weight") or "").strip()
if not txt:
title = ck.get("doc_name") or "Title"
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
txt_in = _clean(txt_in)
if not txt_in:
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
continue
@ -869,10 +885,19 @@ def check_embedding():
continue
try:
qv, _ = emb_mdl.encode_queries(txt)
sim = _cos_sim(qv, ck["vector"])
except Exception:
return get_error_data_result(message="embedding failure")
v, _ = emb_mdl.encode([title, txt_in])
assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})"
sim_content = _cos_sim(v[1], ck["vector"])
title_w = 0.1
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
sim_mix = _cos_sim(qv_mix, ck["vector"])
sim = sim_content
mode = "content_only"
if sim_mix > sim:
sim = sim_mix
mode = "title+content"
except Exception as e:
return get_error_data_result(message=f"Embedding failure. {e}")
eff_sims.append(sim)
results.append({
@ -892,9 +917,10 @@ def check_embedding():
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
"match_mode": mode,
}
if summary["avg_cos_sim"] > 0.99:
if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})

View File

@ -15,8 +15,8 @@
#
from flask import request
from flask_login import current_user, login_required
from quart import request
from api.apps import current_user, login_required
from langfuse import Langfuse
from api.db.db_models import DB
@ -27,8 +27,8 @@ from api.utils.api_utils import get_error_data_result, get_json_result, server_e
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
@login_required
@validate_request("secret_key", "public_key", "host")
def set_api_key():
req = request.get_json()
async def set_api_key():
req = await request.get_json()
secret_key = req.get("secret_key", "")
public_key = req.get("public_key", "")
host = req.get("host", "")

View File

@ -16,8 +16,9 @@
import logging
import json
import os
from flask import request
from flask_login import login_required, current_user
from quart import request
from api.apps import login_required, current_user
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
@ -52,8 +53,8 @@ def factories():
@manager.route("/set_api_key", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
async def set_api_key():
req = await request.json
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
@ -122,8 +123,8 @@ def set_api_key():
@manager.route("/add_llm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory")
def add_llm():
req = request.json
async def add_llm():
req = await request.json
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
@ -142,11 +143,11 @@ def add_llm():
elif factory == "Tencent Hunyuan":
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
return set_api_key()
return await set_api_key()
elif factory == "Tencent Cloud":
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
return set_api_key()
return await set_api_key()
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
@ -267,8 +268,8 @@ def add_llm():
@manager.route("/delete_llm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory", "llm_name")
def delete_llm():
req = request.json
async def delete_llm():
req = await request.json
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True)
@ -276,8 +277,8 @@ def delete_llm():
@manager.route("/enable_llm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory", "llm_name")
def enable_llm():
req = request.json
async def enable_llm():
req = await request.json
TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
)
@ -287,8 +288,8 @@ def enable_llm():
@manager.route("/delete_factory", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory")
def delete_factory():
req = request.json
async def delete_factory():
req = await request.json
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True)

View File

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import Response, request
from flask_login import current_user, login_required
from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import MCPServer
from api.db.services.mcp_server_service import MCPServerService
@ -25,12 +25,12 @@ from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_mcp() -> Response:
async def list_mcp() -> Response:
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
@ -40,7 +40,7 @@ def list_mcp() -> Response:
else:
desc = True
req = request.get_json()
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
@ -72,8 +72,8 @@ def detail() -> Response:
@manager.route("/create", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "url", "server_type")
def create() -> Response:
req = request.get_json()
async def create() -> Response:
req = await request.get_json()
server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES:
@ -127,8 +127,8 @@ def create() -> Response:
@manager.route("/update", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id")
def update() -> Response:
req = request.get_json()
async def update() -> Response:
req = await request.get_json()
mcp_id = req.get("mcp_id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
@ -183,8 +183,8 @@ def update() -> Response:
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def rm() -> Response:
req = request.get_json()
async def rm() -> Response:
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
try:
@ -201,8 +201,8 @@ def rm() -> Response:
@manager.route("/import", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcpServers")
def import_multiple() -> Response:
req = request.get_json()
async def import_multiple() -> Response:
req = await request.get_json()
servers = req.get("mcpServers", {})
if not servers:
return get_data_error_result(message="No MCP servers provided.")
@ -268,8 +268,8 @@ def import_multiple() -> Response:
@manager.route("/export", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def export_multiple() -> Response:
req = request.get_json()
async def export_multiple() -> Response:
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
@ -300,8 +300,8 @@ def export_multiple() -> Response:
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def list_tools() -> Response:
req = request.get_json()
async def list_tools() -> Response:
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")
@ -347,8 +347,8 @@ def list_tools() -> Response:
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id", "tool_name", "arguments")
def test_tool() -> Response:
req = request.get_json()
async def test_tool() -> Response:
req = await request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@ -380,8 +380,8 @@ def test_tool() -> Response:
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id", "tools")
def cache_tool() -> Response:
req = request.get_json()
async def cache_tool() -> Response:
req = await request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@ -403,8 +403,8 @@ def cache_tool() -> Response:
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
@validate_request("url", "server_type")
def test_mcp() -> Response:
req = request.get_json()
async def test_mcp() -> Response:
req = await request.get_json()
url = req.get("url", "")
if not url:

View File

@ -15,8 +15,8 @@
#
from flask import Response
from flask_login import login_required
from quart import Response
from api.apps import login_required
from api.utils.api_utils import get_json_result
from plugin import GlobalPluginManager

View File

@ -27,7 +27,7 @@ from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
from api.utils.api_utils import get_result
from flask import request, Response
from quart import request, Response
@manager.route('/agents', methods=['GET']) # noqa: F821
@ -41,19 +41,19 @@ def list_agents(tenant_id):
return get_error_data_result("The agent doesn't exist.")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
orderby = request.args.get("orderby", "update_time")
order_by = request.args.get("orderby", "update_time")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
return get_result(data=canvas)
@manager.route("/agents", methods=["POST"]) # noqa: F821
@token_required
def create_agent(tenant_id: str):
req: dict[str, Any] = cast(dict[str, Any], request.json)
async def create_agent(tenant_id: str):
req: dict[str, Any] = cast(dict[str, Any], await request.json)
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@ -89,8 +89,8 @@ def create_agent(tenant_id: str):
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
@token_required
def update_agent(tenant_id: str, agent_id: str):
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None}
async def update_agent(tenant_id: str, agent_id: str):
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@ -135,8 +135,8 @@ def delete_agent(tenant_id: str, agent_id: str):
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
@token_required
def webhook(tenant_id: str, agent_id: str):
req = request.json
async def webhook(tenant_id: str, agent_id: str):
req = await request.json
if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',

View File

@ -14,22 +14,20 @@
# limitations under the License.
#
import logging
from flask import request
from quart import request
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
@manager.route("/chats", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id):
req = request.json
async def create(tenant_id):
req = await request_json()
ids = [i for i in req.get("dataset_ids", []) if i]
for kb_id in ids:
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
@ -145,10 +143,10 @@ def create(tenant_id):
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, chat_id):
async def update(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You do not own the chat")
req = request.json
req = await request_json()
ids = req.get("dataset_ids", [])
if "show_quotation" in req:
req["do_refer"] = req.pop("show_quotation")
@ -228,10 +226,10 @@ def update(tenant_id, chat_id):
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id):
async def delete_chats(tenant_id):
errors = []
success_count = 0
req = request.json
req = await request_json()
if not req:
ids = None
else:
@ -251,8 +249,8 @@ def delete(tenant_id):
errors.append(f"Assistant({id}) not found.")
continue
temp_dict = {"status": StatusEnum.INVALID.value}
DialogService.update_by_id(id, temp_dict)
success_count += 1
success_count += DialogService.update_by_id(id, temp_dict)
print(success_count, "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$", flush=True)
if errors:
if success_count > 0:

View File

@ -18,13 +18,14 @@
import logging
import os
import json
from flask import request
from quart import request
from peewee import OperationalError
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
from api.db.services.user_service import TenantService
from common.constants import RetCode, FileSource, StatusEnum
from api.utils.api_utils import (
@ -53,7 +54,7 @@ from common import settings
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id):
async def create(tenant_id):
"""
Create a new dataset.
---
@ -115,17 +116,19 @@ def create(tenant_id):
# | embedding_model| embd_id |
# | chunk_method | parser_id |
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
req, err = await validate_and_parse_json_request(request, CreateDatasetReq)
if err is not None:
return get_error_argument_result(err)
req = KnowledgebaseService.create_with_name(
e, req = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = tenant_id,
parser_id = req.pop("parser_id", None),
**req
)
if not e:
return req
# Insert embedding model(embd id)
ok, t = TenantService.get_by_id(tenant_id)
if not ok:
@ -144,7 +147,6 @@ def create(tenant_id):
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
except Exception as e:
@ -153,7 +155,7 @@ def create(tenant_id):
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id):
async def delete(tenant_id):
"""
Delete datasets.
---
@ -191,7 +193,7 @@ def delete(tenant_id):
schema:
type: object
"""
req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
req, err = await validate_and_parse_json_request(request, DeleteDatasetReq)
if err is not None:
return get_error_argument_result(err)
@ -251,7 +253,7 @@ def delete(tenant_id):
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, dataset_id):
async def update(tenant_id, dataset_id):
"""
Update a dataset.
---
@ -317,7 +319,7 @@ def update(tenant_id, dataset_id):
# | embedding_model| embd_id |
# | chunk_method | parser_id |
extras = {"dataset_id": dataset_id}
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
if err is not None:
return get_error_argument_result(err)
@ -532,3 +534,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
@token_required
def run_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
return get_result(data={"graphrag_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
@token_required
def trace_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_result(data={})
return get_result(data=task.to_dict())
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
@token_required
def run_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
return get_result(data={"raptor_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
@token_required
def trace_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
return get_result(data=task.to_dict())

View File

@ -15,7 +15,7 @@
#
import logging
from flask import request, jsonify
from quart import request, jsonify
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -29,7 +29,7 @@ from common import settings
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
@apikey_required
@validate_request("knowledge_id", "query")
def retrieval(tenant_id):
async def retrieval(tenant_id):
"""
Dify-compatible retrieval API
---
@ -113,7 +113,7 @@ def retrieval(tenant_id):
404:
description: Knowledge base or document not found
"""
req = request.json
req = await request.json
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)
@ -131,12 +131,10 @@ def retrieval(tenant_id):
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
print(metadata_condition)
# print("after", convert_conditions(metadata_condition))
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
# print("doc_ids", doc_ids)
if not doc_ids and metadata_condition is not None:
doc_ids = ['-999']
if metadata_condition:
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
if not doc_ids and metadata_condition:
doc_ids = ["-999"]
ranks = settings.retriever.retrieval(
question,
embd_mdl,

View File

@ -20,7 +20,7 @@ import re
from io import BytesIO
import xxhash
from flask import request, send_file
from quart import request, send_file
from peewee import OperationalError
from pydantic import BaseModel, Field, validator
@ -35,7 +35,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -69,7 +70,7 @@ class Chunk(BaseModel):
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
@token_required
def upload(dataset_id, tenant_id):
async def upload(dataset_id, tenant_id):
"""
Upload documents to a dataset.
---
@ -93,6 +94,10 @@ def upload(dataset_id, tenant_id):
type: file
required: true
description: Document files to upload.
- in: formData
name: parent_path
type: string
description: Optional nested path under the parent folder. Uses '/' separators.
responses:
200:
description: Successfully uploaded documents.
@ -126,9 +131,11 @@ def upload(dataset_id, tenant_id):
type: string
description: Processing status.
"""
if "file" not in request.files:
form = await request.form
files = await request.files
if "file" not in files:
return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_result(message="No file selected!", code=RetCode.ARGUMENT_ERROR)
@ -151,7 +158,7 @@ def upload(dataset_id, tenant_id):
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id)
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=form.get("parent_path"))
if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
# rename key's name
@ -175,7 +182,7 @@ def upload(dataset_id, tenant_id):
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["PUT"]) # noqa: F821
@token_required
def update_doc(tenant_id, dataset_id, document_id):
async def update_doc(tenant_id, dataset_id, document_id):
"""
Update a document within a dataset.
---
@ -224,7 +231,7 @@ def update_doc(tenant_id, dataset_id, document_id):
schema:
type: object
"""
req = request.json
req = await request_json()
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message="You don't own the dataset.")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
@ -355,7 +362,7 @@ def update_doc(tenant_id, dataset_id, document_id):
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
@token_required
def download(tenant_id, dataset_id, document_id):
async def download(tenant_id, dataset_id, document_id):
"""
Download a document from a dataset.
---
@ -405,10 +412,10 @@ def download(tenant_id, dataset_id, document_id):
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type
return send_file(
return await send_file(
file,
as_attachment=True,
download_name=doc[0].name,
attachment_filename=doc[0].name,
mimetype="application/octet-stream", # Set a default MIME type
)
@ -585,7 +592,7 @@ def list_docs(dataset_id, tenant_id):
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id, dataset_id):
async def delete(tenant_id, dataset_id):
"""
Delete documents from a dataset.
---
@ -624,7 +631,7 @@ def delete(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
req = request.json
req = await request_json()
if not req:
doc_ids = None
else:
@ -695,7 +702,7 @@ def delete(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
@token_required
def parse(tenant_id, dataset_id):
async def parse(tenant_id, dataset_id):
"""
Start parsing documents into chunks.
---
@ -734,7 +741,7 @@ def parse(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = request.json
req = await request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids")
@ -778,7 +785,7 @@ def parse(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
@token_required
def stop_parsing(tenant_id, dataset_id):
async def stop_parsing(tenant_id, dataset_id):
"""
Stop parsing documents into chunks.
---
@ -817,7 +824,7 @@ def stop_parsing(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = request.json
req = await request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
@ -1019,7 +1026,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
"/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["POST"]
)
@token_required
def add_chunk(tenant_id, dataset_id, document_id):
async def add_chunk(tenant_id, dataset_id, document_id):
"""
Add a chunk to a document.
---
@ -1089,7 +1096,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
req = request.json
req = await request_json()
if not str(req.get("content", "")).strip():
return get_error_data_result(message="`content` is required")
if "important_keywords" in req:
@ -1148,7 +1155,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
"datasets/<dataset_id>/documents/<document_id>/chunks", methods=["DELETE"]
)
@token_required
def rm_chunk(tenant_id, dataset_id, document_id):
async def rm_chunk(tenant_id, dataset_id, document_id):
"""
Remove chunks from a document.
---
@ -1195,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
docs = DocumentService.get_by_ids([document_id])
if not docs:
raise LookupError(f"Can't find the document with ID {document_id}!")
req = request.json
req = await request_json()
condition = {"doc_id": document_id}
if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
@ -1219,7 +1226,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
"/datasets/<dataset_id>/documents/<document_id>/chunks/<chunk_id>", methods=["PUT"]
)
@token_required
def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
"""
Update a chunk within a document.
---
@ -1281,7 +1288,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
req = request.json
req = await request_json()
if "content" in req:
content = req["content"]
else:
@ -1323,7 +1330,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
@token_required
def retrieval_test(tenant_id):
async def retrieval_test(tenant_id):
"""
Retrieve chunks based on a query.
---
@ -1404,7 +1411,7 @@ def retrieval_test(tenant_id):
format: float
description: Similarity score.
"""
req = request.json
req = await request_json()
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"]

View File

@ -17,9 +17,7 @@
import pathlib
import re
import flask
from flask import request
from quart import request, make_response
from pathlib import Path
from api.db.services.document_service import DocumentService
@ -37,7 +35,7 @@ from common import settings
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required
def upload(tenant_id):
async def upload(tenant_id):
"""
Upload a file to the system.
---
@ -79,15 +77,17 @@ def upload(tenant_id):
type: string
description: File type (e.g., document, folder)
"""
pf_id = request.form.get("parent_id")
form = await request.form
files = await request.files
pf_id = form.get("parent_id")
if not pf_id:
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
if 'file' not in request.files:
if 'file' not in files:
return get_json_result(data=False, message='No file part!', code=400)
file_objs = request.files.getlist('file')
file_objs = files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
@ -151,7 +151,7 @@ def upload(tenant_id):
@manager.route('/file/create', methods=['POST']) # noqa: F821
@token_required
def create(tenant_id):
async def create(tenant_id):
"""
Create a new file or folder.
---
@ -193,9 +193,9 @@ def create(tenant_id):
type:
type: string
"""
req = request.json
pf_id = request.json.get("parent_id")
input_file_type = request.json.get("type")
req = await request.json
pf_id = await request.json.get("parent_id")
input_file_type = await request.json.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
@ -450,7 +450,7 @@ def get_all_parent_folders(tenant_id):
@manager.route('/file/rm', methods=['POST']) # noqa: F821
@token_required
def rm(tenant_id):
async def rm(tenant_id):
"""
Delete one or multiple files/folders.
---
@ -481,7 +481,7 @@ def rm(tenant_id):
type: boolean
example: true
"""
req = request.json
req = await request.json
file_ids = req["file_ids"]
try:
for file_id in file_ids:
@ -524,7 +524,7 @@ def rm(tenant_id):
@manager.route('/file/rename', methods=['POST']) # noqa: F821
@token_required
def rename(tenant_id):
async def rename(tenant_id):
"""
Rename a file.
---
@ -556,7 +556,7 @@ def rename(tenant_id):
type: boolean
example: true
"""
req = request.json
req = await request.json
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@ -585,7 +585,7 @@ def rename(tenant_id):
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
@token_required
def get(tenant_id, file_id):
async def get(tenant_id, file_id):
"""
Download a file.
---
@ -619,7 +619,7 @@ def get(tenant_id, file_id):
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = settings.STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name)
if ext:
if file.type == FileType.VISUAL.value:
@ -633,7 +633,7 @@ def get(tenant_id, file_id):
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required
def move(tenant_id):
async def move(tenant_id):
"""
Move one or multiple files to another folder.
---
@ -667,7 +667,7 @@ def move(tenant_id):
type: boolean
example: true
"""
req = request.json
req = await request.json
try:
file_ids = req["src_file_ids"]
parent_id = req["dest_file_id"]
@ -693,8 +693,8 @@ def move(tenant_id):
@manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required
def convert(tenant_id):
req = request.json
async def convert(tenant_id):
req = await request.json
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []

View File

@ -18,7 +18,7 @@ import re
import time
import tiktoken
from flask import Response, jsonify, request
from quart import Response, jsonify, request
from agent.canvas import Canvas
from api.db.db_models import APIToken
@ -44,8 +44,8 @@ from common import settings
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id, chat_id):
req = request.json
async def create(tenant_id, chat_id):
req = await request.json
req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
@ -97,8 +97,8 @@ def create_agent_session(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, chat_id, session_id):
req = request.json
async def update(tenant_id, chat_id, session_id):
req = await request.json
req["dialog_id"] = chat_id
conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
@ -119,8 +119,8 @@ def update(tenant_id, chat_id, session_id):
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def chat_completion(tenant_id, chat_id):
req = request.json
async def chat_completion(tenant_id, chat_id):
req = await request.json
if not req:
req = {"question": ""}
if not req.get("session_id"):
@ -149,7 +149,7 @@ def chat_completion(tenant_id, chat_id):
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def chat_completion_openai_like(tenant_id, chat_id):
async def chat_completion_openai_like(tenant_id, chat_id):
"""
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
@ -206,7 +206,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
if reference:
print(completion.choices[0].message.reference)
"""
req = request.get_json()
req = await request.get_json()
need_reference = bool(req.get("reference", False))
@ -383,8 +383,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def agents_completion_openai_compatibility(tenant_id, agent_id):
req = request.json
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await request.json
tiktokenenc = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
@ -443,8 +443,8 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def agent_completions(tenant_id, agent_id):
req = request.json
async def agent_completions(tenant_id, agent_id):
req = await request.json
if req.get("stream", True):
@ -610,13 +610,13 @@ def list_agent_session(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id, chat_id):
async def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You don't own the chat")
errors = []
success_count = 0
req = request.json
req = await request.json
convs = ConversationService.query(dialog_id=chat_id)
if not req:
ids = None
@ -661,10 +661,10 @@ def delete(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete_agent_session(tenant_id, agent_id):
async def delete_agent_session(tenant_id, agent_id):
errors = []
success_count = 0
req = request.json
req = await request.json
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
@ -716,8 +716,8 @@ def delete_agent_session(tenant_id, agent_id):
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required
def ask_about(tenant_id):
req = request.json
async def ask_about(tenant_id):
req = await request.json
if not req.get("question"):
return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"):
@ -755,8 +755,8 @@ def ask_about(tenant_id):
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required
def related_questions(tenant_id):
req = request.json
async def related_questions(tenant_id):
req = await request.json
if not req.get("question"):
return get_error_data_result("`question` is required.")
question = req["question"]
@ -806,8 +806,8 @@ Related search terms:
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
def chatbot_completions(dialog_id):
req = request.json
async def chatbot_completions(dialog_id):
req = await request.json
token = request.headers.get("Authorization").split()
if len(token) != 2:
@ -856,8 +856,8 @@ def chatbots_inputs(dialog_id):
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
def agent_bot_completions(agent_id):
req = request.json
async def agent_bot_completions(agent_id):
req = await request.json
token = request.headers.get("Authorization").split()
if len(token) != 2:
@ -901,7 +901,7 @@ def begin_inputs(agent_id):
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
@validate_request("question", "kb_ids")
def ask_about_embedded():
async def ask_about_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -910,7 +910,7 @@ def ask_about_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = request.json
req = await request.json
uid = objs[0].tenant_id
search_id = req.get("search_id", "")
@ -940,7 +940,7 @@ def ask_about_embedded():
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
@validate_request("kb_id", "question")
def retrieval_test_embedded():
async def retrieval_test_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -949,7 +949,7 @@ def retrieval_test_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = request.json
req = await request.json
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
@ -1039,7 +1039,7 @@ def retrieval_test_embedded():
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
@validate_request("question")
def related_questions_embedded():
async def related_questions_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -1048,7 +1048,7 @@ def related_questions_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = request.json
req = await request.json
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
@ -1115,7 +1115,7 @@ def detail_share_embedded():
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
@validate_request("question", "kb_ids")
def mindmap():
async def mindmap():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -1125,7 +1125,7 @@ def mindmap():
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
req = request.json
req = await request.json
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}

View File

@ -14,8 +14,8 @@
# limitations under the License.
#
from flask import request
from flask_login import current_user, login_required
from quart import request
from api.apps import current_user, login_required
from api.constants import DATASET_NAME_LIMIT
from api.db.db_models import DB
@ -30,8 +30,8 @@ from api.utils.api_utils import get_data_error_result, get_json_result, not_allo
@manager.route("/create", methods=["post"]) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.get_json()
async def create():
req = await request.get_json()
search_name = req["name"]
description = req.get("description", "")
if not isinstance(search_name, str):
@ -65,8 +65,8 @@ def create():
@login_required
@validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
def update():
req = request.get_json()
async def update():
req = await request.get_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "":
@ -140,7 +140,7 @@ def detail():
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_search_app():
async def list_search_app():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
@ -150,7 +150,7 @@ def list_search_app():
else:
desc = True
req = request.get_json()
req = await request.get_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -173,8 +173,8 @@ def list_search_app():
@manager.route("/rm", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id")
def rm():
req = request.get_json()
async def rm():
req = await request.get_json()
search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)

View File

@ -17,7 +17,7 @@ import logging
from datetime import datetime
import json
from flask_login import login_required, current_user
from api.apps import login_required, current_user
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService
@ -34,7 +34,7 @@ from common.time_utils import current_timestamp, datetime_format
from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify
from quart import jsonify
from api.utils.health_utils import run_health_checks
from common import settings

View File

@ -14,10 +14,7 @@
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from api.apps import smtp_mail_server
from quart import request
from api.db import UserTenantRole
from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService
@ -28,6 +25,7 @@ from common.time_utils import delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
from api.utils.web_utils import send_invite_email
from common import settings
from api.apps import smtp_mail_server, login_required, current_user
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
@ -51,14 +49,14 @@ def user_list(tenant_id):
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
@login_required
@validate_request("email")
def create(tenant_id):
async def create(tenant_id):
if current_user.id != tenant_id:
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR)
req = request.json
req = await request.json
invite_user_email = req["email"]
invite_users = UserService.query(email=invite_user_email)
if not invite_users:

View File

@ -22,8 +22,7 @@ import secrets
import time
from datetime import datetime
from flask import redirect, request, session, make_response
from flask_login import current_user, login_required, login_user, logout_user
from quart import redirect, request, session, make_response
from werkzeug.security import check_password_hash, generate_password_hash
from api.apps.auth import get_auth_client
@ -45,7 +44,7 @@ from api.utils.api_utils import (
)
from api.utils.crypt import decrypt
from rag.utils.redis_conn import REDIS_CONN
from api.apps import smtp_mail_server
from api.apps import smtp_mail_server, login_required, current_user, login_user, logout_user
from api.utils.web_utils import (
send_email_html,
OTP_LENGTH,
@ -61,7 +60,7 @@ from common import settings
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
def login():
async def login():
"""
User login endpoint.
---
@ -91,10 +90,11 @@ def login():
schema:
type: object
"""
if not request.json:
json_body = await request.json
if not json_body:
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
email = request.json.get("email", "")
email = json_body.get("email", "")
users = UserService.query(email=email)
if not users:
return get_json_result(
@ -103,7 +103,7 @@ def login():
message=f"Email: {email} is not registered!",
)
password = request.json.get("password")
password = json_body.get("password")
try:
password = decrypt(password)
except BaseException:
@ -125,7 +125,8 @@ def login():
user.update_date = (datetime_format(datetime.now()),)
user.save()
msg = "Welcome back!"
return construct_response(data=response_data, auth=user.get_id(), message=msg)
return await construct_response(data=response_data, auth=user.get_id(), message=msg)
else:
return get_json_result(
data=False,
@ -501,7 +502,7 @@ def log_out():
@manager.route("/setting", methods=["POST"]) # noqa: F821
@login_required
def setting_user():
async def setting_user():
"""
Update user settings.
---
@ -530,7 +531,7 @@ def setting_user():
type: object
"""
update_dict = {}
request_data = request.json
request_data = await request.json
if request_data.get("password"):
new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
@ -660,7 +661,7 @@ def user_register(user_id, user):
@manager.route("/register", methods=["POST"]) # noqa: F821
@validate_request("nickname", "email", "password")
def user_add():
async def user_add():
"""
Register a new user.
---
@ -697,7 +698,7 @@ def user_add():
code=RetCode.OPERATING_ERROR,
)
req = request.json
req = await request.json
email_address = req["email"]
# Validate the email address
@ -737,7 +738,7 @@ def user_add():
raise Exception(f"Same email: {email_address} exists!")
user = users[0]
login_user(user)
return construct_response(
return await construct_response(
data=user.to_json(),
auth=user.get_id(),
message=f"{nickname}, welcome aboard!",
@ -793,7 +794,7 @@ def tenant_info():
@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
@login_required
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
def set_tenant_info():
async def set_tenant_info():
"""
Update tenant information.
---
@ -830,7 +831,7 @@ def set_tenant_info():
schema:
type: object
"""
req = request.json
req = await request.json
try:
tid = req.pop("tenant_id")
TenantService.update_by_id(tid, req)
@ -840,7 +841,7 @@ def set_tenant_info():
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
def forget_get_captcha():
async def forget_get_captcha():
"""
GET /forget/captcha?email=<email>
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
@ -862,19 +863,19 @@ def forget_get_captcha():
from captcha.image import ImageCaptcha
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
img_bytes = image.generate(captcha_text).read()
response = make_response(img_bytes)
response = await make_response(img_bytes)
response.headers.set("Content-Type", "image/JPEG")
return response
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
def forget_send_otp():
async def forget_send_otp():
"""
POST /forget/otp
- Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
"""
req = request.get_json()
req = await request.get_json()
email = req.get("email") or ""
captcha = (req.get("captcha") or "").strip()
@ -935,12 +936,12 @@ def forget_send_otp():
@manager.route("/forget", methods=["POST"]) # noqa: F821
def forget():
async def forget():
"""
POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password }
"""
req = request.get_json()
req = await request.get_json()
email = req.get("email") or ""
otp = (req.get("otp") or "").strip()
new_pwd = req.get("new_password")

View File

@ -25,7 +25,7 @@ from datetime import datetime, timezone
from enum import Enum
from functools import wraps
from flask_login import UserMixin
from quart_auth import AuthUser
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
@ -305,6 +305,7 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
time.sleep(self.retry_delay * (2 ** attempt))
else:
raise
return None
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
@ -594,7 +595,7 @@ def fill_db_model_object(model_object, human_model_dict):
return model_object
class User(DataBaseModel, UserMixin):
class User(DataBaseModel, AuthUser):
id = CharField(max_length=32, primary_key=True)
access_token = CharField(max_length=255, null=True, index=True)
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
@ -772,7 +773,7 @@ class Document(DataBaseModel):
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
@ -876,7 +877,7 @@ class Dialog(DataBaseModel):
class Conversation(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
dialog_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
name = CharField(max_length=255, null=True, help_text="conversation name", index=True)
message = JSONField(null=True)
reference = JSONField(null=True, default=[])
user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)

View File

@ -24,7 +24,6 @@ from api.db import InputType
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from common.misc_utils import get_uuid
from common.constants import TaskStatus
from common.time_utils import current_timestamp, timestamp_to_date
@ -68,9 +67,10 @@ class ConnectorService(CommonService):
@classmethod
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
from api.db.services.file_service import FileService
e, conn = cls.get_by_id(connector_id)
if not e:
return
return None
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id)
err = FileService.delete_docs([d.id for d in docs], tenant_id)
@ -125,11 +125,11 @@ class SyncLogsService(CommonService):
)
query = query.distinct().order_by(cls.model.update_time.desc())
totbal = query.count()
total = query.count()
if page_number:
query = query.paginate(page_number, items_per_page)
return list(query.dicts()), totbal
return list(query.dicts()), total
@classmethod
def start(cls, id, connector_id):
@ -191,6 +191,7 @@ class SyncLogsService(CommonService):
@classmethod
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
from api.db.services.file_service import FileService
if not docs:
return None
@ -242,7 +243,7 @@ class Connector2KbService(CommonService):
"id": get_uuid(),
"connector_id": conn_id,
"kb_id": kb_id,
"auto_parse": conn.get("auto_parse", "1")
"auto_parse": conn.get("auto_parse", "1")
})
SyncLogsService.schedule(conn_id, kb_id, reindex=True)

View File

@ -342,7 +342,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream):
yield ans
return
return None
chat_start_ts = timer()
@ -386,7 +386,7 @@ def chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
yield ans
return
return None
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -617,6 +617,8 @@ def chat(dialog, messages, stream=True, **kwargs):
res["audio_binary"] = tts(tts_mdl, answer)
yield res
return None
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """
@ -745,7 +747,7 @@ Please write the SQL, only SQL, without any other explanations or text.
def tts(tts_mdl, text):
if not tts_mdl or not text:
return
return None
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk

View File

@ -41,6 +41,7 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common import settings
class DocumentService(CommonService):
model = Document
@ -113,7 +114,7 @@ class DocumentService(CommonService):
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
raise RuntimeError("Exceed the maximum file number of a free user!")
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
raise RuntimeError("Exceed the maximum length of file name!")
@ -309,7 +310,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -322,7 +323,7 @@ class DocumentService(CommonService):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields(
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
@ -464,7 +465,7 @@ class DocumentService(CommonService):
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["tenant_id"]
@classmethod
@ -473,7 +474,7 @@ class DocumentService(CommonService):
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["kb_id"]
@classmethod
@ -486,7 +487,7 @@ class DocumentService(CommonService):
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["tenant_id"]
@classmethod
@ -533,7 +534,7 @@ class DocumentService(CommonService):
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return None
return docs[0]["embd_id"]
@classmethod
@ -569,7 +570,7 @@ class DocumentService(CommonService):
.where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return
return None
return doc_id[0]["id"]
@classmethod
@ -715,7 +716,7 @@ class DocumentService(CommonService):
prg = 1
status = TaskStatus.DONE.value
# only for special task and parsed docs and unfinised
# only for special task and parsed docs and unfinished
freeze_progress = special_task_running and doc_progress >= 1 and not finished
msg = "\n".join(sorted(msg))
info = {
@ -974,13 +975,13 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
vectors = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
vectors.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
token_counts[doc_id] += c
return vects
return vectors
idxnm = search.index_name(kb.tenant_id)
try_create_idx = True
@ -1011,15 +1012,15 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
except Exception:
logging.exception("Mind map generation error")
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vects)
vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vectors)
for i, d in enumerate(cks):
v = vects[i]
v = vectors[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -18,7 +18,6 @@ import re
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from flask_login import current_user
from peewee import fn
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType
@ -31,7 +30,7 @@ from common.misc_utils import get_uuid
from common.constants import TaskStatus, FileSource, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
from rag.llm.cv_model import GptV4
from common import settings
@ -184,6 +183,7 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def create_folder(cls, file, parent_id, name, count):
from api.apps import current_user
# Recursively create folder structure
# Args:
# file: Current file object
@ -329,7 +329,7 @@ class FileService(CommonService):
current_id = start_id
while current_id:
e, file = cls.get_by_id(current_id)
if file.parent_id != file.id and e:
if e and file.parent_id != file.id:
parent_folders.append(file)
current_id = file.parent_id
else:
@ -423,13 +423,15 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def upload_document(self, kb, file_objs, user_id, src="local"):
def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None):
root_folder = self.get_root_folder(user_id)
pf_id = root_folder["id"]
self.init_knowledgebase_docs(pf_id, user_id)
kb_root_folder = self.get_kb_folder(user_id)
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
safe_parent_path = sanitize_path(parent_path)
err, files = [], []
for file in file_objs:
try:
@ -439,7 +441,7 @@ class FileService(CommonService):
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")
location = filename
location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}"
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
location += "_"
@ -506,6 +508,7 @@ class FileService(CommonService):
@staticmethod
def parse(filename, blob, img_base64=True, tenant_id=None):
from rag.app import audio, email, naive, picture, presentation
from api.apps import current_user
def dummy(prog=None, msg=""):
pass

View File

@ -28,6 +28,7 @@ from common.constants import StatusEnum
from api.constants import DATASET_NAME_LIMIT
from api.utils.api_utils import get_parser_config, get_data_error_result
class KnowledgebaseService(CommonService):
"""Service class for managing knowledge base operations.
@ -391,12 +392,12 @@ class KnowledgebaseService(CommonService):
"""
# Validate name
if not isinstance(name, str):
return get_data_error_result(message="Dataset name must be string.")
return False, get_data_error_result(message="Dataset name must be string.")
dataset_name = name.strip()
if dataset_name == "":
return get_data_error_result(message="Dataset name can't be empty.")
return False, get_data_error_result(message="Dataset name can't be empty.")
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
# Deduplicate name within tenant
dataset_name = duplicate_name(
@ -409,7 +410,7 @@ class KnowledgebaseService(CommonService):
# Verify tenant exists
ok, _t = TenantService.get_by_id(tenant_id)
if not ok:
return False, "Tenant not found."
return False, get_data_error_result(message="Tenant not found.")
# Build payload
kb_id = get_uuid()
@ -419,12 +420,13 @@ class KnowledgebaseService(CommonService):
"tenant_id": tenant_id,
"created_by": tenant_id,
"parser_id": (parser_id or "naive"),
**kwargs
**kwargs # Includes optional fields such as description, language, permission, avatar, parser_config, etc.
}
# Default parser_config (align with kb_app.create) — do not accept external overrides
# Update parser_config (always override with validated default/merged config)
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
return payload
return True, payload
@classmethod

View File

@ -19,6 +19,7 @@ import re
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
@ -32,6 +33,14 @@ def get_init_tenant_llm(user_id):
from common import settings
tenant_llm = []
model_configs = {
LLMType.CHAT: settings.CHAT_CFG,
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
LLMType.SPEECH2TEXT: settings.ASR_CFG,
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
LLMType.RERANK: settings.RERANK_CFG,
}
seen = set()
factory_configs = []
for factory_config in [
@ -54,8 +63,8 @@ def get_init_tenant_llm(user_id):
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
@ -80,8 +89,8 @@ class LLMBundle(LLM4Tenant):
def encode(self, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
safe_texts = []
for text in texts:
token_size = num_tokens_from_string(text)
@ -90,7 +99,7 @@ class LLMBundle(LLM4Tenant):
safe_texts.append(text[:target_len])
else:
safe_texts.append(text)
embeddings, used_tokens = self.mdl.encode(safe_texts)
llm_name = getattr(self, "llm_name", None)

View File

@ -31,7 +31,6 @@ import traceback
import threading
import uuid
from werkzeug.serving import run_simple
from api.apps import app, smtp_mail_server
from api.db.runtime_config import RuntimeConfig
from api.db.services.document_service import DocumentService
@ -41,7 +40,7 @@ from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data
from common.versions import get_ragflow_version
from common.config_utils import show_configs
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event()
@ -153,14 +152,7 @@ if __name__ == '__main__':
# start http server
try:
logging.info("RAGFlow HTTP server start...")
run_simple(
hostname=settings.HOST_IP,
port=settings.HOST_PORT,
application=app,
threaded=True,
use_reloader=RuntimeConfig.DEBUG,
use_debugger=RuntimeConfig.DEBUG,
)
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
except Exception:
traceback.print_exc()
stop_event.set()

View File

@ -15,6 +15,7 @@
#
import functools
import inspect
import json
import logging
import os
@ -24,20 +25,18 @@ from functools import wraps
import requests
import trio
from flask import (
from quart import (
Response,
jsonify,
request
)
from flask_login import current_user
from flask import (
request as flask_request,
)
from peewee import OperationalError
from common.constants import ActiveEnum
from api.db.db_models import APIToken
from api.utils.json_encode import CustomJSONEncoder
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout
from common.constants import RetCode
@ -46,6 +45,12 @@ from common import settings
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
async def request_json():
try:
return await request.json
except Exception:
return {}
def serialize_for_json(obj):
"""
Recursively serialize objects to make them JSON serializable.
@ -105,31 +110,37 @@ def server_error_response(e):
def validate_request(*args, **kwargs):
def process_args(input_arguments):
no_arguments = []
error_arguments = []
for arg in args:
if arg not in input_arguments:
no_arguments.append(arg)
for k, v in kwargs.items():
config_value = input_arguments.get(k, None)
if config_value is None:
no_arguments.append(k)
elif isinstance(v, (tuple, list)):
if config_value not in v:
error_arguments.append((k, set(v)))
elif config_value != v:
error_arguments.append((k, v))
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return error_string
def wrapper(func):
@wraps(func)
def decorated_function(*_args, **_kwargs):
input_arguments = flask_request.json or flask_request.form.to_dict()
no_arguments = []
error_arguments = []
for arg in args:
if arg not in input_arguments:
no_arguments.append(arg)
for k, v in kwargs.items():
config_value = input_arguments.get(k, None)
if config_value is None:
no_arguments.append(k)
elif isinstance(v, (tuple, list)):
if config_value not in v:
error_arguments.append((k, set(v)))
elif config_value != v:
error_arguments.append((k, v))
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string)
async def decorated_function(*_args, **_kwargs):
errs = process_args(await request.json or (await request.form).to_dict())
if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func):
return await func(*_args, **_kwargs)
return func(*_args, **_kwargs)
return decorated_function
@ -138,30 +149,34 @@ def validate_request(*args, **kwargs):
def not_allowed_parameters(*params):
def decorator(f):
def wrapper(*args, **kwargs):
input_arguments = flask_request.json or flask_request.form.to_dict()
def decorator(func):
async def wrapper(*args, **kwargs):
input_arguments = await request.json or (await request.form).to_dict()
for param in params:
if param in input_arguments:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
return f(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
return decorator
def active_required(f):
@wraps(f)
def wrapper(*args, **kwargs):
def active_required(func):
@wraps(func)
async def wrapper(*args, **kwargs):
from api.db.services import UserService
from api.apps import current_user
user_id = current_user.id
usr = UserService.filter_by_id(user_id)
# check is_active
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
return get_json_result(code=RetCode.FORBIDDEN, message="User isn't active, please activate first.")
return f(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
@ -173,12 +188,15 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non
def apikey_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get("Authorization").split()[1]
async def decorated_function(*args, **kwargs):
token = request.headers.get("Authorization").split()[1]
objs = APIToken.query(token=token)
if not objs:
return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN)
kwargs["tenant_id"] = objs[0].tenant_id
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
return decorated_function
@ -199,23 +217,38 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
def get_tenant_id(**kwargs):
if os.environ.get("DISABLE_SDK"):
return get_json_result(data=False, message="`Authorization` can't be empty")
authorization_str = flask_request.headers.get("Authorization")
return False, get_json_result(data=False, message="`Authorization` can't be empty")
authorization_str = request.headers.get("Authorization")
if not authorization_str:
return get_json_result(data=False, message="`Authorization` can't be empty")
return False, get_json_result(data=False, message="`Authorization` can't be empty")
authorization_list = authorization_str.split()
if len(authorization_list) < 2:
return get_json_result(data=False, message="Please check your authorization format.")
return False, get_json_result(data=False, message="Please check your authorization format.")
token = authorization_list[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
kwargs["tenant_id"] = objs[0].tenant_id
return True, kwargs
@wraps(func)
def decorated_function(*args, **kwargs):
e, kwargs = get_tenant_id(**kwargs)
if not e:
return kwargs
return func(*args, **kwargs)
@wraps(func)
async def adecorated_function(*args, **kwargs):
e, kwargs = get_tenant_id(**kwargs)
if not e:
return kwargs
return await func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return adecorated_function
return decorated_function

View File

@ -18,7 +18,7 @@ import base64
import click
import re
from flask import Flask
from quart import Quart
from werkzeug.security import generate_password_hash
from api.db.services import UserService
@ -73,6 +73,7 @@ def reset_email(email, new_email, email_confirm):
UserService.update_user(user[0].id,user_dict)
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
def register_commands(app: Flask):
def register_commands(app: Quart):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)

View File

@ -1,3 +1,19 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reusable HTML email templates and registry.
"""

View File

@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob):
return repaired
return blob
def sanitize_path(raw_path: str | None) -> str:
"""Normalize and sanitize a user-provided path segment.
- Converts backslashes to forward slashes
- Strips leading/trailing slashes
- Removes '.' and '..' segments
- Restricts characters to A-Za-z0-9, underscore, dash, and '/'
"""
if not raw_path:
return ""
backslash_re = re.compile(r"[\\]+")
unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]")
normalized = backslash_re.sub("/", raw_path)
normalized = normalized.strip("/")
parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")]
sanitized = "/".join(parts)
sanitized = unsafe_re.sub("", sanitized)
return sanitized

View File

@ -173,7 +173,8 @@ def check_task_executor_alive():
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
task_executor_heartbeats[task_executor_id] = heartbeats
if task_executor_heartbeats:
return {"status": "alive", "message": task_executor_heartbeats}
status = "alive" if any(task_executor_heartbeats.values()) else "timeout"
return {"status": status, "message": task_executor_heartbeats}
else:
return {"status": "timeout", "message": "Not found any task executor."}
except Exception as e:

View File

@ -1,3 +1,19 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
from enum import Enum, IntEnum

View File

@ -17,7 +17,7 @@ from collections import Counter
from typing import Annotated, Any, Literal
from uuid import UUID
from flask import Request
from quart import Request
from pydantic import (
BaseModel,
ConfigDict,
@ -32,7 +32,7 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses JSON requests through a multi-stage validation pipeline.
@ -81,7 +81,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
from the final output after validation
"""
try:
payload = request.get_json() or {}
payload = await request.get_json() or {}
except UnsupportedMediaType:
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
except BadRequest:

View File

@ -23,7 +23,7 @@ from urllib.parse import urlparse
from api.apps import smtp_mail_server
from flask_mail import Message
from flask import render_template_string
from quart import render_template_string
from api.utils.email_templates import EMAIL_TEMPLATES
from selenium import webdriver
from selenium.common.exceptions import TimeoutException

48
check_comment_ascii.py Normal file
View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""
Check whether given python files contain non-ASCII comments.
How to check the whole git repo:
```
$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py
```
"""
import sys
import tokenize
import ast
import pathlib
import re
ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline
def check(src: str, name: str) -> int:
"""
docstring line 1
docstring line 2
"""
ok = 1
# A common comment begins with `#`
with tokenize.open(src) as fp:
for tk in tokenize.generate_tokens(fp.readline):
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
ok = 0
# A docstring begins and ends with `'''`
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
ok = 0
return ok
if __name__ == "__main__":
status = 0
for file in sys.argv[1:]:
if not check(file, file):
status = 1
sys.exit(status)

View File

@ -21,7 +21,7 @@ from typing import Any, Callable, Coroutine, Optional, Type, Union
import asyncio
import trio
from functools import wraps
from flask import make_response, jsonify
from quart import make_response, jsonify
from common.constants import RetCode
TimeoutException = Union[Type[BaseException], BaseException]
@ -103,7 +103,7 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
return decorator
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
async def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data}
response_dict = {}
for key, value in result_dict.items():
@ -111,7 +111,27 @@ def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=
continue
else:
response_dict[key] = value
response = make_response(jsonify(response_dict))
response = await make_response(jsonify(response_dict))
if auth:
response.headers["Authorization"] = auth
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Method"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response
def sync_construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
import flask
result_dict = {"code": code, "message": message, "data": data}
response_dict = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response_dict[key] = value
response = flask.make_response(flask.jsonify(response_dict))
if auth:
response.headers["Authorization"] = auth
response.headers["Access-Control-Allow-Origin"] = "*"

View File

@ -11,7 +11,7 @@ from .confluence_connector import ConfluenceConnector
from .discord_connector import DiscordConnector
from .dropbox_connector import DropboxConnector
from .google_drive.connector import GoogleDriveConnector
from .jira_connector import JiraConnector
from .jira.connector import JiraConnector
from .sharepoint_connector import SharePointConnector
from .teams_connector import TeamsConnector
from .config import BlobType, DocumentSource

View File

@ -87,6 +87,13 @@ class BlobStorageConnector(LoadConnector, PollConnector):
):
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
elif self.bucket_type == BlobType.S3_COMPATIBLE:
if not all(
credentials.get(key)
for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("S3 Compatible Storage")
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")

View File

@ -13,6 +13,7 @@ def get_current_tz_offset() -> int:
return round(time_diff.total_seconds() / 3600)
ONE_MINUTE = 60
ONE_HOUR = 3600
ONE_DAY = ONE_HOUR * 24
@ -31,6 +32,7 @@ class BlobType(str, Enum):
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
S3_COMPATIBLE = "s3_compatible"
class DocumentSource(str, Enum):
@ -42,9 +44,11 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
SLACK = "slack"
CONFLUENCE = "confluence"
JIRA = "jira"
GOOGLE_DRIVE = "google_drive"
GMAIL = "gmail"
DISCORD = "discord"
S3_COMPATIBLE = "s3_compatible"
class FileOrigin(str, Enum):
@ -178,6 +182,21 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
JIRA_SYNC_TIME_BUFFER_SECONDS = int(
os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE)
)
JIRA_TIMEZONE_OFFSET = float(
os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset())
)
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(

View File

@ -1788,6 +1788,7 @@ class ConfluenceConnector(
cql_url = self.confluence_client.build_cql_url(
page_query, expand=",".join(_PAGE_EXPANSION_FIELDS)
)
logging.info(f"[Confluence Connector] Building CQL URL {cql_url}")
return update_param_in_path(cql_url, "limit", str(limit))
@override

View File

@ -3,15 +3,9 @@ import os
import threading
from typing import Any, Callable
import requests
from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
DEFAULT_DEVICE_INTERVAL = 5
def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var."""
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
return result.get("value")
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
if "client_id" in credentials:
return credentials["client_id"], credentials.get("client_secret")
for key in ("installed", "web"):
if key in credentials and isinstance(credentials[key], dict):
nested = credentials[key]
if "client_id" not in nested:
break
return nested["client_id"], nested.get("client_secret")
raise ValueError("Provided Google OAuth credentials are missing client_id.")
def start_device_authorization_flow(
credentials: dict[str, Any],
source: DocumentSource,
) -> tuple[dict[str, Any], dict[str, Any]]:
client_id, client_secret = _extract_client_info(credentials)
data = {
"client_id": client_id,
"scope": " ".join(_get_requested_scopes(source)),
}
if client_secret:
data["client_secret"] = client_secret
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
resp.raise_for_status()
payload = resp.json()
state = {
"client_id": client_id,
"client_secret": client_secret,
"device_code": payload.get("device_code"),
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
}
response_data = {
"user_code": payload.get("user_code"),
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
"verification_url_complete": payload.get("verification_url_complete")
or payload.get("verification_uri_complete"),
"expires_in": payload.get("expires_in"),
"interval": state["interval"],
}
return state, response_data
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
data = {
"client_id": state["client_id"],
"device_code": state["device_code"],
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
if state.get("client_secret"):
data["client_secret"] = state["client_secret"]
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
resp.raise_for_status()
return resp.json()
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs()
timeout_message = (
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.")
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
instructions = [
"Google rejected one or more of the requested OAuth scopes.",
"Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
]
raise RuntimeError("\n".join(instructions)) from warning
raise
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
client_config = {"web": credentials["web"]}
if client_config is None:
raise ValueError(
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
return _run_local_server_flow(client_config, source)

View File

@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC):
class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)"""
"""Checkpoint connector interface (with permission sync)"""
@abstractmethod
def load_from_checkpoint(
@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
"""If dynamic, the credentials may change during usage ... meaning the client
needs to use the locking features of the credentials provider to operate
correctly.

View File

View File

@ -0,0 +1,973 @@
"""Checkpointed Jira connector that emits markdown blobs for each issue."""
from __future__ import annotations
import argparse
import copy
import logging
import os
import re
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from datetime import datetime, timedelta, timezone
from typing import Any
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from jira import JIRA
from jira.resources import Issue
from pydantic import Field
from common.data_source.config import (
INDEX_BATCH_SIZE,
JIRA_CONNECTOR_LABELS_TO_SKIP,
JIRA_CONNECTOR_MAX_TICKET_SIZE,
JIRA_TIMEZONE_OFFSET,
ONE_HOUR,
DocumentSource,
)
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
InsufficientPermissionsError,
UnexpectedValidationError,
)
from common.data_source.interfaces import (
CheckpointedConnectorWithPermSync,
CheckpointOutputWrapper,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync,
)
from common.data_source.jira.utils import (
JIRA_CLOUD_API_VERSION,
JIRA_SERVER_API_VERSION,
build_issue_url,
extract_body_text,
extract_named_value,
extract_user,
format_attachments,
format_comments,
parse_jira_datetime,
should_skip_issue,
)
from common.data_source.models import (
ConnectorCheckpoint,
ConnectorFailure,
Document,
DocumentFailure,
SlimDocument,
)
from common.data_source.utils import is_atlassian_cloud_url, is_atlassian_date_error, scoped_url
logger = logging.getLogger(__name__)
_DEFAULT_FIELDS = "summary,description,updated,created,status,priority,assignee,reporter,labels,issuetype,project,comment,attachment"
_SLIM_FIELDS = "key,project"
_MAX_RESULTS_FETCH_IDS = 5000
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
_DEFAULT_ATTACHMENT_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB
class JiraCheckpoint(ConnectorCheckpoint):
"""Checkpoint that tracks which slice of the current JQL result set was emitted."""
start_at: int = 0
cursor: str | None = None
ids_done: bool = False
all_issue_ids: list[list[str]] = Field(default_factory=list)
_TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Retrieve Jira issues and emit them as markdown documents."""
def __init__(
self,
jira_base_url: str,
project_key: str | None = None,
jql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
include_comments: bool = True,
include_attachments: bool = False,
labels_to_skip: Sequence[str] | None = None,
comment_email_blacklist: Sequence[str] | None = None,
scoped_token: bool = False,
attachment_size_limit: int | None = None,
timezone_offset: float | None = None,
) -> None:
if not jira_base_url:
raise ConnectorValidationError("Jira base URL must be provided.")
self.jira_base_url = jira_base_url.rstrip("/")
self.project_key = project_key
self.jql_query = jql_query
self.batch_size = batch_size
self.include_comments = include_comments
self.include_attachments = include_attachments
configured_labels = labels_to_skip or JIRA_CONNECTOR_LABELS_TO_SKIP
self.labels_to_skip = {label.lower() for label in configured_labels}
self.comment_email_blacklist = {email.lower() for email in comment_email_blacklist or []}
self.scoped_token = scoped_token
self.jira_client: JIRA | None = None
self.max_ticket_size = JIRA_CONNECTOR_MAX_TICKET_SIZE
self.attachment_size_limit = attachment_size_limit if attachment_size_limit and attachment_size_limit > 0 else _DEFAULT_ATTACHMENT_SIZE_LIMIT
self._fields_param = _DEFAULT_FIELDS
self._slim_fields = _SLIM_FIELDS
tz_offset_value = float(timezone_offset) if timezone_offset is not None else float(JIRA_TIMEZONE_OFFSET)
self.timezone_offset = tz_offset_value
self.timezone = timezone(offset=timedelta(hours=tz_offset_value))
self._timezone_overridden = timezone_offset is not None
# -------------------------------------------------------------------------
# Connector lifecycle helpers
# -------------------------------------------------------------------------
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Instantiate the Jira client using either an API token or username/password."""
jira_url_for_client = self.jira_base_url
if self.scoped_token:
if is_atlassian_cloud_url(self.jira_base_url):
try:
jira_url_for_client = scoped_url(self.jira_base_url, "jira")
except ValueError as exc:
raise ConnectorValidationError(str(exc)) from exc
else:
logger.warning(f"[Jira] Scoped token requested but Jira base URL {self.jira_base_url} does not appear to be an Atlassian Cloud domain; scoped token ignored.")
user_email = credentials.get("jira_user_email") or credentials.get("username")
api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token")
password = credentials.get("jira_password") or credentials.get("password")
rest_api_version = credentials.get("rest_api_version")
if not rest_api_version:
rest_api_version = JIRA_CLOUD_API_VERSION if api_token else JIRA_SERVER_API_VERSION
options: dict[str, Any] = {"rest_api_version": rest_api_version}
try:
if user_email and api_token:
self.jira_client = JIRA(
server=jira_url_for_client,
basic_auth=(user_email, api_token),
options=options,
)
elif api_token:
self.jira_client = JIRA(
server=jira_url_for_client,
token_auth=api_token,
options=options,
)
elif user_email and password:
self.jira_client = JIRA(
server=jira_url_for_client,
basic_auth=(user_email, password),
options=options,
)
else:
raise ConnectorMissingCredentialError("Jira credentials must include either an API token or username/password.")
except Exception as exc: # pragma: no cover - jira lib raises many types
raise ConnectorMissingCredentialError(f"Jira: {exc}") from exc
self._sync_timezone_from_server()
return None
def validate_connector_settings(self) -> None:
"""Validate connectivity by fetching basic Jira info."""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
try:
if self.jql_query:
dummy_checkpoint = self.build_dummy_checkpoint()
checkpoint_callback = self._make_checkpoint_callback(dummy_checkpoint)
iterator = self._perform_jql_search(
jql=self.jql_query,
start=0,
max_results=1,
fields="key",
all_issue_ids=dummy_checkpoint.all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=dummy_checkpoint.cursor,
ids_done=dummy_checkpoint.ids_done,
)
next(iter(iterator), None)
elif self.project_key:
self.jira_client.project(self.project_key)
else:
self.jira_client.projects()
except Exception as exc: # pragma: no cover - dependent on Jira responses
self._handle_validation_error(exc)
# -------------------------------------------------------------------------
# Checkpointed connector implementation
# -------------------------------------------------------------------------
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraCheckpoint,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
"""Load Jira issues, emitting a Document per issue."""
try:
return (yield from self._load_with_retry(start, end, checkpoint))
except Exception as exc:
logger.exception(f"[Jira] Jira query ultimately failed: {exc}")
yield ConnectorFailure(
failure_message=f"Failed to query Jira: {exc}",
exception=exc,
)
return JiraCheckpoint(has_more=False, start_at=checkpoint.start_at)
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraCheckpoint,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
"""Permissions are not synced separately, so reuse the standard loader."""
return (yield from self.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
def _load_with_retry(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraCheckpoint,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
attempt_start = start
retried_with_buffer = False
attempt = 0
while True:
attempt += 1
jql = self._build_jql(attempt_start, end)
logger.info(f"[Jira] Executing Jira JQL attempt {attempt} (start={attempt_start}, end={end}, buffered_retry={retried_with_buffer}): {jql}")
try:
return (yield from self._load_from_checkpoint_internal(jql, checkpoint, start_filter=start))
except Exception as exc:
if attempt_start is not None and not retried_with_buffer and is_atlassian_date_error(exc):
attempt_start = attempt_start - ONE_HOUR
retried_with_buffer = True
logger.info(f"[Jira] Atlassian date error detected; retrying with start={attempt_start}.")
continue
raise
def _handle_validation_error(self, exc: Exception) -> None:
status_code = getattr(exc, "status_code", None)
if status_code == 401:
raise InsufficientPermissionsError("Jira credential appears to be invalid or expired (HTTP 401).") from exc
if status_code == 403:
raise InsufficientPermissionsError("Jira token does not have permission to access the requested resources (HTTP 403).") from exc
if status_code == 404:
raise ConnectorValidationError("Jira resource not found (HTTP 404).") from exc
if status_code == 429:
raise ConnectorValidationError("Jira rate limit exceeded during validation (HTTP 429).") from exc
message = getattr(exc, "text", str(exc))
if not message:
raise UnexpectedValidationError("Unexpected Jira validation error.") from exc
raise ConnectorValidationError(f"Jira validation failed: {message}") from exc
def _load_from_checkpoint_internal(
self,
jql: str,
checkpoint: JiraCheckpoint,
start_filter: SecondsSinceUnixEpoch | None = None,
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
assert self.jira_client, "load_credentials must be called before loading issues."
page_size = self._full_page_size()
new_checkpoint = copy.deepcopy(checkpoint)
starting_offset = new_checkpoint.start_at or 0
current_offset = starting_offset
checkpoint_callback = self._make_checkpoint_callback(new_checkpoint)
issue_iter = self._perform_jql_search(
jql=jql,
start=current_offset,
max_results=page_size,
fields=self._fields_param,
all_issue_ids=new_checkpoint.all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=new_checkpoint.cursor,
ids_done=new_checkpoint.ids_done,
)
start_cutoff = float(start_filter) if start_filter is not None else None
for issue in issue_iter:
current_offset += 1
issue_key = getattr(issue, "key", "unknown")
if should_skip_issue(issue, self.labels_to_skip):
continue
issue_updated = parse_jira_datetime(issue.raw.get("fields", {}).get("updated"))
if start_cutoff is not None and issue_updated is not None and issue_updated.timestamp() <= start_cutoff:
# Jira JQL only supports minute precision, so we discard already-processed
# issues here based on the original second-level cutoff.
continue
try:
document = self._issue_to_document(issue)
except Exception as exc: # pragma: no cover - defensive
logger.exception(f"[Jira] Failed to convert Jira issue {issue_key}: {exc}")
yield ConnectorFailure(
failure_message=f"Failed to convert Jira issue {issue_key}: {exc}",
failed_document=DocumentFailure(
document_id=issue_key,
document_link=build_issue_url(self.jira_base_url, issue_key),
),
exception=exc,
)
continue
if document is not None:
yield document
if self.include_attachments:
for attachment_document in self._attachment_documents(issue):
if attachment_document is not None:
yield attachment_document
self._update_checkpoint_for_next_run(
checkpoint=new_checkpoint,
current_offset=current_offset,
starting_offset=starting_offset,
page_size=page_size,
)
new_checkpoint.start_at = current_offset
return new_checkpoint
def build_dummy_checkpoint(self) -> JiraCheckpoint:
"""Create an empty checkpoint used to kick off ingestion."""
return JiraCheckpoint(has_more=True, start_at=0)
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraCheckpoint:
"""Validate a serialized checkpoint."""
return JiraCheckpoint.model_validate_json(checkpoint_json)
# -------------------------------------------------------------------------
# Slim connector implementation
# -------------------------------------------------------------------------
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None, # noqa: ARG002 - maintained for interface compatibility
) -> Generator[list[SlimDocument], None, None]:
"""Return lightweight references to Jira issues (used for permission syncing)."""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
start_ts = start if start is not None else 0
end_ts = end if end is not None else datetime.now(timezone.utc).timestamp()
jql = self._build_jql(start_ts, end_ts)
checkpoint = self.build_dummy_checkpoint()
checkpoint_callback = self._make_checkpoint_callback(checkpoint)
prev_offset = 0
current_offset = 0
slim_batch: list[SlimDocument] = []
while checkpoint.has_more:
for issue in self._perform_jql_search(
jql=jql,
start=current_offset,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields=self._slim_fields,
all_issue_ids=checkpoint.all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=checkpoint.cursor,
ids_done=checkpoint.ids_done,
):
current_offset += 1
if should_skip_issue(issue, self.labels_to_skip):
continue
doc_id = build_issue_url(self.jira_base_url, issue.key)
slim_batch.append(SlimDocument(id=doc_id))
if len(slim_batch) >= _JIRA_SLIM_PAGE_SIZE:
yield slim_batch
slim_batch = []
self._update_checkpoint_for_next_run(
checkpoint=checkpoint,
current_offset=current_offset,
starting_offset=prev_offset,
page_size=_JIRA_SLIM_PAGE_SIZE,
)
prev_offset = current_offset
if slim_batch:
yield slim_batch
# -------------------------------------------------------------------------
# Internal helpers
# -------------------------------------------------------------------------
def _build_jql(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> str:
clauses: list[str] = []
if self.jql_query:
clauses.append(f"({self.jql_query})")
elif self.project_key:
clauses.append(f'project = "{self.project_key}"')
else:
raise ConnectorValidationError("Either project_key or jql_query must be provided for Jira connector.")
if self.labels_to_skip:
labels = ", ".join(f'"{label}"' for label in self.labels_to_skip)
clauses.append(f"labels NOT IN ({labels})")
if start is not None:
clauses.append(f'updated >= "{self._format_jql_time(start)}"')
if end is not None:
clauses.append(f'updated <= "{self._format_jql_time(end)}"')
if not clauses:
raise ConnectorValidationError("Unable to build Jira JQL query.")
jql = " AND ".join(clauses)
if "order by" not in jql.lower():
jql = f"{jql} ORDER BY updated ASC"
return jql
def _format_jql_time(self, timestamp: SecondsSinceUnixEpoch) -> str:
dt_utc = datetime.fromtimestamp(float(timestamp), tz=timezone.utc)
dt_local = dt_utc.astimezone(self.timezone)
# Jira only accepts minute-precision timestamps in JQL, so we format accordingly
# and rely on a post-query second-level filter to avoid duplicates.
return dt_local.strftime("%Y-%m-%d %H:%M")
def _issue_to_document(self, issue: Issue) -> Document | None:
fields = issue.raw.get("fields", {})
summary = fields.get("summary") or ""
description_text = extract_body_text(fields.get("description"))
comments_text = (
format_comments(
fields.get("comment"),
blacklist=self.comment_email_blacklist,
)
if self.include_comments
else ""
)
attachments_text = format_attachments(fields.get("attachment"))
reporter_name, reporter_email = extract_user(fields.get("reporter"))
assignee_name, assignee_email = extract_user(fields.get("assignee"))
status = extract_named_value(fields.get("status"))
priority = extract_named_value(fields.get("priority"))
issue_type = extract_named_value(fields.get("issuetype"))
project = fields.get("project") or {}
issue_url = build_issue_url(self.jira_base_url, issue.key)
metadata_lines = [
f"key: {issue.key}",
f"url: {issue_url}",
f"summary: {summary}",
f"status: {status or 'Unknown'}",
f"priority: {priority or 'Unspecified'}",
f"issue_type: {issue_type or 'Unknown'}",
f"project: {project.get('name') or ''}",
f"project_key: {project.get('key') or self.project_key or ''}",
]
if reporter_name:
metadata_lines.append(f"reporter: {reporter_name}")
if reporter_email:
metadata_lines.append(f"reporter_email: {reporter_email}")
if assignee_name:
metadata_lines.append(f"assignee: {assignee_name}")
if assignee_email:
metadata_lines.append(f"assignee_email: {assignee_email}")
if fields.get("labels"):
metadata_lines.append(f"labels: {', '.join(fields.get('labels'))}")
created_dt = parse_jira_datetime(fields.get("created"))
updated_dt = parse_jira_datetime(fields.get("updated")) or created_dt or datetime.now(timezone.utc)
metadata_lines.append(f"created: {created_dt.isoformat() if created_dt else ''}")
metadata_lines.append(f"updated: {updated_dt.isoformat() if updated_dt else ''}")
sections: list[str] = [
"---",
"\n".join(filter(None, metadata_lines)),
"---",
"",
"## Description",
description_text or "No description provided.",
]
if comments_text:
sections.extend(["", "## Comments", comments_text])
if attachments_text:
sections.extend(["", "## Attachments", attachments_text])
blob_text = "\n".join(sections).strip() + "\n"
blob = blob_text.encode("utf-8")
if len(blob) > self.max_ticket_size:
logger.info(f"[Jira] Skipping {issue.key} because it exceeds the maximum size of {self.max_ticket_size} bytes.")
return None
semantic_identifier = f"{issue.key}: {summary}" if summary else issue.key
return Document(
id=issue_url,
source=DocumentSource.JIRA,
semantic_identifier=semantic_identifier,
extension=".md",
blob=blob,
doc_updated_at=updated_dt,
size_bytes=len(blob),
)
def _attachment_documents(self, issue: Issue) -> Iterable[Document]:
attachments = issue.raw.get("fields", {}).get("attachment") or []
for attachment in attachments:
try:
document = self._attachment_to_document(issue, attachment)
if document is not None:
yield document
except Exception as exc: # pragma: no cover - defensive
failed_id = attachment.get("id") or attachment.get("filename")
issue_key = getattr(issue, "key", "unknown")
logger.warning(f"[Jira] Failed to process attachment {failed_id} for issue {issue_key}: {exc}")
def _attachment_to_document(self, issue: Issue, attachment: dict[str, Any]) -> Document | None:
if not self.include_attachments:
return None
filename = attachment.get("filename")
content_url = attachment.get("content")
if not filename or not content_url:
return None
try:
attachment_size = int(attachment.get("size", 0))
except (TypeError, ValueError):
attachment_size = 0
if attachment_size and attachment_size > self.attachment_size_limit:
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because reported size exceeds limit ({self.attachment_size_limit} bytes).")
return None
blob = self._download_attachment(content_url)
if blob is None:
return None
if len(blob) > self.attachment_size_limit:
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because it exceeds the size limit ({self.attachment_size_limit} bytes).")
return None
attachment_time = parse_jira_datetime(attachment.get("created")) or parse_jira_datetime(attachment.get("updated"))
updated_dt = attachment_time or parse_jira_datetime(issue.raw.get("fields", {}).get("updated")) or datetime.now(timezone.utc)
extension = os.path.splitext(filename)[1] or ""
document_id = f"{issue.key}::attachment::{attachment.get('id') or filename}"
semantic_identifier = f"{issue.key} attachment: {filename}"
return Document(
id=document_id,
source=DocumentSource.JIRA,
semantic_identifier=semantic_identifier,
extension=extension,
blob=blob,
doc_updated_at=updated_dt,
size_bytes=len(blob),
)
def _download_attachment(self, url: str) -> bytes | None:
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
response = self.jira_client._session.get(url)
response.raise_for_status()
return response.content
def _sync_timezone_from_server(self) -> None:
if self._timezone_overridden or not self.jira_client:
return
try:
server_info = self.jira_client.server_info()
except Exception as exc: # pragma: no cover - defensive
logger.info(f"[Jira] Unable to determine timezone from server info; continuing with offset {self.timezone_offset}. Error: {exc}")
return
detected_offset = self._extract_timezone_offset(server_info)
if detected_offset is None or detected_offset == self.timezone_offset:
return
self.timezone_offset = detected_offset
self.timezone = timezone(offset=timedelta(hours=detected_offset))
logger.info(f"[Jira] Timezone offset adjusted to {detected_offset} hours using Jira server info.")
def _extract_timezone_offset(self, server_info: dict[str, Any]) -> float | None:
server_time_raw = server_info.get("serverTime")
if isinstance(server_time_raw, str):
offset = self._parse_offset_from_datetime_string(server_time_raw)
if offset is not None:
return offset
tz_name = server_info.get("timeZone")
if isinstance(tz_name, str):
offset = self._offset_from_zone_name(tz_name)
if offset is not None:
return offset
return None
@staticmethod
def _parse_offset_from_datetime_string(value: str) -> float | None:
normalized = JiraConnector._normalize_datetime_string(value)
try:
dt = datetime.fromisoformat(normalized)
except ValueError:
return None
if dt.tzinfo is None:
return 0.0
offset = dt.tzinfo.utcoffset(dt)
if offset is None:
return None
return offset.total_seconds() / 3600.0
@staticmethod
def _normalize_datetime_string(value: str) -> str:
trimmed = (value or "").strip()
if trimmed.endswith("Z"):
return f"{trimmed[:-1]}+00:00"
match = _TZ_OFFSET_PATTERN.search(trimmed)
if match and match.group(3) != ":":
sign, hours, _, minutes = match.groups()
trimmed = f"{trimmed[: match.start()]}{sign}{hours}:{minutes}"
return trimmed
@staticmethod
def _offset_from_zone_name(name: str) -> float | None:
try:
tz = ZoneInfo(name)
except (ZoneInfoNotFoundError, ValueError):
return None
reference = datetime.now(tz)
offset = reference.utcoffset()
if offset is None:
return None
return offset.total_seconds() / 3600.0
def _is_cloud_client(self) -> bool:
if not self.jira_client:
return False
rest_version = str(self.jira_client._options.get("rest_api_version", "")).strip()
return rest_version == str(JIRA_CLOUD_API_VERSION)
def _full_page_size(self) -> int:
return max(1, min(self.batch_size, _JIRA_FULL_PAGE_SIZE))
def _perform_jql_search(
self,
*,
jql: str,
start: int,
max_results: int,
fields: str | None = None,
all_issue_ids: list[list[str]] | None = None,
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
next_page_token: str | None = None,
ids_done: bool = False,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
is_cloud = self._is_cloud_client()
if is_cloud:
if all_issue_ids is None:
raise ValueError("all_issue_ids is required for Jira Cloud searches.")
yield from self._perform_jql_search_v3(
jql=jql,
max_results=max_results,
fields=fields,
all_issue_ids=all_issue_ids,
checkpoint_callback=checkpoint_callback,
next_page_token=next_page_token,
ids_done=ids_done,
)
else:
yield from self._perform_jql_search_v2(
jql=jql,
start=start,
max_results=max_results,
fields=fields,
)
def _perform_jql_search_v3(
self,
*,
jql: str,
max_results: int,
all_issue_ids: list[list[str]],
fields: str | None = None,
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
next_page_token: str | None = None,
ids_done: bool = False,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
if not ids_done:
new_ids, page_token = self._enhanced_search_ids(jql, next_page_token)
if checkpoint_callback is not None and new_ids:
checkpoint_callback(
self._chunk_issue_ids(new_ids, max_results),
page_token,
)
elif checkpoint_callback is not None:
checkpoint_callback([], page_token)
if all_issue_ids:
issue_ids = all_issue_ids.pop()
if issue_ids:
yield from self._bulk_fetch_issues(issue_ids, fields)
def _perform_jql_search_v2(
self,
*,
jql: str,
start: int,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
issues = self.jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields or self._fields_param,
expand="renderedFields",
)
for issue in issues:
yield issue
def _enhanced_search_ids(
self,
jql: str,
next_page_token: str | None,
) -> tuple[list[str], str | None]:
assert self.jira_client, "Jira client not initialized."
enhanced_search_path = self.jira_client._get_url("search/jql")
params: dict[str, str | int | None] = {
"jql": jql,
"maxResults": _MAX_RESULTS_FETCH_IDS,
"nextPageToken": next_page_token,
"fields": "id",
}
response = self.jira_client._session.get(enhanced_search_path, params=params)
response.raise_for_status()
data = response.json()
return [str(issue["id"]) for issue in data.get("issues", [])], data.get("nextPageToken")
def _bulk_fetch_issues(
self,
issue_ids: list[str],
fields: str | None,
) -> Iterable[Issue]:
assert self.jira_client, "Jira client not initialized."
if not issue_ids:
return []
bulk_fetch_path = self.jira_client._get_url("issue/bulkfetch")
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
payload["fields"] = fields.split(",") if fields else ["*all"]
response = self.jira_client._session.post(bulk_fetch_path, json=payload)
response.raise_for_status()
data = response.json()
return [Issue(self.jira_client._options, self.jira_client._session, raw=issue) for issue in data.get("issues", [])]
@staticmethod
def _chunk_issue_ids(issue_ids: list[str], chunk_size: int) -> Iterable[list[str]]:
if chunk_size <= 0:
chunk_size = _JIRA_FULL_PAGE_SIZE
for idx in range(0, len(issue_ids), chunk_size):
yield issue_ids[idx : idx + chunk_size]
def _make_checkpoint_callback(self, checkpoint: JiraCheckpoint) -> Callable[[Iterable[list[str]], str | None], None]:
def checkpoint_callback(
issue_ids: Iterable[list[str]] | list[list[str]],
page_token: str | None,
) -> None:
for id_batch in issue_ids:
checkpoint.all_issue_ids.append(list(id_batch))
checkpoint.cursor = page_token
checkpoint.ids_done = page_token is None
return checkpoint_callback
def _update_checkpoint_for_next_run(
self,
*,
checkpoint: JiraCheckpoint,
current_offset: int,
starting_offset: int,
page_size: int,
) -> None:
if self._is_cloud_client():
checkpoint.has_more = bool(checkpoint.all_issue_ids) or not checkpoint.ids_done
else:
checkpoint.has_more = current_offset - starting_offset == page_size
checkpoint.cursor = None
checkpoint.ids_done = True
checkpoint.all_issue_ids = []
def iterate_jira_documents(
connector: "JiraConnector",
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
iteration_limit: int = 100_000,
) -> Iterator[Document]:
"""Yield documents without materializing the entire result set."""
checkpoint = connector.build_dummy_checkpoint()
iterations = 0
while checkpoint.has_more:
wrapper = CheckpointOutputWrapper[JiraCheckpoint]()
generator = wrapper(connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
for document, failure, next_checkpoint in generator:
if failure is not None:
failure_message = getattr(failure, "failure_message", str(failure))
raise RuntimeError(f"Failed to load Jira documents: {failure_message}")
if document is not None:
yield document
if next_checkpoint is not None:
checkpoint = next_checkpoint
iterations += 1
if iterations > iteration_limit:
raise RuntimeError("Too many iterations while loading Jira documents.")
def test_jira(
*,
base_url: str,
project_key: str | None = None,
jql_query: str | None = None,
credentials: dict[str, Any],
batch_size: int = INDEX_BATCH_SIZE,
start_ts: float | None = None,
end_ts: float | None = None,
connector_options: dict[str, Any] | None = None,
) -> list[Document]:
"""Programmatic entry point that mirrors the CLI workflow."""
connector_kwargs = connector_options.copy() if connector_options else {}
connector = JiraConnector(
jira_base_url=base_url,
project_key=project_key,
jql_query=jql_query,
batch_size=batch_size,
**connector_kwargs,
)
connector.load_credentials(credentials)
connector.validate_connector_settings()
now_ts = datetime.now(timezone.utc).timestamp()
start = start_ts if start_ts is not None else 0.0
end = end_ts if end_ts is not None else now_ts
documents = list(iterate_jira_documents(connector, start=start, end=end))
logger.info(f"[Jira] Fetched {len(documents)} Jira documents.")
for doc in documents[:5]:
logger.info(f"[Jira] Document {doc.semantic_identifier} ({doc.id}) size={doc.size_bytes} bytes")
return documents
def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Fetch Jira issues and print summary statistics.")
parser.add_argument("--base-url", dest="base_url", default=os.environ.get("JIRA_BASE_URL"))
parser.add_argument("--project", dest="project_key", default=os.environ.get("JIRA_PROJECT_KEY"))
parser.add_argument("--jql", dest="jql_query", default=os.environ.get("JIRA_JQL"))
parser.add_argument("--email", dest="user_email", default=os.environ.get("JIRA_USER_EMAIL"))
parser.add_argument("--token", dest="api_token", default=os.environ.get("JIRA_API_TOKEN"))
parser.add_argument("--password", dest="password", default=os.environ.get("JIRA_PASSWORD"))
parser.add_argument("--batch-size", dest="batch_size", type=int, default=int(os.environ.get("JIRA_BATCH_SIZE", INDEX_BATCH_SIZE)))
parser.add_argument("--include_comments", dest="include_comments", type=bool, default=True)
parser.add_argument("--include_attachments", dest="include_attachments", type=bool, default=True)
parser.add_argument("--attachment_size_limit", dest="attachment_size_limit", type=float, default=_DEFAULT_ATTACHMENT_SIZE_LIMIT)
parser.add_argument("--start-ts", dest="start_ts", type=float, default=None, help="Epoch seconds inclusive lower bound for updated issues.")
parser.add_argument("--end-ts", dest="end_ts", type=float, default=9999999999, help="Epoch seconds inclusive upper bound for updated issues.")
return parser
def main(config: dict[str, Any] | None = None) -> None:
if config is None:
args = _build_arg_parser().parse_args()
config = {
"base_url": args.base_url,
"project_key": args.project_key,
"jql_query": args.jql_query,
"batch_size": args.batch_size,
"start_ts": args.start_ts,
"end_ts": args.end_ts,
"include_comments": args.include_comments,
"include_attachments": args.include_attachments,
"attachment_size_limit": args.attachment_size_limit,
"credentials": {
"jira_user_email": args.user_email,
"jira_api_token": args.api_token,
"jira_password": args.password,
},
}
base_url = config.get("base_url")
credentials = config.get("credentials", {})
print(f"[Jira] {config=}", flush=True)
print(f"[Jira] {credentials=}", flush=True)
if not base_url:
raise RuntimeError("Jira base URL must be provided via config or CLI arguments.")
if not (credentials.get("jira_api_token") or (credentials.get("jira_user_email") and credentials.get("jira_password"))):
raise RuntimeError("Provide either an API token or both email/password for Jira authentication.")
connector_options = {
key: value
for key, value in (
("include_comments", config.get("include_comments")),
("include_attachments", config.get("include_attachments")),
("attachment_size_limit", config.get("attachment_size_limit")),
("labels_to_skip", config.get("labels_to_skip")),
("comment_email_blacklist", config.get("comment_email_blacklist")),
("scoped_token", config.get("scoped_token")),
("timezone_offset", config.get("timezone_offset")),
)
if value is not None
}
documents = test_jira(
base_url=base_url,
project_key=config.get("project_key"),
jql_query=config.get("jql_query"),
credentials=credentials,
batch_size=config.get("batch_size", INDEX_BATCH_SIZE),
start_ts=config.get("start_ts"),
end_ts=config.get("end_ts"),
connector_options=connector_options,
)
preview_count = min(len(documents), 5)
for idx in range(preview_count):
doc = documents[idx]
print(f"[Jira] [Sample {idx + 1}] {doc.semantic_identifier} | id={doc.id} | size={doc.size_bytes} bytes")
print(f"[Jira] Jira connector test completed. Documents fetched: {len(documents)}")
if __name__ == "__main__": # pragma: no cover - manual execution path
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(name)s %(message)s")
main()

View File

@ -0,0 +1,149 @@
"""Helper utilities for the Jira connector."""
from __future__ import annotations
import os
from collections.abc import Collection
from datetime import datetime, timezone
from typing import Any, Iterable
from jira.resources import Issue
from common.data_source.utils import datetime_from_string
JIRA_SERVER_API_VERSION = os.environ.get("JIRA_SERVER_API_VERSION", "2")
JIRA_CLOUD_API_VERSION = os.environ.get("JIRA_CLOUD_API_VERSION", "3")
def build_issue_url(base_url: str, issue_key: str) -> str:
"""Return the canonical UI URL for a Jira issue."""
return f"{base_url.rstrip('/')}/browse/{issue_key}"
def parse_jira_datetime(value: Any) -> datetime | None:
"""Best-effort parse of Jira datetime values to aware UTC datetimes."""
if value is None:
return None
if isinstance(value, datetime):
return value.astimezone(timezone.utc) if value.tzinfo else value.replace(tzinfo=timezone.utc)
if isinstance(value, str):
return datetime_from_string(value)
return None
def extract_named_value(value: Any) -> str | None:
"""Extract a readable string out of Jira's typed objects."""
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, dict):
return value.get("name") or value.get("value")
return getattr(value, "name", None)
def extract_user(value: Any) -> tuple[str | None, str | None]:
"""Return display name + email tuple for a Jira user blob."""
if value is None:
return None, None
if isinstance(value, dict):
return value.get("displayName"), value.get("emailAddress")
display = getattr(value, "displayName", None)
email = getattr(value, "emailAddress", None)
return display, email
def extract_text_from_adf(adf: Any) -> str:
"""Flatten Atlassian Document Format (ADF) structures to text."""
texts: list[str] = []
def _walk(node: Any) -> None:
if node is None:
return
if isinstance(node, dict):
node_type = node.get("type")
if node_type == "text":
texts.append(node.get("text", ""))
for child in node.get("content", []):
_walk(child)
elif isinstance(node, list):
for child in node:
_walk(child)
_walk(adf)
return "\n".join(part for part in texts if part)
def extract_body_text(value: Any) -> str:
"""Normalize Jira description/comments (raw/adf/str) into plain text."""
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, dict):
return extract_text_from_adf(value).strip()
return str(value).strip()
def format_comments(
comment_block: Any,
*,
blacklist: Collection[str],
) -> str:
"""Convert Jira comments into a markdown-ish bullet list."""
if not isinstance(comment_block, dict):
return ""
comments = comment_block.get("comments") or []
lines: list[str] = []
normalized_blacklist = {email.lower() for email in blacklist if email}
for comment in comments:
author = comment.get("author") or {}
author_email = (author.get("emailAddress") or "").lower()
if author_email and author_email in normalized_blacklist:
continue
author_name = author.get("displayName") or author.get("name") or author_email or "Unknown"
created = parse_jira_datetime(comment.get("created"))
created_str = created.isoformat() if created else "Unknown time"
body = extract_body_text(comment.get("body"))
if not body:
continue
lines.append(f"- {author_name} ({created_str}):\n{body}")
return "\n\n".join(lines)
def format_attachments(attachments: Any) -> str:
"""List Jira attachments as bullet points."""
if not isinstance(attachments, list):
return ""
attachment_lines: list[str] = []
for attachment in attachments:
filename = attachment.get("filename")
if not filename:
continue
size = attachment.get("size")
size_text = f" ({size} bytes)" if isinstance(size, int) else ""
content_url = attachment.get("content") or ""
url_suffix = f" -> {content_url}" if content_url else ""
attachment_lines.append(f"- {filename}{size_text}{url_suffix}")
return "\n".join(attachment_lines)
def should_skip_issue(issue: Issue, labels_to_skip: set[str]) -> bool:
"""Return True if the issue contains any label from the skip list."""
if not labels_to_skip:
return False
fields = getattr(issue, "raw", {}).get("fields", {})
labels: Iterable[str] = fields.get("labels") or []
for label in labels:
if (label or "").lower() in labels_to_skip:
return True
return False

View File

@ -1,112 +0,0 @@
"""Jira connector"""
from typing import Any
from jira import JIRA
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.exceptions import (
ConnectorValidationError,
InsufficientPermissionsError,
UnexpectedValidationError, ConnectorMissingCredentialError
)
from common.data_source.interfaces import (
CheckpointedConnectorWithPermSync,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.models import (
ConnectorCheckpoint
)
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Jira connector for accessing Jira issues and projects"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.jira_client: JIRA | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load Jira credentials"""
try:
url = credentials.get("url")
username = credentials.get("username")
password = credentials.get("password")
token = credentials.get("token")
if not url:
raise ConnectorMissingCredentialError("Jira URL is required")
if token:
# API token authentication
self.jira_client = JIRA(server=url, token_auth=token)
elif username and password:
# Basic authentication
self.jira_client = JIRA(server=url, basic_auth=(username, password))
else:
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
return None
except Exception as e:
raise ConnectorMissingCredentialError(f"Jira: {e}")
def validate_connector_settings(self) -> None:
"""Validate Jira connector settings"""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
try:
# Test connection by getting server info
self.jira_client.server_info()
except Exception as e:
if "401" in str(e) or "403" in str(e):
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
elif "404" in str(e):
raise ConnectorValidationError("Jira instance not found")
else:
raise UnexpectedValidationError(f"Jira validation error: {e}")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Jira for recent issues"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint"""
# Simplified implementation
return []
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint with permission sync"""
# Simplified implementation
return []
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
return ConnectorCheckpoint()
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
# Simplified implementation
return ConnectorCheckpoint()
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
return []

View File

@ -48,17 +48,35 @@ from common.data_source.exceptions import RateLimitTriedTooManyTimesError
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
from common.data_source.models import BasicExpertInfo, Document
_TZ_SUFFIX_PATTERN = re.compile(r"([+-])([\d:]+)$")
def datetime_from_string(datetime_string: str) -> datetime:
datetime_string = datetime_string.strip()
match_jira_format = _TZ_SUFFIX_PATTERN.search(datetime_string)
if match_jira_format:
sign, tz_field = match_jira_format.groups()
digits = tz_field.replace(":", "")
if digits.isdigit() and 1 <= len(digits) <= 4:
if len(digits) >= 3:
hours = digits[:-2].rjust(2, "0")
minutes = digits[-2:]
else:
hours = digits.rjust(2, "0")
minutes = "00"
normalized = f"{sign}{hours}:{minutes}"
datetime_string = f"{datetime_string[: match_jira_format.start()]}{normalized}"
# Handle the case where the datetime string ends with 'Z' (Zulu time)
if datetime_string.endswith('Z'):
datetime_string = datetime_string[:-1] + '+00:00'
if datetime_string.endswith("Z"):
datetime_string = datetime_string[:-1] + "+00:00"
# Handle timezone format "+0000" -> "+00:00"
if datetime_string.endswith('+0000'):
datetime_string = datetime_string[:-5] + '+00:00'
if datetime_string.endswith("+0000"):
datetime_string = datetime_string[:-5] + "+00:00"
datetime_object = datetime.fromisoformat(datetime_string)
@ -293,6 +311,13 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
aws_secret_access_key=credentials["secret_access_key"],
region_name=credentials["region"],
)
elif bucket_type == BlobType.S3_COMPATIBLE:
return boto3.client(
"s3",
endpoint_url=credentials["endpoint_url"],
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
else:
raise ValueError(f"Unsupported bucket type: {bucket_type}")
@ -480,7 +505,7 @@ def get_file_ext(file_name: str) -> str:
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"}
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
@ -902,6 +927,18 @@ def load_all_docs_from_checkpoint_connector(
)
_ATLASSIAN_CLOUD_DOMAINS = (".atlassian.net", ".jira.com", ".jira-dev.com")
def is_atlassian_cloud_url(url: str) -> bool:
try:
host = urlparse(url).hostname or ""
except ValueError:
return False
host = host.lower()
return any(host.endswith(domain) for domain in _ATLASSIAN_CLOUD_DOMAINS)
def get_cloudId(base_url: str) -> str:
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
response = requests.get(tenant_info_url, timeout=10)

View File

@ -80,4 +80,4 @@ def log_exception(e, *args):
raise Exception(a.text)
else:
logging.error(str(a))
raise e
raise e

View File

@ -21,7 +21,7 @@ import weakref
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from string import Template
from typing import Any, Literal
from typing import Any, Literal, Protocol
from typing_extensions import override
@ -30,12 +30,15 @@ from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
from rag.llm.chat_model import ToolCallSession
MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class MCPToolCallSession(ToolCallSession):
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession):
await self._process_mcp_tasks(None, msg)
else:
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
await self._process_mcp_tasks(None,
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
while not self._close:
@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession):
raise
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
timeout=timeout)
if result.isError:
return f"MCP server error: {result.content}"
@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) ->
except Exception:
logging.exception("Exception during MCP session cleanup thread management")
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
logging.info(
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
def shutdown_all_mcp_sessions():
@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions():
logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
if isinstance(mcp_tool, dict):
return {
"type": "function",

View File

@ -1429,6 +1429,13 @@
"status": "1",
"rank": "980",
"llm": [
{
"llm_name": "gemini-3-pro-preview",
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
"max_tokens": 1048576,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash",
"tags": "LLM,CHAT,1024K,IMAGE2TEXT",
@ -4839,6 +4846,639 @@
"is_tools": false
}
]
},
{
"name": "Jiekou.AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK",
"status": "1",
"llm": [
{
"llm_name": "Sao10K/L3-8B-Stheno-v3.2",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baichuan/baichuan-m2-32b",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baidu/ernie-4.5-300b-a47b-paddle",
"tags": "LLM,CHAT,123K",
"max_tokens": 123000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baidu/ernie-4.5-vl-424b-a47b",
"tags": "LLM,CHAT,123K",
"max_tokens": 123000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-5-haiku-20241022",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-5-sonnet-20241022",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-7-sonnet-20250219",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-3-haiku-20240307",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-haiku-4-5-20251001",
"tags": "LLM,CHAT,20K,IMAGE2TEXT",
"max_tokens": 20000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "claude-opus-4-1-20250805",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-opus-4-20250514",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-sonnet-4-20250514",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-sonnet-4-5-20250929",
"tags": "LLM,CHAT,200K,IMAGE2TEXT",
"max_tokens": 200000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "deepseek/deepseek-r1-0528",
"tags": "LLM,CHAT,163K",
"max_tokens": 163840,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "deepseek/deepseek-v3-0324",
"tags": "LLM,CHAT,163K",
"max_tokens": 163840,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "deepseek/deepseek-v3.1",
"tags": "LLM,CHAT,163K",
"max_tokens": 163840,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "doubao-1-5-pro-32k-250115",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "doubao-1.5-pro-32k-character-250715",
"tags": "LLM,CHAT,200K",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.0-flash-20250609",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.0-flash-lite",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-lite",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-lite-preview-06-17",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-lite-preview-09-2025",
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
"max_tokens": 1048576,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "gemini-2.5-flash-preview-05-20",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-pro",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gemini-2.5-pro-preview-06-05",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "google/gemma-3-12b-it",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "google/gemma-3-27b-it",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4.1",
"tags": "LLM,CHAT,1M",
"max_tokens": 1047576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4.1-mini",
"tags": "LLM,CHAT,1M",
"max_tokens": 1047576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4.1-nano",
"tags": "LLM,CHAT,1M",
"max_tokens": 1047576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4o",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-4o-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-chat-latest",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-codex",
"tags": "LLM,CHAT,400K,IMAGE2TEXT",
"max_tokens": 400000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "gpt-5-mini",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-nano",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5-pro",
"tags": "LLM,CHAT,400K,IMAGE2TEXT",
"max_tokens": 400000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "gpt-5.1",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5.1-chat-latest",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5.1-codex",
"tags": "LLM,CHAT,400K",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-3",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-3-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-4-0709",
"tags": "LLM,CHAT,256K",
"max_tokens": 256000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "grok-4-fast-non-reasoning",
"tags": "LLM,CHAT,2M,IMAGE2TEXT",
"max_tokens": 2000000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "grok-4-fast-reasoning",
"tags": "LLM,CHAT,2M,IMAGE2TEXT",
"max_tokens": 2000000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "grok-code-fast-1",
"tags": "LLM,CHAT,256K",
"max_tokens": 256000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gryphe/mythomax-l2-13b",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "meta-llama/llama-3.1-8b-instruct",
"tags": "LLM,CHAT,16K",
"max_tokens": 16384,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "meta-llama/llama-3.2-3b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "meta-llama/llama-3.3-70b-instruct",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
"tags": "LLM,CHAT,1M",
"max_tokens": 1048576,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "meta-llama/llama-4-scout-17b-16e-instruct",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "minimaxai/minimax-m1-80k",
"tags": "LLM,CHAT,1M",
"max_tokens": 1000000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "mistralai/mistral-7b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "mistralai/mistral-nemo",
"tags": "LLM,CHAT,60K",
"max_tokens": 60288,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "moonshotai/kimi-k2-0905",
"tags": "LLM,CHAT,262K",
"max_tokens": 262144,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "moonshotai/kimi-k2-instruct",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o1",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o1-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o3",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "o3-mini",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "openai/gpt-oss-120b",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "openai/gpt-oss-20b",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen-2.5-72b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen-mt-plus",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen2.5-7b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen2.5-vl-72b-instruct",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-235b-a22b-fp8",
"tags": "LLM,CHAT,40K",
"max_tokens": 40960,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-235b-a22b-instruct-2507",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-235b-a22b-thinking-2507",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-30b-a3b-fp8",
"tags": "LLM,CHAT,40K",
"max_tokens": 40960,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-32b-fp8",
"tags": "LLM,CHAT,40K",
"max_tokens": 40960,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-8b-fp8",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "qwen/qwen3-coder-480b-a35b-instruct",
"tags": "LLM,CHAT,262K",
"max_tokens": 262144,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-next-80b-a3b-instruct",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen/qwen3-next-80b-a3b-thinking",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "sao10k/l3-70b-euryale-v2.1",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "sao10k/l3-8b-lunaris",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "sao10k/l31-70b-euryale-v2.2",
"tags": "LLM,CHAT,8K",
"max_tokens": 8192,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "thudm/glm-4.1v-9b-thinking",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": false
},
{
"llm_name": "zai-org/glm-4.5",
"tags": "LLM,CHAT,131K",
"max_tokens": 131072,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "zai-org/glm-4.5v",
"tags": "LLM,CHAT,65K",
"max_tokens": 65536,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "baai/bge-m3",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8192,
"model_type": "embedding"
},
{
"llm_name": "qwen/qwen3-embedding-0.6b",
"tags": "TEXT EMBEDDING,32K",
"max_tokens": 32768,
"model_type": "embedding"
},
{
"llm_name": "qwen/qwen3-embedding-8b",
"tags": "TEXT EMBEDDING,32K",
"max_tokens": 32768,
"model_type": "embedding"
},
{
"llm_name": "baai/bge-reranker-v2-m3",
"tags": "RE-RANK,8K",
"max_tokens": 8000,
"model_type": "reranker"
},
{
"llm_name": "qwen/qwen3-reranker-8b",
"tags": "RE-RANK,32K",
"max_tokens": 32768,
"model_type": "reranker"
}
]
}
]
}
}

View File

@ -61,7 +61,9 @@ class DoclingParser(RAGFlowPdfParser):
self.page_images: list[Image.Image] = []
self.page_from = 0
self.page_to = 10_000
self.outlines = []
def check_installation(self) -> bool:
if DocumentConverter is None:
self.logger.warning("[Docling] 'docling' is not importable, please: pip install docling")
@ -186,9 +188,6 @@ class DoclingParser(RAGFlowPdfParser):
yield (DoclingContentType.EQUATION.value, text, bbox)
def _transfer_to_sections(self, doc) -> list[tuple[str, str]]:
"""
和 MinerUParser 保持一致:返回 [(section_text, line_tag), ...]
"""
sections: list[tuple[str, str]] = []
for typ, payload, bbox in self._iter_doc_items(doc):
if typ == DoclingContentType.TEXT.value:

View File

@ -34,6 +34,7 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if isinstance(figure_data[1], Image.Image)
]
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
@ -50,7 +51,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
@ -72,6 +74,7 @@ def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
shared_executor = ThreadPoolExecutor(max_workers=10)

View File

@ -59,6 +59,7 @@ class MinerUParser(RAGFlowPdfParser):
self.mineru_api = mineru_api.rstrip("/")
self.mineru_server_url = mineru_server_url.rstrip("/")
self.using_api = False
self.outlines = []
self.logger = logging.getLogger(self.__class__.__name__)
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
@ -317,7 +318,7 @@ class MinerUParser(RAGFlowPdfParser):
def _line_tag(self, bx):
pn = [bx["page_idx"] + 1]
positions = bx["bbox"]
positions = bx.get("bbox", (0, 0, 0, 0))
x0, top, x1, bott = positions
if hasattr(self, "page_images") and self.page_images and len(self.page_images) > bx["page_idx"]:
@ -337,12 +338,54 @@ class MinerUParser(RAGFlowPdfParser):
return None, None
return
if not getattr(self, "page_images", None):
self.logger.warning("[MinerU] crop called without page images; skipping image generation.")
if need_position:
return None, None
return
page_count = len(self.page_images)
filtered_poss = []
for pns, left, right, top, bottom in poss:
if not pns:
self.logger.warning("[MinerU] Empty page index list in crop; skipping this position.")
continue
valid_pns = [p for p in pns if 0 <= p < page_count]
if not valid_pns:
self.logger.warning(f"[MinerU] All page indices {pns} out of range for {page_count} pages; skipping.")
continue
filtered_poss.append((valid_pns, left, right, top, bottom))
poss = filtered_poss
if not poss:
self.logger.warning("[MinerU] No valid positions after filtering; skip cropping.")
if need_position:
return None, None
return
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
GAP = 6
pos = poss[0]
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
first_page_idx = pos[0][0]
poss.insert(0, ([first_page_idx], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
pos = poss[-1]
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1], pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1], pos[4] + 120)))
last_page_idx = pos[0][-1]
if not (0 <= last_page_idx < page_count):
self.logger.warning(f"[MinerU] Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.")
if need_position:
return None, None
return
last_page_height = self.page_images[last_page_idx].size[1]
poss.append(
(
[last_page_idx],
pos[1],
pos[2],
min(last_page_height, pos[4] + GAP),
min(last_page_height, pos[4] + 120),
)
)
positions = []
for ii, (pns, left, right, top, bottom) in enumerate(poss):
@ -352,7 +395,14 @@ class MinerUParser(RAGFlowPdfParser):
bottom = top + 2
for pn in pns[1:]:
bottom += self.page_images[pn - 1].size[1]
if 0 <= pn - 1 < page_count:
bottom += self.page_images[pn - 1].size[1]
else:
self.logger.warning(f"[MinerU] Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.")
if not (0 <= pns[0] < page_count):
self.logger.warning(f"[MinerU] Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.")
continue
img0 = self.page_images[pns[0]]
x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1]))
@ -363,6 +413,9 @@ class MinerUParser(RAGFlowPdfParser):
bottom -= img0.size[1]
for pn in pns[1:]:
if not (0 <= pn < page_count):
self.logger.warning(f"[MinerU] Page index {pn} out of range for {page_count} pages during crop; skipping this page.")
continue
page = self.page_images[pn]
x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1]))
cimgp = page.crop((x0, y0, x1, y1))
@ -434,7 +487,7 @@ class MinerUParser(RAGFlowPdfParser):
if not section.strip():
section = "FAILED TO PARSE TABLE"
case MinerUContentType.IMAGE:
section = "".join(output["image_caption"]) + "\n" + "".join(output["image_footnote"])
section = "".join(output.get("image_caption", [])) + "\n" + "".join(output.get("image_footnote", []))
case MinerUContentType.EQUATION:
section = output["text"]
case MinerUContentType.CODE:

View File

@ -1252,24 +1252,77 @@ class RAGFlowPdfParser:
return None, None
return
if not getattr(self, "page_images", None):
logging.warning("crop called without page images; skipping image generation.")
if need_position:
return None, None
return
page_count = len(self.page_images)
filtered_poss = []
for pns, left, right, top, bottom in poss:
if not pns:
logging.warning("Empty page index list in crop; skipping this position.")
continue
valid_pns = [p for p in pns if 0 <= p < page_count]
if not valid_pns:
logging.warning(f"All page indices {pns} out of range for {page_count} pages; skipping.")
continue
filtered_poss.append((valid_pns, left, right, top, bottom))
poss = filtered_poss
if not poss:
logging.warning("No valid positions after filtering; skip cropping.")
if need_position:
return None, None
return
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
GAP = 6
pos = poss[0]
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
first_page_idx = pos[0][0]
poss.insert(0, ([first_page_idx], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
pos = poss[-1]
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
last_page_idx = pos[0][-1]
if not (0 <= last_page_idx < page_count):
logging.warning(f"Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.")
if need_position:
return None, None
return
last_page_height = self.page_images[last_page_idx].size[1] / ZM
poss.append(
(
[last_page_idx],
pos[1],
pos[2],
min(last_page_height, pos[4] + GAP),
min(last_page_height, pos[4] + 120),
)
)
positions = []
for ii, (pns, left, right, top, bottom) in enumerate(poss):
right = left + max_width
bottom *= ZM
for pn in pns[1:]:
bottom += self.page_images[pn - 1].size[1]
if 0 <= pn - 1 < page_count:
bottom += self.page_images[pn - 1].size[1]
else:
logging.warning(f"Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.")
if not (0 <= pns[0] < page_count):
logging.warning(f"Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.")
continue
imgs.append(self.page_images[pns[0]].crop((left * ZM, top * ZM, right * ZM, min(bottom, self.page_images[pns[0]].size[1]))))
if 0 < ii < len(poss) - 1:
positions.append((pns[0] + self.page_from, left, right, top, min(bottom, self.page_images[pns[0]].size[1]) / ZM))
bottom -= self.page_images[pns[0]].size[1]
for pn in pns[1:]:
if not (0 <= pn < page_count):
logging.warning(f"Page index {pn} out of range for {page_count} pages during crop; skipping this page.")
continue
imgs.append(self.page_images[pn].crop((left * ZM, 0, right * ZM, min(bottom, self.page_images[pn].size[1]))))
if 0 < ii < len(poss) - 1:
positions.append((pn + self.page_from, left, right, 0, min(bottom, self.page_images[pn].size[1]) / ZM))

View File

@ -47,6 +47,7 @@ class TencentCloudAPIClient:
self.secret_id = secret_id
self.secret_key = secret_key
self.region = region
self.outlines = []
# Create credentials
self.cred = credential.Credential(secret_id, secret_key)

View File

@ -117,7 +117,6 @@ def load_model(model_dir, nm, device_id: int | None = None):
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
else:
sess = ort.InferenceSession(

View File

@ -106,11 +106,11 @@ ADMIN_SVR_HTTP_PORT=9381
SVR_MCP_PORT=9382
# The RAGFlow Docker image to download. v0.22+ doesn't include embedding models.
RAGFLOW_IMAGE=infiniflow/ragflow:v0.22.0
RAGFLOW_IMAGE=infiniflow/ragflow:v0.22.1
# If you cannot download the RAGFlow Docker image:
# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.22.0
# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.22.0
# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.22.1
# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.22.1
#
# - For the `nightly` edition, uncomment either of the following:
# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly

View File

@ -77,7 +77,7 @@ The [.env](./.env) file contains important environment variables for Docker.
- `SVR_HTTP_PORT`
The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`.
- `RAGFLOW-IMAGE`
The Docker image edition. Defaults to `infiniflow/ragflow:v0.22.0`. The RAGFlow Docker image does not include embedding models.
The Docker image edition. Defaults to `infiniflow/ragflow:v0.22.1`. The RAGFlow Docker image does not include embedding models.
> [!TIP]

View File

@ -71,7 +71,7 @@ for arg in "$@"; do
ENABLE_TASKEXECUTOR=0
shift
;;
--disable-datasyn)
--disable-datasync)
ENABLE_DATASYNC=0
shift
;;

View File

@ -97,7 +97,7 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit
- `SVR_HTTP_PORT`
The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`.
- `RAGFLOW-IMAGE`
The Docker image edition. Defaults to `infiniflow/ragflow:v0.22.0` (the RAGFlow Docker image without embedding models).
The Docker image edition. Defaults to `infiniflow/ragflow:v0.22.1` (the RAGFlow Docker image without embedding models).
:::tip NOTE
If you cannot download the RAGFlow Docker image, try the following mirrors.

View File

@ -47,7 +47,7 @@ After building the infiniflow/ragflow:nightly image, you are ready to launch a f
1. Edit Docker Compose Configuration
Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.22.0` to `infiniflow/ragflow:nightly` to use the pre-built image.
Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.22.1` to `infiniflow/ragflow:nightly` to use the pre-built image.
2. Launch the Service

View File

@ -12,6 +12,10 @@ The RAGFlow Admin UI is a web-based interface that provides comprehensive system
To access the RAGFlow admin UI, append `/admin` to the web UI's address, e.g. `http://[RAGFLOW_WEB_UI_ADDR]/admin`, replace `[RAGFLOW_WEB_UI_ADDR]` with real RAGFlow web UI address.
### Default Credentials
| Username | Password |
|----------|----------|
| `admin@ragflow.io` | `admin` |
## Admin UI Overview

Some files were not shown because too many files have changed in this diff Show More