Compare commits

...

217 Commits

Author SHA1 Message Date
7c9b6e032b Fix: The minimum size of the historical message window for the classification operator is 1. #12778 (#12779)
### What problem does this PR solve?

Fix: The minimum size of the historical message window for the
classification operator is 1. #12778

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-22 19:45:25 +08:00
3beb85efa0 Feat: enhance metadata arranging. (#12745)
### What problem does this PR solve?
#11564

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-22 15:34:08 +08:00
bc7b864a6c top_k parameter ignored, always returned page_size results (#12753)
### What problem does this PR solve?
**Backend**
\rag\nlp\search.py
*Before the fix*
The top_k parameter was not applied to limit the total number of chunks,
and the rerank model also uses the exact whole valid_idx rather than
assigning valid_idx = valid_idx[:top] firstly.
*After the fix*
The top_k limit is applied to the total results before pagination, using
a default value of top = 1024 if top_k is not modified.

session.py
*Before the fix:*
When the frontend calls the retrieval API with `search_id`, the backend
only reads `meta_data_filter` from the saved `search_config`. The
`rerank_id`, `top_k`, `similarity_threshold`, and
`vector_similarity_weight` parameters are only taken from the direct
request body. Since the frontend doesn't pass these parameters
explicitly (it only passes `search_id`), they always fall back to
default values:
- `similarity_threshold` = 0.0
- `vector_similarity_weight` = 0.3
- `top_k` = 1024
- `rerank_id` = "" (no rerank)
This means user settings saved in the Search Settings page have no
effect on actual search results.

*After the fix:*
When a `search_id` is provided, the backend now reads all relevant
configuration from the saved `search_config`, including `rerank_id`,
`top_k`, `similarity_threshold`, and `vector_similarity_weight`. Request
parameters can still override these values if explicitly provided,
allowing flexibility. The rerank model is now properly instantiated
using the configured `rerank_id`, making the rerank feature actually
work.



**Frontend** 
\web\src\pages\next-search\search-setting.tsx
*Before the fix*
search-setting.tsx file, the top_k input box is only displayed when
rerank is enabled (wrapped in the rerankModelDisabled condition). If the
rerank switch is turned off, the top_k input field will be hidden, but
the form value will remain unchanged. In other words: - When rerank is
enabled, users can modify top_k (default 1024). - When rerank is
disabled, top_k retains the previous value, but it's not visible on the
interface. Therefore, the backend will always receive the top_k
parameter; it's just that the frontend UI binds this configuration item
to the rerank switch. When rerank is turned off, top_k will not
automatically reset to 1024, but will retain its original value.
*After the fix*
On the contrary, if we switch off the button rerank model, the value
top-k will be reset to 1024. By the way, If we use top-k in an
individual method, rather than put it into the method retrieval, we can
control it separately



Now all methods valid
Using rerank

<img width="2378" height="1565" alt="Screenshot 2026-01-21 190206"
src="https://github.com/user-attachments/assets/fa2b0df0-1334-4ca3-b169-da6c5fd59935"
/>

Not using rerank
<img width="2596" height="1559" alt="Screenshot 2026-01-21 190229"
src="https://github.com/user-attachments/assets/c5a80522-a0e1-40e7-b349-42fe86df3138"
/>




Before fixing they are the same

### Type of change
- Bug Fix (non-breaking change which fixes an issue)
2026-01-22 15:33:42 +08:00
93091f4551 [Feat]Automatic table orientation detection and correction (#12719)
### What problem does this PR solve?
This PR introduces automatic table orientation detection and correction
within the PDF parser. This ensures that tables in PDFs are correctly
oriented before structure recognition, improving overall parsing
accuracy.

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
2026-01-22 12:47:55 +08:00
2d9e7b4acd Fix: aliyun oss need to use s3 signature_version (#12766)
### What problem does this PR solve?

Aliyun OSS do not support boto s4 signature_version which will lead to
an error:

```
botocore.exceptions.ClientError: An error occurred (InvalidArgument) when calling the PutObject operation: aws-chunked encoding is not supported with the specified x-amz-content-sha256 value
```

According to aliyun oss docs, oss_conn need to use s3 signature_version.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-22 11:43:55 +08:00
6f3f69b62e Feat: API adds audio to text and text to speech functions (#12764)
### What problem does this PR solve?

API adds audio to text and text to speech functions

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-22 11:20:26 +08:00
bfd5435087 Fix: After deleting metadata in batches, the selected items need to be cleared. (#12767)
### What problem does this PR solve?

Fix: After deleting metadata in batches, the selected items need to be
cleared.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-22 11:20:11 +08:00
0e9fe68110 Feat: Adjust the icons in the chat page's collapsible panel. (#12755)
### What problem does this PR solve?

Feat: Adjust the icons in the chat page's collapsible panel.

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2026-01-22 09:48:44 +08:00
89f438fe45 Add ping command to test ping API (#12757)
### What problem does this PR solve?

As title.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-22 00:18:29 +08:00
2e2c8f6ca9 Add more commands to RAGFlow CLI (#12731)
### What problem does this PR solve?

This PR is going to make RAGFlow CLI to access RAGFlow as normal user,
and work as the a testing tool for RAGFlow server.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-21 18:49:52 +08:00
6cd4fd91e6 Fix: Allow classification operators to be followed by other classification operators. #9082 (#12744)
### What problem does this PR solve?

Fix: Allow classification operators to be followed by other
classification operators. #9082

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-21 16:24:39 +08:00
83e17d8c4a Fix: Optimize the metadata code structure to implement metadata list structure functionality. (#12741)
### What problem does this PR solve?

Fix: Optimize the metadata code structure to implement metadata list
structure functionality.

#11564

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-21 16:15:43 +08:00
e1143d40bc Feat: Add a think button to the chat box. #12742 (#12743)
### What problem does this PR solve?

Feat: Add a think button to the chat box. #12742
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2026-01-21 15:39:18 +08:00
f98abf14a8 Refa(test): improve code formatting and remove debug prints (#12739)
### What problem does this PR solve?

- Improving code formatting and consistency
- Removing debug print statements

### Type of change

- [x] Refactoring
2026-01-21 14:53:17 +08:00
2a87778e10 Chore(ci): use new Web API test cases in CI (#12738)
### What problem does this PR solve?

- Update pytest commands to use new test directory structure

### Type of change

- [x] chore(ci)
2026-01-21 14:53:05 +08:00
5836823187 Refactor:better handle list agent api desc param (#12733)
### What problem does this PR solve?
better handle list agent api desc param

### Type of change

- [x] Refactoring
2026-01-21 13:09:27 +08:00
5a7026cf55 Feat: Improve metadata logic (#12730)
### What problem does this PR solve?

Feat: Improve metadata logic

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-21 11:31:26 +08:00
bc7935d627 feat: add batch delete for conversations in chat(web) (#12584)
Resolves #12572

## What problem does this PR solve?
The conversation list in chat sessions previously only supported
deleting conversations one by one. This was inefficient when users
needed to clean up multiple conversations. This PR adds batch delete
functionality to improve user experience.

## Type of change
 - [x] New Feature (non-breaking change which adds functionality)

## Specific changes
  - Add selection mode with checkboxes for conversation list
  - Add batch delete functionality with custom icons
  - Add internationalization support (en/zh)
  - Use existing removeConversation API which supports batch deletion

## UI modification status
  - Default: Show [+] and [batch delete icon]
  - Selection mode: Show checkboxes, keep [+] and [select all icon]
  - Items selected: Show [return icon] and [red trash icon]"

### Repair Comparison
**1.Before Repair**
<img width="982" height="1221" alt="image"
src="https://github.com/user-attachments/assets/8a80f7c0-7da6-41ec-9d1a-ac887ede96ba"
/>


**2.After Repair**
<img width="1273" height="919" alt="新增批量删除效果图"
src="https://github.com/user-attachments/assets/e179bdf3-3779-4bd5-84b6-8e24780a22ea"
/>

---
Co-authored-by: Gongzi

---------

Co-authored-by: Liu An <asiro@qq.com>
2026-01-20 19:13:53 +08:00
7787085664 Doc: add README for test (#12728)
### What problem does this PR solve?

We added instructions on how to test RAGFlow in test/README.md.

### Type of change

- [x] Documentation Update
2026-01-20 19:12:35 +08:00
960ecd3158 Feat: update and add new tests for web api apps (#12714)
### What problem does this PR solve?

This PR adds missing web API tests (system, search, KB, LLM, plugin,
connector). It also addresses a contract mismatch that was causing test
failures: metadata updates did not persist new keys (update‑only
behavior).

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Other (please describe): Test coverage expansion and test helper
instrumentation
2026-01-20 19:12:15 +08:00
aee9860970 Make document change-status idempotent for Infinity doc store (#12717)
### What problem does this PR solve?

This PR makes the document change‑status endpoint idempotent under the
Infinity doc store. If a document already has the requested status, the
handler returns success without touching the engine, preventing
unnecessary updates and avoiding missing‑table errors while keeping
responses consistent.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-20 19:11:21 +08:00
9ebbc5a74d chore: redirect to login page if api reports unauthorized in admin page (#12726)
### What problem does this PR solve?

Auto redirect to login page if API reports `401: Unauthroized` in ANY
**Admin** page.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-20 18:58:13 +08:00
1c65f64bda fix: missing route for user detail page (#12725)
### What problem does this PR solve?

Add missing route for navigating to `/admin/users/:id`

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-20 18:55:44 +08:00
32841549c1 Fix: Not within a request context (#12723)
### What problem does this PR solve?

ERROR    1819426 Unhandled exception during request
Traceback (most recent call last):
File
"/home/qinling/[github.com/infiniflow/ragflow/api/apps/document_app.py](http://github.com/infiniflow/ragflow/api/apps/document_app.py)",
line 639, in run
    return await thread_pool_exec(_run_sync)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/qinling/[github.com/infiniflow/ragflow/common/misc_utils.py](http://github.com/infiniflow/ragflow/common/misc_utils.py)",
line 132, in thread_pool_exec
return await loop.run_in_executor(_thread_pool_executor(), func, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/futures.py", line 287, in __await__
    yield self  # This tells Task to wait for completion.
    ^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/tasks.py", line 385, in __wakeup
    future.result()
  File "/usr/lib/python3.12/asyncio/futures.py", line 203, in result
    raise self._exception.with_traceback(self._exception_tb)
File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/qinling/[github.com/infiniflow/ragflow/api/apps/document_app.py](http://github.com/infiniflow/ragflow/api/apps/document_app.py)",
line 593, in _run_sync
if not DocumentService.accessible(doc_id,
[current_user.id](http://current_user.id/)):
                                              ^^^^^^^^^^^^^^^
File
"/home/qinling/[github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py](http://github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py)",
line 318, in __get__
    obj = instance._get_current_object()
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/qinling/[github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py](http://github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py)",
line 526, in _get_current_object
    return get_name(local())
                    ^^^^^^^
File
"/home/qinling/[github.com/infiniflow/ragflow/api/apps/__init__.py](http://github.com/infiniflow/ragflow/api/apps/__init__.py)",
line 97, in _load_user
    authorization = request.headers.get("Authorization")
                    ^^^^^^^^^^^^^^^
File
"/home/qinling/[github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py](http://github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py)",
line 318, in __get__
    obj = instance._get_current_object()
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/qinling/[github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py](http://github.com/infiniflow/ragflow/.venv/lib/python3.12/site-packages/werkzeug/local.py)",
line 519, in _get_current_object
    raise RuntimeError(unbound_message) from None
RuntimeError: Not within a request context

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-20 16:56:41 +08:00
046d4ffdef Docs: Updated configuration file name (#12720)
### What problem does this PR solve?

### Type of change

- [x] Documentation Update
2026-01-20 15:40:03 +08:00
4c4d434bc1 Unify MySQL configuration (#12644)
### What problem does this PR solve?

Align MySQL defaults between docker/.env and
docker/service_conf.yaml.template
close #12645

### Type of change

- [x] Other (please describe):Unify MySQL configuration
2026-01-20 13:42:22 +08:00
80612bc992 Refactor: Replace antd with shadcn (#12718)
### What problem does this PR solve?

Refactor: Replace antd with shadcn
### Type of change

- [x] Refactoring
2026-01-20 13:38:54 +08:00
927db0b373 Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change

- [x] Refactoring
2026-01-20 13:29:37 +08:00
120648ac81 fix: inaccurate error message when uploading multiple files containing an unsupported file type (#12711)
### What problem does this PR solve?

When uploading multiple files at once, if any of the files are of an
unsupported type and the blob is not removed, it triggers a
TypeError('Object of type bytes is not JSON serializable') exception.
This prevents the frontend from responding properly.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-20 12:24:54 +08:00
E.G
f367189703 fix(raptor): handle missing vector fields gracefully (#12713)
## Summary

This PR fixes a `KeyError` crash when running RAPTOR tasks on documents
that don't have the expected vector field.

## Related Issue

Fixes https://github.com/infiniflow/ragflow/issues/12675

## Problem

When running RAPTOR tasks, the code assumes all chunks have the vector
field `q_<size>_vec` (e.g., `q_1024_vec`). However, chunks may not have
this field if:
1. They were indexed with a **different embedding model** (different
vector size)
2. The embedding step **failed silently** during initial parsing
3. The document was parsed before the current embedding model was
configured

This caused a crash:
```
KeyError: 'q_1024_vec'
```

## Solution

Added defensive validation in `run_raptor_for_kb()`:

1. **Check for vector field existence** before accessing it
2. **Skip chunks** that don't have the required vector field instead of
crashing
3. **Log warnings** for skipped chunks with actionable guidance
4. **Provide informative error messages** suggesting users re-parse
documents with the current embedding model
5. **Handle both scopes** (`file` and `kb` modes)

## Changes

- `rag/svr/task_executor.py`: Added validation and error handling in
`run_raptor_for_kb()`

## Testing

1. Create a knowledge base with an embedding model
2. Parse documents
3. Change the embedding model to one with a different vector size
4. Run RAPTOR task
5. **Before**: Crashes with `KeyError`
6. **After**: Gracefully skips incompatible chunks with informative
warnings

---

<!-- Gittensor Contribution Tag: @GlobalStar117 -->

Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com>
2026-01-20 12:24:20 +08:00
1b1554c563 Docs: Added ingestion pipeline quickstart (#12708)
### What problem does this PR solve?

Added ingestion pipeline quickstart

### Type of change

- [x] Documentation Update
2026-01-20 09:48:32 +08:00
59f3da2bdf Fix: The time zone is unable to update properly in the database #12696 (#12704)
### What problem does this PR solve?

Fix: The time zone is unable to update properly in the database #12696

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-20 09:47:16 +08:00
b40d639fdb Add dataset with table parser type for Infinity and answer question in chat using SQL (#12541)
### What problem does this PR solve?

1) Create  dataset using table parser for infinity
2) Answer questions in chat using SQL

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-19 19:35:14 +08:00
05da2a5872 Fix: When large models output data rapidly, the scrollbar cannot remain at the bottom. #12701 (#12702)
### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-19 19:09:41 +08:00
4fbaa4aae9 Bump to infinity v0.7.0-dev1 (#12699)
### What problem does this PR solve?

Bump to infinity v0.7.0-dev1

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-19 16:36:03 +08:00
E.G
3188cd2659 fix: Ensure pip is available in venv for runtime installation (#12667)
## Summary

Fixes #12651

The Docker container was failing at startup with:
```
/ragflow/.venv/bin/python3: No module named pip
```

This occurred when `USE_DOCLING=true` because the `entrypoint.sh` tries
to use `uv pip install` to install docling at runtime.

## Root Cause

As explained in the issue:
1. `uv sync` creates a minimal, production-focused environment **without
pip**
2. The production stage copies the venv from builder
3. Runtime commands using `uv pip install` fail because pip is not
present

## Solution

Added `python -m ensurepip --upgrade` after `uv sync` in the Dockerfile
to ensure pip is available in the virtual environment:

```dockerfile
uv sync --python 3.12 --frozen && \
# Ensure pip is available in the venv for runtime package installation (fixes #12651)
.venv/bin/python3 -m ensurepip --upgrade
```

This is a minimal change that:
- Ensures pip is installed during build time
- Doesn't change any other behavior
- Allows runtime package installation via `uv pip install` to work

---
This is a Gittensor contribution.
gittensor:user:GlobalStar117

Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com>
2026-01-19 16:08:14 +08:00
c4a982e9fa feat: add seekdb which is lite version of oceanbase (#12692)
### What problem does this PR solve?

Add seekdb as doc_engine wich is the lite version of oceanbase.
close #12691
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-19 16:07:43 +08:00
b27dc26be3 fix: Update answer concatenation logic to handle overlapping values (#12676)
### What problem does this PR solve?

Update answer concatenation logic to handle overlapping values

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-19 16:06:36 +08:00
ab1836f216 An issue involving node.js OOM happened (#12690)
### What problem does this PR solve?
The Node.js memory issue occurred due to JavaScript heap exhaustion
during the Vite build process sometimes. Here's what happened:
export NODE_OPTIONS="--max-old-space-size=4096" && \
Root Cause:
The Node.js memory issue occurred due to JavaScript heap exhaustion
during the Vite build process sometimes. Here's what happened:

Root Cause:

When building the web frontend with npm run build, Vite needs to bundle,
transform, and optimize all JavaScript/TypeScript code
Node.js has a default maximum heap size of ~2GB
The RAGFlow web application is large enough that the build process
exceeded this limit
This triggered garbage collection failures ("Ineffective mark-compacts
near heap limit") and eventually crashed with exit code 134 (SIGABRT)

The solution I attempted:
I did not find a simple method to reduce the use of memory for node.js,
so I added NODE_OPTIONS=--max-old-space-size=4096 to allocate 4GB heap
memory for Node.js during the build.

### Type of change
- Bug Fix (non-breaking change which fixes an issue)

=> ERROR [builder 6/8] RUN --mount=type=cache,id=ragflow_npm,target=/ro
53.3s
[builder 6/8] RUN
--mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked cd
web && npm install && npm run build:
4.551
4.551 > prepare
4.551 > cd .. && husky web/.husky
4.551
4.810 .git can't be found
4.833 added 7 packages in 4s
4.833
4.833 499 packages are looking for funding
4.833 run npm fund for details
5.206
5.206 > build
5.206 > vite build --mode production
5.206
5.939 vite v7.3.0 building client environment for production...
6.169 transforming...
6.472
6.472 WARN
6.472
6.472
6.472 WARN warn - As of Tailwind CSS v3.3, the @tailwindcss/line-clamp
plugin is now included by default.
6.472
6.472
6.472 WARN warn - Remove it from the plugins array in your configuration
to eliminate this warning.
6.472
53.14
53.14 <--- Last few GCs --->
53.14
53.14 [41:0x55f82d0] 47673 ms: Scavenge (reduce) 2041.5 (2086.0) ->
2038.7 (2079.7) MB, 6.11 / 0.00 ms (average mu = 0.330, current mu =
0.319) allocation failure;
53.14 [41:0x55f82d0] 47727 ms: Scavenge (reduce) 2039.4 (2079.7) ->
2038.7 (2080.2) MB, 5.34 / 0.00 ms (average mu = 0.330, current mu =
0.319) allocation failure;
53.14 [41:0x55f82d0] 47809 ms: Scavenge (reduce) 2039.6 (2080.2) ->
2038.7 (2080.2) MB, 4.59 / 0.00 ms (average mu = 0.330, current mu =
0.319) allocation failure;
53.14
53.14
53.14 <--- JS stacktrace --->
53.14
53.14 FATAL ERROR: Ineffective mark-compacts near heap limit Allocation
failed - JavaScript heap out of memory
53.14 ----- Native stack trace -----
53.14
53.14 1: 0xb76db1 node::OOMErrorHandler(char const*, v8::OOMDetails
const&) [node]
53.14 2: 0xee62f0 v8::Utils::ReportOOMFailure(v8::internal::Isolate*,
char const*, v8::OOMDetails const&) [node]
53.14 3: 0xee65d7
v8::internal::V8::FatalProcessOutOfMemory(v8::internal::Isolate*, char
const*, v8::OOMDetails const&) [node]
53.14 4: 0x10f82d5 [node]
53.14 5: 0x10f8864
v8::internal::Heap::RecomputeLimits(v8::internal::GarbageCollector)
[node]
53.14 6: 0x110f754
v8::internal::Heap::PerformGarbageCollection(v8::internal::GarbageCollector,
v8::internal::GarbageCollectionReason, char const*) [node]
53.14 7: 0x110ff6c
v8::internal::Heap::CollectGarbage(v8::internal::AllocationSpace,
v8::internal::GarbageCollectionReason, v8::GCCallbackFlags) [node]
53.14 8: 0x11120ca v8::internal::Heap::HandleGCRequest() [node]
53.14 9: 0x107d737 v8::internal::StackGuard::HandleInterrupts() [node]
53.15 10: 0x151fb9a v8::internal::Runtime_StackGuard(int, unsigned
long*, v8::internal::Isolate*) [node]
53.15 11: 0x1959ef6 [node]
53.22 Aborted

[+] up 0/1
⠙ Image docker-ragflow Building 58.0s
Dockerfile:161

160 | COPY docs docs

161 | >>> RUN
--mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \

162 | >>> cd web && npm install && npm run build

163 |

failed to solve: process "/bin/bash -c cd web && npm install && npm run
build" did not complete successfully: exit code: 134

View build details:
docker-desktop://dashboard/build/default/default/j68n2ke32cd8bte4y8fs471au
2026-01-19 14:28:38 +08:00
7a53d2dd97 Fix CVE-2025-59466 (#12679)
### What problem does this PR solve?


https://nodejs.org/en/blog/vulnerability/january-2026-dos-mitigation-async-hooks


### Type of change

- [X] Bug Fix (non-breaking change which fixes an issue)
2026-01-19 13:15:15 +08:00
f3d347f55f feat: Add n1n provider (#12680)
This PR adds n1n as an LLM provider to RAGFlow.

Co-authored-by: Qun <qun@ip-10-5-5-38.us-west-2.compute.internal>
2026-01-19 13:12:42 +08:00
E.G
9da48ab0bd fix: Handle NaN/Infinity values in ExeSQL JSON response (#12666)
## Summary

Fixes #12631

When SQL query results contain NaN (Not a Number) or Infinity values
(e.g., from division by zero or other calculations), the JSON
serialization would fail because **NaN and Infinity are not valid JSON
values**.

This caused the agent interface to show 'undefined' error, as described
in the issue where `EXAMINE_TIMES` became `NaN` and broke the JSON
parsing.

## Root Cause

The `convert_decimals` function in `exesql.py` was only handling
`Decimal` types, but not `float` values that could be `NaN` or
`Infinity`.

When these invalid JSON values were serialized:
```json
{"EXAMINE_TIMES": NaN}  // Invalid JSON!
```

The frontend JSON parser would fail, causing the 'undefined' error.

## Solution

Extended `convert_decimals` to detect `float` values and convert
`NaN`/`Infinity` to `null` before JSON serialization:

```python
if isinstance(obj, float):
    if math.isnan(obj) or math.isinf(obj):
        return None
    return obj
```

This ensures all SQL results can be properly serialized to valid JSON.

---
This is a Gittensor contribution.
gittensor:user:GlobalStar117

Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com>
Co-authored-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
2026-01-19 12:46:06 +08:00
4a7e40630b Refactor:memory delete will re-use super method (#12684)
### What problem does this PR solve?
memory delete will re-use super method

### Type of change

- [x] Refactoring
2026-01-19 12:45:37 +08:00
d6897b6054 Fix chat error (#12693)
### What problem does this PR solve?

As title.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-19 12:45:14 +08:00
828ae1e82f Round float value of minimum_should_match (#12688)
### What problem does this PR solve?

In paragraph() of class FulltextQueryer, "len(keywords) / 10" should be
rounded to integer before set to minimum_should_match.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-19 11:39:33 +08:00
57d189b483 fix: Correct gitlab_url access in sync_data_source.py (#12681)
### What problem does this PR solve?

Correct gitlab_url access. See
https://github.com/infiniflow/ragflow/blob/main/web/src/pages/user-setting/data-source/constant/index.tsx#L660-L666

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-19 11:01:34 +08:00
0a8eb11c3d fix: Add proper error handling for database reconnection attempts (#12650)
## Problem
When database connection is lost, the reconnection logic had a bug: if
the first reconnect attempt failed, the second attempt was not wrapped
in error handling, causing unhandled exceptions.

## Solution
Added proper try-except blocks around the second reconnect attempt in
both MySQL and PostgreSQL database classes to ensure errors are properly
logged and handled.

## Changes
- Fixed `_handle_connection_loss()` in `RetryingPooledMySQLDatabase`
- Fixed `_handle_connection_loss()` in
`RetryingPooledPostgresqlDatabase`

Fixes #12294

---

Contribution by Gittensor, see my contribution statistics at
https://gittensor.io/miners/details?githubId=158349177

Co-authored-by: SID <158349177+0xsid0703@users.noreply.github.com>
2026-01-19 09:48:10 +08:00
38f0a92da9 Use RAGFlow CLI to replace RAGFlow Admin CLI (#12653)
### What problem does this PR solve?

```
$ python admin/client/ragflow_cli.py -t user -u aaa@aaa.com -p 9380

ragflow> list datasets;
ragflow> list default models;
ragflow> show version;

```


### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-17 17:52:38 +08:00
067ddcbf23 Docs: Added configure memory (#12665)
### What problem does this PR solve?

As title.

### Type of change

- [x] Documentation Update
2026-01-17 17:49:19 +08:00
46305ef35e Add User API Token Management to Admin API and CLI (#12595)
## Summary

This PR extends the RAGFlow Admin API and CLI with comprehensive user
API token management capabilities. Administrators can now generate,
list, and delete API tokens for users through both the REST API and the
Admin CLI interface.

## Changes

### Backend API (`admin/server/`)

#### New Endpoints
- **POST `/api/v1/admin/users/<username>/new_token`** - Generate a new
API token for a user
- **GET `/api/v1/admin/users/<username>/token_list`** - List all API
tokens for a user
- **DELETE `/api/v1/admin/users/<username>/token/<token>`** - Delete a
specific API token for a user

#### Service Layer Updates (`services.py`)
- Added `get_user_api_key(username)` - Retrieves all API tokens for a
user
- Added `save_api_token(api_token)` - Saves a new API token to the
database
- Added `delete_api_token(username, token)` - Deletes an API token for a
user

### Admin CLI (`admin/client/`)

#### New Commands
- **`GENERATE TOKEN FOR USER <username>;`** - Generate a new API token
for the specified user
- **`LIST TOKENS OF <username>;`** - List all API tokens associated with
a user
- **`DROP TOKEN <token> OF <username>;`** - Delete a specific API token
for a user

### Testing

Added comprehensive test suite in `test/testcases/test_admin_api/`:
- **`test_generate_user_api_key.py`** - Tests for API token generation
- **`test_get_user_api_key.py`** - Tests for listing user API tokens
- **`test_delete_user_api_key.py`** - Tests for deleting API tokens
- **`conftest.py`** - Shared test fixtures and utilities

## Technical Details

### Token Generation
- Tokens are generated using `generate_confirmation_token()` utility
- Each token includes metadata: `tenant_id`, `token`, `beta`,
`create_time`, `create_date`
- Tokens are associated with user tenants automatically

### Security Considerations
- All endpoints require admin authentication (`@check_admin_auth`)
- Tokens are URL-encoded when passed in DELETE requests to handle
special characters
- Proper error handling for unauthorized access and missing resources

### API Response Format
All endpoints follow the standard RAGFlow response format:
```json
{
  "code": 0,
  "data": {...},
  "message": "Success message"
}
```

## Files Changed

- `admin/client/admin_client.py` - CLI token management commands
- `admin/server/routes.py` - New API endpoints
- `admin/server/services.py` - Token management service methods
- `docs/guides/admin/admin_cli.md` - CLI documentation updates
- `test/testcases/test_admin_api/conftest.py` - Test fixtures
- `test/testcases/test_admin_api/test_user_api_key_management/*` - Test
suites

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Alexander Strasser <alexander.strasser@ondewo.com>
Co-authored-by: Hetavi Shah <your.email@example.com>
2026-01-17 15:21:00 +08:00
bd9163904a fix(ob_conn): ignore duplicate errors when executing 'create_idx' (#12661)
### What problem does this PR solve?

Skip duplicate errors to avoid 'create_idx' failures caused by slow
metadata refresh or external modifications.


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-16 20:46:37 +08:00
b6d7733058 Feat: metadata settings in KB. (#12662)
### What problem does this PR solve?

#11910

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-16 20:14:02 +08:00
4f036a881d Fix: Infinity keyword round-trip, highlight fallback, and KB update guards (#12660)
### What problem does this PR solve?

Fixes Infinity-specific API regressions: preserves ```important_kwd```
round‑trip for ```[""]```, restores required highlight key in retrieval
responses, and enforces Infinity guards for unsupported
```parser_id=tag``` and pagerank in ```/v1/kb/update```. Also removes a
slow/buggy pandas row-wise apply that was throwing ```ValueError``` and
causing flakiness.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-16 20:03:52 +08:00
59075a0b58 Fix : p3 level sdk test error for update chat (#12654)
### What problem does this PR solve?

fix for update chat failing

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-16 17:47:12 +08:00
30bd25716b Fix PDF Generator output variables not appearing in subsequent agent steps (#12619)
This commit fixes multiple issues preventing PDF Generator (Docs
Generator) output variables from being visible in the Output section and
available to downstream nodes.

### What problem does this PR solve?

Issues Fixed:
1. PDF Generator nodes initialized with empty object instead of proper
initial values
2. Output structure mismatch (had 'value' property that system doesn't
expect)
3. Missing 'download' output in form schema
4. Output list computed from static values instead of form state
5. Added null/undefined guard to transferOutputs function

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
Changes:
- web/src/pages/agent/constant/index.tsx: Fixed output structure in
initialPDFGeneratorValues
- web/src/pages/agent/hooks/use-add-node.ts: Initialize PDF Generator
with proper values
- web/src/pages/agent/form/pdf-generator-form/index.tsx: Fixed schema
and use form.watch
- web/src/pages/agent/form/components/output.tsx: Added null guard and
spacing
2026-01-16 16:50:53 +08:00
99dae3c64c Fix: In the agent loop, if the await response is selected as the variable, the operator cannot be selected. #12656 (#12657)
### What problem does this PR solve?

Fix: In the agent loop, if the await response is selected as the
variable, the operator cannot be selected. #12656

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-16 16:49:48 +08:00
045314a1aa Fix: duplicate content in chunk (#12655)
### What problem does this PR solve?

Fix: duplicate content in chunk #12336

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-16 15:32:04 +08:00
2b20d0b3bb Fix : Web API tests by normalizing errors, validation, and uploads (#12620)
### What problem does this PR solve?

Fixes web API behavior mismatches that caused test failures by
normalizing error responses, tightening validations, correcting error
messages, and closing upload file handles.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-16 11:09:22 +08:00
59f4c51222 fix(entrypoint): Preserve $ in passwords during template expansion (#12509)
### What problem does this PR solve?

Fix shell variable expansion to preserve $ in password defaults when 
env vars are unset. Fixes Azure RDS auto-rotated passwords (that contain
$) being
truncated during template processing.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-15 19:30:33 +08:00
8c1fbfb130 Fix:Some bugs (#12648)
### What problem does this PR solve?

Fix: Modified and optimized the metadata condition card component.
Fix: Use startOfDay and endOfDay to ensure the date range includes a
full day.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-15 19:28:22 +08:00
cec06bfb5d Fix: empty chunk issue. (#12638)
#12570

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-15 17:46:21 +08:00
2167e3a3c0 Docs: Added share memory (#12647)
### Type of change

- [x] Documentation Update
2026-01-15 17:21:36 +08:00
2ea8dddef6 fix(infinity): Use comma separator for important_kwd to preserve mult… (#12618)
## Problem

The \`important_kwd\` field in Infinity connector was using mismatched
separators:
- **Storage**: \`list2str(v)\` uses space as default separator
- **Reading**: \`v.split()\` splits by all whitespace

This causes multi-word keywords like \`\"Senior Fund Manager\"\` to be
incorrectly split into \`[\"Senior\", \"Fund\", \"Manager\"]\`.

## Solution

Use comma \`,\` as separator for both storing and reading, consistent
with:
1. The LLM output format in \`keyword_prompt.md\` (\"delimited by
ENGLISH COMMA\")
2. The \`cached.split(\",\")\` in \`task_executor.py\`

## Changes

- \`insert()\`: \`list2str(v)\` → \`list2str(v, \",\")\`
- \`update()\`: \`list2str(v)\` → \`list2str(v, \",\")\`
- \`get_fields()\`: \`v.split()\` → \`v.split(\",\") if v else []\`

## Impact

This bug affects:
- Python-level reranking weight calculation (\`important_kwd * 5\`)
- API response keyword display
- Search precision due to fragmented keywords
2026-01-15 15:32:40 +08:00
18867daba7 chore: bump pyobvector from 0.2.18 to 0.2.22 (#12640)
### What problem does this PR solve?

Update ob client

### Type of change

- [x] Other (please describe):dependency upgrade
2026-01-15 15:21:34 +08:00
d68176326d feat: add oceanbase mount to gitignore (#12642)
### What problem does this PR solve?

feat: add oceanbase mount to .gitignore

### Type of change

- [x] Refactoring
2026-01-15 15:20:40 +08:00
d531bd4f1a Fix: Editing the agent greeting causes the greeting to be continuously added to the message list. #12635 (#12636)
### What problem does this PR solve?

Fix: Editing the agent greeting causes the greeting to be continuously
added to the message list. #12635
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-15 14:55:19 +08:00
ac936005e6 fix: ensure deleted chunks are not returned in retrieval (#12520) (#12546)
## Summary
Fixes #12520 - Deleted chunks should not appear in retrieval/reference
results.

## Changes

### Core Fix
- **api/apps/chunk_app.py**: Include \doc_id\ in delete condition to
properly scope the delete operation

### Improved Error Handling
- **api/db/services/document_service.py**: Better separation of concerns
with individual try-catch blocks and proper logging for each cleanup
operation

### Doc Store Updates
- **rag/utils/es_conn.py**: Updated delete query construction to support
compound conditions
- **rag/utils/opensearch_conn.py**: Same updates for OpenSearch
compatibility

### Tests
- **test/testcases/.../test_retrieval_chunks.py**: Added
\TestDeletedChunksNotRetrievable\ class with regression tests
- **test/unit/test_delete_query_construction.py**: Unit tests for delete
query construction

## Testing
- Added regression tests that verify deleted chunks are not returned by
retrieval API
- Tests cover single chunk deletion and batch deletion scenarios
2026-01-15 14:45:55 +08:00
d8192f8f17 Fix: validate regex pattern in split_with_pattern to prevent crash (#12633)
### What problem does this PR solve?

Fix regex pattern validation in split_with_pattern (#12605)

- Add try-except block to validate user-provided regex patterns before
use
- Gracefully fallback to single chunk when invalid regex is provided
- Prevent server crash during DOCX parsing with malformed delimiters

## Problem

Parsing DOCX files with custom regex delimiters crashes with `re.error:
nothing to repeat at position 9` when users provide invalid regex
patterns.

Closes #12605 

## Solution

Validate and compile regex pattern before use. On invalid pattern, log
warning and return content as single chunk instead of crashing.

## Changes

- `rag/nlp/__init__.py`: Add regex validation in `split_with_pattern()`
function

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Contribution by Gittensor, see my contribution statistics at
https://gittensor.io/miners/details?githubId=42954461
2026-01-15 14:24:51 +08:00
eb35e2b89f Fix: async invocation isssue. (#12634)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-15 14:22:16 +08:00
97b983fd0b fix: add fallback parser list for empty parser_ids (#12632)
### What problem does this PR solve?

Fixes #12570 - The slicing method dropdown was empty when deploying
RAGFlow v0.23.1 from source code.

The issue occurred because `parser_ids` from the tenant info was empty
or undefined, causing `useSelectParserList` to return an empty array.
This PR adds a fallback to a default parser list when `parser_ids` is
empty, ensuring the dropdown always has options.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---
Contribution by Gittensor, see my contribution statistics at
https://gittensor.io/miners/details?githubId=94194147
2026-01-15 14:05:25 +08:00
b40a7b2e7d Feat: Hash doc id to avoid duplicate name. (#12573)
### What problem does this PR solve?

Feat: Hash doc id to avoid duplicate name. 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-15 14:02:15 +08:00
9a10558f80 Refa: async retrieval process. (#12629)
### Type of change

- [x] Refactoring
- [x] Performance Improvement
2026-01-15 12:28:49 +08:00
SID
f82628c40c Fix: langfuse connection error handling #12621 (#12626)
## Description

Fixes connection error handling when langfuse service is unavailable.
The application now gracefully handles connection failures instead of
crashing.

## Changes

- Wrapped `langfuse.auth_check()` calls in try-except blocks in:
  - `api/db/services/dialog_service.py`
  - `api/db/services/tenant_llm_service.py`

## Problem

When langfuse service is unavailable or connection is refused,
`langfuse.auth_check()` throws `httpx.ConnectError: [Errno 111]
Connection refused`, causing the application to crash during document
parsing or dialog operations.

## Solution

Added try-except blocks around `langfuse.auth_check()` calls to catch
connection errors and gracefully skip langfuse tracing instead of
crashing. The application continues functioning normally even when
langfuse is unavailable.

## Related Issue

Fixes #12621

---

Contribution by Gittensor, see my contribution statistics at
https://gittensor.io/miners/details?githubId=158349177
2026-01-15 11:23:15 +08:00
7af98328f5 Fix: the styles of the multi-select component and the filter pop-up. (#12628)
### What problem does this PR solve?

Fix: Fix the styles of the multi-select component and the filter pop-up.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-15 10:53:18 +08:00
678a4f959c Fix: skip internal bookmark references in DOCX parsing (#12604) (#12611)
### What problem does this PR solve?

Fixes #12604 - DOCX files containing hyperlinks to internal bookmarks
(e.g., `#_文档目录`) cause a `KeyError` during parsing:

```
KeyError: "There is no item named 'word/#_文档目录' in the archive"
```

This happens because python-docx incorrectly tries to read internal
bookmark references as files from the ZIP archive. Internal bookmarks
are relationship targets starting with `#` and are not actual files.

This PR extends the existing `load_from_xml_v2` workaround (which
already handles `NULL` targets) to also skip relationship targets
starting with `#`.

Related upstream issue:
https://github.com/python-openxml/python-docx/issues/902

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---
Contribution by Gittensor, see my contribution statistics at
https://gittensor.io/miners/details?githubId=94194147
2026-01-14 19:08:46 +08:00
15a8bb2e9c Fix: chunk list async issue. (#12615)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-14 17:32:07 +08:00
b091ff2730 Fix enable_thinking parameter for Qwen3 models (#12603)
### Issue

When using Qwen3 models (`qwen3-32b`, `qwen3-max`) through the
Tongyi-Qianwen provider for non-streaming calls (e.g., knowledge graph
generation), the API fails with:

Closes #12424

```
parameter.enable_thinking must be set to false for non-streaming calls
```

### Root Cause

In `LiteLLMBase.async_chat()`, the `extra_body={"enable_thinking":
False}` was set in `kwargs` but never forwarded to
`_construct_completion_args()`.

### What problem does this PR solve?

Pass merged kwargs to `_construct_completion_args()` using
`**{**gen_conf, **kwargs}` to safely handle potential duplicate
parameters.

### Changes

- `rag/llm/chat_model.py`: Forward kwargs containing `extra_body` to
`_construct_completion_args()` in `async_chat()`


_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Contribution by Gittensor, see my contribution statistics at
https://gittensor.io/miners/details?githubId=42954461
2026-01-14 16:35:46 +08:00
5b22f94502 Feat: Benchmark CLI additions and documentation (#12536)
### What problem does this PR solve?

This PR adds a dedicated HTTP benchmark CLI for RAGFlow chat and
retrieval endpoints so we can measure latency/QPS.

### Type of change

- [x] Documentation Update
- [x] Other (please describe): Adds a CLI benchmarking tool for
chat/retrieval latency/QPS

---------

Co-authored-by: Liu An <asiro@qq.com>
2026-01-14 13:49:16 +08:00
a7671583b3 Feat: add CN regions for AWS (#12610)
### What problem does this PR solve?

Add CN regions for AWS.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-14 12:34:55 +08:00
d32fa02d97 Fix: Unable to copy category node. #12607 (#12609)
### What problem does this PR solve?

Fix: Unable to copy category node. #12607

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-14 11:45:31 +08:00
f72a35188d refactor: remove debug print statements (#12598)
### What problem does this PR solve?

This PR eliminates unnecessary debug print statements that were left in
hot paths of the codebase.

### Type of change

- [x] Refactoring
2026-01-14 10:05:34 +08:00
ea619dba3b Added to the HTTP API test suite (#12556)
### What problem does this PR solve?

This PR adds missing HTTP API test coverage for dataset
graph/GraphRAG/RAPTOR tasks, metadata summary, chat completions, agent
sessions/completions, and related questions. It also introduces minimal
HTTP test helpers to exercise these endpoints consistently with the
existing suite.

### Type of change

- [x]  Other (please describe): Test coverage (HTTP API tests)

---------

Co-authored-by: Liu An <asiro@qq.com>
2026-01-14 10:02:30 +08:00
36b0835740 Docs: Use memory (#12599)
### What problem does this PR solve?


### Type of change


- [x] Documentation Update
2026-01-14 09:40:31 +08:00
0795616b34 Align p3 HTTP/SDK tests with current backend behavior (#12563)
### What problem does this PR solve?

Updates pre-existing HTTP API and SDK tests to align with current
backend behavior (validation errors, 404s, and schema defaults). This
ensures p3 regression coverage is accurate without changing production
code.

### Type of change

- [x] Other (please describe): align p3 HTTP/SDK tests with current
backend behavior

---------

Co-authored-by: Liu An <asiro@qq.com>
2026-01-13 19:22:47 +08:00
941651a16f Fix: wrong input trace in Category component (#12590)
### What problem does this PR solve?

Wrong input trace in Category component

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-13 17:54:57 +08:00
360114ed42 fix(ob_conn): avoid reusing SQLAlchemy Column objects in DDL (#12588)
### What problem does this PR solve?

When there are multiple users, parsing a document for a new user can
trigger the reuse of column objects, leading to the error
`sqlalchemy.exc.ArgumentError: Column object 'id' already assigned to
Table xxx`.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-13 17:39:20 +08:00
ffedb2c6d3 Feat: The MetadataFilterConditions component supports adding values ​​via search. (#12585)
### What problem does this PR solve?

Feat: The MetadataFilterConditions component supports adding values
​​via search.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-13 17:03:25 +08:00
947e63ca14 Fixed typos and added pptx preview for frontend (#12577)
### What problem does this PR solve?
Previously, we added support for previewing PPT and PPTX files in the
backend. Now, we are adding it to the frontend, so when the slides in
the chat interface are referenced, they will no longer be blank.
### Type of change

- Bug Fix (non-breaking change which fixes an issue)
2026-01-13 17:02:36 +08:00
34d74d9928 fix: add uv-aarch64-unknown-linux-gnu.tar.gz to deps image (#12516)
### What problem does this PR solve?

Add uv-aarch64-unknown-linux-gnu.tar.gz to support building ARM64 Docker
images.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Liu An <asiro@qq.com>
2026-01-13 15:37:32 +08:00
accae95126 Feat: Exported Agent JSON Should Include Conversation Variables Configuration #11796 (#12579)
### What problem does this PR solve?

Feat: Exported Agent JSON Should Include Conversation Variables
Configuration #11796

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2026-01-13 15:35:45 +08:00
68e5c86e9c Fix: image not displaying thumbnails when using pipeline (#12574)
### What problem does this PR solve?

Fix image not displaying thumbnails when using pipeline.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-13 12:54:13 +08:00
64c75d558e Fix: zip extraction vulnerabilities in MinerU and TCADP (#12527)
### What problem does this PR solve?

Fix zip extraction vulnerabilities:
   - Block symlink entries in zip files.
   - Reject encrypted zip entries.
   - Prevent absolute path attacks (including Windows paths).
   - Block path traversal attempts (../).
   - Stop zip slip exploits (directory escape).
   - Use streaming for memory-safe file handling.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-13 12:24:50 +08:00
41c84fd78f Add MIME types for PPT and PPTX files (#12562)
Otherwise, slide files cannot be opened in Chat module

### What problem does this PR solve?

Backend Reason (API): In the api/utils/web_utils.py file of the backend,
the CONTENT_TYPE_MAP dictionary is missing ppt and pptx.
MIME type mapping. This means that when the frontend requests a PPTX
file, the backend cannot correctly inform the browser that it is a PPTX
file, resulting in the file being displayed incorrectly.
Type identification error.

### Type of change

-  Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-13 12:17:49 +08:00
d76912ab15 Fix: Use uv pip install for Docling installation (#12567)
Fixes #12440
### What problem does this PR solve?
The current implementation uses `python3 -m pip` which can fail in
certain environments. This change leverages `uv pip install` instead,
which aligns with the project's existing tooling.

### Type of change
- Removed the ensurepip line (not needed since uv manages pip)
- Changed python3 to "$PY" for consistency with the rest of the script
- Changed python3 -m pip install to uv pip install

Co-authored-by: Gongzi <gongzi@192.168.0.100>
2026-01-13 11:48:42 +08:00
4fe3c24198 feat: PaddleOCR PDF parser supports thumnails and positions (#12565)
### What problem does this PR solve?

1. PaddleOCR PDF parser supports thumnails and positions.
2. Add FAQ documentation for PaddleOCR PDF parser.


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-13 09:51:08 +08:00
44bada64c9 Feat: support tree structured deep-research policy. (#12559)
### What problem does this PR solve?

#12558
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-13 09:41:35 +08:00
867ec94258 revert white-space changes in docs (#12557)
### What problem does this PR solve?

Trailing white-spaces in commit 6814ace1aa
got automatically trimmed by code editor may causes documentation
typesetting broken.

Mostly for double spaces for soft line breaks.  

### Type of change

- [x] Documentation Update
2026-01-13 09:41:02 +08:00
fd0a1fde6b Feat: Enhanced metadata functionality (#12560)
### What problem does this PR solve?

Feat: Enhanced metadata functionality
- Metadata filtering supports searching.
- Values ​​can be directly modified.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-12 19:05:33 +08:00
653001b14f Doc: python sdk document (#12554)
### What problem does this PR solve?

Add python sdk document for memory api.

### Type of change

- [x] Documentation Update
2026-01-12 15:31:02 +08:00
d4f8c724ed Fix:Automatically enable metadata and optimize parser dialog logic (#12553)
### What problem does this PR solve?

Fix:Automatically enable metadata and optimize parser dialog logic

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-12 15:29:50 +08:00
a7dd3b7e9e Add time cost when start servers (#12552)
### What problem does this PR solve?

- API server
- Ingestion server
- Data sync server
- Admin server

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-12 12:48:23 +08:00
638c510468 refactor: introduce common normalize method in rerank base class (#12550)
### What problem does this PR solve?

introduce common normalize method in rerank base class

### Type of change

- [x] Refactoring
2026-01-12 11:07:11 +08:00
ff11e3171e Feat: SandBox docker CLI error in ARM CPU #12433 (#12434)
### What problem does this PR solve?

Add multi-architecture support for Sandbox

Updated Dockerfile to support multiple architectures for Docker Sandbox
installation.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-12 11:06:33 +08:00
030d6ba004 CI collect ragflow log (#12543)
### What problem does this PR solve?

As title

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [x] Other (please describe): CI
2026-01-10 09:52:32 +08:00
b226e06e2d refactor: remove debug print statements (#12534)
### What problem does this PR solve?

refactor: remove debug print statements

### Type of change

- [x] Refactoring
2026-01-09 19:23:50 +08:00
2e09db02f3 feat: add paddleocr parser (#12513)
### What problem does this PR solve?

Add PaddleOCR as a new PDF parser.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-09 17:48:45 +08:00
6abf55c048 Feat: support openapi (#12521)
### What problem does this PR solve?
Support OpenAPI interface description.

The issue of not supporting the Swagger interface after upgrading the
system framework from Flask to Quart has been resolved.

Resolved https://github.com/infiniflow/ragflow/issues/5264

### Type of change
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: puhaiyang <“761396462@qq.com”>
2026-01-09 17:48:20 +08:00
f9d4179bf2 Feat:memory sdk (#12538)
### What problem does this PR solve?

Move memory and message apis to /api, and add sdk support.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-09 17:45:58 +08:00
64b1e0b4c3 Feat: The translation model type options should be consistent with the model's labels. #1036 (#12537)
### What problem does this PR solve?

Feat: The translation model type options should be consistent with the
model's labels. #1036

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2026-01-09 17:39:40 +08:00
b65daeb945 Fix: Baiduyiyan key invaild (#12531)
### What problem does this PR solve?

Fix: Baiduyiyan key invaild

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-09 17:37:17 +08:00
fbe55cef05 fix: keep password in opendal config to fix connection initialization (#12529)
### What problem does this PR solve?

If we delete the password in kwargs, func 'init_db_config' will fail, so
we need to keep this field.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-09 14:19:32 +08:00
0878526ba8 Refactor: Refactoring OllamaModal using shadcn. #1036 (#12530)
### What problem does this PR solve?

Refactor: Refactoring OllamaModal using shadcn.  #1036

### Type of change

- [x] Refactoring
2026-01-09 13:42:28 +08:00
a2db3e3292 Fix: Bugs fixed (#12524)
### What problem does this PR solve?

Fix: Bugs fixed
- The issue of filter conditions not being able to be deleted on the
knowledge base file page
- The issue of metadata filter conditions not working.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-09 13:41:24 +08:00
f522391d1e Fix: "AttributeError(\"'list' object has no attribute 'get'\")" (#12518)
### What problem does this PR solve?
https://github.com/infiniflow/ragflow/issues/12515

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-09 10:19:51 +08:00
9562762af2 docs: fix embedding model switching tooltip (#12517)
### What problem does this PR solve?
After version 0.22.1, the embedding model supports switching; the
corresponding tooltip needs to be updated.

### Type of change

- [x] Documentation Update
2026-01-09 10:19:40 +08:00
455fd04050 Refactor: Replace Ant Design with shadcn in SparkModal, TencentCloudModal, HunyuanModal, and GoogleModal. #1036 (#12510)
### What problem does this PR solve?

Refactor: Replace Ant Design with shadcn in SparkModal,
TencentCloudModal, HunyuanModal, and GoogleModal. #1036
### Type of change

- [x] Refactoring
2026-01-08 19:42:45 +08:00
14c250e3d7 Fix adding column error (#12503)
### What problem does this PR solve?

1. Fix redundant column adding
2. Refactor the code

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-08 16:44:53 +08:00
a093e616cf Fix: add multimodel models in chat api (#12496)
### What problem does this PR solve?

Fix: add multimodel models in chat api #11986
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-01-08 16:12:08 +08:00
696397ebba Fix: apply kb setting while re-parsing.... (#12501)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-08 16:11:50 +08:00
6f1a555d5f Refa(sdk/python/test): remove unused testcases and utilities (#12505)
### What problem does this PR solve?

Removed the following dir:
- sdk/python/test/libs/
- sdk/python/test/test_http_api/
- sdk/python/test/test_sdk_api/

### Type of change

- [x] Refactoring
2026-01-08 16:11:35 +08:00
1996aa0dac Refactor: Enhance delta streaming in chat functions for improved reasoning and content handling (#12453)
### What problem does this PR solve?

change:
Enhance delta streaming in chat functions for improved reasoning and
content handling

### Type of change


- [x] Refactoring
2026-01-08 13:34:16 +08:00
f4e2783eb4 optimize doc id check: do not query db when doc id to validate is empty (#12500)
### What problem does this PR solve?
when a kb contains many documents, say 50000, and the retrieval is only
made against some kb without specifying any doc ids, the query for all
docs from the db is not necessary, and can be omitted to improve
performance.

### Type of change

- [x] Performance Improvement
2026-01-08 13:22:58 +08:00
2fd4a3134d Doc: memory http api (#12499)
### What problem does this PR solve?

Use task save function for add_message api, and added http API document.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
2026-01-08 12:54:10 +08:00
f1dc2df23c Fix:Bedrock assume_role auth mode fails with LiteLLM "Extra inputs are not permitted" error (#12495)
### What problem does this PR solve?
https://github.com/infiniflow/ragflow/issues/12489

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-08 12:53:41 +08:00
de27c006d8 Feat: The chat feature supports streaming output, displaying results one by one. #12490 (#12493)
### What problem does this PR solve?

Feat: The chat feature supports streaming output, displaying results one
by one.

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2026-01-08 09:43:57 +08:00
23a9544b73 Fix: toc async issue. (#12485)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-07 15:35:30 +08:00
011bbe9556 Feat: support context window for docx (#12455)
### What problem does this PR solve?

Feat: support context window for docx

#12303

Done:
- [x] naive.py
- [x] one.py

TODO:
- [ ] book.py
- [ ] manual.py

Fix: incorrect image position
Fix: incorrect chunk type tag

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
2026-01-07 15:08:17 +08:00
a442c9cac6 Fix: Fixed an issue where ESLint suggestions were not working in VS Code after upgrading to Vite. #12483 (#12484)
### What problem does this PR solve?

Fix: Fixed an issue where ESLint suggestions were not working in VS Code
after upgrading to Vite. #12483

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-07 14:32:17 +08:00
671e719d75 Feat: Memory-message supports categorized display (#12482)
### What problem does this PR solve?

Feat: Memory-message supports categorized display

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-07 13:48:40 +08:00
07845be5bd Fix: display agent name for extract messages (#12480)
### What problem does this PR solve?

Display agent name for extract messages

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-07 13:19:54 +08:00
8d406bd2e6 fix: prevent MinIO health check failure in multi-bucket mode (#12446)
### What problem does this PR solve?

- Fixes the health check failure in multi-bucket MinIO environments.
Previously, health checks would fail because the default
"ragflow-bucket" did not exist. This caused false negatives for system
health.

- Also removes the _health_check write in single-bucket mode to avoid
side effects (minor optimization).

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-07 10:07:18 +08:00
2a4627d9a0 Fix: Issues and style fixes related to the 'Memory' page (#12469)
### What problem does this PR solve?

Fix:  Some bugs
- Issues and style fixes related to the 'Memory' page
- Data source icon replacement
- Build optimization

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-07 10:03:54 +08:00
6814ace1aa docs: update docs icons (#12465)
### What problem does this PR solve?

Update icons for docs.
Trailing spaces are auto truncated by the editor, does not affect real
content.

### Type of change

- [x] Documentation Update
2026-01-07 10:00:09 +08:00
ca9645f39b Feat: adapt to , arglist (#12468)
### What problem does this PR solve?

Adapt to ',' joined arg list in get method url.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-07 09:59:08 +08:00
8e03843145 fix: task executor with status "timeout" corrupts page when checking its details (#12467)
### What problem does this PR solve?

In **Admin UI** > **Service Status**, clicking "Show details" on task
executor with status "Timeout" may corrupts page.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-07 09:58:16 +08:00
51ece37db2 refactor: migrate env prefix to VITE_* (#12466)
### What problem does this PR solve?

`UMI_APP_*` to `VITE_*`

### Type of change

- [x] Refactoring
2026-01-07 09:39:18 +08:00
45fb2719cf Fix: update uv python installation to version 3.12 in Dockerfile (#12464)
### What problem does this PR solve?
issue:
https://github.com/infiniflow/ragflow/issues/12440
change:
update uv python installation to version 3.12 in Dockerfile

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-06 19:27:46 +08:00
bdd9f3d4d1 Fix: try handle authorization as api-token (#12462)
### What problem does this PR solve?

Try handle authorization as api-token when jwt load failed.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-06 19:25:42 +08:00
1f60863f60 Docs: Fixed a display issue. (#12463)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2026-01-06 17:40:53 +08:00
02e6870755 Refactor: import_test_cases use bulk_create (#12456)
### What problem does this PR solve?

import_test_cases use bulk_create

### Type of change

- [x] Refactoring
2026-01-06 11:39:07 +08:00
aa08920e51 Fix: The avatar and greeting message no longer appear in the Agent iFrame. [#12410] (#12459)
### What problem does this PR solve?
Fix: The avatar and greeting message no longer appear in the Agent
iFrame. [#12410]

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-06 11:01:16 +08:00
7818644129 Fix: add uv binary archive to ignored files (#12451)
### What problem does this PR solve?

After I ran this command, 

```bash
uv run ./download_deps.py 
```

a file was not ignored.

```bash
❯ git status
On branch feat/ignore-uv
Untracked files:
  (use "git add <file>..." to include in what will be committed)
        uv-x86_64-unknown-linux-gnu.tar.gz

nothing added to commit but untracked files present (use "git add" to track)
```

Add this file name to `.gitignore`

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-05 20:22:35 +08:00
55c9fc0017 fix: add 'mom_id' column to OBConnection chunk table (#12444)
### What problem does this PR solve?

Fix #12428

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-05 19:31:44 +08:00
140dd2c8cc Refactor: Refactor FishAudioModal and BedrockModal using shadcn. #1036 (#12449)
### What problem does this PR solve?

Refactor: Refactor FishAudioModal and BedrockModal using shadcn. #1036

### Type of change

- [x] Refactoring
2026-01-05 19:27:56 +08:00
fada223249 Feat: process memory (#12445)
### What problem does this PR solve?

Add task status for raw message, and move extract message as a nested
property under raw message

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-05 17:58:32 +08:00
00f8a80ca4 Fix: Some bugs (#12441)
### What problem does this PR solve?

Fix: Some bugs
- In a production environment, a second-level page refresh results in a
white screen.
- The knowledge graph cannot be opened.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-05 15:28:57 +08:00
4e9407b4ae Refactor: Refactoring AzureOpenAIModal using shadcn. #10427 (#12436)
### What problem does this PR solve?

Refactor: Refactoring AzureOpenAIModal using shadcn. #10427

### Type of change

- [x] Refactoring
2026-01-05 14:09:55 +08:00
42461bc378 Update admin doc (#12439)
### What problem does this PR solve?

update for 'list configs' and 'list envs'

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-05 13:26:33 +08:00
92780c486a Add list configs and environments (#12438)
### What problem does this PR solve?

1. list configs;
3. list envs;

```
admin> list configs;
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
| extra                                                                                     | host      | id | name          | port  | service_type   |
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
| {}                                                                                        | 0.0.0.0   | 0  | ragflow_0     | 9380  | ragflow_server |
| {'meta_type': 'mysql', 'password': 'infini_rag_flow', 'username': 'root'}                 | localhost | 1  | mysql         | 5455  | meta_data      |
| {'password': 'infini_rag_flow', 'store_type': 'minio', 'user': 'rag_flow'}                | localhost | 2  | minio         | 9000  | file_store     |
| {'password': 'infini_rag_flow', 'retrieval_type': 'elasticsearch', 'username': 'elastic'} | localhost | 3  | elasticsearch | 1200  | retrieval      |
| {'db_name': 'default_db', 'retrieval_type': 'infinity'}                                   | localhost | 4  | infinity      | 23817 | retrieval      |
| {'database': 1, 'mq_type': 'redis', 'password': 'infini_rag_flow'}                        | localhost | 5  | redis         | 6379  | message_queue  |
| {'message_queue_type': 'redis'}                                                           |           | 6  | task_executor | 0     | task_executor  |
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
admin> list envs;
+-------------------------+------------------+
| env                     | value            |
+-------------------------+------------------+
| DOC_ENGINE              | elasticsearch    |
| DEFAULT_SUPERUSER_EMAIL | admin@ragflow.io |
| DB_TYPE                 | mysql            |
| DEVICE                  | cpu              |
| STORAGE_IMPL            | MINIO            |
+-------------------------+------------------+
admin> 
```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-05 13:26:22 +08:00
lif
81f9296d79 Fix: handle invalid img_id format in chunk update (#12422)
## Summary
- Fix ValueError when updating chunk with invalid/empty `img_id` format
- Add validation before splitting `img_id` by hyphen
- Use `split("-", 1)` to handle object names containing hyphens

## Test plan
- [x] Verify chunk update works with valid `img_id` (format:
`bucket-objectname`)
- [x] Verify chunk update doesn't crash with empty `img_id`
- [x] Verify chunk update doesn't crash when `img_id` has no hyphen
- [x] Verify ruff check passes

Fixes #12035

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-05 11:27:19 +08:00
606f4e6c9e Refa: improve TOC building with better error handling (#12427)
### What problem does this PR solve?

Refactor TOC building logic to use enumerate instead of while loop, add
comprehensive error handling for missing/invalid chunk_id values, and
improve logging with more specific error messages. The changes make the
code more robust against malformed TOC data while maintaining the same
functionality for valid inputs.

### Type of change

- [x] Refactoring
2026-01-05 10:02:42 +08:00
4cd4526492 Feat: PDF vision figure parser supports reading context (#12416)
### What problem does this PR solve?

PDF vision figure parser supports reading context.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2026-01-05 09:55:43 +08:00
cc8a10376a Refactor: Refactoring VolcEngine and Yiyan modal using shadcn. #10427 (#12426)
### What problem does this PR solve?

Refactor: Refactoring VolcEngine and Yiyan modal using shadcn. #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2026-01-05 09:53:47 +08:00
5ebe334a2f Refactor setting type (#12425)
### What problem does this PR solve?

Refactor setting type

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-04 20:26:12 +08:00
932496a8ec Fix:bug fix (#12423)
### What problem does this PR solve?
change: 
initialize webhook configuration in webhook function
remove debug print statement from airtable_connector
remove redundant uuid import in imap_connector

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-04 19:16:29 +08:00
a8a060676a Refactor: UmiJs -> Vite (#12410)
### What problem does this PR solve?

Refactor: UmiJs -> Vite+React

### Type of change

- [x] Refactoring

---------

Co-authored-by: Liu An <asiro@qq.com>
2026-01-04 19:14:20 +08:00
2c10ccd622 Chore(compose): remove unnecessary history_data_agent volume mount (#12418)
### What problem does this PR solve?

Removed the volume mount mapping
../history_data_agent:/ragflow/history_data_agent from
docker-compose.yml as it appears to be no longer in use

### Type of change

- [x] Chore
2026-01-04 16:58:23 +08:00
a2211c200d Feat: message write testcase (#12417)
### What problem does this PR solve?

Write testcase for message web apis.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-01-04 16:52:44 +08:00
21ba9e6d72 doc: update admin CLI (#12413)
### What problem does this PR solve?

`SHOW VERSION;`
- Display the current RAGFlow version.

`GRANT ADMIN <username>`
- Grant administrator privileges to the specified user.

`REVOKE ADMIN <username>`
- Revoke administrator privileges from the specified user.

`LIST VARS`
- List all system configurations and settings.

`SHOW VAR <var_name>`
- Display the content of a specific system configuration/setting by its
name or name prefix.

`SET VAR <var_name> <var_value>`
- Set the value for a specified configuration item.

related to: #12409 

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-04 15:22:01 +08:00
ac9113b0ef feature: add system setting service (#12408)
### What problem does this PR solve?

#12409 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-04 14:21:39 +08:00
11779697de Test: get message content testcase (#12403)
### What problem does this PR solve?

Testcase for get_message_content api.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-01-04 11:25:24 +08:00
d6e006f086 Improve task executor heartbeat handling and cleanup (#12390)
Improve task executor heartbeat handling and cleanup.

### What problem does this PR solve?

- **Reduce lock contention during executor cleanup**: The cleanup lock
is acquired only when removing expired executors, not during regular
heartbeat reporting, reducing potential lock contention.

- **Optimize own heartbeat cleanup**: Each executor removes its own
expired heartbeat using `zremrangebyscore` instead of `zcount` +
`zpopmin`, reducing Redis operations and improving efficiency.

- **Improve cleanup of other executors' heartbeats**: Expired executors
are detected by checking their latest heartbeat, and stale entries are
removed safely.

- **Other improvements**: IP address and PID are captured once at
startup, and unnecessary global declarations are removed.

### Type of change

- [x] Performance Improvement

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-01-04 11:24:05 +08:00
d39fa75d36 Fix: Not able to add MCP Server [#12394](https://github.com/infiniflow/ragflow/issues/12394) (#12406)
### What problem does this PR solve?

Fix: Not able to add MCP Server
[#12394](https://github.com/infiniflow/ragflow/issues/12394)

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-04 11:22:34 +08:00
f56bceb2a9 Fix: remvoe async wrappers (#12405)
### What problem does this PR solve?

Fix: remvoe async wrappers  #12396

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-01-04 11:19:48 +08:00
Rin
bbaf918d74 security: harden OpenDAL SQL initialization against injection (#12393)
Eliminates SQL injection vectors in the OpenDAL MySQL initialization
logic by implementing strict input validation and explicit type casting.

**Modifications:**
1. **`init_db_config`**: Enforced integer casting for
`max_allowed_packet` before formatting it into the SQL string.
2. **`init_opendal_mysql_table`**: Implemented regex-based validation
for `table_name` to ensure only alphanumeric characters and underscores
are permitted, preventing arbitrary SQL command injection through
configuration parameters.

These changes ensure that even if configuration values are sourced from
untrusted environments, the database initialization remains secure.
2026-01-04 11:19:26 +08:00
KKM
89a97be2c5 Remove duplicated tag_feas assignment in create route (api/apps/chunk_app.py) (#12392)
### What problem does this PR solve?

This PR removes a duplicated assignment of `tag_feas` in the
`@manager.route('/create')` API handler located in
`api/apps/chunk_app.py`.

The same conditional block was unintentionally repeated twice, which had
no
functional impact but reduced code readability and maintainability.
This change eliminates the redundancy while preserving existing
behavior.

### Type of change

- [x] Refactoring

Co-authored-by: 김경만 <kmkim7@humaxit.com>
2026-01-04 10:32:36 +08:00
6f2fc2f1cb refactor:re order logics in clean_gen_conf (#12391)
### What problem does this PR solve?

re order logics in clean_gen_conf
#12388

### Type of change
- [x] Refactoring
2026-01-04 10:31:56 +08:00
42da080d89 Fix: Fixed the issue where the upload DSL dialog box was too narrow. #10427 (#12384)
### What problem does this PR solve?

Fix: Fixed the issue where the upload DSL dialog box was too narrow.
#10427

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-01-04 09:40:59 +08:00
1f4a17863f Feat: read web api testcases (#12383)
### What problem does this PR solve?

Web API testcase for list_messages, get_recent_message.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-01-01 12:52:40 +08:00
4d3a3a97ef Update HELP command of ADMIN CLI (#12387)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2026-01-01 12:52:13 +08:00
ff1020ccfb ADMIN CLI: support grant/revoke user admin authorization (#12381)
### What problem does this PR solve?

```
admin> grant admin 'aaa@aaa1.com';
Fail to grant aaa@aaa1.com admin authorization, code: 404, message: User 'aaa@aaa1.com' not found
admin> grant admin 'aaa@aaa.com';
Grant successfully!
admin> revoke admin 'aaa1@aaa.com';
Fail to revoke aaa1@aaa.com admin authorization, code: 404, message: User 'aaa1@aaa.com' not found
admin> revoke admin 'aaa@aaa.com';
Revoke successfully!
admin> revoke admin 'aaa@aaa.com';
aaa@aaa.com isn't superuser, yet!
admin> grant admin 'aaa@aaa.com';
Grant successfully!
admin> grant admin 'aaa@aaa.com';
aaa@aaa.com is already superuser!
admin> revoke admin 'aaa@aaa.com';
Revoke successfully!

```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2026-01-01 12:49:34 +08:00
ca3bd2cf9f Update README (#12386)
### Type of change

- [x] Documentation Update
2025-12-31 20:07:40 +08:00
eb661c028d Fix Tika version mismatch in Dockerfile.deps (3.0.0 → 3.2.3) (#12267)
Fixes #12266 

Dockerfile.deps still referenced `tika-server-standard-3.0.0.jar` even
after
the project moved to Tika 3.2.3 for security reasons.

This caused Docker builds to fail due to a version mismatch and missing
artifact.

Changes:
- Update Dockerfile.deps to consistently use Tika 3.2.3

No functional changes beyond dependency alignment.

Co-authored-by: Liu An <asiro@qq.com>
2025-12-31 19:55:39 +08:00
10c28c5ecd Feat: Refactoring the documentation page using shadcn. #10427 (#12376)
### What problem does this PR solve?

Feat: Refactoring the documentation page using shadcn. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-31 19:00:37 +08:00
96810b7d97 Fix: webdav connector (#12380)
### What problem does this PR solve?

fix webdav #11422

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 19:00:00 +08:00
365f9b01ae Fix: metadata data synchronization issues; add memory tab in home page (#12368)
### What problem does this PR solve?

fix: metadata data synchronization issues; add memory tab in home page

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 17:19:04 +08:00
7d4d687dde Feat: Bitbucket connector (#12332)
### What problem does this PR solve?

Feat: Bitbucket connector NOT READY TO MERGE

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-31 17:18:30 +08:00
6a664fea3b Docs: Updated v0.23.0 release notes (#12374)
### What problem does this PR solve?


### Type of change


- [x] Documentation Update
2025-12-31 17:10:15 +08:00
dcdc1b0ec7 Fix urls for basic docs (#12372)
### Type of change

- [x] Documentation Update
2025-12-31 17:02:34 +08:00
4af4c36e60 Docs: Added v0.23.1 release notes (#12371)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2025-12-31 16:43:56 +08:00
05e5244d94 Refactor docs of RAGFlow admin (#12361)
### What problem does this PR solve?

as title

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-31 14:42:53 +08:00
c2ee2bf7fe Feat: add Zendesk data source integration with configuration and sync capabilities (#12344)
### What problem does this PR solve?
issue:
#12313
change:
add Zendesk data source integration with configuration and sync
capabilities

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-31 14:40:49 +08:00
461c81e14a Fix: KG search issue. (#12364)
### What problem does this PR solve?

Close #12347

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 14:40:27 +08:00
675d18d359 Add docs category file (#12359)
### What problem does this PR solve?

As title.

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-31 13:56:11 +08:00
750335978c Fix: Batch parsing problem (#12358)
### What problem does this PR solve?

Fix: Batch parsing problem

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 13:52:50 +08:00
ae7c623a35 fix(rag/prompts): Restructure metadata extraction rules for precision (#12360)
### What problem does this PR solve?

- Simplified and consolidated extraction rules
- Emphasized strict evidence-based extraction only
- Strengthened enum handling and hallucination prevention
- Clarified output requirements for empty results

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 13:52:33 +08:00
f24bdc0f83 Remove doc of health check (#12363)
### What problem does this PR solve?

System web page is disabled since v0.22.0, and the health check API is
also described in API reference. This document is obsolete.

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-31 13:49:38 +08:00
07ef35b7e6 Docs: Update version references to v0.23.1 in READMEs and docs (#12349)
### What problem does this PR solve?

- Update version tags in README files (including translations) from
v0.23.0 to v0.23.1
- Modify Docker image references and documentation to reflect new
version
- Update version badges and image descriptions
- Maintain consistency across all language variants of README files

### Type of change

- [x] Documentation Update
2025-12-31 12:49:42 +08:00
7c9823a1ff Update release notes (#12356)
### What problem does this PR solve?

As title

### Type of change

- [x] Documentation Update

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-31 12:49:09 +08:00
a0c3bcf798 [Bug] Don't display not used component status in admin status dashboard (#12355)
### What problem does this PR solve?

Currently, all components in configs are displayed even they are not
used. This PR is to fix it.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-31 12:40:52 +08:00
1a4a7d1705 Fix: apply kb configured llm issue. (#12354)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 12:40:28 +08:00
f141947085 [Feature] Admin: sort user list by email (#12350)
### What problem does this PR solve?

Sort the user list by user email address.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-31 11:55:18 +08:00
a07e947644 Feat: Fixed the issue where the newly created agent begin node displayed "undefined". #10427 (#12348)
### What problem does this PR solve?
Feat: Fixed the issue where the newly created agent begin node displayed
"undefined". #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-31 11:19:33 +08:00
ae4692a845 Fix: Bug fixed (#12345)
### What problem does this PR solve?

Fix: Bug fixed
- Memory type multilingual display
- Name modification is prohibited in Data source
- Jump directly to Metadata settings

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 10:43:57 +08:00
7dac269429 fix: correct session reference initialization to prevent dialogue misalignment (#12343)
## Summary

Fixes #12311

Changes the `reference` field initialization from `[{}]` to `[]` in
session creation.

### Problem

When creating a session via the SDK API, the `reference` field was
incorrectly initialized as `[{}]`. This caused:
- First dialogue round: Empty reference
- Second dialogue round: Reference pointing to first round's data
- Overall misalignment between dialogue rounds and their references

### Solution

Changed the initialization to `[]` (empty list), which:
- Matches the `Conversation` model's expected default
- Ensures references grow correctly one-to-one with assistant responses
- Aligns with the service layer's expectations

### Testing

After applying this fix:
1. Create a session via `POST /api/v1/chats/{conversation_id}/sessions`
2. Send multiple questions via `POST
/api/v1/chats/{conversation_id}/completions`
3. View the conversation on web - references should now align correctly
with each dialogue round
2025-12-31 10:25:49 +08:00
ec5575dce2 fix(admin-ui): pagination auto reset to first page when after refetching data (#12339)
### What problem does this PR solve?

Admin user enabling/disabling a user will causing user list reset to
first page

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-31 09:39:13 +08:00
6fee60e110 Docs: What is RAG & What is Agent context engine (#12341)
### Type of change

- [x] Documentation Update
2025-12-30 21:29:21 +08:00
52f91c2388 Refine: image/table context. (#12336)
### What problem does this PR solve?

#12303

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-30 20:24:27 +08:00
348265afc1 Feat: Display mode at the begin node #10427 (#12326)
### What problem does this PR solve?

Feat: Display mode at the begin node #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-30 19:53:24 +08:00
a7e466142d Fix: Dataset parse logic (#12330)
### What problem does this PR solve?

Fix: Dataset  logic of parser

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-30 19:53:00 +08:00
2fccf3924d Feat: Adapt the theme of the documentation page. #10427 (#12337)
### What problem does this PR solve?

Feat: Adapt the theme of the documentation page. #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-30 19:35:44 +08:00
4705d07e11 fix: malformed dynamic translation key chunk.docType.${chunkType} (#12329)
### What problem does this PR solve?

Back-end may returns empty array on `"doc_type_kwd"` property which
causes translation key malformed.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-30 19:33:20 +08:00
68be3b9a3d Update release workflow (#12335)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 19:06:27 +08:00
e2d17d808b Potential fix for code scanning alert no. 62: Workflow does not contain permissions (#12334)
Potential fix for
[https://github.com/infiniflow/ragflow/security/code-scanning/62](https://github.com/infiniflow/ragflow/security/code-scanning/62)

In general, the fix is to explicitly declare a `permissions:` block so
the GITHUB_TOKEN used by this workflow only has the scopes required:
read access to repository contents and write access to
contents/releases. Since this workflow creates or moves tags and
creates/overwrites releases via `softprops/action-gh-release`, it needs
`contents: write`. There is no evidence that it needs other elevated
scopes (issues, pull-requests, actions, etc.), so these should remain at
their default of `none` by omission.

The best minimal fix without changing existing functionality is to add a
workflow-level `permissions:` block near the top of
`.github/workflows/release.yml`, after `name:` and before `on:` (or
anywhere at the root level, but this is conventional). This will apply
to all jobs (there is only `jobs.release`) and ensure that the
GITHUB_TOKEN has only `contents: write`. No additional imports or
methods are needed because this is a YAML configuration change only.

Concretely:
- Edit `.github/workflows/release.yml`.
- Insert:

```yaml
permissions:
  contents: write
```

between line 2 (empty line after `name: release`) and line 3 (`on:`). No
other lines need to be changed.


_Suggested fixes powered by Copilot Autofix. Review carefully before
merging._

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-12-30 18:59:51 +08:00
95edbd43ba Update model providers (#12333)
### What problem does this PR solve?

As title

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 18:51:35 +08:00
b96d553cd8 Update release workflow (#12327)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 17:25:27 +08:00
bffdb5fb11 Feat: add IMAP data source integration with configuration and sync capabilities (#12316)
### What problem does this PR solve?
issue:
#12217 [#12313](https://github.com/infiniflow/ragflow/issues/12313)
change:
add IMAP data source integration with configuration and sync
capabilities

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-30 17:09:13 +08:00
109e782493 Feat: On the agent page and chat page, you can only select knowledge bases that use the same embedding model. #12320 (#12321)
### What problem does this PR solve?

Feat: On the agent page and chat page, you can only select knowledge
bases that use the same embedding model. #12320

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-30 17:08:30 +08:00
ff2c70608d Fix: judge index exist before delete memory message. (#12318)
### What problem does this PR solve?

Judge index exist before delete memory message.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-30 15:54:07 +08:00
5903d1c8f1 Feat: GitHub connector (#12314)
### What problem does this PR solve?

Feat: GitHub connector

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-30 15:09:52 +08:00
f0392e7501 Fix IDE warnings (#12315)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 15:04:09 +08:00
4037788e0c Fix: Dataset parse error (#12310)
### What problem does this PR solve?

Fix: Dataset parse error

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-30 13:08:20 +08:00
59884ab0fb Fix TypeError in meta_filter when using numeric metadata (#12286)
The filter_out function in metadata_utils.py was using a list of tuples
to evaluate conditions. Python eagerly evaluates all tuple elements when
constructing the list, causing "input in value" to be evaluated even
when the operator is "=". When input and value are floats (after numeric
conversion), this causes TypeError: "argument of type 'float' is not
iterable".

This change replaces the tuple list with if-elif chain, ensuring only
the matching condition is evaluated.

### What problem does this PR solve?

Fixes #12285

When using comparison operators like `=`, `>`, `<` with numeric
metadata, the `filter_out` function throws `TypeError("argument of type
'float' is not iterable")`. This is because Python eagerly evaluates all
tuple elements when constructing a list, causing `input in value` to be
evaluated even when the operator is `=`.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
2025-12-30 11:56:48 +08:00
4a6d37f0e8 Fix: use async task to save memory (#12308)
### What problem does this PR solve?

Use async task to save memory.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 11:41:38 +08:00
731e2d5f26 api key delete bug - Bug #3045 (#12299)
Description:
Fixed an issue where deleting an API token would fail because it was
incorrectly using current_user.id as the tenant_id instead of querying
the actual tenant ID from UserTenantService.

Changes:

Updated rm() endpoint to fetch the correct tenant_id from
UserTenantService before deleting the API token
Added proper error handling with try/except block
Code style cleanup: consistent quote usage and formatting
Related Issue: #3045

https://github.com/infiniflow/ragflow/issues/3045

Co-authored-by: Mardani, Ramin <ramin.mardani@sscinc.com>
2025-12-30 11:27:04 +08:00
df3cbb9b9e Refactor code (#12305)
### What problem does this PR solve?

as title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-30 11:09:18 +08:00
5402666b19 docs: fix typos (#12301)
### What problem does this PR solve?

fix typos

### Type of change

- [x] Documentation Update
2025-12-30 09:39:28 +08:00
744 changed files with 46745 additions and 42333 deletions

View File

@ -10,6 +10,12 @@ on:
tags: tags:
- "v*.*.*" # normal release - "v*.*.*" # normal release
permissions:
contents: write
actions: read
checks: read
statuses: read
# https://docs.github.com/en/actions/using-jobs/using-concurrency # https://docs.github.com/en/actions/using-jobs/using-concurrency
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -76,6 +82,14 @@ jobs:
# The body field does not support environment variable substitution directly. # The body field does not support environment variable substitution directly.
body_path: release_body.md body_path: release_body.md
- name: Build and push image
run: |
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
sudo docker push infiniflow/ragflow:latest
- name: Build and push ragflow-sdk - name: Build and push ragflow-sdk
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')
run: | run: |
@ -85,11 +99,3 @@ jobs:
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')
run: | run: |
cd admin/client && uv build && uv publish --token ${{ secrets.PYPI_API_TOKEN }} cd admin/client && uv build && uv publish --token ${{ secrets.PYPI_API_TOKEN }}
- name: Build and push image
run: |
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
sudo docker push infiniflow/ragflow:latest

View File

@ -86,6 +86,9 @@ jobs:
mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY} mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}
echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP} echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP}
fi fi
ARTIFACTS_DIR=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/${GITHUB_RUN_ID}
echo "ARTIFACTS_DIR=${ARTIFACTS_DIR}" >> ${GITHUB_ENV}
rm -rf ${ARTIFACTS_DIR} && mkdir -p ${ARTIFACTS_DIR}
# https://github.com/astral-sh/ruff-action # https://github.com/astral-sh/ruff-action
- name: Static check with Ruff - name: Static check with Ruff
@ -161,7 +164,7 @@ jobs:
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10)) INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10)) INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10)) INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
MYSQL_PORT=$((5455 + RUNNER_NUM * 10)) EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
MINIO_PORT=$((9000 + RUNNER_NUM * 10)) MINIO_PORT=$((9000 + RUNNER_NUM * 10))
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10)) MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
REDIS_PORT=$((6379 + RUNNER_NUM * 10)) REDIS_PORT=$((6379 + RUNNER_NUM * 10))
@ -181,7 +184,7 @@ jobs:
echo -e "INFINITY_THRIFT_PORT=${INFINITY_THRIFT_PORT}" >> docker/.env echo -e "INFINITY_THRIFT_PORT=${INFINITY_THRIFT_PORT}" >> docker/.env
echo -e "INFINITY_HTTP_PORT=${INFINITY_HTTP_PORT}" >> docker/.env echo -e "INFINITY_HTTP_PORT=${INFINITY_HTTP_PORT}" >> docker/.env
echo -e "INFINITY_PSQL_PORT=${INFINITY_PSQL_PORT}" >> docker/.env echo -e "INFINITY_PSQL_PORT=${INFINITY_PSQL_PORT}" >> docker/.env
echo -e "MYSQL_PORT=${MYSQL_PORT}" >> docker/.env echo -e "EXPOSE_MYSQL_PORT=${EXPOSE_MYSQL_PORT}" >> docker/.env
echo -e "MINIO_PORT=${MINIO_PORT}" >> docker/.env echo -e "MINIO_PORT=${MINIO_PORT}" >> docker/.env
echo -e "MINIO_CONSOLE_PORT=${MINIO_CONSOLE_PORT}" >> docker/.env echo -e "MINIO_CONSOLE_PORT=${MINIO_CONSOLE_PORT}" >> docker/.env
echo -e "REDIS_PORT=${REDIS_PORT}" >> docker/.env echo -e "REDIS_PORT=${REDIS_PORT}" >> docker/.env
@ -211,14 +214,14 @@ jobs:
done done
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
- name: Run frontend api tests against Elasticsearch - name: Run web api tests against Elasticsearch
run: | run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
echo "Waiting for service to be available..." echo "Waiting for service to be available..."
sleep 5 sleep 5
done done
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/ 2>&1 | tee es_web_api_test.log
- name: Run http api tests against Elasticsearch - name: Run http api tests against Elasticsearch
run: | run: |
@ -229,6 +232,13 @@ jobs:
done done
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
- name: Collect ragflow log
if: ${{ !cancelled() }}
run: |
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-es
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log
sudo rm -rf docker/ragflow-logs
- name: Stop ragflow:nightly - name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed if: always() # always run this step even if previous steps failed
run: | run: |
@ -249,14 +259,14 @@ jobs:
done done
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
- name: Run frontend api tests against Infinity - name: Run web api tests against Infinity
run: | run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
echo "Waiting for service to be available..." echo "Waiting for service to be available..."
sleep 5 sleep 5
done done
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/ 2>&1 | tee infinity_web_api_test.log
- name: Run http api tests against Infinity - name: Run http api tests against Infinity
run: | run: |
@ -267,6 +277,12 @@ jobs:
done done
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
- name: Collect ragflow log
if: ${{ !cancelled() }}
run: |
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-infinity
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log
sudo rm -rf docker/ragflow-logs
- name: Stop ragflow:nightly - name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed if: always() # always run this step even if previous steps failed
run: | run: |

13
.gitignore vendored
View File

@ -44,6 +44,7 @@ cl100k_base.tiktoken
chrome* chrome*
huggingface.co/ huggingface.co/
nltk_data/ nltk_data/
uv-x86_64*.tar.gz
# Exclude hash-like temporary files like 9b5ad71b2ce5302211f9c61530b329a4922fc6a4 # Exclude hash-like temporary files like 9b5ad71b2ce5302211f9c61530b329a4922fc6a4
*[0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f]* *[0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f]*
@ -51,6 +52,13 @@ nltk_data/
.venv .venv
docker/data docker/data
# OceanBase data and conf
docker/oceanbase/conf
docker/oceanbase/data
# SeekDB data and conf
docker/seekdb
#--------------------------------------------------# #--------------------------------------------------#
# The following was generated with gitignore.nvim: # # The following was generated with gitignore.nvim: #
@ -198,3 +206,8 @@ backup
.hypothesis .hypothesis
# Added by cargo
/target

View File

@ -19,17 +19,17 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache. # This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
cp -r /deps/nltk_data /root/ && \ cp -r /deps/nltk_data /root/ && \
cp /deps/tika-server-standard-3.0.0.jar /deps/tika-server-standard-3.0.0.jar.md5 /ragflow/ && \ cp /deps/tika-server-standard-3.2.3.jar /deps/tika-server-standard-3.2.3.jar.md5 /ragflow/ && \
cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4 cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.0.0.jar" ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.2.3.jar"
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
# Setup apt # Setup apt
# Python package and implicit dependencies: # Python package and implicit dependencies:
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1 # opencv-python: libglib2.0-0 libglx-mesa0 libgl1
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb # aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
# python-pptx: default-jdk tika-server-standard-3.0.0.jar # python-pptx: default-jdk tika-server-standard-3.2.3.jar
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85 # selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev # Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
@ -53,7 +53,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y ghostscript && \ apt install -y ghostscript && \
apt install -y pandoc && \ apt install -y pandoc && \
apt install -y texlive && \ apt install -y texlive && \
apt install -y fonts-freefont-ttf fonts-noto-cjk apt install -y fonts-freefont-ttf fonts-noto-cjk && \
apt install -y postgresql-client
# Install uv # Install uv
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
@ -64,10 +65,12 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \ echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
echo 'default = true' >> /etc/uv/uv.toml; \ echo 'default = true' >> /etc/uv/uv.toml; \
fi; \ fi; \
tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \ arch="$(uname -m)"; \
&& cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \ if [ "$arch" = "x86_64" ]; then uv_arch="x86_64"; else uv_arch="aarch64"; fi; \
&& rm -rf uv-x86_64-unknown-linux-gnu \ tar xzf "/deps/uv-${uv_arch}-unknown-linux-gnu.tar.gz" \
&& uv python install 3.11 && cp "uv-${uv_arch}-unknown-linux-gnu/"* /usr/local/bin/ \
&& rm -rf "uv-${uv_arch}-unknown-linux-gnu" \
&& uv python install 3.12
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
ENV PATH=/root/.local/bin:$PATH ENV PATH=/root/.local/bin:$PATH
@ -152,11 +155,14 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
else \ else \
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \ sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
fi; \ fi; \
uv sync --python 3.12 --frozen uv sync --python 3.12 --frozen && \
# Ensure pip is available in the venv for runtime package installation (fixes #12651)
.venv/bin/python3 -m ensurepip --upgrade
COPY web web COPY web web
COPY docs docs COPY docs docs
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \ RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \
export NODE_OPTIONS="--max-old-space-size=4096" && \
cd web && npm install && npm run build cd web && npm install && npm run build
COPY .git /ragflow/.git COPY .git /ragflow/.git
@ -187,7 +193,6 @@ COPY deepdoc deepdoc
COPY rag rag COPY rag rag
COPY agent agent COPY agent agent
COPY graphrag graphrag COPY graphrag graphrag
COPY agentic_reasoning agentic_reasoning
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
COPY mcp mcp COPY mcp mcp
COPY plugin plugin COPY plugin plugin

View File

@ -3,7 +3,7 @@
FROM scratch FROM scratch
# Copy resources downloaded via download_deps.py # Copy resources downloaded via download_deps.py
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz / COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.2.3.jar tika-server-standard-3.2.3.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz uv-aarch64-unknown-linux-gnu.tar.gz /
COPY nltk_data /nltk_data COPY nltk_data /nltk_data

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"> <img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a> </a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank"> <a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.0"> <img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
</a> </a>
<a href="https://github.com/infiniflow/ragflow/releases/latest"> <a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release"> <img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center"> <h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> | <a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> | <a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> | <a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> | <a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a> <a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 What is RAGFlow? ## 💡 What is RAGFlow?
[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation (RAG) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged context engine and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision. [RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation ([RAG](https://ragflow.io/basics/what-is-rag)) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged [context engine](https://ragflow.io/basics/what-is-agent-context-engine) and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
## 🎮 Demo ## 🎮 Demo
@ -188,12 +188,12 @@ releases! 🌟
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64. > All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system. > If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
> The command below downloads the `v0.23.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.23.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. > The command below downloads the `v0.23.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.23.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
```bash ```bash
$ cd ragflow/docker $ cd ragflow/docker
# git checkout v0.23.0 # git checkout v0.23.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version. # This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
@ -233,7 +233,7 @@ releases! 🌟
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anormal` > If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network abnormal`
> error because, at that moment, your RAGFlow may not be fully initialized. > error because, at that moment, your RAGFlow may not be fully initialized.
> >
5. In your web browser, enter the IP address of your server and log in to RAGFlow. 5. In your web browser, enter the IP address of your server and log in to RAGFlow.
@ -396,7 +396,7 @@ docker build --platform linux/amd64 \
## 📜 Roadmap ## 📜 Roadmap
See the [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) See the [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241)
## 🏄 Community ## 🏄 Community

View File

@ -22,7 +22,7 @@
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99"> <img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a> </a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank"> <a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.0"> <img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
</a> </a>
<a href="https://github.com/infiniflow/ragflow/releases/latest"> <a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru"> <img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
@ -37,7 +37,7 @@
<h4 align="center"> <h4 align="center">
<a href="https://ragflow.io/docs/dev/">Dokumentasi</a> | <a href="https://ragflow.io/docs/dev/">Dokumentasi</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Peta Jalan</a> | <a href="https://github.com/infiniflow/ragflow/issues/12241">Peta Jalan</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> | <a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> | <a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a> <a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 Apa Itu RAGFlow? ## 💡 Apa Itu RAGFlow?
[RAGFlow](https://ragflow.io/) adalah mesin RAG (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa. [RAGFlow](https://ragflow.io/) adalah mesin [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
## 🎮 Demo ## 🎮 Demo
@ -188,12 +188,12 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64. > Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image). > Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
> Perintah di bawah ini mengunduh edisi v0.23.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.23.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. > Perintah di bawah ini mengunduh edisi v0.23.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.23.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
```bash ```bash
$ cd ragflow/docker $ cd ragflow/docker
# git checkout v0.23.0 # git checkout v0.23.1
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases) # Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version. # This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
@ -233,7 +233,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network anormal` > Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network abnormal`
> karena RAGFlow mungkin belum sepenuhnya siap. > karena RAGFlow mungkin belum sepenuhnya siap.
> >
2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow. 2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow.
@ -368,7 +368,7 @@ docker build --platform linux/amd64 \
## 📜 Roadmap ## 📜 Roadmap
Lihat [Roadmap RAGFlow 2025](https://github.com/infiniflow/ragflow/issues/4214) Lihat [Roadmap RAGFlow 2026](https://github.com/infiniflow/ragflow/issues/12241)
## 🏄 Komunitas ## 🏄 Komunitas

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"> <img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a> </a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank"> <a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.0"> <img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
</a> </a>
<a href="https://github.com/infiniflow/ragflow/releases/latest"> <a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release"> <img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center"> <h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> | <a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> | <a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> | <a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> | <a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a> <a href="https://demo.ragflow.io">Demo</a>
@ -53,7 +53,7 @@
## 💡 RAGFlow とは? ## 💡 RAGFlow とは?
[RAGFlow](https://ragflow.io/) は、先進的なRAGRetrieval-Augmented Generation技術と Agent 機能を融合し、大規模言語モデルLLMに優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型コンテキストエンジンと事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。 [RAGFlow](https://ragflow.io/) は、先進的な[RAG](https://ragflow.io/basics/what-is-rag)Retrieval-Augmented Generation技術と Agent 機能を融合し、大規模言語モデルLLMに優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型[コンテキストエンジン](https://ragflow.io/basics/what-is-agent-context-engine)と事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
## 🎮 Demo ## 🎮 Demo
@ -168,12 +168,12 @@
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。 > 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。 > ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
> 以下のコマンドは、RAGFlow Docker イメージの v0.23.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.23.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。 > 以下のコマンドは、RAGFlow Docker イメージの v0.23.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.23.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
```bash ```bash
$ cd ragflow/docker $ cd ragflow/docker
# git checkout v0.23.0 # git checkout v0.23.1
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) # 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。 # この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
@ -368,7 +368,7 @@ docker build --platform linux/amd64 \
## 📜 ロードマップ ## 📜 ロードマップ
[RAGFlow ロードマップ 2025](https://github.com/infiniflow/ragflow/issues/4214) を参照 [RAGFlow ロードマップ 2026](https://github.com/infiniflow/ragflow/issues/12241) を参照
## 🏄 コミュニティ ## 🏄 コミュニティ

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"> <img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a> </a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank"> <a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.0"> <img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
</a> </a>
<a href="https://github.com/infiniflow/ragflow/releases/latest"> <a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release"> <img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center"> <h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> | <a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> | <a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> | <a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> | <a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a> <a href="https://demo.ragflow.io">Demo</a>
@ -54,7 +54,7 @@
## 💡 RAGFlow란? ## 💡 RAGFlow란?
[RAGFlow](https://ragflow.io/) 는 최첨단 RAG(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 컨텍스트 엔진과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다. [RAGFlow](https://ragflow.io/) 는 최첨단 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 [컨텍스트 엔진](https://ragflow.io/basics/what-is-agent-context-engine)과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
## 🎮 데모 ## 🎮 데모
@ -170,12 +170,12 @@
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다. > 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image). > ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
> 아래 명령어는 RAGFlow Docker 이미지의 v0.23.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.23.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. > 아래 명령어는 RAGFlow Docker 이미지의 v0.23.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.23.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
```bash ```bash
$ cd ragflow/docker $ cd ragflow/docker
# git checkout v0.23.0 # git checkout v0.23.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다. # 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
@ -214,7 +214,7 @@
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network anormal` 오류가 발생할 수 있습니다. > 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network abnormal` 오류가 발생할 수 있습니다.
2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요. 2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요.
> 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다. > 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다.
@ -372,7 +372,7 @@ docker build --platform linux/amd64 \
## 📜 로드맵 ## 📜 로드맵
[RAGFlow 로드맵 2025](https://github.com/infiniflow/ragflow/issues/4214)을 확인하세요. [RAGFlow 로드맵 2026](https://github.com/infiniflow/ragflow/issues/12241)을 확인하세요.
## 🏄 커뮤니티 ## 🏄 커뮤니티

View File

@ -22,7 +22,7 @@
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99"> <img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a> </a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank"> <a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.0"> <img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
</a> </a>
<a href="https://github.com/infiniflow/ragflow/releases/latest"> <a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão"> <img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
@ -37,7 +37,7 @@
<h4 align="center"> <h4 align="center">
<a href="https://ragflow.io/docs/dev/">Documentação</a> | <a href="https://ragflow.io/docs/dev/">Documentação</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> | <a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> | <a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> | <a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a> <a href="https://demo.ragflow.io">Demo</a>
@ -73,7 +73,7 @@
## 💡 O que é o RAGFlow? ## 💡 O que é o RAGFlow?
[RAGFlow](https://ragflow.io/) é um mecanismo de RAG (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por um motor de contexto convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão. [RAGFlow](https://ragflow.io/) é um mecanismo de [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por [um motor de contexto](https://ragflow.io/basics/what-is-agent-context-engine) convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
## 🎮 Demo ## 🎮 Demo
@ -188,12 +188,12 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64. > Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema. > Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
> O comando abaixo baixa a edição`v0.23.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.23.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. > O comando abaixo baixa a edição`v0.23.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.23.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
```bash ```bash
$ cd ragflow/docker $ cd ragflow/docker
# git checkout v0.23.0 # git checkout v0.23.1
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases) # Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker. # Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
@ -232,7 +232,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
* Rodando em todos os endereços (0.0.0.0) * Rodando em todos os endereços (0.0.0.0)
``` ```
> Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network anormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado. > Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network abnormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado.
> >
5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow. 5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow.
@ -385,7 +385,7 @@ docker build --platform linux/amd64 \
## 📜 Roadmap ## 📜 Roadmap
Veja o [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) Veja o [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241)
## 🏄 Comunidade ## 🏄 Comunidade

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"> <img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a> </a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank"> <a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.0"> <img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
</a> </a>
<a href="https://github.com/infiniflow/ragflow/releases/latest"> <a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release"> <img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center"> <h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> | <a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> | <a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> | <a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> | <a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a> <a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 RAGFlow 是什麼? ## 💡 RAGFlow 是什麼?
[RAGFlow](https://ragflow.io/) 是一款領先的開源 RAGRetrieval-Augmented Generation引擎通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式上下文引擎與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。 [RAGFlow](https://ragflow.io/) 是一款領先的開源 [RAG](https://ragflow.io/basics/what-is-rag)Retrieval-Augmented Generation引擎通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
## 🎮 Demo 試用 ## 🎮 Demo 試用
@ -125,7 +125,7 @@
### 🍔 **相容各類異質資料來源** ### 🍔 **相容各類異質資料來源**
- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、印件、結構化資料、網頁等。 - 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、印件、結構化資料、網頁等。
### 🛀 **全程無憂、自動化的 RAG 工作流程** ### 🛀 **全程無憂、自動化的 RAG 工作流程**
@ -187,12 +187,12 @@
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。 > 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。 > 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.23.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.23.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。 > 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.23.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.23.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
```bash ```bash
$ cd ragflow/docker $ cd ragflow/docker
# git checkout v0.23.0 # git checkout v0.23.1
# 可選使用穩定版標籤查看發佈https://github.com/infiniflow/ragflow/releases # 可選使用穩定版標籤查看發佈https://github.com/infiniflow/ragflow/releases
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。 # 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
@ -237,7 +237,7 @@
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> 如果您跳過這一步驟系統確認步驟就登入 RAGFlow你的瀏覽器有可能會提示 `network anormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。 > 如果您跳過這一步驟系統確認步驟就登入 RAGFlow你的瀏覽器有可能會提示 `network abnormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。
> >
5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。 5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。
@ -399,7 +399,7 @@ docker build --platform linux/amd64 \
## 📜 路線圖 ## 📜 路線圖
詳見 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。 詳見 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。
## 🏄 開源社群 ## 🏄 開源社群

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99"> <img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a> </a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank"> <a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.0"> <img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
</a> </a>
<a href="https://github.com/infiniflow/ragflow/releases/latest"> <a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release"> <img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center"> <h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> | <a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> | <a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> | <a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> | <a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a> <a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 RAGFlow 是什么? ## 💡 RAGFlow 是什么?
[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成RAG引擎通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式上下文引擎与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。 [RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成([RAG](https://ragflow.io/basics/what-is-rag))引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
## 🎮 Demo 试用 ## 🎮 Demo 试用
@ -188,12 +188,12 @@
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。 > 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。 > 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.23.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.23.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。 > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.23.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.23.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
```bash ```bash
$ cd ragflow/docker $ cd ragflow/docker
# git checkout v0.23.0 # git checkout v0.23.1
# 可选使用稳定版本标签查看发布https://github.com/infiniflow/ragflow/releases # 可选使用稳定版本标签查看发布https://github.com/infiniflow/ragflow/releases
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。 # 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
@ -238,7 +238,7 @@
* Running on all addresses (0.0.0.0) * Running on all addresses (0.0.0.0)
``` ```
> 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow你的浏览器有可能会提示 `network anormal` 或 `网络异常`。 > 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow你的浏览器有可能会提示 `network abnormal` 或 `网络异常`。
5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。 5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
> 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80 > 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80
@ -402,7 +402,7 @@ docker build --platform linux/amd64 \
## 📜 路线图 ## 📜 路线图
详见 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。 详见 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。
## 🏄 开源社区 ## 🏄 开源社区

View File

@ -21,7 +21,7 @@ cp pyproject.toml release/$PROJECT_NAME/pyproject.toml
cp README.md release/$PROJECT_NAME/README.md cp README.md release/$PROJECT_NAME/README.md
mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p
cp admin_client.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/admin_client.py cp ragflow_cli.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/ragflow_cli.py
if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then
echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR" echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR"

View File

@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
1. Ensure the Admin Service is running. 1. Ensure the Admin Service is running.
2. Install ragflow-cli. 2. Install ragflow-cli.
```bash ```bash
pip install ragflow-cli==0.23.0 pip install ragflow-cli==0.23.1
``` ```
3. Launch the CLI client: 3. Launch the CLI client:
```bash ```bash

View File

@ -1,938 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import base64
import getpass
from cmd import Cmd
from typing import Any, Dict, List
import requests
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from Cryptodome.PublicKey import RSA
from lark import Lark, Transformer, Tree
GRAMMAR = r"""
start: command
command: sql_command | meta_command
sql_command: list_services
| show_service
| startup_service
| shutdown_service
| restart_service
| list_users
| show_user
| drop_user
| alter_user
| create_user
| activate_user
| list_datasets
| list_agents
| create_role
| drop_role
| alter_role
| list_roles
| show_role
| grant_permission
| revoke_permission
| alter_user_role
| show_user_permission
| show_version
// meta command definition
meta_command: "\\" meta_command_name [meta_args]
meta_command_name: /[a-zA-Z?]+/
meta_args: (meta_arg)+
meta_arg: /[^\\s"']+/ | quoted_string
// command definition
LIST: "LIST"i
SERVICES: "SERVICES"i
SHOW: "SHOW"i
CREATE: "CREATE"i
SERVICE: "SERVICE"i
SHUTDOWN: "SHUTDOWN"i
STARTUP: "STARTUP"i
RESTART: "RESTART"i
USERS: "USERS"i
DROP: "DROP"i
USER: "USER"i
ALTER: "ALTER"i
ACTIVE: "ACTIVE"i
PASSWORD: "PASSWORD"i
DATASETS: "DATASETS"i
OF: "OF"i
AGENTS: "AGENTS"i
ROLE: "ROLE"i
ROLES: "ROLES"i
DESCRIPTION: "DESCRIPTION"i
GRANT: "GRANT"i
REVOKE: "REVOKE"i
ALL: "ALL"i
PERMISSION: "PERMISSION"i
TO: "TO"i
FROM: "FROM"i
FOR: "FOR"i
RESOURCES: "RESOURCES"i
ON: "ON"i
SET: "SET"i
VERSION: "VERSION"i
list_services: LIST SERVICES ";"
show_service: SHOW SERVICE NUMBER ";"
startup_service: STARTUP SERVICE NUMBER ";"
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
restart_service: RESTART SERVICE NUMBER ";"
list_users: LIST USERS ";"
drop_user: DROP USER quoted_string ";"
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
show_user: SHOW USER quoted_string ";"
create_user: CREATE USER quoted_string quoted_string ";"
activate_user: ALTER USER ACTIVE quoted_string status ";"
list_datasets: LIST DATASETS OF quoted_string ";"
list_agents: LIST AGENTS OF quoted_string ";"
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
drop_role: DROP ROLE identifier ";"
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
list_roles: LIST ROLES ";"
show_role: SHOW ROLE identifier ";"
grant_permission: GRANT action_list ON identifier TO ROLE identifier ";"
revoke_permission: REVOKE action_list ON identifier FROM ROLE identifier ";"
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
show_user_permission: SHOW USER PERMISSION quoted_string ";"
show_version: SHOW VERSION ";"
action_list: identifier ("," identifier)*
identifier: WORD
quoted_string: QUOTED_STRING
status: WORD
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
WORD: /[a-zA-Z0-9_\-\.]+/
NUMBER: /[0-9]+/
%import common.WS
%ignore WS
"""
class AdminTransformer(Transformer):
def start(self, items):
return items[0]
def command(self, items):
return items[0]
def list_services(self, items):
result = {"type": "list_services"}
return result
def show_service(self, items):
service_id = int(items[2])
return {"type": "show_service", "number": service_id}
def startup_service(self, items):
service_id = int(items[2])
return {"type": "startup_service", "number": service_id}
def shutdown_service(self, items):
service_id = int(items[2])
return {"type": "shutdown_service", "number": service_id}
def restart_service(self, items):
service_id = int(items[2])
return {"type": "restart_service", "number": service_id}
def list_users(self, items):
return {"type": "list_users"}
def show_user(self, items):
user_name = items[2]
return {"type": "show_user", "user_name": user_name}
def drop_user(self, items):
user_name = items[2]
return {"type": "drop_user", "user_name": user_name}
def alter_user(self, items):
user_name = items[3]
new_password = items[4]
return {"type": "alter_user", "user_name": user_name, "password": new_password}
def create_user(self, items):
user_name = items[2]
password = items[3]
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
def activate_user(self, items):
user_name = items[3]
activate_status = items[4]
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
def list_datasets(self, items):
user_name = items[3]
return {"type": "list_datasets", "user_name": user_name}
def list_agents(self, items):
user_name = items[3]
return {"type": "list_agents", "user_name": user_name}
def create_role(self, items):
role_name = items[2]
if len(items) > 4:
description = items[4]
return {"type": "create_role", "role_name": role_name, "description": description}
else:
return {"type": "create_role", "role_name": role_name}
def drop_role(self, items):
role_name = items[2]
return {"type": "drop_role", "role_name": role_name}
def alter_role(self, items):
role_name = items[2]
description = items[5]
return {"type": "alter_role", "role_name": role_name, "description": description}
def list_roles(self, items):
return {"type": "list_roles"}
def show_role(self, items):
role_name = items[2]
return {"type": "show_role", "role_name": role_name}
def grant_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def revoke_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def alter_user_role(self, items):
user_name = items[2]
role_name = items[5]
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
def show_user_permission(self, items):
user_name = items[3]
return {"type": "show_user_permission", "user_name": user_name}
def show_version(self, items):
return {"type": "show_version"}
def action_list(self, items):
return items
def meta_command(self, items):
command_name = str(items[0]).lower()
args = items[1:] if len(items) > 1 else []
# handle quoted parameter
parsed_args = []
for arg in args:
if hasattr(arg, "value"):
parsed_args.append(arg.value)
else:
parsed_args.append(str(arg))
return {"type": "meta", "command": command_name, "args": parsed_args}
def meta_command_name(self, items):
return items[0]
def meta_args(self, items):
return items
def encrypt(input_string):
pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
pub_key = RSA.importKey(pub)
cipher = Cipher_pkcs1_v1_5.new(pub_key)
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
return base64.b64encode(cipher_text).decode("utf-8")
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode("utf-8"))
return base64_encoded.decode("utf-8")
class AdminCLI(Cmd):
def __init__(self):
super().__init__()
self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=AdminTransformer())
self.command_history = []
self.is_interactive = False
self.admin_account = "admin@ragflow.io"
self.admin_password: str = "admin"
self.session = requests.Session()
self.access_token: str = ""
self.host: str = ""
self.port: int = 0
intro = r"""Type "\h" for help."""
prompt = "admin> "
def onecmd(self, command: str) -> bool:
try:
result = self.parse_command(command)
if isinstance(result, dict):
if "type" in result and result.get("type") == "empty":
return False
self.execute_command(result)
if isinstance(result, Tree):
return False
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
return True
except KeyboardInterrupt:
print("\nUse '\\q' to quit")
except EOFError:
print("\nGoodbye!")
return True
return False
def emptyline(self) -> bool:
return False
def default(self, line: str) -> bool:
return self.onecmd(line)
def parse_command(self, command_str: str) -> dict[str, str]:
if not command_str.strip():
return {"type": "empty"}
self.command_history.append(command_str)
try:
result = self.parser.parse(command_str)
return result
except Exception as e:
return {"type": "error", "message": f"Parse error: {str(e)}"}
def verify_admin(self, arguments: dict, single_command: bool):
self.host = arguments["host"]
self.port = arguments["port"]
print("Attempt to access server for admin login")
url = f"http://{self.host}:{self.port}/api/v1/admin/login"
attempt_count = 3
if single_command:
attempt_count = 1
try_count = 0
while True:
try_count += 1
if try_count > attempt_count:
return False
if single_command:
admin_passwd = arguments["password"]
else:
admin_passwd = getpass.getpass(f"password for {self.admin_account}: ").strip()
try:
self.admin_password = encrypt(admin_passwd)
response = self.session.post(url, json={"email": self.admin_account, "password": self.admin_password})
if response.status_code == 200:
res_json = response.json()
error_code = res_json.get("code", -1)
if error_code == 0:
self.session.headers.update({"Content-Type": "application/json", "Authorization": response.headers["Authorization"], "User-Agent": "RAGFlow-CLI/0.23.0"})
print("Authentication successful.")
return True
else:
error_message = res_json.get("message", "Unknown error")
print(f"Authentication failed: {error_message}, try again")
continue
else:
print(f"Bad responsestatus: {response.status_code}, password is wrong")
except Exception as e:
print(str(e))
print("Can't access server for admin login (connection failed)")
def _format_service_detail_table(self, data):
if isinstance(data, list):
return data
if not all([isinstance(v, list) for v in data.values()]):
# normal table
return data
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
task_executor_list = []
for k, v in data.items():
# display latest status
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
task_executor_list.append(
{
"task_executor_name": k,
**heartbeats[0],
}
if heartbeats
else {"task_executor_name": k}
)
return task_executor_list
def _print_table_simple(self, data):
if not data:
print("No data to print")
return
if isinstance(data, dict):
# handle single row data
data = [data]
columns = list(set().union(*(d.keys() for d in data)))
columns.sort()
col_widths = {}
def get_string_width(text):
half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
width = 0
for char in text:
if char in half_width_chars:
width += 1
else:
width += 2
return width
for col in columns:
max_width = get_string_width(str(col))
for item in data:
value_len = get_string_width(str(item.get(col, "")))
if value_len > max_width:
max_width = value_len
col_widths[col] = max(2, max_width)
# Generate delimiter
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
# Print header
print(separator)
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
print(header)
print(separator)
# Print data
for item in data:
row = "|"
for col in columns:
value = str(item.get(col, ""))
if get_string_width(value) > col_widths[col]:
value = value[: col_widths[col] - 3] + "..."
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
print(row)
print(separator)
def run_interactive(self):
self.is_interactive = True
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
while True:
try:
command = input("admin> ").strip()
if not command:
continue
print(f"command: {command}")
result = self.parse_command(command)
self.execute_command(result)
if isinstance(result, Tree):
continue
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
break
except KeyboardInterrupt:
print("\nUse '\\q' to quit")
except EOFError:
print("\nGoodbye!")
break
def run_single_command(self, command: str):
result = self.parse_command(command)
self.execute_command(result)
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
parser = argparse.ArgumentParser(description="Admin CLI Client", add_help=False)
parser.add_argument("-h", "--host", default="localhost", help="Admin service host")
parser.add_argument("-p", "--port", type=int, default=9381, help="Admin service port")
parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
parser.add_argument("command", nargs="?", help="Single command")
try:
parsed_args, remaining_args = parser.parse_known_args(args)
if remaining_args:
command = remaining_args[0]
return {"host": parsed_args.host, "port": parsed_args.port, "password": parsed_args.password, "command": command}
else:
return {
"host": parsed_args.host,
"port": parsed_args.port,
}
except SystemExit:
return {"error": "Invalid connection arguments"}
def execute_command(self, parsed_command: Dict[str, Any]):
command_dict: dict
if isinstance(parsed_command, Tree):
command_dict = parsed_command.children[0]
else:
if parsed_command["type"] == "error":
print(f"Error: {parsed_command['message']}")
return
else:
command_dict = parsed_command
# print(f"Parsed command: {command_dict}")
command_type = command_dict["type"]
match command_type:
case "list_services":
self._handle_list_services(command_dict)
case "show_service":
self._handle_show_service(command_dict)
case "restart_service":
self._handle_restart_service(command_dict)
case "shutdown_service":
self._handle_shutdown_service(command_dict)
case "startup_service":
self._handle_startup_service(command_dict)
case "list_users":
self._handle_list_users(command_dict)
case "show_user":
self._handle_show_user(command_dict)
case "drop_user":
self._handle_drop_user(command_dict)
case "alter_user":
self._handle_alter_user(command_dict)
case "create_user":
self._handle_create_user(command_dict)
case "activate_user":
self._handle_activate_user(command_dict)
case "list_datasets":
self._handle_list_datasets(command_dict)
case "list_agents":
self._handle_list_agents(command_dict)
case "create_role":
self._create_role(command_dict)
case "drop_role":
self._drop_role(command_dict)
case "alter_role":
self._alter_role(command_dict)
case "list_roles":
self._list_roles(command_dict)
case "show_role":
self._show_role(command_dict)
case "grant_permission":
self._grant_permission(command_dict)
case "revoke_permission":
self._revoke_permission(command_dict)
case "alter_user_role":
self._alter_user_role(command_dict)
case "show_user_permission":
self._show_user_permission(command_dict)
case "show_version":
self._show_version(command_dict)
case "meta":
self._handle_meta_command(command_dict)
case _:
print(f"Command '{command_type}' would be executed with API")
def _handle_list_services(self, command):
print("Listing all services")
url = f"http://{self.host}:{self.port}/api/v1/admin/services"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to get all services, code: {res_json['code']}, message: {res_json['message']}")
def _handle_show_service(self, command):
service_id: int = command["number"]
print(f"Showing service: {service_id}")
url = f"http://{self.host}:{self.port}/api/v1/admin/services/{service_id}"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
res_data = res_json["data"]
if "status" in res_data and res_data["status"] == "alive":
print(f"Service {res_data['service_name']} is alive, ")
if isinstance(res_data["message"], str):
print(res_data["message"])
else:
data = self._format_service_detail_table(res_data["message"])
self._print_table_simple(data)
else:
print(f"Service {res_data['service_name']} is down, {res_data['message']}")
else:
print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}")
def _handle_restart_service(self, command):
service_id: int = command["number"]
print(f"Restart service {service_id}")
def _handle_shutdown_service(self, command):
service_id: int = command["number"]
print(f"Shutdown service {service_id}")
def _handle_startup_service(self, command):
service_id: int = command["number"]
print(f"Startup service {service_id}")
def _handle_list_users(self, command):
print("Listing all users")
url = f"http://{self.host}:{self.port}/api/v1/admin/users"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
def _handle_show_user(self, command):
username_tree: Tree = command["user_name"]
user_name: str = username_tree.children[0].strip("'\"")
print(f"Showing user: {user_name}")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
table_data = res_json["data"]
table_data.pop("avatar")
self._print_table_simple(table_data)
else:
print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_drop_user(self, command):
username_tree: Tree = command["user_name"]
user_name: str = username_tree.children[0].strip("'\"")
print(f"Drop user: {user_name}")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}"
response = self.session.delete(url)
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
def _handle_alter_user(self, command):
user_name_tree: Tree = command["user_name"]
user_name: str = user_name_tree.children[0].strip("'\"")
password_tree: Tree = command["password"]
password: str = password_tree.children[0].strip("'\"")
print(f"Alter user: {user_name}, password: ******")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password"
response = self.session.put(url, json={"new_password": encrypt(password)})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
def _handle_create_user(self, command):
user_name_tree: Tree = command["user_name"]
user_name: str = user_name_tree.children[0].strip("'\"")
password_tree: Tree = command["password"]
password: str = password_tree.children[0].strip("'\"")
role: str = command["role"]
print(f"Create user: {user_name}, password: ******, role: {role}")
url = f"http://{self.host}:{self.port}/api/v1/admin/users"
response = self.session.post(url, json={"user_name": user_name, "password": encrypt(password), "role": role})
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to create user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_activate_user(self, command):
user_name_tree: Tree = command["user_name"]
user_name: str = user_name_tree.children[0].strip("'\"")
activate_tree: Tree = command["activate_status"]
activate_status: str = activate_tree.children[0].strip("'\"")
if activate_status.lower() in ["on", "off"]:
print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/activate"
response = self.session.put(url, json={"activate_status": activate_status})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
else:
print(f"Unknown activate status: {activate_status}.")
def _handle_list_datasets(self, command):
username_tree: Tree = command["user_name"]
user_name: str = username_tree.children[0].strip("'\"")
print(f"Listing all datasets of user: {user_name}")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/datasets"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
table_data = res_json["data"]
for t in table_data:
t.pop("avatar")
self._print_table_simple(table_data)
else:
print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_list_agents(self, command):
username_tree: Tree = command["user_name"]
user_name: str = username_tree.children[0].strip("'\"")
print(f"Listing all agents of user: {user_name}")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/agents"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
table_data = res_json["data"]
for t in table_data:
t.pop("avatar")
self._print_table_simple(table_data)
else:
print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _create_role(self, command):
role_name_tree: Tree = command["role_name"]
role_name: str = role_name_tree.children[0].strip("'\"")
desc_str: str = ""
if "description" in command:
desc_tree: Tree = command["description"]
desc_str = desc_tree.children[0].strip("'\"")
print(f"create role name: {role_name}, description: {desc_str}")
url = f"http://{self.host}:{self.port}/api/v1/admin/roles"
response = self.session.post(url, json={"role_name": role_name, "description": desc_str})
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to create role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
def _drop_role(self, command):
role_name_tree: Tree = command["role_name"]
role_name: str = role_name_tree.children[0].strip("'\"")
print(f"drop role name: {role_name}")
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}"
response = self.session.delete(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to drop role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
def _alter_role(self, command):
role_name_tree: Tree = command["role_name"]
role_name: str = role_name_tree.children[0].strip("'\"")
desc_tree: Tree = command["description"]
desc_str: str = desc_tree.children[0].strip("'\"")
print(f"alter role name: {role_name}, description: {desc_str}")
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}"
response = self.session.put(url, json={"description": desc_str})
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}")
def _list_roles(self, command):
print("Listing all roles")
url = f"http://{self.host}:{self.port}/api/v1/admin/roles"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
def _show_role(self, command):
role_name_tree: Tree = command["role_name"]
role_name: str = role_name_tree.children[0].strip("'\"")
print(f"show role: {role_name}")
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}/permission"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
def _grant_permission(self, command):
role_name_tree: Tree = command["role_name"]
role_name_str: str = role_name_tree.children[0].strip("'\"")
resource_tree: Tree = command["resource"]
resource_str: str = resource_tree.children[0].strip("'\"")
action_tree_list: list = command["actions"]
actions: list = []
for action_tree in action_tree_list:
action_str: str = action_tree.children[0].strip("'\"")
actions.append(action_str)
print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission"
response = self.session.post(url, json={"actions": actions, "resource": resource_str})
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
def _revoke_permission(self, command):
role_name_tree: Tree = command["role_name"]
role_name_str: str = role_name_tree.children[0].strip("'\"")
resource_tree: Tree = command["resource"]
resource_str: str = resource_tree.children[0].strip("'\"")
action_tree_list: list = command["actions"]
actions: list = []
for action_tree in action_tree_list:
action_str: str = action_tree.children[0].strip("'\"")
actions.append(action_str)
print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission"
response = self.session.delete(url, json={"actions": actions, "resource": resource_str})
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
def _alter_user_role(self, command):
role_name_tree: Tree = command["role_name"]
role_name_str: str = role_name_tree.children[0].strip("'\"")
user_name_tree: Tree = command["user_name"]
user_name_str: str = user_name_tree.children[0].strip("'\"")
print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/role"
response = self.session.put(url, json={"role_name": role_name_str})
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}")
def _show_user_permission(self, command):
user_name_tree: Tree = command["user_name"]
user_name_str: str = user_name_tree.children[0].strip("'\"")
print(f"show_user_permission user_name: {user_name_str}")
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/permission"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}")
def _show_version(self, command):
print("show_version")
url = f"http://{self.host}:{self.port}/api/v1/admin/version"
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json["data"])
else:
print(f"Fail to show version, code: {res_json['code']}, message: {res_json['message']}")
def _handle_meta_command(self, command):
meta_command = command["command"]
args = command.get("args", [])
if meta_command in ["?", "h", "help"]:
self.show_help()
elif meta_command in ["q", "quit", "exit"]:
print("Goodbye!")
else:
print(f"Meta command '{meta_command}' with args {args}")
def show_help(self):
"""Help info"""
help_text = """
Commands:
LIST SERVICES
SHOW SERVICE <service>
STARTUP SERVICE <service>
SHUTDOWN SERVICE <service>
RESTART SERVICE <service>
LIST USERS
SHOW USER <user>
DROP USER <user>
CREATE USER <user> <password>
ALTER USER PASSWORD <user> <new_password>
ALTER USER ACTIVE <user> <on/off>
LIST DATASETS OF <user>
LIST AGENTS OF <user>
Meta Commands:
\\?, \\h, \\help Show this help
\\q, \\quit, \\exit Quit the CLI
"""
print(help_text)
def main():
import sys
cli = AdminCLI()
args = cli.parse_connection_args(sys.argv)
if "error" in args:
print("Error: Invalid connection arguments")
return
if "command" in args:
if "password" not in args:
print("Error: password is missing")
return
if cli.verify_admin(args, single_command=True):
command: str = args["command"]
# print(f"Run single command: {command}")
cli.run_single_command(command)
else:
if cli.verify_admin(args, single_command=False):
print(r"""
____ ___ ______________ ___ __ _
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
""")
cli.cmdloop()
if __name__ == "__main__":
main()

166
admin/client/http_client.py Normal file
View File

@ -0,0 +1,166 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import json
from typing import Any, Dict, Optional
import requests
# from requests.sessions import HTTPAdapter
class HttpClient:
def __init__(
self,
host: str = "127.0.0.1",
port: int = 9381,
api_version: str = "v1",
api_key: Optional[str] = None,
connect_timeout: float = 5.0,
read_timeout: float = 60.0,
verify_ssl: bool = False,
) -> None:
self.host = host
self.port = port
self.api_version = api_version
self.api_key = api_key
self.login_token: str | None = None
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.verify_ssl = verify_ssl
def api_base(self) -> str:
return f"{self.host}:{self.port}/api/{self.api_version}"
def non_api_base(self) -> str:
return f"{self.host}:{self.port}/{self.api_version}"
def build_url(self, path: str, use_api_base: bool = True) -> str:
base = self.api_base() if use_api_base else self.non_api_base()
if self.verify_ssl:
return f"https://{base}/{path.lstrip('/')}"
else:
return f"http://{base}/{path.lstrip('/')}"
def _headers(self, auth_kind: Optional[str], extra: Optional[Dict[str, str]]) -> Dict[str, str]:
headers = {}
if auth_kind == "api" and self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
elif auth_kind == "web" and self.login_token:
headers["Authorization"] = self.login_token
elif auth_kind == "admin" and self.login_token:
headers["Authorization"] = self.login_token
else:
pass
if extra:
headers.update(extra)
return headers
def request(
self,
method: str,
path: str,
*,
use_api_base: bool = True,
auth_kind: Optional[str] = "api",
headers: Optional[Dict[str, str]] = None,
json_body: Optional[Dict[str, Any]] = None,
data: Any = None,
files: Any = None,
params: Optional[Dict[str, Any]] = None,
stream: bool = False,
iterations: int = 1,
) -> requests.Response | dict:
url = self.build_url(path, use_api_base=use_api_base)
merged_headers = self._headers(auth_kind, headers)
# timeout: Tuple[float, float] = (self.connect_timeout, self.read_timeout)
session = requests.Session()
# adapter = HTTPAdapter(pool_connections=100, pool_maxsize=100)
# session.mount("http://", adapter)
if iterations > 1:
response_list = []
total_duration = 0.0
for _ in range(iterations):
start_time = time.perf_counter()
response = session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
# response = requests.request(
# method=method,
# url=url,
# headers=merged_headers,
# json=json_body,
# data=data,
# files=files,
# params=params,
# timeout=timeout,
# stream=stream,
# verify=self.verify_ssl,
# )
end_time = time.perf_counter()
total_duration += end_time - start_time
response_list.append(response)
return {"duration": total_duration, "response_list": response_list}
else:
return session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
# return requests.request(
# method=method,
# url=url,
# headers=merged_headers,
# json=json_body,
# data=data,
# files=files,
# params=params,
# timeout=timeout,
# stream=stream,
# verify=self.verify_ssl,
# )
def request_json(
self,
method: str,
path: str,
*,
use_api_base: bool = True,
auth_kind: Optional[str] = "api",
headers: Optional[Dict[str, str]] = None,
json_body: Optional[Dict[str, Any]] = None,
data: Any = None,
files: Any = None,
params: Optional[Dict[str, Any]] = None,
stream: bool = False,
) -> Dict[str, Any]:
response = self.request(
method,
path,
use_api_base=use_api_base,
auth_kind=auth_kind,
headers=headers,
json_body=json_body,
data=data,
files=files,
params=params,
stream=stream,
)
try:
return response.json()
except Exception as exc:
raise ValueError(f"Non-JSON response from {path}: {exc}") from exc
@staticmethod
def parse_json_bytes(raw: bytes) -> Dict[str, Any]:
try:
return json.loads(raw.decode("utf-8"))
except Exception as exc:
raise ValueError(f"Invalid JSON payload: {exc}") from exc

623
admin/client/parser.py Normal file
View File

@ -0,0 +1,623 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from lark import Transformer
GRAMMAR = r"""
start: command
command: sql_command | meta_command
sql_command: login_user
| ping_server
| list_services
| show_service
| startup_service
| shutdown_service
| restart_service
| register_user
| list_users
| show_user
| drop_user
| alter_user
| create_user
| activate_user
| list_datasets
| list_agents
| create_role
| drop_role
| alter_role
| list_roles
| show_role
| grant_permission
| revoke_permission
| alter_user_role
| show_user_permission
| show_version
| grant_admin
| revoke_admin
| set_variable
| show_variable
| list_variables
| list_configs
| list_environments
| generate_key
| list_keys
| drop_key
| show_current_user
| set_default_llm
| set_default_vlm
| set_default_embedding
| set_default_reranker
| set_default_asr
| set_default_tts
| reset_default_llm
| reset_default_vlm
| reset_default_embedding
| reset_default_reranker
| reset_default_asr
| reset_default_tts
| create_model_provider
| drop_model_provider
| create_user_dataset_with_parser
| create_user_dataset_with_pipeline
| drop_user_dataset
| list_user_datasets
| list_user_dataset_files
| list_user_agents
| list_user_chats
| create_user_chat
| drop_user_chat
| list_user_model_providers
| list_user_default_models
| parse_dataset_docs
| parse_dataset_sync
| parse_dataset_async
| import_docs_into_dataset
| search_on_datasets
| benchmark
// meta command definition
meta_command: "\\" meta_command_name [meta_args]
meta_command_name: /[a-zA-Z?]+/
meta_args: (meta_arg)+
meta_arg: /[^\\s"']+/ | quoted_string
// command definition
LOGIN: "LOGIN"i
REGISTER: "REGISTER"i
LIST: "LIST"i
SERVICES: "SERVICES"i
SHOW: "SHOW"i
CREATE: "CREATE"i
SERVICE: "SERVICE"i
SHUTDOWN: "SHUTDOWN"i
STARTUP: "STARTUP"i
RESTART: "RESTART"i
USERS: "USERS"i
DROP: "DROP"i
USER: "USER"i
ALTER: "ALTER"i
ACTIVE: "ACTIVE"i
ADMIN: "ADMIN"i
PASSWORD: "PASSWORD"i
DATASET: "DATASET"i
DATASETS: "DATASETS"i
OF: "OF"i
AGENTS: "AGENTS"i
ROLE: "ROLE"i
ROLES: "ROLES"i
DESCRIPTION: "DESCRIPTION"i
GRANT: "GRANT"i
REVOKE: "REVOKE"i
ALL: "ALL"i
PERMISSION: "PERMISSION"i
TO: "TO"i
FROM: "FROM"i
FOR: "FOR"i
RESOURCES: "RESOURCES"i
ON: "ON"i
SET: "SET"i
RESET: "RESET"i
VERSION: "VERSION"i
VAR: "VAR"i
VARS: "VARS"i
CONFIGS: "CONFIGS"i
ENVS: "ENVS"i
KEY: "KEY"i
KEYS: "KEYS"i
GENERATE: "GENERATE"i
MODEL: "MODEL"i
MODELS: "MODELS"i
PROVIDER: "PROVIDER"i
PROVIDERS: "PROVIDERS"i
DEFAULT: "DEFAULT"i
CHATS: "CHATS"i
CHAT: "CHAT"i
FILES: "FILES"i
AS: "AS"i
PARSE: "PARSE"i
IMPORT: "IMPORT"i
INTO: "INTO"i
WITH: "WITH"i
PARSER: "PARSER"i
PIPELINE: "PIPELINE"i
SEARCH: "SEARCH"i
CURRENT: "CURRENT"i
LLM: "LLM"i
VLM: "VLM"i
EMBEDDING: "EMBEDDING"i
RERANKER: "RERANKER"i
ASR: "ASR"i
TTS: "TTS"i
ASYNC: "ASYNC"i
SYNC: "SYNC"i
BENCHMARK: "BENCHMARK"i
PING: "PING"i
login_user: LOGIN USER quoted_string ";"
list_services: LIST SERVICES ";"
show_service: SHOW SERVICE NUMBER ";"
startup_service: STARTUP SERVICE NUMBER ";"
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
restart_service: RESTART SERVICE NUMBER ";"
register_user: REGISTER USER quoted_string AS quoted_string PASSWORD quoted_string ";"
list_users: LIST USERS ";"
drop_user: DROP USER quoted_string ";"
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
show_user: SHOW USER quoted_string ";"
create_user: CREATE USER quoted_string quoted_string ";"
activate_user: ALTER USER ACTIVE quoted_string status ";"
list_datasets: LIST DATASETS OF quoted_string ";"
list_agents: LIST AGENTS OF quoted_string ";"
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
drop_role: DROP ROLE identifier ";"
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
list_roles: LIST ROLES ";"
show_role: SHOW ROLE identifier ";"
grant_permission: GRANT identifier_list ON identifier TO ROLE identifier ";"
revoke_permission: REVOKE identifier_list ON identifier FROM ROLE identifier ";"
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
show_user_permission: SHOW USER PERMISSION quoted_string ";"
show_version: SHOW VERSION ";"
grant_admin: GRANT ADMIN quoted_string ";"
revoke_admin: REVOKE ADMIN quoted_string ";"
generate_key: GENERATE KEY FOR USER quoted_string ";"
list_keys: LIST KEYS OF quoted_string ";"
drop_key: DROP KEY quoted_string OF quoted_string ";"
set_variable: SET VAR identifier identifier ";"
show_variable: SHOW VAR identifier ";"
list_variables: LIST VARS ";"
list_configs: LIST CONFIGS ";"
list_environments: LIST ENVS ";"
benchmark: BENCHMARK NUMBER NUMBER user_statement
user_statement: ping_server
| show_current_user
| create_model_provider
| drop_model_provider
| set_default_llm
| set_default_vlm
| set_default_embedding
| set_default_reranker
| set_default_asr
| set_default_tts
| reset_default_llm
| reset_default_vlm
| reset_default_embedding
| reset_default_reranker
| reset_default_asr
| reset_default_tts
| create_user_dataset_with_parser
| create_user_dataset_with_pipeline
| drop_user_dataset
| list_user_datasets
| list_user_dataset_files
| list_user_agents
| list_user_chats
| create_user_chat
| drop_user_chat
| list_user_model_providers
| list_user_default_models
| import_docs_into_dataset
| search_on_datasets
ping_server: PING ";"
show_current_user: SHOW CURRENT USER ";"
create_model_provider: CREATE MODEL PROVIDER quoted_string quoted_string ";"
drop_model_provider: DROP MODEL PROVIDER quoted_string ";"
set_default_llm: SET DEFAULT LLM quoted_string ";"
set_default_vlm: SET DEFAULT VLM quoted_string ";"
set_default_embedding: SET DEFAULT EMBEDDING quoted_string ";"
set_default_reranker: SET DEFAULT RERANKER quoted_string ";"
set_default_asr: SET DEFAULT ASR quoted_string ";"
set_default_tts: SET DEFAULT TTS quoted_string ";"
reset_default_llm: RESET DEFAULT LLM ";"
reset_default_vlm: RESET DEFAULT VLM ";"
reset_default_embedding: RESET DEFAULT EMBEDDING ";"
reset_default_reranker: RESET DEFAULT RERANKER ";"
reset_default_asr: RESET DEFAULT ASR ";"
reset_default_tts: RESET DEFAULT TTS ";"
list_user_datasets: LIST DATASETS ";"
create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";"
create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";"
drop_user_dataset: DROP DATASET quoted_string ";"
list_user_dataset_files: LIST FILES OF DATASET quoted_string ";"
list_user_agents: LIST AGENTS ";"
list_user_chats: LIST CHATS ";"
create_user_chat: CREATE CHAT quoted_string ";"
drop_user_chat: DROP CHAT quoted_string ";"
list_user_model_providers: LIST MODEL PROVIDERS ";"
list_user_default_models: LIST DEFAULT MODELS ";"
import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";"
search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";"
parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
parse_dataset_async: PARSE DATASET quoted_string ASYNC ";"
identifier_list: identifier ("," identifier)*
identifier: WORD
quoted_string: QUOTED_STRING
status: WORD
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
WORD: /[a-zA-Z0-9_\-\.]+/
NUMBER: /[0-9]+/
%import common.WS
%ignore WS
"""
class RAGFlowCLITransformer(Transformer):
def start(self, items):
return items[0]
def command(self, items):
return items[0]
def login_user(self, items):
email = items[2].children[0].strip("'\"")
return {"type": "login_user", "email": email}
def ping_server(self, items):
return {"type": "ping_server"}
def list_services(self, items):
result = {"type": "list_services"}
return result
def show_service(self, items):
service_id = int(items[2])
return {"type": "show_service", "number": service_id}
def startup_service(self, items):
service_id = int(items[2])
return {"type": "startup_service", "number": service_id}
def shutdown_service(self, items):
service_id = int(items[2])
return {"type": "shutdown_service", "number": service_id}
def restart_service(self, items):
service_id = int(items[2])
return {"type": "restart_service", "number": service_id}
def register_user(self, items):
user_name: str = items[2].children[0].strip("'\"")
nickname: str = items[4].children[0].strip("'\"")
password: str = items[6].children[0].strip("'\"")
return {"type": "register_user", "user_name": user_name, "nickname": nickname, "password": password}
def list_users(self, items):
return {"type": "list_users"}
def show_user(self, items):
user_name = items[2]
return {"type": "show_user", "user_name": user_name}
def drop_user(self, items):
user_name = items[2]
return {"type": "drop_user", "user_name": user_name}
def alter_user(self, items):
user_name = items[3]
new_password = items[4]
return {"type": "alter_user", "user_name": user_name, "password": new_password}
def create_user(self, items):
user_name = items[2]
password = items[3]
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
def activate_user(self, items):
user_name = items[3]
activate_status = items[4]
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
def list_datasets(self, items):
user_name = items[3]
return {"type": "list_datasets", "user_name": user_name}
def list_agents(self, items):
user_name = items[3]
return {"type": "list_agents", "user_name": user_name}
def create_role(self, items):
role_name = items[2]
if len(items) > 4:
description = items[4]
return {"type": "create_role", "role_name": role_name, "description": description}
else:
return {"type": "create_role", "role_name": role_name}
def drop_role(self, items):
role_name = items[2]
return {"type": "drop_role", "role_name": role_name}
def alter_role(self, items):
role_name = items[2]
description = items[5]
return {"type": "alter_role", "role_name": role_name, "description": description}
def list_roles(self, items):
return {"type": "list_roles"}
def show_role(self, items):
role_name = items[2]
return {"type": "show_role", "role_name": role_name}
def grant_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def revoke_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def alter_user_role(self, items):
user_name = items[2]
role_name = items[5]
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
def show_user_permission(self, items):
user_name = items[3]
return {"type": "show_user_permission", "user_name": user_name}
def show_version(self, items):
return {"type": "show_version"}
def grant_admin(self, items):
user_name = items[2]
return {"type": "grant_admin", "user_name": user_name}
def revoke_admin(self, items):
user_name = items[2]
return {"type": "revoke_admin", "user_name": user_name}
def generate_key(self, items):
user_name = items[4]
return {"type": "generate_key", "user_name": user_name}
def list_keys(self, items):
user_name = items[3]
return {"type": "list_keys", "user_name": user_name}
def drop_key(self, items):
key = items[2]
user_name = items[4]
return {"type": "drop_key", "key": key, "user_name": user_name}
def set_variable(self, items):
var_name = items[2]
var_value = items[3]
return {"type": "set_variable", "var_name": var_name, "var_value": var_value}
def show_variable(self, items):
var_name = items[2]
return {"type": "show_variable", "var_name": var_name}
def list_variables(self, items):
return {"type": "list_variables"}
def list_configs(self, items):
return {"type": "list_configs"}
def list_environments(self, items):
return {"type": "list_environments"}
def create_model_provider(self, items):
provider_name = items[3].children[0].strip("'\"")
provider_key = items[4].children[0].strip("'\"")
return {"type": "create_model_provider", "provider_name": provider_name, "provider_key": provider_key}
def drop_model_provider(self, items):
provider_name = items[3].children[0].strip("'\"")
return {"type": "drop_model_provider", "provider_name": provider_name}
def show_current_user(self, items):
return {"type": "show_current_user"}
def set_default_llm(self, items):
llm_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "llm_id", "model_id": llm_id}
def set_default_vlm(self, items):
vlm_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "img2txt_id", "model_id": vlm_id}
def set_default_embedding(self, items):
embedding_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "embd_id", "model_id": embedding_id}
def set_default_reranker(self, items):
reranker_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "reranker_id", "model_id": reranker_id}
def set_default_asr(self, items):
asr_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "asr_id", "model_id": asr_id}
def set_default_tts(self, items):
tts_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "tts_id", "model_id": tts_id}
def reset_default_llm(self, items):
return {"type": "reset_default_model", "model_type": "llm_id"}
def reset_default_vlm(self, items):
return {"type": "reset_default_model", "model_type": "img2txt_id"}
def reset_default_embedding(self, items):
return {"type": "reset_default_model", "model_type": "embd_id"}
def reset_default_reranker(self, items):
return {"type": "reset_default_model", "model_type": "reranker_id"}
def reset_default_asr(self, items):
return {"type": "reset_default_model", "model_type": "asr_id"}
def reset_default_tts(self, items):
return {"type": "reset_default_model", "model_type": "tts_id"}
def list_user_datasets(self, items):
return {"type": "list_user_datasets"}
def create_user_dataset_with_parser(self, items):
dataset_name = items[2].children[0].strip("'\"")
embedding = items[5].children[0].strip("'\"")
parser_type = items[7].children[0].strip("'\"")
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
"parser_type": parser_type}
def create_user_dataset_with_pipeline(self, items):
dataset_name = items[2].children[0].strip("'\"")
embedding = items[5].children[0].strip("'\"")
pipeline = items[7].children[0].strip("'\"")
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
"pipeline": pipeline}
def drop_user_dataset(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "drop_user_dataset", "dataset_name": dataset_name}
def list_user_dataset_files(self, items):
dataset_name = items[4].children[0].strip("'\"")
return {"type": "list_user_dataset_files", "dataset_name": dataset_name}
def list_user_agents(self, items):
return {"type": "list_user_agents"}
def list_user_chats(self, items):
return {"type": "list_user_chats"}
def create_user_chat(self, items):
chat_name = items[2].children[0].strip("'\"")
return {"type": "create_user_chat", "chat_name": chat_name}
def drop_user_chat(self, items):
chat_name = items[2].children[0].strip("'\"")
return {"type": "drop_user_chat", "chat_name": chat_name}
def list_user_model_providers(self, items):
return {"type": "list_user_model_providers"}
def list_user_default_models(self, items):
return {"type": "list_user_default_models"}
def parse_dataset_docs(self, items):
document_list_str = items[1].children[0].strip("'\"")
document_names = document_list_str.split(",")
if len(document_names) == 1:
document_names = document_names[0]
document_names = document_names.split(" ")
dataset_name = items[4].children[0].strip("'\"")
return {"type": "parse_dataset_docs", "dataset_name": dataset_name, "document_names": document_names}
def parse_dataset_sync(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "sync"}
def parse_dataset_async(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "async"}
def import_docs_into_dataset(self, items):
document_list_str = items[1].children[0].strip("'\"")
document_paths = document_list_str.split(",")
if len(document_paths) == 1:
document_paths = document_paths[0]
document_paths = document_paths.split(" ")
dataset_name = items[4].children[0].strip("'\"")
return {"type": "import_docs_into_dataset", "dataset_name": dataset_name, "document_paths": document_paths}
def search_on_datasets(self, items):
question = items[1].children[0].strip("'\"")
datasets_str = items[4].children[0].strip("'\"")
datasets = datasets_str.split(",")
if len(datasets) == 1:
datasets = datasets[0]
datasets = datasets.split(" ")
return {"type": "search_on_datasets", "datasets": datasets, "question": question}
def benchmark(self, items):
concurrency: int = int(items[1])
iterations: int = int(items[2])
command = items[3].children[0]
return {"type": "benchmark", "concurrency": concurrency, "iterations": iterations, "command": command}
def action_list(self, items):
return items
def meta_command(self, items):
command_name = str(items[0]).lower()
args = items[1:] if len(items) > 1 else []
# handle quoted parameter
parsed_args = []
for arg in args:
if hasattr(arg, "value"):
parsed_args.append(arg.value)
else:
parsed_args.append(str(arg))
return {"type": "meta", "command": command_name, "args": parsed_args}
def meta_command_name(self, items):
return items[0]
def meta_args(self, items):
return items

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ragflow-cli" name = "ragflow-cli"
version = "0.23.0" version = "0.23.1"
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. " description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }] authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
license = { text = "Apache License, Version 2.0" } license = { text = "Apache License, Version 2.0" }
@ -20,5 +20,8 @@ test = [
"requests-toolbelt>=1.0.0", "requests-toolbelt>=1.0.0",
] ]
[tool.setuptools]
py-modules = ["ragflow_cli", "parser"]
[project.scripts] [project.scripts]
ragflow-cli = "admin_client:main" ragflow-cli = "ragflow_cli:main"

322
admin/client/ragflow_cli.py Normal file
View File

@ -0,0 +1,322 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import argparse
import base64
import getpass
from cmd import Cmd
from typing import Any, Dict, List
import requests
import warnings
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from Cryptodome.PublicKey import RSA
from lark import Lark, Tree
from parser import GRAMMAR, RAGFlowCLITransformer
from http_client import HttpClient
from ragflow_client import RAGFlowClient, run_command
from user import login_user
warnings.filterwarnings("ignore", category=getpass.GetPassWarning)
def encrypt(input_string):
pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
pub_key = RSA.importKey(pub)
cipher = Cipher_pkcs1_v1_5.new(pub_key)
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
return base64.b64encode(cipher_text).decode("utf-8")
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode("utf-8"))
return base64_encoded.decode("utf-8")
class RAGFlowCLI(Cmd):
def __init__(self):
super().__init__()
self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=RAGFlowCLITransformer())
self.command_history = []
self.account = "admin@ragflow.io"
self.account_password: str = "admin"
self.session = requests.Session()
self.host: str = ""
self.port: int = 0
self.mode: str = "admin"
self.ragflow_client = None
intro = r"""Type "\h" for help."""
prompt = "ragflow> "
def onecmd(self, command: str) -> bool:
try:
result = self.parse_command(command)
if isinstance(result, dict):
if "type" in result and result.get("type") == "empty":
return False
self.execute_command(result)
if isinstance(result, Tree):
return False
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
return True
except KeyboardInterrupt:
print("\nUse '\\q' to quit")
except EOFError:
print("\nGoodbye!")
return True
return False
def emptyline(self) -> bool:
return False
def default(self, line: str) -> bool:
return self.onecmd(line)
def parse_command(self, command_str: str) -> dict[str, str]:
if not command_str.strip():
return {"type": "empty"}
self.command_history.append(command_str)
try:
result = self.parser.parse(command_str)
return result
except Exception as e:
return {"type": "error", "message": f"Parse error: {str(e)}"}
def verify_auth(self, arguments: dict, single_command: bool, auth: bool):
server_type = arguments.get("type", "admin")
http_client = HttpClient(arguments["host"], arguments["port"])
if not auth:
self.ragflow_client = RAGFlowClient(http_client, server_type)
return True
user_name = arguments["username"]
attempt_count = 3
if single_command:
attempt_count = 1
try_count = 0
while True:
try_count += 1
if try_count > attempt_count:
return False
if single_command:
user_password = arguments["password"]
else:
user_password = getpass.getpass(f"password for {user_name}: ").strip()
try:
token = login_user(http_client, server_type, user_name, user_password)
http_client.login_token = token
self.ragflow_client = RAGFlowClient(http_client, server_type)
return True
except Exception as e:
print(str(e))
print("Can't access server for login (connection failed)")
def _format_service_detail_table(self, data):
if isinstance(data, list):
return data
if not all([isinstance(v, list) for v in data.values()]):
# normal table
return data
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
task_executor_list = []
for k, v in data.items():
# display latest status
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
task_executor_list.append(
{
"task_executor_name": k,
**heartbeats[0],
}
if heartbeats
else {"task_executor_name": k}
)
return task_executor_list
def _print_table_simple(self, data):
if not data:
print("No data to print")
return
if isinstance(data, dict):
# handle single row data
data = [data]
columns = list(set().union(*(d.keys() for d in data)))
columns.sort()
col_widths = {}
def get_string_width(text):
half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
width = 0
for char in text:
if char in half_width_chars:
width += 1
else:
width += 2
return width
for col in columns:
max_width = get_string_width(str(col))
for item in data:
value_len = get_string_width(str(item.get(col, "")))
if value_len > max_width:
max_width = value_len
col_widths[col] = max(2, max_width)
# Generate delimiter
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
# Print header
print(separator)
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
print(header)
print(separator)
# Print data
for item in data:
row = "|"
for col in columns:
value = str(item.get(col, ""))
if get_string_width(value) > col_widths[col]:
value = value[: col_widths[col] - 3] + "..."
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
print(row)
print(separator)
def run_interactive(self, args):
if self.verify_auth(args, single_command=False, auth=args["auth"]):
print(r"""
____ ___ ______________ ________ ____
/ __ \/ | / ____/ ____/ /___ _ __ / ____/ / / _/
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / / / / / /
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / /___/ /____/ /
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ \____/_____/___/
""")
self.cmdloop()
print("RAGFlow command line interface - Type '\\?' for help, '\\q' to quit")
def run_single_command(self, args):
if self.verify_auth(args, single_command=True, auth=args["auth"]):
command = args["command"]
result = self.parse_command(command)
self.execute_command(result)
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
parser = argparse.ArgumentParser(description="RAGFlow CLI Client", add_help=False)
parser.add_argument("-h", "--host", default="127.0.0.1", help="Admin or RAGFlow service host")
parser.add_argument("-p", "--port", type=int, default=9381, help="Admin or RAGFlow service port")
parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
parser.add_argument("-t", "--type", default="admin", type=str, help="CLI mode, admin or user")
parser.add_argument("-u", "--username", default=None,
help="Username (email). In admin mode defaults to admin@ragflow.io, in user mode required.")
parser.add_argument("command", nargs="?", help="Single command")
try:
parsed_args, remaining_args = parser.parse_known_args(args)
# Determine username based on mode
username = parsed_args.username
if parsed_args.type == "admin":
if username is None:
username = "admin@ragflow.io"
if remaining_args:
if remaining_args[0] == "command":
command_str = ' '.join(remaining_args[1:]) + ';'
auth = True
if remaining_args[1] == "register":
auth = False
else:
if username is None:
print("Error: username (-u) is required in user mode")
return {"error": "Username required"}
return {
"host": parsed_args.host,
"port": parsed_args.port,
"password": parsed_args.password,
"type": parsed_args.type,
"username": username,
"command": command_str,
"auth": auth
}
else:
return {"error": "Invalid command"}
else:
auth = True
if username is None:
auth = False
return {
"host": parsed_args.host,
"port": parsed_args.port,
"type": parsed_args.type,
"username": username,
"auth": auth
}
except SystemExit:
return {"error": "Invalid connection arguments"}
def execute_command(self, parsed_command: Dict[str, Any]):
command_dict: dict
if isinstance(parsed_command, Tree):
command_dict = parsed_command.children[0]
else:
if parsed_command["type"] == "error":
print(f"Error: {parsed_command['message']}")
return
else:
command_dict = parsed_command
# print(f"Parsed command: {command_dict}")
run_command(self.ragflow_client, command_dict)
def main():
cli = RAGFlowCLI()
args = cli.parse_connection_args(sys.argv)
if "error" in args:
print("Error: Invalid connection arguments")
return
if "command" in args:
# single command mode
# for user mode, api key or password is ok
# for admin mode, only password
if "password" not in args:
print("Error: password is missing")
return
cli.run_single_command(args)
else:
cli.run_interactive(args)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

65
admin/client/user.py Normal file
View File

@ -0,0 +1,65 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from http_client import HttpClient
class AuthException(Exception):
def __init__(self, message, code=401):
super().__init__(message)
self.code = code
self.message = message
def encrypt_password(password_plain: str) -> str:
try:
from api.utils.crypt import crypt
except Exception as exc:
raise AuthException(
"Password encryption unavailable; install pycryptodomex (uv sync --python 3.12 --group test)."
) from exc
return crypt(password_plain)
def register_user(client: HttpClient, email: str, nickname: str, password: str) -> None:
password_enc = encrypt_password(password)
payload = {"email": email, "nickname": nickname, "password": password_enc}
res = client.request_json("POST", "/user/register", use_api_base=False, auth_kind=None, json_body=payload)
if res.get("code") == 0:
return
msg = res.get("message", "")
if "has already registered" in msg:
return
raise AuthException(f"Register failed: {msg}")
def login_user(client: HttpClient, server_type: str, email: str, password: str) -> str:
password_enc = encrypt_password(password)
payload = {"email": email, "password": password_enc}
if server_type == "admin":
response = client.request("POST", "/admin/login", use_api_base=True, auth_kind=None, json_body=payload)
else:
response = client.request("POST", "/user/login", use_api_base=False, auth_kind=None, json_body=payload)
try:
res = response.json()
except Exception as exc:
raise AuthException(f"Login failed: invalid JSON response ({exc})") from exc
if res.get("code") != 0:
raise AuthException(f"Login failed: {res.get('message')}")
token = response.headers.get("Authorization")
if not token:
raise AuthException("Login failed: missing Authorization header")
return token

2
admin/client/uv.lock generated
View File

@ -196,7 +196,7 @@ wheels = [
[[package]] [[package]]
name = "ragflow-cli" name = "ragflow-cli"
version = "0.23.0" version = "0.23.1"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "beartype" }, { name = "beartype" },

View File

@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
# #
import time
start_ts = time.time()
import os import os
import signal import signal
import logging import logging
import time
import threading import threading
import traceback import traceback
import faulthandler import faulthandler
@ -66,7 +68,7 @@ if __name__ == '__main__':
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF) SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
try: try:
logging.info("RAGFlow Admin service start...") logging.info(f"RAGFlow admin is ready after {time.time() - start_ts}s initialization.")
run_simple( run_simple(
hostname="0.0.0.0", hostname="0.0.0.0",
port=9381, port=9381,

View File

@ -15,29 +15,33 @@
# #
import secrets import secrets
from typing import Any
from flask import Blueprint, request from common.time_utils import current_timestamp, datetime_format
from datetime import datetime
from flask import Blueprint, Response, request
from flask_login import current_user, login_required, logout_user from flask_login import current_user, login_required, logout_user
from auth import login_verify, login_admin, check_admin_auth from auth import login_verify, login_admin, check_admin_auth
from responses import success_response, error_response from responses import success_response, error_response
from services import UserMgr, ServiceMgr, UserServiceMgr from services import UserMgr, ServiceMgr, UserServiceMgr, SettingsMgr, ConfigMgr, EnvironmentsMgr
from roles import RoleMgr from roles import RoleMgr
from api.common.exceptions import AdminException from api.common.exceptions import AdminException
from common.versions import get_ragflow_version from common.versions import get_ragflow_version
from api.utils.api_utils import generate_confirmation_token
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin') admin_bp = Blueprint("admin", __name__, url_prefix="/api/v1/admin")
@admin_bp.route('/ping', methods=['GET']) @admin_bp.route("/ping", methods=["GET"])
def ping(): def ping():
return success_response('PONG') return success_response("PONG")
@admin_bp.route('/login', methods=['POST']) @admin_bp.route("/login", methods=["POST"])
def login(): def login():
if not request.json: if not request.json:
return error_response('Authorize admin failed.' ,400) return error_response("Authorize admin failed.", 400)
try: try:
email = request.json.get("email", "") email = request.json.get("email", "")
password = request.json.get("password", "") password = request.json.get("password", "")
@ -46,7 +50,7 @@ def login():
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/logout', methods=['GET']) @admin_bp.route("/logout", methods=["GET"])
@login_required @login_required
def logout(): def logout():
try: try:
@ -58,7 +62,7 @@ def logout():
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/auth', methods=['GET']) @admin_bp.route("/auth", methods=["GET"])
@login_verify @login_verify
def auth_admin(): def auth_admin():
try: try:
@ -67,7 +71,7 @@ def auth_admin():
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users', methods=['GET']) @admin_bp.route("/users", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def list_users(): def list_users():
@ -78,18 +82,18 @@ def list_users():
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users', methods=['POST']) @admin_bp.route("/users", methods=["POST"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def create_user(): def create_user():
try: try:
data = request.get_json() data = request.get_json()
if not data or 'username' not in data or 'password' not in data: if not data or "username" not in data or "password" not in data:
return error_response("Username and password are required", 400) return error_response("Username and password are required", 400)
username = data['username'] username = data["username"]
password = data['password'] password = data["password"]
role = data.get('role', 'user') role = data.get("role", "user")
res = UserMgr.create_user(username, password, role) res = UserMgr.create_user(username, password, role)
if res["success"]: if res["success"]:
@ -105,7 +109,7 @@ def create_user():
return error_response(str(e)) return error_response(str(e))
@admin_bp.route('/users/<username>', methods=['DELETE']) @admin_bp.route("/users/<username>", methods=["DELETE"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def delete_user(username): def delete_user(username):
@ -122,16 +126,16 @@ def delete_user(username):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users/<username>/password', methods=['PUT']) @admin_bp.route("/users/<username>/password", methods=["PUT"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def change_password(username): def change_password(username):
try: try:
data = request.get_json() data = request.get_json()
if not data or 'new_password' not in data: if not data or "new_password" not in data:
return error_response("New password is required", 400) return error_response("New password is required", 400)
new_password = data['new_password'] new_password = data["new_password"]
msg = UserMgr.update_user_password(username, new_password) msg = UserMgr.update_user_password(username, new_password)
return success_response(None, msg) return success_response(None, msg)
@ -141,15 +145,15 @@ def change_password(username):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users/<username>/activate', methods=['PUT']) @admin_bp.route("/users/<username>/activate", methods=["PUT"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def alter_user_activate_status(username): def alter_user_activate_status(username):
try: try:
data = request.get_json() data = request.get_json()
if not data or 'activate_status' not in data: if not data or "activate_status" not in data:
return error_response("Activation status is required", 400) return error_response("Activation status is required", 400)
activate_status = data['activate_status'] activate_status = data["activate_status"]
msg = UserMgr.update_user_activate_status(username, activate_status) msg = UserMgr.update_user_activate_status(username, activate_status)
return success_response(None, msg) return success_response(None, msg)
except AdminException as e: except AdminException as e:
@ -158,7 +162,39 @@ def alter_user_activate_status(username):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users/<username>', methods=['GET']) @admin_bp.route("/users/<username>/admin", methods=["PUT"])
@login_required
@check_admin_auth
def grant_admin(username):
try:
if current_user.email == username:
return error_response(f"can't grant current user: {username}", 409)
msg = UserMgr.grant_admin(username)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/admin", methods=["DELETE"])
@login_required
@check_admin_auth
def revoke_admin(username):
try:
if current_user.email == username:
return error_response(f"can't grant current user: {username}", 409)
msg = UserMgr.revoke_admin(username)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_user_details(username): def get_user_details(username):
@ -172,7 +208,7 @@ def get_user_details(username):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users/<username>/datasets', methods=['GET']) @admin_bp.route("/users/<username>/datasets", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_user_datasets(username): def get_user_datasets(username):
@ -186,7 +222,7 @@ def get_user_datasets(username):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users/<username>/agents', methods=['GET']) @admin_bp.route("/users/<username>/agents", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_user_agents(username): def get_user_agents(username):
@ -200,7 +236,7 @@ def get_user_agents(username):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/services', methods=['GET']) @admin_bp.route("/services", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_services(): def get_services():
@ -211,7 +247,7 @@ def get_services():
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/service_types/<service_type>', methods=['GET']) @admin_bp.route("/service_types/<service_type>", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_services_by_type(service_type_str): def get_services_by_type(service_type_str):
@ -222,7 +258,7 @@ def get_services_by_type(service_type_str):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/services/<service_id>', methods=['GET']) @admin_bp.route("/services/<service_id>", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_service(service_id): def get_service(service_id):
@ -233,7 +269,7 @@ def get_service(service_id):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/services/<service_id>', methods=['DELETE']) @admin_bp.route("/services/<service_id>", methods=["DELETE"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def shutdown_service(service_id): def shutdown_service(service_id):
@ -244,7 +280,7 @@ def shutdown_service(service_id):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/services/<service_id>', methods=['PUT']) @admin_bp.route("/services/<service_id>", methods=["PUT"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def restart_service(service_id): def restart_service(service_id):
@ -255,38 +291,38 @@ def restart_service(service_id):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/roles', methods=['POST']) @admin_bp.route("/roles", methods=["POST"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def create_role(): def create_role():
try: try:
data = request.get_json() data = request.get_json()
if not data or 'role_name' not in data: if not data or "role_name" not in data:
return error_response("Role name is required", 400) return error_response("Role name is required", 400)
role_name: str = data['role_name'] role_name: str = data["role_name"]
description: str = data['description'] description: str = data["description"]
res = RoleMgr.create_role(role_name, description) res = RoleMgr.create_role(role_name, description)
return success_response(res) return success_response(res)
except Exception as e: except Exception as e:
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/roles/<role_name>', methods=['PUT']) @admin_bp.route("/roles/<role_name>", methods=["PUT"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def update_role(role_name: str): def update_role(role_name: str):
try: try:
data = request.get_json() data = request.get_json()
if not data or 'description' not in data: if not data or "description" not in data:
return error_response("Role description is required", 400) return error_response("Role description is required", 400)
description: str = data['description'] description: str = data["description"]
res = RoleMgr.update_role_description(role_name, description) res = RoleMgr.update_role_description(role_name, description)
return success_response(res) return success_response(res)
except Exception as e: except Exception as e:
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/roles/<role_name>', methods=['DELETE']) @admin_bp.route("/roles/<role_name>", methods=["DELETE"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def delete_role(role_name: str): def delete_role(role_name: str):
@ -297,7 +333,7 @@ def delete_role(role_name: str):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/roles', methods=['GET']) @admin_bp.route("/roles", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def list_roles(): def list_roles():
@ -308,7 +344,7 @@ def list_roles():
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/roles/<role_name>/permission', methods=['GET']) @admin_bp.route("/roles/<role_name>/permission", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_role_permission(role_name: str): def get_role_permission(role_name: str):
@ -319,54 +355,54 @@ def get_role_permission(role_name: str):
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/roles/<role_name>/permission', methods=['POST']) @admin_bp.route("/roles/<role_name>/permission", methods=["POST"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def grant_role_permission(role_name: str): def grant_role_permission(role_name: str):
try: try:
data = request.get_json() data = request.get_json()
if not data or 'actions' not in data or 'resource' not in data: if not data or "actions" not in data or "resource" not in data:
return error_response("Permission is required", 400) return error_response("Permission is required", 400)
actions: list = data['actions'] actions: list = data["actions"]
resource: str = data['resource'] resource: str = data["resource"]
res = RoleMgr.grant_role_permission(role_name, actions, resource) res = RoleMgr.grant_role_permission(role_name, actions, resource)
return success_response(res) return success_response(res)
except Exception as e: except Exception as e:
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/roles/<role_name>/permission', methods=['DELETE']) @admin_bp.route("/roles/<role_name>/permission", methods=["DELETE"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def revoke_role_permission(role_name: str): def revoke_role_permission(role_name: str):
try: try:
data = request.get_json() data = request.get_json()
if not data or 'actions' not in data or 'resource' not in data: if not data or "actions" not in data or "resource" not in data:
return error_response("Permission is required", 400) return error_response("Permission is required", 400)
actions: list = data['actions'] actions: list = data["actions"]
resource: str = data['resource'] resource: str = data["resource"]
res = RoleMgr.revoke_role_permission(role_name, actions, resource) res = RoleMgr.revoke_role_permission(role_name, actions, resource)
return success_response(res) return success_response(res)
except Exception as e: except Exception as e:
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users/<user_name>/role', methods=['PUT']) @admin_bp.route("/users/<user_name>/role", methods=["PUT"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def update_user_role(user_name: str): def update_user_role(user_name: str):
try: try:
data = request.get_json() data = request.get_json()
if not data or 'role_name' not in data: if not data or "role_name" not in data:
return error_response("Role name is required", 400) return error_response("Role name is required", 400)
role_name: str = data['role_name'] role_name: str = data["role_name"]
res = RoleMgr.update_user_role(user_name, role_name) res = RoleMgr.update_user_role(user_name, role_name)
return success_response(res) return success_response(res)
except Exception as e: except Exception as e:
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/users/<user_name>/permission', methods=['GET']) @admin_bp.route("/users/<user_name>/permission", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def get_user_permission(user_name: str): def get_user_permission(user_name: str):
@ -376,7 +412,140 @@ def get_user_permission(user_name: str):
except Exception as e: except Exception as e:
return error_response(str(e), 500) return error_response(str(e), 500)
@admin_bp.route('/version', methods=['GET'])
@admin_bp.route("/variables", methods=["PUT"])
@login_required
@check_admin_auth
def set_variable():
try:
data = request.get_json()
if not data and "var_name" not in data:
return error_response("Var name is required", 400)
if "var_value" not in data:
return error_response("Var value is required", 400)
var_name: str = data["var_name"]
var_value: str = data["var_value"]
SettingsMgr.update_by_name(var_name, var_value)
return success_response(None, "Set variable successfully")
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/variables", methods=["GET"])
@login_required
@check_admin_auth
def get_variable():
try:
if request.content_length is None or request.content_length == 0:
# list variables
res = list(SettingsMgr.get_all())
return success_response(res)
# get var
data = request.get_json()
if not data and "var_name" not in data:
return error_response("Var name is required", 400)
var_name: str = data["var_name"]
res = SettingsMgr.get_by_name(var_name)
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/configs", methods=["GET"])
@login_required
@check_admin_auth
def get_config():
try:
res = list(ConfigMgr.get_all())
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/environments", methods=["GET"])
@login_required
@check_admin_auth
def get_environments():
try:
res = list(EnvironmentsMgr.get_all())
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/keys", methods=["POST"])
@login_required
@check_admin_auth
def generate_user_api_key(username: str) -> tuple[Response, int]:
try:
user_details: list[dict[str, Any]] = UserMgr.get_user_details(username)
if not user_details:
return error_response("User not found!", 404)
tenants: list[dict[str, Any]] = UserServiceMgr.get_user_tenants(username)
if not tenants:
return error_response("Tenant not found!", 404)
tenant_id: str = tenants[0]["tenant_id"]
key: str = generate_confirmation_token()
obj: dict[str, Any] = {
"tenant_id": tenant_id,
"token": key,
"beta": generate_confirmation_token().replace("ragflow-", "")[:32],
"create_time": current_timestamp(),
"create_date": datetime_format(datetime.now()),
"update_time": None,
"update_date": None,
}
if not UserMgr.save_api_key(obj):
return error_response("Failed to generate API key!", 500)
return success_response(obj, "API key generated successfully")
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/keys", methods=["GET"])
@login_required
@check_admin_auth
def get_user_api_keys(username: str) -> tuple[Response, int]:
try:
api_keys: list[dict[str, Any]] = UserMgr.get_user_api_key(username)
return success_response(api_keys, "Get user API keys")
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/keys/<key>", methods=["DELETE"])
@login_required
@check_admin_auth
def delete_user_api_key(username: str, key: str) -> tuple[Response, int]:
try:
deleted = UserMgr.delete_api_key(username, key)
if deleted:
return success_response(None, "API key deleted successfully")
else:
return error_response("API key not found or could not be deleted", 404)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/version", methods=["GET"])
@login_required @login_required
@check_admin_auth @check_admin_auth
def show_version(): def show_version():

View File

@ -13,15 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import logging import logging
import re import re
from typing import Any
from werkzeug.security import check_password_hash from werkzeug.security import check_password_hash
from common.constants import ActiveEnum from common.constants import ActiveEnum
from api.db.services import UserService from api.db.services import UserService
from api.db.joint_services.user_account_service import create_new_user, delete_user_data from api.db.joint_services.user_account_service import create_new_user, delete_user_data
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.system_settings_service import SystemSettingsService
from api.db.services.api_service import APITokenService
from api.db.db_models import APIToken
from api.utils.crypt import decrypt from api.utils.crypt import decrypt
from api.utils import health_utils from api.utils import health_utils
@ -35,13 +42,15 @@ class UserMgr:
users = UserService.get_all_users() users = UserService.get_all_users()
result = [] result = []
for user in users: for user in users:
result.append({ result.append(
'email': user.email, {
'nickname': user.nickname, "email": user.email,
'create_date': user.create_date, "nickname": user.nickname,
'is_active': user.is_active, "create_date": user.create_date,
'is_superuser': user.is_superuser, "is_active": user.is_active,
}) "is_superuser": user.is_superuser,
}
)
return result return result
@staticmethod @staticmethod
@ -50,19 +59,21 @@ class UserMgr:
users = UserService.query_user_by_email(username) users = UserService.query_user_by_email(username)
result = [] result = []
for user in users: for user in users:
result.append({ result.append(
'avatar': user.avatar, {
'email': user.email, "avatar": user.avatar,
'language': user.language, "email": user.email,
'last_login_time': user.last_login_time, "language": user.language,
'is_active': user.is_active, "last_login_time": user.last_login_time,
'is_anonymous': user.is_anonymous, "is_active": user.is_active,
'login_channel': user.login_channel, "is_anonymous": user.is_anonymous,
'status': user.status, "login_channel": user.login_channel,
'is_superuser': user.is_superuser, "status": user.status,
'create_date': user.create_date, "is_superuser": user.is_superuser,
'update_date': user.update_date "create_date": user.create_date,
}) "update_date": user.update_date,
}
)
return result return result
@staticmethod @staticmethod
@ -124,8 +135,8 @@ class UserMgr:
# format activate_status before handle # format activate_status before handle
_activate_status = activate_status.lower() _activate_status = activate_status.lower()
target_status = { target_status = {
'on': ActiveEnum.ACTIVE.value, "on": ActiveEnum.ACTIVE.value,
'off': ActiveEnum.INACTIVE.value, "off": ActiveEnum.INACTIVE.value,
}.get(_activate_status) }.get(_activate_status)
if not target_status: if not target_status:
raise AdminException(f"Invalid activate_status: {activate_status}") raise AdminException(f"Invalid activate_status: {activate_status}")
@ -135,9 +146,84 @@ class UserMgr:
UserService.update_user(usr.id, {"is_active": target_status}) UserService.update_user(usr.id, {"is_active": target_status})
return f"Turn {_activate_status} user activate status successfully!" return f"Turn {_activate_status} user activate status successfully!"
@staticmethod
def get_user_api_key(username: str) -> list[dict[str, Any]]:
# use email to find user. check exist and unique.
user_list: list[Any] = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"More than one user with username '{username}' found!")
usr: Any = user_list[0]
# tenant_id is typically the same as user_id for the owner tenant
tenant_id: str = usr.id
# Query all API keys for this tenant
api_keys: Any = APITokenService.query(tenant_id=tenant_id)
result: list[dict[str, Any]] = []
for key in api_keys:
result.append(key.to_dict())
return result
@staticmethod
def save_api_key(api_key: dict[str, Any]) -> bool:
return APITokenService.save(**api_key)
@staticmethod
def delete_api_key(username: str, key: str) -> bool:
# use email to find user. check exist and unique.
user_list: list[Any] = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
usr: Any = user_list[0]
# tenant_id is typically the same as user_id for the owner tenant
tenant_id: str = usr.id
# Delete the API key
deleted_count: int = APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == key])
return deleted_count > 0
@staticmethod
def grant_admin(username: str):
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check activate status different from new
usr = user_list[0]
if usr.is_superuser:
return f"{usr} is already superuser!"
# update is_active
UserService.update_user(usr.id, {"is_superuser": True})
return "Grant successfully!"
@staticmethod
def revoke_admin(username: str):
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check activate status different from new
usr = user_list[0]
if not usr.is_superuser:
return f"{usr} isn't superuser, yet!"
# update is_active
UserService.update_user(usr.id, {"is_superuser": False})
return "Revoke successfully!"
class UserServiceMgr: class UserServiceMgr:
@staticmethod @staticmethod
def get_user_datasets(username): def get_user_datasets(username):
# use email to find user. # use email to find user.
@ -167,35 +253,43 @@ class UserServiceMgr:
tenant_ids = [m["tenant_id"] for m in tenants] tenant_ids = [m["tenant_id"] for m in tenants]
# filter permitted agents and owned agents # filter permitted agents and owned agents
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id) res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
return [{ return [{"title": r["title"], "permission": r["permission"], "canvas_category": r["canvas_category"].split("_")[0], "avatar": r["avatar"]} for r in res]
'title': r['title'],
'permission': r['permission'], @staticmethod
'canvas_category': r['canvas_category'].split('_')[0], def get_user_tenants(email: str) -> list[dict[str, Any]]:
'avatar': r['avatar'] users: list[Any] = UserService.query_user_by_email(email)
} for r in res] if not users:
raise UserNotFoundError(email)
user: Any = users[0]
tenants: list[dict[str, Any]] = UserTenantService.get_tenants_by_user_id(user.id)
return tenants
class ServiceMgr: class ServiceMgr:
@staticmethod @staticmethod
def get_all_services(): def get_all_services():
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
result = [] result = []
configs = SERVICE_CONFIGS.configs configs = SERVICE_CONFIGS.configs
for service_id, config in enumerate(configs): for service_id, config in enumerate(configs):
config_dict = config.to_dict() config_dict = config.to_dict()
if config_dict["service_type"] == "retrieval":
if config_dict["extra"]["retrieval_type"] != doc_engine:
continue
try: try:
service_detail = ServiceMgr.get_service_details(service_id) service_detail = ServiceMgr.get_service_details(service_id)
if "status" in service_detail: if "status" in service_detail:
config_dict['status'] = service_detail['status'] config_dict["status"] = service_detail["status"]
else: else:
config_dict['status'] = 'timeout' config_dict["status"] = "timeout"
except Exception as e: except Exception as e:
logging.warning(f"Can't get service details, error: {e}") logging.warning(f"Can't get service details, error: {e}")
config_dict['status'] = 'timeout' config_dict["status"] = "timeout"
if not config_dict['host']: if not config_dict["host"]:
config_dict['host'] = '-' config_dict["host"] = "-"
if not config_dict['port']: if not config_dict["port"]:
config_dict['port'] = '-' config_dict["port"] = "-"
result.append(config_dict) result.append(config_dict)
return result return result
@ -211,11 +305,18 @@ class ServiceMgr:
raise AdminException(f"invalid service_index: {service_idx}") raise AdminException(f"invalid service_index: {service_idx}")
service_config = configs[service_idx] service_config = configs[service_idx]
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
detail_func = getattr(health_utils, service_info.get('detail_func_name')) # exclude retrieval service if retrieval_type is not matched
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
if service_config.service_type == "retrieval":
if service_config.retrieval_type != doc_engine:
raise AdminException(f"invalid service_index: {service_idx}")
service_info = {"name": service_config.name, "detail_func_name": service_config.detail_func_name}
detail_func = getattr(health_utils, service_info.get("detail_func_name"))
res = detail_func() res = detail_func()
res.update({'service_name': service_info.get('name')}) res.update({"service_name": service_info.get("name")})
return res return res
@staticmethod @staticmethod
@ -225,3 +326,84 @@ class ServiceMgr:
@staticmethod @staticmethod
def restart_service(service_id: int): def restart_service(service_id: int):
raise AdminException("restart_service: not implemented") raise AdminException("restart_service: not implemented")
class SettingsMgr:
@staticmethod
def get_all():
settings = SystemSettingsService.get_all()
result = []
for setting in settings:
result.append(
{
"name": setting.name,
"source": setting.source,
"data_type": setting.data_type,
"value": setting.value,
}
)
return result
@staticmethod
def get_by_name(name: str):
settings = SystemSettingsService.get_by_name(name)
if len(settings) == 0:
raise AdminException(f"Can't get setting: {name}")
result = []
for setting in settings:
result.append(
{
"name": setting.name,
"source": setting.source,
"data_type": setting.data_type,
"value": setting.value,
}
)
return result
@staticmethod
def update_by_name(name: str, value: str):
settings = SystemSettingsService.get_by_name(name)
if len(settings) == 1:
setting = settings[0]
setting.value = value
setting_dict = setting.to_dict()
SystemSettingsService.update_by_name(name, setting_dict)
elif len(settings) > 1:
raise AdminException(f"Can't update more than 1 setting: {name}")
else:
raise AdminException(f"No setting: {name}")
class ConfigMgr:
@staticmethod
def get_all():
result = []
configs = SERVICE_CONFIGS.configs
for config in configs:
config_dict = config.to_dict()
result.append(config_dict)
return result
class EnvironmentsMgr:
@staticmethod
def get_all():
result = []
env_kv = {"env": "DOC_ENGINE", "value": os.getenv("DOC_ENGINE")}
result.append(env_kv)
env_kv = {"env": "DEFAULT_SUPERUSER_EMAIL", "value": os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")}
result.append(env_kv)
env_kv = {"env": "DB_TYPE", "value": os.getenv("DB_TYPE", "mysql")}
result.append(env_kv)
env_kv = {"env": "DEVICE", "value": os.getenv("DEVICE", "cpu")}
result.append(env_kv)
env_kv = {"env": "STORAGE_IMPL", "value": os.getenv("STORAGE_IMPL", "MINIO")}
result.append(env_kv)
return result

View File

@ -27,6 +27,10 @@ import pandas as pd
from agent import settings from agent import settings
from common.connection_utils import timeout from common.connection_utils import timeout
from common.misc_utils import thread_pool_exec
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params"
_USER_FEEDED_PARAMS = "_user_feeded_params" _USER_FEEDED_PARAMS = "_user_feeded_params"
@ -379,6 +383,7 @@ class ComponentBase(ABC):
def __init__(self, canvas, id, param: ComponentParamBase): def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Graph # Local import to avoid cyclic dependency from agent.canvas import Graph # Local import to avoid cyclic dependency
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas" assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
self._canvas = canvas self._canvas = canvas
self._id = id self._id = id
@ -430,7 +435,7 @@ class ComponentBase(ABC):
elif asyncio.iscoroutinefunction(self._invoke): elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs) await self._invoke(**kwargs)
else: else:
await asyncio.to_thread(self._invoke, **kwargs) await thread_pool_exec(self._invoke, **kwargs)
except Exception as e: except Exception as e:
if self.get_exception_default_value(): if self.get_exception_default_value():
self.set_exception_default_value() self.set_exception_default_value()

View File

@ -97,6 +97,13 @@ Here's description of each category:
class Categorize(LLM, ABC): class Categorize(LLM, ABC):
component_name = "Categorize" component_name = "Categorize"
def get_input_elements(self) -> dict[str, dict]:
query_key = self._param.query or "sys.query"
elements = self.get_input_elements_from_text(f"{{{query_key}}}")
if not elements:
logging.warning(f"[Categorize] input element not detected for query key: {query_key}")
return elements
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
async def _invoke_async(self, **kwargs): async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Categorize processing"): if self.check_if_canceled("Categorize processing"):
@ -105,12 +112,15 @@ class Categorize(LLM, ABC):
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg: if not msg:
msg = [{"role": "user", "content": ""}] msg = [{"role": "user", "content": ""}]
if kwargs.get("sys.query"): query_key = self._param.query or "sys.query"
msg[-1]["content"] = kwargs["sys.query"] if query_key in kwargs:
self.set_input_value("sys.query", kwargs["sys.query"]) query_value = kwargs[query_key]
else: else:
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query) query_value = self._canvas.get_variable_value(query_key)
self.set_input_value(self._param.query, msg[-1]["content"]) if query_value is None:
query_value = ""
msg[-1]["content"] = query_value
self.set_input_value(query_key, msg[-1]["content"])
self._param.update_prompt() self._param.update_prompt()
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)

View File

@ -33,7 +33,7 @@ from common.connection_utils import timeout
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common import settings from common import settings
from api.db.joint_services.memory_message_service import save_to_memory from api.db.joint_services.memory_message_service import queue_save_to_memory_task
class MessageParam(ComponentParamBase): class MessageParam(ComponentParamBase):
@ -437,17 +437,4 @@ class Message(ComponentBase):
"user_input": self._canvas.get_sys_query(), "user_input": self._canvas.get_sys_query(),
"agent_response": content "agent_response": content
} }
res = [] return await queue_save_to_memory_task(self._param.memory_ids, message_dict)
for memory_id in self._param.memory_ids:
success, msg = await save_to_memory(memory_id, message_dict)
res.append({
"memory_id": memory_id,
"success": success,
"msg": msg
})
if all([r["success"] for r in res]):
return True, "Successfully added to memories."
error_text = "Some messages failed to add. " + " ".join([f"Add to memory {r['memory_id']} failed, detail: {r['msg']}" for r in res if not r["success"]])
logging.error(error_text)
return False, error_text

View File

@ -27,6 +27,10 @@ from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer from timeit import default_timer as timer
from common.misc_utils import thread_pool_exec
class ToolParameter(TypedDict): class ToolParameter(TypedDict):
type: str type: str
description: str description: str
@ -56,12 +60,12 @@ class LLMToolPluginCallSession(ToolCallSession):
st = timer() st = timer()
tool_obj = self.tools_map[name] tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession): if isinstance(tool_obj, MCPToolCallSession):
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60) resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
else: else:
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments) resp = await tool_obj.invoke_async(**arguments)
else: else:
resp = await asyncio.to_thread(tool_obj.invoke, **arguments) resp = await thread_pool_exec(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st) self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp return resp
@ -122,6 +126,7 @@ class ToolParamBase(ComponentParamBase):
class ToolBase(ComponentBase): class ToolBase(ComponentBase):
def __init__(self, canvas, id, param: ComponentParamBase): def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas" assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas self._canvas = canvas
self._id = id self._id = id
@ -164,7 +169,7 @@ class ToolBase(ComponentBase):
elif asyncio.iscoroutinefunction(self._invoke): elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs) res = await self._invoke(**kwargs)
else: else:
res = await asyncio.to_thread(self._invoke, **kwargs) res = await thread_pool_exec(self._invoke, **kwargs)
except Exception as e: except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)} self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e) logging.exception(e)

View File

@ -86,6 +86,12 @@ class ExeSQL(ToolBase, ABC):
def convert_decimals(obj): def convert_decimals(obj):
from decimal import Decimal from decimal import Decimal
import math
if isinstance(obj, float):
# Handle NaN and Infinity which are not valid JSON values
if math.isnan(obj) or math.isinf(obj):
return None
return obj
if isinstance(obj, Decimal): if isinstance(obj, Decimal):
return float(obj) # 或 str(obj) return float(obj) # 或 str(obj)
elif isinstance(obj, dict): elif isinstance(obj, dict):

View File

@ -174,7 +174,7 @@ class Retrieval(ToolBase, ABC):
if kbs: if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE) query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retriever.retrieval( kbinfos = await settings.retriever.retrieval(
query, query,
embd_mdl, embd_mdl,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],
@ -193,7 +193,7 @@ class Retrieval(ToolBase, ABC):
if self._param.toc_enhance: if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n) chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"): if self.check_if_canceled("Retrieval processing"):
return return
@ -202,7 +202,7 @@ class Retrieval(ToolBase, ABC):
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs]) [kb.tenant_id for kb in kbs])
if self._param.use_kg: if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query, ck = await settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],
kb_ids, kb_ids,
embd_mdl, embd_mdl,
@ -215,7 +215,7 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []} kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs: if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"): if self.check_if_canceled("Retrieval processing"):
return return

View File

@ -1 +0,0 @@
from .deep_research import DeepResearcher as DeepResearcher

View File

@ -1,238 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from functools import partial
from agentic_reasoning.prompts import BEGIN_SEARCH_QUERY, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT, MAX_SEARCH_LIMIT, \
END_SEARCH_QUERY, REASON_PROMPT, RELEVANT_EXTRACTION_PROMPT
from api.db.services.llm_service import LLMBundle
from rag.nlp import extract_between
from rag.prompts import kb_prompt
from rag.utils.tavily_conn import Tavily
class DeepResearcher:
def __init__(self,
chat_mdl: LLMBundle,
prompt_config: dict,
kb_retrieve: partial = None,
kg_retrieve: partial = None
):
self.chat_mdl = chat_mdl
self.prompt_config = prompt_config
self._kb_retrieve = kb_retrieve
self._kg_retrieve = kg_retrieve
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
"""General Tag Removal Method"""
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
return re.sub(pattern, "", text)
@staticmethod
def _remove_query_tags(text: str) -> str:
"""Remove Query Tags"""
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
@staticmethod
def _remove_result_tags(text: str) -> str:
"""Remove Result Tags"""
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
async def _generate_reasoning(self, msg_history):
"""Generate reasoning steps"""
query_think = ""
if msg_history[-1]["role"] != "user":
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
else:
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if not ans:
continue
query_think = ans
yield query_think
query_think = ""
yield query_think
def _extract_search_queries(self, query_think, question, step_index):
"""Extract search queries from thinking"""
queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
if not queries and step_index == 0:
# If this is the first step and no queries are found, use the original question as the query
queries = [question]
return queries
def _truncate_previous_reasoning(self, all_reasoning_steps):
"""Truncate previous reasoning steps to maintain a reasonable length"""
truncated_prev_reasoning = ""
for i, step in enumerate(all_reasoning_steps):
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
prev_steps = truncated_prev_reasoning.split('\n\n')
if len(prev_steps) <= 5:
truncated_prev_reasoning = '\n\n'.join(prev_steps)
else:
truncated_prev_reasoning = ''
for i, step in enumerate(prev_steps):
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
truncated_prev_reasoning += step + '\n\n'
else:
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
truncated_prev_reasoning += '...\n\n'
return truncated_prev_reasoning.strip('\n')
def _retrieve_information(self, search_query):
"""Retrieve information from different sources"""
# 1. Knowledge base retrieval
kbinfos = []
try:
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
except Exception as e:
logging.error(f"Knowledge base retrieval error: {e}")
# 2. Web retrieval (if Tavily API is configured)
try:
if self.prompt_config.get("tavily_api_key"):
tav = Tavily(self.prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(search_query)
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
except Exception as e:
logging.error(f"Web retrieval error: {e}")
# 3. Knowledge graph retrieval (if configured)
try:
if self.prompt_config.get("use_kg") and self._kg_retrieve:
ck = self._kg_retrieve(question=search_query)
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
except Exception as e:
logging.error(f"Knowledge graph retrieval error: {e}")
return kbinfos
def _update_chunk_info(self, chunk_info, kbinfos):
"""Update chunk information for citations"""
if not chunk_info["chunks"]:
# If this is the first retrieval, use the retrieval results directly
for k in chunk_info.keys():
chunk_info[k] = kbinfos[k]
else:
# Merge newly retrieved information, avoiding duplicates
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
for c in kbinfos["chunks"]:
if c["chunk_id"] not in cids:
chunk_info["chunks"].append(c)
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
for d in kbinfos["doc_aggs"]:
if d["doc_id"] not in dids:
chunk_info["doc_aggs"].append(d)
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
"""Extract and summarize relevant information"""
summary_think = ""
async for ans in self.chat_mdl.async_chat_streamly(
RELEVANT_EXTRACTION_PROMPT.format(
prev_reasoning=truncated_prev_reasoning,
search_query=search_query,
document="\n".join(kb_prompt(kbinfos, 4096))
),
[{"role": "user",
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
{"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if not ans:
continue
summary_think = ans
yield summary_think
summary_think = ""
yield summary_think
async def thinking(self, chunk_info: dict, question: str):
executed_search_queries = []
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
all_reasoning_steps = []
think = "<think>"
for step_index in range(MAX_SEARCH_LIMIT + 1):
# Check if the maximum search limit has been reached
if step_index == MAX_SEARCH_LIMIT - 1:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
all_reasoning_steps.append(summary_think)
msg_history.append({"role": "assistant", "content": summary_think})
break
# Step 1: Generate reasoning
query_think = ""
async for ans in self._generate_reasoning(msg_history):
query_think = ans
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
think += self._remove_query_tags(query_think)
all_reasoning_steps.append(query_think)
# Step 2: Extract search queries
queries = self._extract_search_queries(query_think, question, step_index)
if not queries and step_index > 0:
# If not the first step and no queries, end the search process
break
# Process each search query
for search_query in queries:
logging.info(f"[THINK]Query: {step_index}. {search_query}")
msg_history.append({"role": "assistant", "content": search_query})
think += f"\n\n> {step_index + 1}. {search_query}\n\n"
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
# Check if the query has already been executed
if search_query in executed_search_queries:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
all_reasoning_steps.append(summary_think)
msg_history.append({"role": "user", "content": summary_think})
think += summary_think
continue
executed_search_queries.append(search_query)
# Step 3: Truncate previous reasoning steps
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
# Step 4: Retrieve information
kbinfos = self._retrieve_information(search_query)
# Step 5: Update chunk information
self._update_chunk_info(chunk_info, kbinfos)
# Step 6: Extract relevant information
think += "\n\n"
summary_think = ""
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
summary_think = ans
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
all_reasoning_steps.append(summary_think)
msg_history.append(
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
think += self._remove_result_tags(summary_think)
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
yield think + "</think>"

View File

@ -1,147 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
END_SEARCH_QUERY = "<|end_search_query|>"
BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
END_SEARCH_RESULT = "<|end_search_result|>"
MAX_SEARCH_LIMIT = 6
REASON_PROMPT = f"""You are an advanced reasoning agent. Your goal is to answer the user's question by breaking it down into a series of verifiable steps.
You have access to a powerful search tool to find information.
**Your Task:**
1. Analyze the user's question.
2. If you need information, issue a search query to find a specific fact.
3. Review the search results.
4. Repeat the search process until you have all the facts needed to answer the question.
5. Once you have gathered sufficient information, synthesize the facts and provide the final answer directly.
**Tool Usage:**
- To search, you MUST write your query between the special tokens: {BEGIN_SEARCH_QUERY}your query{END_SEARCH_QUERY}.
- The system will provide results between {BEGIN_SEARCH_RESULT}search results{END_SEARCH_RESULT}.
- You have a maximum of {MAX_SEARCH_LIMIT} search attempts.
---
**Example 1: Multi-hop Question**
**Question:** "Are both the directors of Jaws and Casino Royale from the same country?"
**Your Thought Process & Actions:**
First, I need to identify the director of Jaws.
{BEGIN_SEARCH_QUERY}who is the director of Jaws?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Jaws is a 1975 American thriller film directed by Steven Spielberg.
{END_SEARCH_RESULT}
Okay, the director of Jaws is Steven Spielberg. Now I need to find out his nationality.
{BEGIN_SEARCH_QUERY}where is Steven Spielberg from?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Steven Allan Spielberg is an American filmmaker. Born in Cincinnati, Ohio...
{END_SEARCH_RESULT}
So, Steven Spielberg is from the USA. Next, I need to find the director of Casino Royale.
{BEGIN_SEARCH_QUERY}who is the director of Casino Royale 2006?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Casino Royale is a 2006 spy film directed by Martin Campbell.
{END_SEARCH_RESULT}
The director of Casino Royale is Martin Campbell. Now I need his nationality.
{BEGIN_SEARCH_QUERY}where is Martin Campbell from?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Martin Campbell (born 24 October 1943) is a New Zealand film and television director.
{END_SEARCH_RESULT}
I have all the information. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand. They are not from the same country.
Final Answer: No, the directors of Jaws and Casino Royale are not from the same country. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand.
---
**Example 2: Simple Fact Retrieval**
**Question:** "When was the founder of craigslist born?"
**Your Thought Process & Actions:**
First, I need to know who founded craigslist.
{BEGIN_SEARCH_QUERY}who founded craigslist?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Craigslist was founded in 1995 by Craig Newmark.
{END_SEARCH_RESULT}
The founder is Craig Newmark. Now I need his birth date.
{BEGIN_SEARCH_QUERY}when was Craig Newmark born?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Craig Newmark was born on December 6, 1952.
{END_SEARCH_RESULT}
I have found the answer.
Final Answer: The founder of craigslist, Craig Newmark, was born on December 6, 1952.
---
**Important Rules:**
- **One Fact at a Time:** Decompose the problem and issue one search query at a time to find a single, specific piece of information.
- **Be Precise:** Formulate clear and precise search queries. If a search fails, rephrase it.
- **Synthesize at the End:** Do not provide the final answer until you have completed all necessary searches.
- **Language Consistency:** Your search queries should be in the same language as the user's question.
Now, begin your work. Please answer the following question by thinking step-by-step.
"""
RELEVANT_EXTRACTION_PROMPT = """You are a highly efficient information extraction module. Your sole purpose is to extract the single most relevant piece of information from the provided `Searched Web Pages` that directly answers the `Current Search Query`.
**Your Task:**
1. Read the `Current Search Query` to understand what specific information is needed.
2. Scan the `Searched Web Pages` to find the answer to that query.
3. Extract only the essential, factual information that answers the query. Be concise.
**Context (For Your Information Only):**
The `Previous Reasoning Steps` are provided to give you context on the overall goal, but your primary focus MUST be on answering the `Current Search Query`. Do not use information from the previous steps in your output.
**Output Format:**
Your response must follow one of two formats precisely.
1. **If a direct and relevant answer is found:**
- Start your response immediately with `Final Information`.
- Provide only the extracted fact(s). Do not add any extra conversational text.
*Example:*
`Current Search Query`: Where is Martin Campbell from?
`Searched Web Pages`: [Long article snippet about Martin Campbell's career, which includes the sentence "Martin Campbell (born 24 October 1943) is a New Zealand film and television director..."]
*Your Output:*
Final Information
Martin Campbell is a New Zealand film and television director.
2. **If no relevant answer that directly addresses the query is found in the web pages:**
- Start your response immediately with `Final Information`.
- Write the exact phrase: `No helpful information found.`
---
**BEGIN TASK**
**Inputs:**
- **Previous Reasoning Steps:**
{prev_reasoning}
- **Current Search Query:**
{search_query}
- **Searched Web Pages:**
{document}
"""

View File

@ -16,21 +16,23 @@
import logging import logging
import os import os
import sys import sys
import time
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session from quart import Blueprint, Quart, request, g, current_app, session, jsonify
from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors from quart_cors import cors
from common.constants import StatusEnum from common.constants import StatusEnum, RetCode
from api.db.db_models import close_connection, APIToken from api.db.db_models import close_connection, APIToken
from api.db.services import UserService from api.db.services import UserService
from api.utils.json_encode import CustomJSONEncoder from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands from api.utils import commands
from quart_auth import Unauthorized from quart_auth import Unauthorized as QuartAuthUnauthorized
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
from quart_schema import QuartSchema
from common import settings from common import settings
from api.utils.api_utils import server_error_response from api.utils.api_utils import server_error_response, get_json_result
from api.constants import API_VERSION from api.constants import API_VERSION
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
@ -38,40 +40,27 @@ settings.init_settings()
__all__ = ["app"] __all__ = ["app"]
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
def _unauthorized_message(error):
if error is None:
return UNAUTHORIZED_MESSAGE
try:
msg = repr(error)
except Exception:
return UNAUTHORIZED_MESSAGE
if msg == UNAUTHORIZED_MESSAGE:
return msg
if "Unauthorized" in msg and "401" in msg:
return msg
return UNAUTHORIZED_MESSAGE
app = Quart(__name__) app = Quart(__name__)
app = cors(app, allow_origin="*") app = cors(app, allow_origin="*")
# Add this at the beginning of your file to configure Swagger UI # openapi supported
swagger_config = { QuartSchema(app)
"headers": [],
"specs": [
{
"endpoint": "apispec",
"route": "/apispec.json",
"rule_filter": lambda rule: True, # Include all endpoints
"model_filter": lambda tag: True, # Include all models
}
],
"static_url_path": "/flasgger_static",
"swagger_ui": True,
"specs_route": "/apidocs/",
}
swagger = Swagger(
app,
config=swagger_config,
template={
"swagger": "2.0",
"info": {
"title": "RAGFlow API",
"description": "",
"version": "1.0.0",
},
"securityDefinitions": {
"ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"}
},
},
)
app.url_map.strict_slashes = False app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder app.json_encoder = CustomJSONEncoder
@ -125,18 +114,28 @@ def _load_user():
user = UserService.query( user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value access_token=access_token, status=StatusEnum.VALID.value
) )
if not user and len(authorization.split()) == 2:
objs = APIToken.query(token=authorization.split()[1])
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user: if user:
if not user[0].access_token or not user[0].access_token.strip(): if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database") logging.warning(f"User {user[0].email} has empty access_token in database")
return None return None
g.user = user[0] g.user = user[0]
return user[0] return user[0]
except Exception as e: except Exception as e_auth:
logging.warning(f"load_user got exception {e}") logging.warning(f"load_user got exception {e_auth}")
try:
authorization = request.headers.get("Authorization")
if len(authorization.split()) == 2:
objs = APIToken.query(token=authorization.split()[1])
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
g.user = user[0]
return user[0]
except Exception as e_api_token:
logging.warning(f"load_user got exception {e_api_token}")
current_user = LocalProxy(_load_user) current_user = LocalProxy(_load_user)
@ -164,10 +163,18 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
@wraps(func) @wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not current_user: # or not session.get("_user_id"): timing_enabled = os.getenv("RAGFLOW_API_TIMING")
raise Unauthorized() t_start = time.perf_counter() if timing_enabled else None
else: user = current_user
return await current_app.ensure_async(func)(*args, **kwargs) if timing_enabled:
logging.info(
"api_timing login_required auth_ms=%.2f path=%s",
(time.perf_counter() - t_start) * 1000,
request.path,
)
if not user: # or not session.get("_user_id"):
raise QuartAuthUnauthorized()
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper return wrapper
@ -277,14 +284,34 @@ client_urls_prefix = [
@app.errorhandler(404) @app.errorhandler(404)
async def not_found(error): async def not_found(error):
error_msg: str = f"The requested URL {request.path} was not found" logging.error(f"The requested URL {request.path} was not found")
logging.error(error_msg) message = f"Not Found: {request.path}"
return { response = {
"code": RetCode.NOT_FOUND,
"message": message,
"data": None,
"error": "Not Found", "error": "Not Found",
"message": error_msg, }
}, 404 return jsonify(response), RetCode.NOT_FOUND
@app.errorhandler(401)
async def unauthorized(error):
logging.warning("Unauthorized request")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.errorhandler(QuartAuthUnauthorized)
async def unauthorized_quart_auth(error):
logging.warning("Unauthorized request (quart_auth)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
@app.errorhandler(WerkzeugUnauthorized)
async def unauthorized_werkzeug(error):
logging.warning("Unauthorized request (werkzeug)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.teardown_request @app.teardown_request
def _db_close(exception): def _db_close(exception):
if exception: if exception:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import inspect import inspect
import json import json
import logging import logging
@ -29,9 +28,14 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode from common.constants import RetCode
from common.misc_utils import get_uuid from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \ from api.utils.api_utils import (
get_request_json get_json_result,
server_error_response,
validate_request,
get_data_error_result,
get_request_json,
)
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task from api.db.db_models import APIToken, Task
@ -132,12 +136,12 @@ async def run():
files = req.get("files", []) files = req.get("files", [])
inputs = req.get("inputs", {}) inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id) user_id = req.get("user_id", current_user.id)
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id): if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"]) e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
if not e: if not e:
return get_data_error_result(message="canvas not found.") return get_data_error_result(message="canvas not found.")
@ -147,7 +151,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow: if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid() task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0) ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
if not ok: if not ok:
return get_data_error_result(message=error_message) return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id}) return get_json_result(data={"message_id": task_id})
@ -540,6 +544,7 @@ def sessions(canvas_id):
@login_required @login_required
def prompts(): def prompts():
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
return get_json_result(data={ return get_json_result(data={
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER, "task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
"plan_generation": NEXT_STEP, "plan_generation": NEXT_STEP,

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio import base64
import datetime import datetime
import json import json
import logging
import re import re
import base64
import xxhash import xxhash
from quart import request from quart import request
@ -27,8 +27,14 @@ from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ from api.utils.api_utils import (
get_request_json get_data_error_result,
get_json_result,
server_error_response,
validate_request,
get_request_json,
)
from common.misc_utils import thread_pool_exec
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@ -38,7 +44,6 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
from common import settings from common import settings
from api.apps import login_required, current_user from api.apps import login_required, current_user
@manager.route('/list', methods=['POST']) # noqa: F821 @manager.route('/list', methods=['POST']) # noqa: F821
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
@ -61,7 +66,7 @@ async def list_chunk():
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -126,10 +131,15 @@ def get():
@validate_request("doc_id", "chunk_id", "content_with_weight") @validate_request("doc_id", "chunk_id", "content_with_weight")
async def set(): async def set():
req = await get_request_json() req = await get_request_json()
content_with_weight = req["content_with_weight"]
if not isinstance(content_with_weight, (str, bytes)):
raise TypeError("expected string or bytes-like object")
if isinstance(content_with_weight, bytes):
content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
d = { d = {
"id": req["chunk_id"], "id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]} "content_with_weight": content_with_weight}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req: if "important_kwd" in req:
if not isinstance(req["important_kwd"], list): if not isinstance(req["important_kwd"], list):
@ -171,20 +181,21 @@ async def set():
_d = beAdoc(d, q, a, not any( _d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a])) [rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
_d["q_%d_vec" % len(v)] = v.tolist() _d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
# update image # update image
image_base64 = req.get("image_base64", None) image_base64 = req.get("image_base64", None)
if image_base64: img_id = req.get("img_id", "")
bkt, name = req.get("img_id", "-").split("-") if image_base64 and img_id and "-" in img_id:
bkt, name = img_id.split("-", 1)
image_binary = base64.b64decode(image_base64) image_binary = base64.b64decode(image_base64)
settings.STORAGE_IMPL.put(bkt, name, image_binary) settings.STORAGE_IMPL.put(bkt, name, image_binary)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_set_sync) return await thread_pool_exec(_set_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -207,7 +218,7 @@ async def switch():
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_switch_sync) return await thread_pool_exec(_switch_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -222,19 +233,34 @@ async def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), try:
doc.kb_id): deleted_count = settings.docStoreConn.delete(condition,
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
except Exception:
return get_data_error_result(message="Chunk deleting failure") return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"] deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids) if isinstance(deleted_chunk_ids, list):
unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
has_ids = len(unique_chunk_ids) > 0
else:
unique_chunk_ids = [deleted_chunk_ids]
has_ids = deleted_chunk_ids not in (None, "")
if has_ids and deleted_count == 0:
return get_data_error_result(message="Index updating failure")
if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
chunk_number = deleted_count
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids: for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid) settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync) return await thread_pool_exec(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -244,6 +270,7 @@ async def rm():
@validate_request("doc_id", "content_with_weight") @validate_request("doc_id", "content_with_weight")
async def create(): async def create():
req = await get_request_json() req = await get_request_json()
req_id = request.headers.get("X-Request-ID")
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]} "content_with_weight": req["content_with_weight"]}
@ -260,14 +287,23 @@ async def create():
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if "tag_feas" in req: if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"] d["tag_feas"] = req["tag_feas"]
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
try: try:
def _log_response(resp, code, message):
logging.info(
"chunk_create response req_id=%s status=%s code=%s message=%s",
req_id,
getattr(resp, "status_code", None),
code,
message,
)
def _create_sync(): def _create_sync():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") resp = get_data_error_result(message="Document not found!")
_log_response(resp, RetCode.DATA_ERROR, "Document not found!")
return resp
d["kb_id"] = [doc.kb_id] d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["title_tks"] = rag_tokenizer.tokenize(doc.name)
@ -275,11 +311,15 @@ async def create():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") resp = get_data_error_result(message="Tenant not found!")
_log_response(resp, RetCode.DATA_ERROR, "Tenant not found!")
return resp
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") resp = get_data_error_result(message="Knowledgebase not found!")
_log_response(resp, RetCode.DATA_ERROR, "Knowledgebase not found!")
return resp
if kb.pagerank: if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank d[PAGERANK_FLD] = kb.pagerank
@ -293,10 +333,13 @@ async def create():
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id}) resp = get_json_result(data={"chunk_id": chunck_id})
_log_response(resp, RetCode.SUCCESS, "success")
return resp
return await asyncio.to_thread(_create_sync) return await thread_pool_exec(_create_sync)
except Exception as e: except Exception as e:
logging.info("chunk_create exception req_id=%s error=%r", req_id, e)
return server_error_response(e) return server_error_response(e)
@ -372,16 +415,23 @@ async def retrieval_test():
_question += await keyword_extraction(chat_mdl, _question) _question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = await settings.retriever.retrieval(
float(req.get("similarity_threshold", 0.0)), _question,
float(req.get("vector_similarity_weight", 0.3)), embd_mdl,
top, tenant_ids,
local_doc_ids, rerank_mdl=rerank_mdl, kb_ids,
highlight=req.get("highlight", False), page,
rank_feature=labels size,
) float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
doc_ids=local_doc_ids,
top=top,
rerank_mdl=rerank_mdl,
rank_feature=labels
)
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(_question, ck = await settings.kg_retriever.retrieval(_question,
tenant_ids, tenant_ids,
kb_ids, kb_ids,
embd_mdl, embd_mdl,
@ -407,7 +457,7 @@ async def retrieval_test():
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821 @manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
@login_required @login_required
def knowledge_graph(): async def knowledge_graph():
doc_id = request.args["doc_id"] doc_id = request.args["doc_id"]
tenant_id = DocumentService.get_tenant_id(doc_id) tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
@ -415,7 +465,7 @@ def knowledge_graph():
"doc_ids": [doc_id], "doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode from common.constants import RetCode
from api.apps import login_required, current_user from api.apps import login_required, current_user
import logging
@manager.route('/set', methods=['POST']) # noqa: F821 @manager.route('/set', methods=['POST']) # noqa: F821
@ -42,13 +43,19 @@ async def set_dialog():
if len(name.encode("utf-8")) > 255: if len(name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255") return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()): name = name.strip()
name = name.strip() if is_create:
name = duplicate_name( # only for chat creating
DialogService.query, existing_names = {
name=name, d.name.casefold()
tenant_id=current_user.id, for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
status=StatusEnum.VALID.value) if d.name
}
if name.casefold() in existing_names:
def _name_exists(name: str, **_kwargs) -> bool:
return name.casefold() in existing_names
name = duplicate_name(_name_exists, name=name)
description = req.get("description", "A helpful dialog") description = req.get("description", "A helpful dialog")
icon = req.get("icon", "") icon = req.get("icon", "")
@ -63,16 +70,30 @@ async def set_dialog():
meta_data_filter = req.get("meta_data_filter", {}) meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"] prompt_config = req["prompt_config"]
# Set default parameters for datasets with knowledge retrieval
# All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
kb_ids = req.get("kb_ids", [])
parameters = prompt_config.get("parameters")
logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
# Check if parameters is missing, None, or empty list
if kb_ids and not parameters:
# Check if system prompt uses {knowledge} placeholder
if "{knowledge}" in prompt_config.get("system", ""):
# Set default parameters for any dataset with knowledge placeholder
prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
if not is_create: if not is_create:
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']: # only for chat updating
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
for p in prompt_config["parameters"]: for p in prompt_config.get("parameters", []):
if p["optional"]: if p["optional"]:
continue continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0: if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
return get_data_error_result( return get_data_error_result(
message="Parameter '{}' is not used".format(p["key"])) message="Parameter '{}' is not used".format(p["key"]))
try: try:
e, tenant = TenantService.get_by_id(current_user.id) e, tenant = TenantService.get_by_id(current_user.id)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
# #
import asyncio
import json import json
import os.path import os.path
import pathlib import pathlib
@ -27,18 +26,19 @@ from api.db import VALID_FILE_TYPES, FileType
from api.db.db_models import Task from api.db.db_models import Task
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.document_service import DocumentService, doc_upload_and_parse
from common.metadata_utils import meta_filter, convert_conditions from common.metadata_utils import meta_filter, convert_conditions, turn2jsonschema
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, cancel_all_task_of from api.db.services.task_service import TaskService, cancel_all_task_of
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import ( from api.utils.api_utils import (
get_data_error_result, get_data_error_result,
get_json_result, get_json_result,
server_error_response, server_error_response,
validate_request, get_request_json, validate_request,
get_request_json,
) )
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
@ -62,10 +62,21 @@ async def upload():
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = files.getlist("file") file_objs = files.getlist("file")
def _close_file_objs(objs):
for obj in objs:
try:
obj.close()
except Exception:
try:
obj.stream.close()
except Exception:
pass
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == "": if file_obj.filename == "":
_close_file_objs(file_objs)
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
_close_file_objs(file_objs)
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
@ -74,8 +85,9 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id): if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id) err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
if err: if err:
files = [f[0] for f in files] if files else []
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR) return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
if not files: if not files:
@ -214,6 +226,7 @@ async def list_docs():
kb_id = request.args.get("kb_id") kb_id = request.args.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
@ -333,6 +346,8 @@ async def list_docs():
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}" doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
if doc_item.get("source_type"): if doc_item.get("source_type"):
doc_item["source_type"] = doc_item["source_type"].split("/")[0] doc_item["source_type"] = doc_item["source_type"].split("/")[0]
if doc_item["parser_config"].get("metadata"):
doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
return get_json_result(data={"total": tol, "docs": docs}) return get_json_result(data={"total": tol, "docs": docs})
except Exception as e: except Exception as e:
@ -394,6 +409,7 @@ async def doc_infos():
async def metadata_summary(): async def metadata_summary():
req = await get_request_json() req = await get_request_json()
kb_id = req.get("kb_id") kb_id = req.get("kb_id")
doc_ids = req.get("doc_ids")
if not kb_id: if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -405,7 +421,7 @@ async def metadata_summary():
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
try: try:
summary = DocumentService.get_metadata_summary(kb_id) summary = DocumentService.get_metadata_summary(kb_id, doc_ids)
return get_json_result(data={"summary": summary}) return get_json_result(data={"summary": summary})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -413,36 +429,16 @@ async def metadata_summary():
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821 @manager.route("/metadata/update", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_ids")
async def metadata_update(): async def metadata_update():
req = await get_request_json() req = await get_request_json()
kb_id = req.get("kb_id") document_ids = req.get("doc_ids")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
selector = req.get("selector", {}) or {}
updates = req.get("updates", []) or [] updates = req.get("updates", []) or []
deletes = req.get("deletes", []) or [] deletes = req.get("deletes", []) or []
if not isinstance(selector, dict):
return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR)
if not isinstance(updates, list) or not isinstance(deletes, list): if not isinstance(updates, list) or not isinstance(deletes, list):
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
metadata_condition = selector.get("metadata_condition", {}) or {}
if metadata_condition and not isinstance(metadata_condition, dict):
return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR)
document_ids = selector.get("document_ids", []) or []
if document_ids and not isinstance(document_ids, list):
return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR)
for upd in updates: for upd in updates:
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
@ -450,24 +446,8 @@ async def metadata_update():
if not isinstance(d, dict) or not d.get("key"): if not isinstance(d, dict) or not d.get("key"):
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id]) updated = DocumentService.batch_update_metadata(None, document_ids, updates, deletes)
target_doc_ids = set(kb_doc_ids) return get_json_result(data={"updated": updated})
if document_ids:
invalid_ids = set(document_ids) - set(kb_doc_ids)
if invalid_ids:
return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR)
target_doc_ids = set(document_ids)
if metadata_condition:
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
target_doc_ids = target_doc_ids & filtered_ids
if metadata_condition.get("conditions") and not target_doc_ids:
return get_json_result(data={"updated": 0, "matched_docs": 0})
target_doc_ids = list(target_doc_ids)
updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes)
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821 @manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
@ -521,31 +501,61 @@ async def change_status():
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR)
result = {} result = {}
has_error = False
for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
result[doc_id] = {"error": "No authorization."} result[doc_id] = {"error": "No authorization."}
has_error = True
continue continue
try: try:
e, doc = DocumentService.get_by_id(doc_id) e, doc = DocumentService.get_by_id(doc_id)
if not e: if not e:
result[doc_id] = {"error": "No authorization."} result[doc_id] = {"error": "No authorization."}
has_error = True
continue continue
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e: if not e:
result[doc_id] = {"error": "Can't find this dataset!"} result[doc_id] = {"error": "Can't find this dataset!"}
has_error = True
continue
current_status = str(doc.status)
if current_status == status:
result[doc_id] = {"status": status}
continue continue
if not DocumentService.update_by_id(doc_id, {"status": str(status)}): if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
result[doc_id] = {"error": "Database error (Document update)!"} result[doc_id] = {"error": "Database error (Document update)!"}
has_error = True
continue continue
status_int = int(status) status_int = int(status)
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id): if getattr(doc, "chunk_num", 0) > 0:
result[doc_id] = {"error": "Database error (docStore update)!"} try:
ok = settings.docStoreConn.update(
{"doc_id": doc_id},
{"available_int": status_int},
search.index_name(kb.tenant_id),
doc.kb_id,
)
except Exception as exc:
msg = str(exc)
if "3022" in msg:
result[doc_id] = {"error": "Document store table missing."}
else:
result[doc_id] = {"error": f"Document store update failed: {msg}"}
has_error = True
continue
if not ok:
result[doc_id] = {"error": "Database error (docStore update)!"}
has_error = True
continue
result[doc_id] = {"status": status} result[doc_id] = {"status": status}
except Exception as e: except Exception as e:
result[doc_id] = {"error": f"Internal server error: {str(e)}"} result[doc_id] = {"error": f"Internal server error: {str(e)}"}
has_error = True
if has_error:
return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR)
return get_json_result(data=result) return get_json_result(data=result)
@ -562,7 +572,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id): if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id) errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -575,10 +585,11 @@ async def rm():
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
async def run(): async def run():
req = await get_request_json() req = await get_request_json()
uid = current_user.id
try: try:
def _run_sync(): def _run_sync():
for doc_id in req["doc_ids"]: for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, uid):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
kb_table_num_map = {} kb_table_num_map = {}
@ -615,6 +626,7 @@ async def run():
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e: if not e:
raise LookupError("Can't find this dataset!") raise LookupError("Can't find this dataset!")
doc.parser_config["llm_id"] = kb.parser_config.get("llm_id")
doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False) doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False)
doc.parser_config["metadata"] = kb.parser_config.get("metadata", {}) doc.parser_config["metadata"] = kb.parser_config.get("metadata", {})
DocumentService.update_parser_config(doc.id, doc.parser_config) DocumentService.update_parser_config(doc.id, doc.parser_config)
@ -623,7 +635,7 @@ async def run():
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_run_sync) return await thread_pool_exec(_run_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -633,9 +645,10 @@ async def run():
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
async def rename(): async def rename():
req = await get_request_json() req = await get_request_json()
uid = current_user.id
try: try:
def _rename_sync(): def _rename_sync():
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], uid):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
@ -674,7 +687,7 @@ async def rename():
) )
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rename_sync) return await thread_pool_exec(_rename_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -689,7 +702,7 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id) b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data) response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = re.search(r"\.([^.]+)$", doc.name.lower())
@ -711,7 +724,7 @@ async def get(doc_id):
async def download_attachment(attachment_id): async def download_attachment(attachment_id):
try: try:
ext = request.args.get("ext", "markdown") ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id) data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
response = await make_response(data) response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -784,7 +797,7 @@ async def get_image(image_id):
if len(arr) != 2: if len(arr) != 2:
return get_data_error_result(message="Image not found.") return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm) data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data) response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG") response.headers.set("Content-Type", "image/JPEG")
return response return response

View File

@ -14,7 +14,6 @@
# limitations under the License # limitations under the License
# #
import logging import logging
import asyncio
import os import os
import pathlib import pathlib
import re import re
@ -25,7 +24,7 @@ from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from common.misc_utils import get_uuid from common.misc_utils import get_uuid, thread_pool_exec
from common.constants import RetCode, FileSource from common.constants import RetCode, FileSource
from api.db import FileType from api.db import FileType
from api.db.services import duplicate_name from api.db.services import duplicate_name
@ -35,7 +34,6 @@ from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings from common import settings
@manager.route('/upload', methods=['POST']) # noqa: F821 @manager.route('/upload', methods=['POST']) # noqa: F821
@login_required @login_required
# @validate_request("parent_id") # @validate_request("parent_id")
@ -65,7 +63,7 @@ async def upload():
async def _handle_single_file(file_obj): async def _handle_single_file(file_obj):
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id): if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!") return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path # split file name path
@ -77,35 +75,35 @@ async def upload():
file_len = len(file_obj_names) file_len = len(file_obj_names)
# get folder # get folder
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id]) file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list) len_id_list = len(file_id_list)
# create folder # create folder
if file_len != len_id_list: if file_len != len_id_list:
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1]) e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list) len_id_list)
else: else:
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2]) e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list) len_id_list)
# file type # file type
filetype = filename_type(file_obj_names[file_len - 1]) filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1] location = file_obj_names[file_len - 1]
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location): while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
location += "_" location += "_"
blob = await asyncio.to_thread(file_obj.read) blob = await thread_pool_exec(file_obj.read)
filename = await asyncio.to_thread( filename = await thread_pool_exec(
duplicate_name, duplicate_name,
FileService.query, FileService.query,
name=file_obj_names[file_len - 1], name=file_obj_names[file_len - 1],
parent_id=last_folder.id) parent_id=last_folder.id)
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob) await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file_data = { file_data = {
"id": get_uuid(), "id": get_uuid(),
"parent_id": last_folder.id, "parent_id": last_folder.id,
@ -116,7 +114,7 @@ async def upload():
"location": location, "location": location,
"size": len(blob), "size": len(blob),
} }
inserted = await asyncio.to_thread(FileService.insert, file_data) inserted = await thread_pool_exec(FileService.insert, file_data)
return inserted.to_json() return inserted.to_json()
for file_obj in file_objs: for file_obj in file_objs:
@ -249,6 +247,7 @@ def get_all_parent_folders():
async def rm(): async def rm():
req = await get_request_json() req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
uid = current_user.id
try: try:
def _delete_single_file(file): def _delete_single_file(file):
@ -287,21 +286,21 @@ async def rm():
return get_data_error_result(message="File or Folder not found!") return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id: if not file.tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id): if not check_file_team_permission(file, uid):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE: if file.source_type == FileSource.KNOWLEDGEBASE:
continue continue
if file.type == FileType.FOLDER.value: if file.type == FileType.FOLDER.value:
_delete_folder_recursive(file, current_user.id) _delete_folder_recursive(file, uid)
continue continue
_delete_single_file(file) _delete_single_file(file)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync) return await thread_pool_exec(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -357,10 +356,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id): if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location) blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
if not blob: if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id) b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
response = await make_response(blob) response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower()) ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -460,7 +459,7 @@ async def move():
_move_entry_recursive(file, dest_folder) _move_entry_recursive(file, dest_folder)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_move_sync) return await thread_pool_exec(_move_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -17,8 +17,8 @@ import json
import logging import logging
import random import random
import re import re
import asyncio
from common.metadata_utils import turn2jsonschema
from quart import request from quart import request
import numpy as np import numpy as np
@ -30,8 +30,15 @@ from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \ from api.utils.api_utils import (
get_request_json get_error_data_result,
server_error_response,
get_data_error_result,
validate_request,
not_allowed_parameters,
get_request_json,
)
from common.misc_utils import thread_pool_exec
from api.db import VALID_FILE_TYPES from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
@ -44,7 +51,6 @@ from common import settings
from common.doc_store.doc_store_base import OrderByExpr from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user from api.apps import login_required, current_user
@manager.route('/create', methods=['post']) # noqa: F821 @manager.route('/create', methods=['post']) # noqa: F821
@login_required @login_required
@validate_request("name") @validate_request("name")
@ -82,6 +88,20 @@ async def update():
return get_data_error_result( return get_data_error_result(
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
if settings.DOC_ENGINE_INFINITY:
parser_id = req.get("parser_id")
if isinstance(parser_id, str) and parser_id.lower() == "tag":
return get_json_result(
code=RetCode.OPERATING_ERROR,
message="The chunking method Tag has not been supported by Infinity yet.",
data=False,
)
if "pagerank" in req and req["pagerank"] > 0:
return get_json_result(
code=RetCode.DATA_ERROR,
message="'pagerank' can only be set when doc_engine is elasticsearch",
data=False,
)
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result( return get_json_result(
@ -130,7 +150,7 @@ async def update():
if kb.pagerank != req.get("pagerank", 0): if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0: if req.get("pagerank", 0) > 0:
await asyncio.to_thread( await thread_pool_exec(
settings.docStoreConn.update, settings.docStoreConn.update,
{"kb_id": kb.id}, {"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]}, {PAGERANK_FLD: req["pagerank"]},
@ -139,7 +159,7 @@ async def update():
) )
else: else:
# Elasticsearch requires PAGERANK_FLD be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
await asyncio.to_thread( await thread_pool_exec(
settings.docStoreConn.update, settings.docStoreConn.update,
{"exists": PAGERANK_FLD}, {"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
@ -174,6 +194,7 @@ async def update_metadata_setting():
message="Database error (Knowledgebase rename)!") message="Database error (Knowledgebase rename)!")
kb = kb.to_dict() kb = kb.to_dict()
kb["parser_config"]["metadata"] = req["metadata"] kb["parser_config"]["metadata"] = req["metadata"]
kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True)
KnowledgebaseService.update_by_id(kb["id"], kb) KnowledgebaseService.update_by_id(kb["id"], kb)
return get_json_result(data=kb) return get_json_result(data=kb)
@ -198,6 +219,8 @@ def detail():
message="Can't find this dataset!") message="Can't find this dataset!")
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[]) kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
kb["connectors"] = Connector2KbService.list_connectors(kb_id) kb["connectors"] = Connector2KbService.list_connectors(kb_id)
if kb["parser_config"].get("metadata"):
kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"])
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]: for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
if finish_at := kb.get(key): if finish_at := kb.get(key):
@ -249,7 +272,8 @@ async def list_kbs():
@validate_request("kb_id") @validate_request("kb_id")
async def rm(): async def rm():
req = await get_request_json() req = await get_request_json()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): uid = current_user.id
if not KnowledgebaseService.accessible4deletion(req["kb_id"], uid):
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
@ -257,7 +281,7 @@ async def rm():
) )
try: try:
kbs = KnowledgebaseService.query( kbs = KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]) created_by=uid, id=req["kb_id"])
if not kbs: if not kbs:
return get_json_result( return get_json_result(
data=False, message='Only owner of dataset authorized for this operation.', data=False, message='Only owner of dataset authorized for this operation.',
@ -280,17 +304,24 @@ async def rm():
File.name == kbs[0].name, File.name == kbs[0].name,
] ]
) )
# Delete the table BEFORE deleting the database record
for kb in kbs:
try:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
logging.info(f"Dropped index for dataset {kb.id}")
except Exception as e:
logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
if not KnowledgebaseService.delete_by_id(req["kb_id"]): if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result( return get_data_error_result(
message="Database error (Knowledgebase removal)!") message="Database error (Knowledgebase removal)!")
for kb in kbs: for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id) settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync) return await thread_pool_exec(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -372,7 +403,7 @@ async def rename_tags(kb_id):
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821 @manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
@login_required @login_required
def knowledge_graph(kb_id): async def knowledge_graph(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id): if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -388,7 +419,7 @@ def knowledge_graph(kb_id):
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id): if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj) return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids): if not len(sres.ids):
return get_json_result(data=obj) return get_json_result(data=obj)

View File

@ -195,6 +195,9 @@ async def add_llm():
elif factory == "MinerU": elif factory == "MinerU":
api_key = apikey_json(["api_key", "provider_order"]) api_key = apikey_json(["api_key", "provider_order"])
elif factory == "PaddleOCR":
api_key = apikey_json(["api_key", "provider_order"])
llm = { llm = {
"tenant_id": current_user.id, "tenant_id": current_user.id,
"llm_factory": factory, "llm_factory": factory,
@ -230,8 +233,7 @@ async def add_llm():
**extra, **extra,
) )
try: try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
{"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0: if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m) raise Exception(m)
except Exception as e: except Exception as e:
@ -371,17 +373,18 @@ def my_llms():
@manager.route("/list", methods=["GET"]) # noqa: F821 @manager.route("/list", methods=["GET"]) # noqa: F821
@login_required @login_required
def list_app(): async def list_app():
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"] self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = [] weighted = []
model_type = request.args.get("model_type") model_type = request.args.get("model_type")
tenant_id = current_user.id
try: try:
TenantLLMService.ensure_mineru_from_env(current_user.id) TenantLLMService.ensure_mineru_from_env(tenant_id)
objs = TenantLLMService.query(tenant_id=current_user.id) objs = TenantLLMService.query(tenant_id=tenant_id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value]) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value} status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
llms = LLMService.get_all() llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == 'Builtin' or (m.llm_name + "@" + m.fid) in status)] llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == "Builtin" or (m.llm_name + "@" + m.fid) in status)]
for m in llms: for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""): if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):

View File

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
from quart import Response, request from quart import Response, request
from api.apps import current_user, login_required from api.apps import current_user, login_required
@ -23,12 +21,11 @@ from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES from common.constants import RetCode, VALID_MCP_SERVER_TYPES
from common.misc_utils import get_uuid from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821 @manager.route("/list", methods=["POST"]) # noqa: F821
@login_required @login_required
async def list_mcp() -> Response: async def list_mcp() -> Response:
@ -108,7 +105,7 @@ async def create() -> Response:
return get_data_error_result(message="Tenant not found.") return get_data_error_result(message="Tenant not found.")
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message: if err_message:
return get_data_error_result(err_message) return get_data_error_result(err_message)
@ -160,7 +157,7 @@ async def update() -> Response:
req["id"] = mcp_id req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message: if err_message:
return get_data_error_result(err_message) return get_data_error_result(err_message)
@ -244,7 +241,7 @@ async def import_multiple() -> Response:
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {} headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}} variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers) mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message: if err_message:
results.append({"server": base_name, "success": False, "message": err_message}) results.append({"server": base_name, "success": False, "message": err_message})
continue continue
@ -324,7 +321,7 @@ async def list_tools() -> Response:
tool_call_sessions.append(tool_call_session) tool_call_sessions.append(tool_call_session)
try: try:
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout) tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
except Exception as e: except Exception as e:
return get_data_error_result(message=f"MCP list tools error: {e}") return get_data_error_result(message=f"MCP list tools error: {e}")
@ -341,7 +338,7 @@ async def list_tools() -> Response:
return server_error_response(e) return server_error_response(e)
finally: finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue # PERF: blocking call to close sessions — consider moving to background thread or task queue
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions) await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821 @manager.route("/test_tool", methods=["POST"]) # noqa: F821
@ -368,10 +365,10 @@ async def test_tool() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session) tool_call_sessions.append(tool_call_session)
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout) result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue # PERF: blocking call to close sessions — consider moving to background thread or task queue
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions) await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
return get_json_result(data=result) return get_json_result(data=result)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -425,12 +422,12 @@ async def test_mcp() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
try: try:
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout) tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
except Exception as e: except Exception as e:
return get_data_error_result(message=f"Test MCP error: {e}") return get_data_error_result(message=f"Test MCP error: {e}")
finally: finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue # PERF: blocking call to close sessions — consider moving to background thread or task queue
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session]) await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools: for tool in tools:
tool_dict = tool.model_dump() tool_dict = tool.model_dump()

View File

@ -51,7 +51,7 @@ def list_agents(tenant_id):
page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30)) items_per_page = int(request.args.get("page_size", 30))
order_by = request.args.get("orderby", "update_time") order_by = request.args.get("orderby", "update_time")
if request.args.get("desc") == "False" or request.args.get("desc") == "false": if str(request.args.get("desc","false")).lower() == "false":
desc = False desc = False
else: else:
desc = True desc = True
@ -162,6 +162,7 @@ async def webhook(agent_id: str):
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
# 4. Check webhook configuration in DSL # 4. Check webhook configuration in DSL
webhook_cfg = {}
components = dsl.get("components", {}) components = dsl.get("components", {})
for k, _ in components.items(): for k, _ in components.items():
cpn_obj = components[k]["obj"] cpn_obj = components[k]["obj"]

View File

@ -51,7 +51,9 @@ async def create(tenant_id):
req["llm_id"] = llm.pop("model_name") req["llm_id"] = llm.pop("model_name")
if req.get("llm_id") is not None: if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"): model_type = llm.get("model_type")
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
req["llm_setting"] = req.pop("llm") req["llm_setting"] = req.pop("llm")
e, tenant = TenantService.get_by_id(tenant_id) e, tenant = TenantService.get_by_id(tenant_id)
@ -174,7 +176,7 @@ async def update(tenant_id, chat_id):
req["llm_id"] = llm.pop("model_name") req["llm_id"] = llm.pop("model_name")
if req.get("llm_id") is not None: if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
model_type = llm.pop("model_type") model_type = llm.get("model_type")
model_type = model_type if model_type in ["chat", "image2text"] else "chat" model_type = model_type if model_type in ["chat", "image2text"] else "chat"
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type): if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")

View File

@ -233,6 +233,15 @@ async def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id) File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete( FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
# Drop index for this dataset
try:
from rag.nlp import search
idxnm = search.index_name(kb.tenant_id)
settings.docStoreConn.delete_idx(idxnm, kb_id)
except Exception as e:
logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
if not KnowledgebaseService.delete_by_id(kb_id): if not KnowledgebaseService.delete_by_id(kb_id):
errors.append(f"Delete dataset error for {kb_id}") errors.append(f"Delete dataset error for {kb_id}")
continue continue
@ -481,7 +490,7 @@ def list_datasets(tenant_id):
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821 @manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
@token_required @token_required
def knowledge_graph(tenant_id, dataset_id): async def knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result( return get_result(
data=False, data=False,
@ -497,7 +506,7 @@ def knowledge_graph(tenant_id, dataset_id):
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj) return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids): if not len(sres.ids):
return get_result(data=obj) return get_result(data=obj)

View File

@ -135,7 +135,7 @@ async def retrieval(tenant_id):
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if not doc_ids and metadata_condition: if not doc_ids and metadata_condition:
doc_ids = ["-999"] doc_ids = ["-999"]
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
kb.tenant_id, kb.tenant_id,
@ -150,7 +150,7 @@ async def retrieval(tenant_id):
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, ck = await settings.kg_retriever.retrieval(question,
[tenant_id], [tenant_id],
[kb_id], [kb_id],
embd_mdl, embd_mdl,

View File

@ -606,12 +606,12 @@ def list_docs(dataset_id, tenant_id):
@manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821 @manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821
@token_required @token_required
def metadata_summary(dataset_id, tenant_id): async def metadata_summary(dataset_id, tenant_id):
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
req = await get_request_json()
try: try:
summary = DocumentService.get_metadata_summary(dataset_id) summary = DocumentService.get_metadata_summary(dataset_id, req.get("doc_ids"))
return get_result(data={"summary": summary}) return get_result(data={"summary": summary})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -648,9 +648,9 @@ async def metadata_batch_update(dataset_id, tenant_id):
if not isinstance(d, dict) or not d.get("key"): if not isinstance(d, dict) or not d.get("key"):
return get_error_data_result(message="Each delete requires key.") return get_error_data_result(message="Each delete requires key.")
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
target_doc_ids = set(kb_doc_ids)
if document_ids: if document_ids:
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
target_doc_ids = set(kb_doc_ids)
invalid_ids = set(document_ids) - set(kb_doc_ids) invalid_ids = set(document_ids) - set(kb_doc_ids)
if invalid_ids: if invalid_ids:
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}") return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
@ -935,7 +935,7 @@ async def stop_parsing(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821 @manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_chunks(tenant_id, dataset_id, document_id): async def list_chunks(tenant_id, dataset_id, document_id):
""" """
List chunks of a document. List chunks of a document.
--- ---
@ -1081,7 +1081,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
_ = Chunk(**final_chunk) _ = Chunk(**final_chunk)
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total res["total"] = sres.total
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -1520,10 +1520,11 @@ async def retrieval_test(tenant_id):
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
if not isinstance(doc_ids, list): if not isinstance(doc_ids, list):
return get_error_data_result("`documents` should be a list") return get_error_data_result("`documents` should be a list")
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) if doc_ids:
for doc_id in doc_ids: doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
if doc_id not in doc_ids_list: for doc_id in doc_ids:
return get_error_data_result(f"The datasets don't own the document {doc_id}") if doc_id not in doc_ids_list:
return get_error_data_result(f"The datasets don't own the document {doc_id}")
if not doc_ids: if not doc_ids:
metadata_condition = req.get("metadata_condition", {}) or {} metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
@ -1558,7 +1559,7 @@ async def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += await keyword_extraction(chat_mdl, question) question += await keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,
@ -1575,11 +1576,11 @@ async def retrieval_test(tenant_id):
) )
if toc_enhance: if toc_enhance:
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size) cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
if cks: if cks:
ranks["chunks"] = cks ranks["chunks"] = cks
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import pathlib import pathlib
import re import re
from quart import request, make_response from quart import request, make_response
@ -24,7 +23,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
from common.misc_utils import get_uuid from common.misc_utils import get_uuid, thread_pool_exec
from api.db import FileType from api.db import FileType
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
@ -33,7 +32,6 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings from common import settings
from common.constants import RetCode from common.constants import RetCode
@manager.route('/file/upload', methods=['POST']) # noqa: F821 @manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required @token_required
async def upload(tenant_id): async def upload(tenant_id):
@ -640,7 +638,7 @@ async def get(tenant_id, file_id):
async def download_attachment(tenant_id, attachment_id): async def download_attachment(tenant_id, attachment_id):
try: try:
ext = request.args.get("ext", "markdown") ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id) data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
response = await make_response(data) response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
# #
import logging import logging
import os
import time
from quart import request from quart import request
from api.apps import login_required, current_user from api.apps import login_required, current_user
@ -21,6 +23,7 @@ from api.db import TenantPermission
from api.db.services.memory_service import MemoryService from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.task_service import TaskService
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
@ -30,26 +33,60 @@ from memory.utils.prompt_util import PromptAssembler
from common.constants import MemoryType, RetCode, ForgettingPolicy from common.constants import MemoryType, RetCode, ForgettingPolicy
@manager.route("", methods=["POST"]) # noqa: F821 @manager.route("/memories", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("name", "memory_type", "embd_id", "llm_id") @validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory(): async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json() req = await get_request_json()
t_parsed = time.perf_counter() if timing_enabled else None
# check name length # check name length
name = req["name"] name = req["name"]
memory_name = name.strip() memory_name = name.strip()
if len(memory_name) == 0: if len(memory_name) == 0:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory name cannot be empty or whitespace.") return get_error_argument_result("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT: if len(memory_name) > MEMORY_NAME_LIMIT:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
# check memory_type valid # check memory_type valid
if not isinstance(req["memory_type"], list):
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory type must be a list.")
memory_type = set(req["memory_type"]) memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType} invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type: if invalid_type:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
memory_type = list(memory_type) memory_type = list(memory_type)
try: try:
t_before_db = time.perf_counter() if timing_enabled else None
res, memory = MemoryService.create_memory( res, memory = MemoryService.create_memory(
tenant_id=current_user.id, tenant_id=current_user.id,
name=memory_name, name=memory_name,
@ -57,6 +94,15 @@ async def create_memory():
embd_id=req["embd_id"], embd_id=req["embd_id"],
llm_id=req["llm_id"] llm_id=req["llm_id"]
) )
if timing_enabled:
logging.info(
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(t_before_db - t_parsed) * 1000,
(time.perf_counter() - t_before_db) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
if res: if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory)) return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@ -67,7 +113,7 @@ async def create_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821 @manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
@login_required @login_required
async def update_memory(memory_id): async def update_memory(memory_id):
req = await get_request_json() req = await get_request_json()
@ -151,7 +197,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821 @manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required @login_required
async def delete_memory(memory_id): async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id) memory = MemoryService.get_by_memory_id(memory_id)
@ -159,14 +205,15 @@ async def delete_memory(memory_id):
return get_json_result(message=True, code=RetCode.NOT_FOUND) return get_json_result(message=True, code=RetCode.NOT_FOUND)
try: try:
MemoryService.delete_memory(memory_id) MemoryService.delete_memory(memory_id)
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) if MessageService.has_index(memory.tenant_id, memory_id):
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True) return get_json_result(message=True)
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("", methods=["GET"]) # noqa: F821 @manager.route("/memories", methods=["GET"]) # noqa: F821
@login_required @login_required
async def list_memory(): async def list_memory():
args = request.args args = request.args
@ -178,13 +225,18 @@ async def list_memory():
page = int(args.get("page", 1)) page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50)) page_size = int(args.get("page_size", 50))
# make filter dict # make filter dict
filter_dict = {"memory_type": memory_types, "storage_type": storage_type} filter_dict: dict = {"storage_type": storage_type}
if not tenant_ids: if not tenant_ids:
# restrict to current user's tenants # restrict to current user's tenants
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
else: else:
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
tenant_ids = tenant_ids[0].split(',')
filter_dict["tenant_id"] = tenant_ids filter_dict["tenant_id"] = tenant_ids
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
memory_types = memory_types[0].split(',')
filter_dict["memory_type"] = memory_types
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
@ -195,7 +247,7 @@ async def list_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821 @manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
@login_required @login_required
async def get_memory_config(memory_id): async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id) memory = MemoryService.get_with_owner_name_by_id(memory_id)
@ -204,11 +256,13 @@ async def get_memory_config(memory_id):
return get_json_result(message=True, data=format_ret_data_from_memory(memory)) return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821 @manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
@login_required @login_required
async def get_memory_detail(memory_id): async def get_memory_detail(memory_id):
args = request.args args = request.args
agent_ids = args.getlist("agent_id") agent_ids = args.getlist("agent_id")
if len(agent_ids) == 1 and ',' in agent_ids[0]:
agent_ids = agent_ids[0].split(',')
keywords = args.get("keywords", "") keywords = args.get("keywords", "")
keywords = keywords.strip() keywords = keywords.strip()
page = int(args.get("page", 1)) page = int(args.get("page", 1))
@ -219,9 +273,19 @@ async def get_memory_detail(memory_id):
messages = MessageService.list_message( messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {} agent_name_mapping = {}
extract_task_mapping = {}
if messages["message_list"]: if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
if task_list:
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
for task in task_list:
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
extract_task_mapping.update({int(task["digest"]): task})
for message in messages["message_list"]: for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
message["task"] = extract_task_mapping.get(message["message_id"], {})
for extract_msg in message["extract"]:
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True) return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)

View File

@ -24,44 +24,31 @@ from api.utils.api_utils import validate_request, get_request_json, get_error_ar
from common.constants import RetCode from common.constants import RetCode
@manager.route("", methods=["POST"]) # noqa: F821 @manager.route("/messages", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") @validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
async def add_message(): async def add_message():
req = await get_request_json() req = await get_request_json()
memory_ids = req["memory_id"] memory_ids = req["memory_id"]
agent_id = req["agent_id"]
session_id = req["session_id"]
user_id = req["user_id"] if req.get("user_id") else ""
user_input = req["user_input"]
agent_response = req["agent_response"]
res = [] message_dict = {
for memory_id in memory_ids: "user_id": req.get("user_id"),
success, msg = await memory_message_service.save_to_memory( "agent_id": req["agent_id"],
memory_id, "session_id": req["session_id"],
{ "user_input": req["user_input"],
"user_id": user_id, "agent_response": req["agent_response"],
"agent_id": agent_id, }
"session_id": session_id,
"user_input": user_input,
"agent_response": agent_response
}
)
res.append({
"memory_id": memory_id,
"success": success,
"message": msg
})
if all([r["success"] for r in res]): res, msg = await memory_message_service.queue_save_to_memory_task(memory_ids, message_dict)
return get_json_result(message="Successfully added to memories.")
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res) if res:
return get_json_result(message=msg)
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add. Detail:" + msg)
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821 @manager.route("/messages/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
@login_required @login_required
async def forget_message(memory_id: str, message_id: int): async def forget_message(memory_id: str, message_id: int):
@ -80,7 +67,7 @@ async def forget_message(memory_id: str, message_id: int):
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.") return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821 @manager.route("/messages/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
@login_required @login_required
@validate_request("status") @validate_request("status")
async def update_message(memory_id: str, message_id: int): async def update_message(memory_id: str, message_id: int):
@ -100,16 +87,17 @@ async def update_message(memory_id: str, message_id: int):
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
@manager.route("/search", methods=["GET"]) # noqa: F821 @manager.route("/messages/search", methods=["GET"]) # noqa: F821
@login_required @login_required
async def search_message(): async def search_message():
args = request.args args = request.args
print(args, flush=True)
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)] empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
if empty_fields: if empty_fields:
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.") return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
memory_ids = args.getlist("memory_id") memory_ids = args.getlist("memory_id")
if len(memory_ids) == 1 and ',' in memory_ids[0]:
memory_ids = memory_ids[0].split(',')
query = args.get("query") query = args.get("query")
similarity_threshold = float(args.get("similarity_threshold", 0.2)) similarity_threshold = float(args.get("similarity_threshold", 0.2))
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
@ -132,11 +120,13 @@ async def search_message():
return get_json_result(message=True, data=res) return get_json_result(message=True, data=res)
@manager.route("", methods=["GET"]) # noqa: F821 @manager.route("/messages", methods=["GET"]) # noqa: F821
@login_required @login_required
async def get_messages(): async def get_messages():
args = request.args args = request.args
memory_ids = args.getlist("memory_id") memory_ids = args.getlist("memory_id")
if len(memory_ids) == 1 and ',' in memory_ids[0]:
memory_ids = memory_ids[0].split(',')
agent_id = args.get("agent_id", "") agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "") session_id = args.get("session_id", "")
limit = int(args.get("limit", 10)) limit = int(args.get("limit", 10))
@ -154,7 +144,7 @@ async def get_messages():
return get_json_result(message=True, data=res) return get_json_result(message=True, data=res)
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821 @manager.route("/messages/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
@login_required @login_required
async def get_message_content(memory_id:str, message_id: int): async def get_message_content(memory_id:str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id) memory = MemoryService.get_by_memory_id(memory_id)

View File

@ -19,6 +19,10 @@ import re
import time import time
import tiktoken import tiktoken
import os
import tempfile
import logging
from quart import Response, jsonify, request from quart import Response, jsonify, request
from agent.canvas import Canvas from agent.canvas import Canvas
@ -35,7 +39,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import TenantService,UserTenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, token_required, validate_request get_result, get_request_json, server_error_response, token_required, validate_request
@ -60,7 +64,7 @@ async def create(tenant_id, chat_id):
"name": req.get("name", "New session"), "name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
"user_id": req.get("user_id", ""), "user_id": req.get("user_id", ""),
"reference": [{}], "reference": [],
} }
if not conv.get("name"): if not conv.get("name"):
return get_error_data_result(message="`name` can not be empty.") return get_error_data_result(message="`name` can not be empty.")
@ -304,9 +308,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The choices field on the last chunk will always be an empty array []. # The choices field on the last chunk will always be an empty array [].
async def streamed_response_generator(chat_id, dia, msg): async def streamed_response_generator(chat_id, dia, msg):
token_used = 0 token_used = 0
answer_cache = ""
reasoning_cache = ""
last_ans = {} last_ans = {}
full_content = ""
full_reasoning = ""
final_answer = None
final_reference = None
in_think = False
response = { response = {
"id": f"chatcmpl-{chat_id}", "id": f"chatcmpl-{chat_id}",
"choices": [ "choices": [
@ -336,47 +343,30 @@ async def chat_completion_openai_like(tenant_id, chat_id):
chat_kwargs["doc_ids"] = doc_ids_str chat_kwargs["doc_ids"] = doc_ids_str
async for ans in async_chat(dia, msg, True, **chat_kwargs): async for ans in async_chat(dia, msg, True, **chat_kwargs):
last_ans = ans last_ans = ans
answer = ans["answer"] if ans.get("final"):
if ans.get("answer"):
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL) full_content = ans["answer"]
if reasoning_match: final_answer = ans.get("answer") or full_content
reasoning_part = reasoning_match.group(1) final_reference = ans.get("reference", {})
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer
reasoning_incremental = ""
if reasoning_part:
if reasoning_part.startswith(reasoning_cache):
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
else:
reasoning_incremental = reasoning_part
reasoning_cache = reasoning_part
content_incremental = ""
if content_part:
if content_part.startswith(answer_cache):
content_incremental = content_part.replace(answer_cache, "", 1)
else:
content_incremental = content_part
answer_cache = content_part
token_used += len(reasoning_incremental) + len(content_incremental)
if not any([reasoning_incremental, content_incremental]):
continue continue
if ans.get("start_to_think"):
if reasoning_incremental: in_think = True
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental continue
else: if ans.get("end_to_think"):
response["choices"][0]["delta"]["reasoning_content"] = None in_think = False
continue
if content_incremental: delta = ans.get("answer") or ""
response["choices"][0]["delta"]["content"] = content_incremental if not delta:
else: continue
token_used += len(delta)
if in_think:
full_reasoning += delta
response["choices"][0]["delta"]["reasoning_content"] = delta
response["choices"][0]["delta"]["content"] = None response["choices"][0]["delta"]["content"] = None
else:
full_content += delta
response["choices"][0]["delta"]["content"] = delta
response["choices"][0]["delta"]["reasoning_content"] = None
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e: except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@ -388,8 +378,9 @@ async def chat_completion_openai_like(tenant_id, chat_id):
response["choices"][0]["finish_reason"] = "stop" response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used} response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
if need_reference: if need_reference:
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", [])) reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "") response["choices"][0]["delta"]["reference"] = chunks_format(reference_payload)
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n" yield "data:[DONE]\n\n"
@ -438,7 +429,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
], ],
} }
if need_reference: if need_reference:
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", [])) response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", {}))
return jsonify(response) return jsonify(response)
@ -1058,11 +1049,13 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
rerank_id = req.get("rerank_id", "")
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
async def _retrieval(): async def _retrieval():
nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
local_doc_ids = list(doc_ids) if doc_ids else [] local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = [] tenant_ids = []
_question = question _question = question
@ -1074,6 +1067,15 @@ async def retrieval_test_embedded():
meta_data_filter = search_config.get("meta_data_filter", {}) meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]: if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
# Apply search_config settings if not explicitly provided in request
if not req.get("similarity_threshold"):
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
if not req.get("vector_similarity_weight"):
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
if not req.get("top_k"):
top = int(search_config.get("top_k", top))
if not req.get("rerank_id"):
rerank_id = search_config.get("rerank_id", "")
else: else:
meta_data_filter = req.get("meta_data_filter") or {} meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]: if meta_data_filter.get("method") in ["auto", "semi_auto"]:
@ -1103,20 +1105,20 @@ async def retrieval_test_embedded():
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None rerank_mdl = None
if req.get("rerank_id"): if rerank_id:
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=rerank_id)
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
_question += await keyword_extraction(chat_mdl, _question) _question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT)) LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)
@ -1233,3 +1235,93 @@ async def mindmap():
if "error" in mind_map: if "error" in mind_map:
return server_error_response(Exception(mind_map["error"])) return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map) return get_json_result(data=mind_map)
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
@token_required
async def sequence2txt(tenant_id):
req = await request.form
stream_mode = req.get("stream", "false").lower() == "true"
files = await request.files
if "file" not in files:
return get_error_data_result(message="Missing 'file' in multipart form-data")
uploaded = files["file"]
ALLOWED_EXTS = {
".wav", ".mp3", ".m4a", ".aac",
".flac", ".ogg", ".webm",
".opus", ".wma"
}
filename = uploaded.filename or ""
suffix = os.path.splitext(filename)[-1].lower()
if suffix not in ALLOWED_EXTS:
return get_error_data_result(message=
f"Unsupported audio format: {suffix}. "
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
)
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
os.close(fd)
await uploaded.save(temp_audio_path)
tenants = TenantService.get_info_by(tenant_id)
if not tenants:
return get_error_data_result(message="Tenant not found!")
asr_id = tenants[0]["asr_id"]
if not asr_id:
return get_error_data_result(message="No default ASR model is set")
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
if not stream_mode:
text = asr_mdl.transcription(temp_audio_path)
try:
os.remove(temp_audio_path)
except Exception as e:
logging.error(f"Failed to remove temp audio file: {str(e)}")
return get_json_result(data={"text": text})
async def event_stream():
try:
for evt in asr_mdl.stream_transcription(temp_audio_path):
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
except Exception as e:
err = {"event": "error", "text": str(e)}
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
finally:
try:
os.remove(temp_audio_path)
except Exception as e:
logging.error(f"Failed to remove temp audio file: {str(e)}")
return Response(event_stream(), content_type="text/event-stream")
@manager.route("/tts", methods=["POST"]) # noqa: F821
@token_required
async def tts(tenant_id):
req = await get_request_json()
text = req["text"]
tenants = TenantService.get_info_by(tenant_id)
if not tenants:
return get_error_data_result(message="Tenant not found!")
tts_id = tenants[0]["tts_id"]
if not tts_id:
return get_error_data_result(message="No default TTS model is set")
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
def stream_audio():
try:
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
for chunk in tts_mdl.tts(txt):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp

View File

@ -177,8 +177,8 @@ def healthz():
return jsonify(result), (200 if all_ok else 500) return jsonify(result), (200 if all_ok else 500)
@manager.route("/ping", methods=["GET"]) # noqa: F821 @manager.route("/ping", methods=["GET"]) # noqa: F821
def ping(): async def ping():
return "pong", 200 return "pong", 200
@ -213,7 +213,7 @@ def new_token():
if not tenants: if not tenants:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
obj = { obj = {
"tenant_id": tenant_id, "tenant_id": tenant_id,
"token": generate_confirmation_token(), "token": generate_confirmation_token(),
@ -268,13 +268,12 @@ def token_list():
if not tenants: if not tenants:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
objs = APITokenService.query(tenant_id=tenant_id) objs = APITokenService.query(tenant_id=tenant_id)
objs = [o.to_dict() for o in objs] objs = [o.to_dict() for o in objs]
for o in objs: for o in objs:
if not o["beta"]: if not o["beta"]:
o["beta"] = generate_confirmation_token().replace( o["beta"] = generate_confirmation_token().replace("ragflow-", "")[:32]
"ragflow-", "")[:32]
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o) APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
return get_json_result(data=objs) return get_json_result(data=objs)
except Exception as e: except Exception as e:
@ -307,13 +306,19 @@ def rm(token):
type: boolean type: boolean
description: Deletion status. description: Deletion status.
""" """
APITokenService.filter_delete( try:
[APIToken.tenant_id == current_user.id, APIToken.token == token] tenants = UserTenantService.query(user_id=current_user.id)
) if not tenants:
return get_json_result(data=True) return get_data_error_result(message="Tenant not found!")
tenant_id = tenants[0].tenant_id
APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == token])
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/config', methods=['GET']) # noqa: F821 @manager.route("/config", methods=["GET"]) # noqa: F821
def get_config(): def get_config():
""" """
Get system configuration. Get system configuration.
@ -330,6 +335,4 @@ def get_config():
type: integer 0 means disabled, 1 means enabled type: integer 0 means disabled, 1 means enabled
description: Whether user registration is enabled description: Whether user registration is enabled
""" """
return get_json_result(data={ return get_json_result(data={"registerEnabled": settings.REGISTER_ENABLED})
"registerEnabled": settings.REGISTER_ENABLED
})

View File

@ -281,7 +281,11 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
except Exception as e: except Exception as e:
logging.error(f"Failed to reconnect: {e}") logging.error(f"Failed to reconnect: {e}")
time.sleep(0.1) time.sleep(0.1)
self.connect() try:
self.connect()
except Exception as e2:
logging.error(f"Failed to reconnect on second attempt: {e2}")
raise
def begin(self): def begin(self):
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
@ -352,7 +356,11 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
except Exception as e: except Exception as e:
logging.error(f"Failed to reconnect to PostgreSQL: {e}") logging.error(f"Failed to reconnect to PostgreSQL: {e}")
time.sleep(0.1) time.sleep(0.1)
self.connect() try:
self.connect()
except Exception as e2:
logging.error(f"Failed to reconnect to PostgreSQL on second attempt: {e2}")
raise
def begin(self): def begin(self):
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
@ -1197,224 +1205,93 @@ class Memory(DataBaseModel):
class Meta: class Meta:
db_table = "memory" db_table = "memory"
class SystemSettings(DataBaseModel):
name = CharField(max_length=128, primary_key=True)
source = CharField(max_length=32, null=False, index=False)
data_type = CharField(max_length=32, null=False, index=False)
value = CharField(max_length=1024, null=False, index=False)
class Meta:
db_table = "system_settings"
def alter_db_add_column(migrator, table_name, column_name, column_type):
try:
migrate(migrator.add_column(table_name, column_name, column_type))
except OperationalError as ex:
error_codes = [1060]
error_messages = ['Duplicate column name']
should_skip_error = (
(hasattr(ex, 'args') and ex.args and ex.args[0] in error_codes) or
(str(ex) in error_messages)
)
if not should_skip_error:
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, operation error: {ex}")
except Exception as ex:
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, error: {ex}")
pass
def alter_db_column_type(migrator, table_name, column_name, new_column_type):
try:
migrate(migrator.alter_column_type(table_name, column_name, new_column_type))
except Exception as ex:
logging.critical(f"Failed to alter {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name} type, error: {ex}")
pass
def alter_db_rename_column(migrator, table_name, old_column_name, new_column_name):
try:
migrate(migrator.rename_column(table_name, old_column_name, new_column_name))
except Exception:
# rename fail will lead to a weired error.
# logging.critical(f"Failed to rename {settings.DATABASE_TYPE.upper()}.{table_name} column {old_column_name} to {new_column_name}, error: {ex}")
pass
def migrate_db(): def migrate_db():
logging.disable(logging.ERROR) logging.disable(logging.ERROR)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
try: alter_db_add_column(migrator, "file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True))
migrate(migrator.add_column("file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True))) alter_db_add_column(migrator, "tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
except Exception: alter_db_add_column(migrator, "dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
pass alter_db_column_type(migrator, "dialog", "top_k", IntegerField(default=1024))
try: alter_db_add_column(migrator, "tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True))
migrate(migrator.add_column("tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))) alter_db_add_column(migrator, "api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
except Exception: alter_db_add_column(migrator, "tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
pass alter_db_add_column(migrator, "api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
try: alter_db_add_column(migrator, "task", "retry_count", IntegerField(default=0))
migrate(migrator.add_column("dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))) alter_db_column_type(migrator, "api_token", "dialog_id", CharField(max_length=32, null=True, index=True))
except Exception: alter_db_add_column(migrator, "tenant_llm", "max_tokens", IntegerField(default=8192, index=True))
pass alter_db_add_column(migrator, "api_4_conversation", "dsl", JSONField(null=True, default={}))
try: alter_db_add_column(migrator, "knowledgebase", "pagerank", IntegerField(default=0, index=False))
migrate(migrator.add_column("dialog", "top_k", IntegerField(default=1024))) alter_db_add_column(migrator, "api_token", "beta", CharField(max_length=255, null=True, index=True))
except Exception: alter_db_add_column(migrator, "task", "digest", TextField(null=True, help_text="task digest", default=""))
pass alter_db_add_column(migrator, "task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
try: alter_db_add_column(migrator, "conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True))
migrate(migrator.alter_column_type("tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True))) alter_db_add_column(migrator, "document", "meta_fields", JSONField(null=True, default={}))
except Exception: alter_db_add_column(migrator, "task", "task_type", CharField(max_length=32, null=False, default=""))
pass alter_db_add_column(migrator, "task", "priority", IntegerField(default=0))
try: alter_db_add_column(migrator, "user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True))
migrate(migrator.add_column("api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))) alter_db_add_column(migrator, "llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))
except Exception: alter_db_add_column(migrator, "mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict))
pass alter_db_rename_column(migrator, "task", "process_duation", "process_duration")
try: alter_db_rename_column(migrator, "document", "process_duation", "process_duration")
migrate(migrator.add_column("tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True))) alter_db_add_column(migrator, "document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True))
except Exception: alter_db_add_column(migrator, "api_4_conversation", "errors", TextField(null=True, help_text="errors"))
pass alter_db_add_column(migrator, "dialog", "meta_data_filter", JSONField(null=True, default={}))
try: alter_db_column_type(migrator, "canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title"))
migrate(migrator.add_column("api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))) alter_db_column_type(migrator, "canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description"))
except Exception: alter_db_add_column(migrator, "user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
pass alter_db_add_column(migrator, "canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
try: alter_db_add_column(migrator, "knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
migrate(migrator.add_column("task", "retry_count", IntegerField(default=0))) alter_db_add_column(migrator, "document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
except Exception: alter_db_add_column(migrator, "knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True))
pass alter_db_add_column(migrator, "knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True))
try: alter_db_add_column(migrator, "knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True))
migrate(migrator.alter_column_type("api_token", "dialog_id", CharField(max_length=32, null=True, index=True))) alter_db_add_column(migrator, "knowledgebase", "raptor_task_finish_at", CharField(null=True))
except Exception: alter_db_add_column(migrator, "knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True))
pass alter_db_add_column(migrator, "knowledgebase", "mindmap_task_finish_at", CharField(null=True))
try: alter_db_column_type(migrator, "tenant_llm", "api_key", TextField(null=True, help_text="API KEY"))
migrate(migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True))) alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True))
except Exception: alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False))
pass alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False))
try:
migrate(migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={})))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)))
except Exception:
pass
try:
migrate(migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")))
except Exception:
pass
try:
migrate(migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")))
except Exception:
pass
try:
migrate(migrator.add_column("conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("document", "meta_fields", JSONField(null=True, default={})))
except Exception:
pass
try:
migrate(migrator.add_column("task", "task_type", CharField(max_length=32, null=False, default="")))
except Exception:
pass
try:
migrate(migrator.add_column("task", "priority", IntegerField(default=0)))
except Exception:
pass
try:
migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False)))
except Exception:
pass
try:
migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict)))
except Exception:
pass
try:
migrate(migrator.rename_column("task", "process_duation", "process_duration"))
except Exception:
pass
try:
migrate(migrator.rename_column("document", "process_duation", "process_duration"))
except Exception:
pass
try:
migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
except Exception:
pass
try:
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
except Exception:
pass
try:
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
except Exception:
pass
try:
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
except Exception:
pass
try:
migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
except Exception:
pass
try:
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
except Exception:
pass
try:
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)))
except Exception:
pass
try:
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
except Exception:
pass
# RAG Evaluation tables
try:
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
except Exception:
pass
logging.disable(logging.NOTSET) logging.disable(logging.NOTSET)

View File

@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.system_settings_service import SystemSettingsService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
from common.constants import LLMType from common.constants import LLMType
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
@ -158,13 +159,15 @@ def add_graph_templates():
CanvasTemplateService.save(**cnvs) CanvasTemplateService.save(**cnvs)
except Exception: except Exception:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs) CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
except Exception: except Exception as e:
logging.exception("Add agent templates error: ") logging.exception(f"Add agent templates error: {e}")
def init_web_data(): def init_web_data():
start_time = time.time() start_time = time.time()
init_table()
init_llm_factory() init_llm_factory()
# if not UserService.get_all().count(): # if not UserService.get_all().count():
# init_superuser() # init_superuser()
@ -174,6 +177,31 @@ def init_web_data():
init_memory_size_cache() init_memory_size_cache()
logging.info("init web data success:{}".format(time.time() - start_time)) logging.info("init web data success:{}".format(time.time() - start_time))
def init_table():
# init system_settings
with open(os.path.join(get_project_base_directory(), "conf", "system_settings.json"), "r") as f:
records_from_file = json.load(f)["system_settings"]
record_index = {}
records_from_db = SystemSettingsService.get_all()
for index, record in enumerate(records_from_db):
record_index[record.name] = index
to_save = []
for record in records_from_file:
setting_name = record["name"]
if setting_name not in record_index:
to_save.append(record)
len_to_save = len(to_save)
if len_to_save > 0:
# not initialized
try:
SystemSettingsService.insert_many(to_save, len_to_save)
except Exception as e:
logging.exception("System settings init error: {}".format(e))
raise e
if __name__ == '__main__': if __name__ == '__main__':
init_web_db() init_web_db()

View File

@ -16,9 +16,14 @@
import logging import logging
from typing import List from typing import List
from common import settings
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType from common.constants import MemoryType, LLMType
from common.doc_store.doc_store_base import FusionExpr from common.doc_store.doc_store_base import FusionExpr
from common.misc_utils import get_uuid
from api.db.db_utils import bulk_insert_into_db
from api.db.db_models import Task
from api.db.services.task_service import TaskService
from api.db.services.memory_service import MemoryService from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
@ -82,36 +87,60 @@ async def save_to_memory(memory_id: str, message_dict: dict):
"forget_at": None, "forget_at": None,
"status": True "status": True
} for content in extracted_content]] } for content in extracted_content]]
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) return await embed_and_save(memory, message_list)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(tenant_id, memory_id):
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "FIFO":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
decrease_memory_size_cache(memory_id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
increase_memory_size_cache(memory_id, new_msg_size) async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None):
return True, "Message saved successfully." memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
msg = f"Memory '{memory_id}' not found."
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return False, msg
if memory.memory_type == MemoryType.RAW.value:
msg = f"Memory '{memory_id}' don't need to extract."
if task_id:
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return True, msg
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", ""),
task_id=task_id
)
message_list = [{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
"message_type": content["message_type"],
"source_id": source_message_id,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": content["content"],
"valid_at": content["valid_at"],
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
"forget_at": None,
"status": True
} for content in extracted_content]
if not message_list:
msg = "No memory extracted from raw message."
if task_id:
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return True, msg
if task_id:
TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."})
return await embed_and_save(memory, message_list, task_id)
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str, async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]: agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id) llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type: if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'") raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
@ -126,8 +155,12 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
else: else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id) llm = LLMBundle(tenant_id, llm_type, llm_id)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
res = await llm.async_chat(system_prompt, user_prompts, extract_conf) res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
res_json = get_json_result_from_llm_response(res) res_json = get_json_result_from_llm_response(res)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."})
return [{ return [{
"content": extracted_content["content"], "content": extracted_content["content"],
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]), "valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
@ -136,6 +169,51 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list] } for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
if task_id:
TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."})
vector_dimension = len(vector_list[0])
if not MessageService.has_index(memory.tenant_id, memory.id):
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
if not created:
error_msg = "Failed to create message index."
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
return False, error_msg
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "FIFO":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id,
size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
decrease_memory_size_cache(memory.id, delete_size)
else:
error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
return False, error_msg
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
if fail_cases:
error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
return False, error_msg
if task_id:
TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."})
increase_memory_size_cache(memory.id, new_msg_size)
return True, "Message saved successfully."
def query_message(filter_dict: dict, params: dict): def query_message(filter_dict: dict, params: dict):
""" """
:param filter_dict: { :param filter_dict: {
@ -231,3 +309,112 @@ def init_memory_size_cache():
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]): def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type) memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list}) return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict):
"""
:param memory_ids:
:param message_dict: {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
"""
def new_task(_memory_id: str, _source_id: int):
return {
"id": get_uuid(),
"doc_id": _memory_id,
"task_type": "memory",
"progress": 0.0,
"digest": str(_source_id)
}
not_found_memory = []
failed_memory = []
for memory_id in memory_ids:
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
not_found_memory.append(memory_id)
continue
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
raw_message = {
"message_id": raw_message_id,
"message_type": MemoryType.RAW.name.lower(),
"source_id": 0,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
"valid_at": timestamp_to_date(current_timestamp()),
"invalid_at": None,
"forget_at": None,
"status": True
}
res, msg = await embed_and_save(memory, [raw_message])
if not res:
failed_memory.append({"memory_id": memory_id, "fail_msg": msg})
continue
task = new_task(memory_id, raw_message_id)
bulk_insert_into_db(Task, [task], replace_on_conflict=True)
task_message = {
"id": task["id"],
"task_id": task["id"],
"task_type": task["task_type"],
"memory_id": memory_id,
"source_id": raw_message_id,
"message_dict": message_dict
}
if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message):
failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."})
error_msg = ""
if not_found_memory:
error_msg = f"Memory {not_found_memory} not found."
if failed_memory:
error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory])
if error_msg:
return False, error_msg
return True, "All add to task."
async def handle_save_to_memory_task(task_param: dict):
"""
:param task_param: {
"id": task_id
"memory_id": id
"source_id": id
"message_dict": {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
}
"""
_, task = TaskService.get_by_id(task_param["id"])
if not task:
return False, f"Task {task_param['id']} is not found."
if task.progress == -1:
return False, f"Task {task_param['id']} is already failed."
now_time = current_timestamp()
TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)})
memory_id = task_param["memory_id"]
source_id = task_param["source_id"]
message_dict = task_param["message_dict"]
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id)
if success:
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return True, msg
logging.error(msg)
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return False, msg

View File

@ -190,10 +190,15 @@ class CommonService:
data_list (list): List of dictionaries containing record data to insert. data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100. batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
""" """
current_ts = current_timestamp()
current_datetime = datetime_format(datetime.now())
with DB.atomic(): with DB.atomic():
for d in data_list: for d in data_list:
d["create_time"] = current_timestamp() d["create_time"] = current_ts
d["create_date"] = datetime_format(datetime.now()) d["create_date"] = current_datetime
d["update_time"] = current_ts
d["update_date"] = current_datetime
for i in range(0, len(data_list), batch_size): for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i : i + batch_size]).execute() cls.model.insert_many(data_list[i : i + batch_size]).execute()

View File

@ -29,7 +29,6 @@ from common.misc_utils import get_uuid
from common.constants import TaskStatus from common.constants import TaskStatus
from common.time_utils import current_timestamp, timestamp_to_date from common.time_utils import current_timestamp, timestamp_to_date
class ConnectorService(CommonService): class ConnectorService(CommonService):
model = Connector model = Connector
@ -202,6 +201,7 @@ class SyncLogsService(CommonService):
return None return None
class FileObj(BaseModel): class FileObj(BaseModel):
id: str
filename: str filename: str
blob: bytes blob: bytes
@ -209,7 +209,7 @@ class SyncLogsService(CommonService):
return self.blob return self.blob
errs = [] errs = []
files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs] files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
doc_ids = [] doc_ids = []
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src) err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
errs.extend(err) errs.extend(err)

View File

@ -64,11 +64,13 @@ class ConversationService(CommonService):
offset += limit offset += limit
return res return res
def structure_answer(conv, ans, message_id, session_id): def structure_answer(conv, ans, message_id, session_id):
reference = ans["reference"] reference = ans["reference"]
if not isinstance(reference, dict): if not isinstance(reference, dict):
reference = {} reference = {}
ans["reference"] = {} ans["reference"] = {}
is_final = ans.get("final", True)
chunk_list = chunks_format(reference) chunk_list = chunks_format(reference)
@ -81,14 +83,32 @@ def structure_answer(conv, ans, message_id, session_id):
if not conv.message: if not conv.message:
conv.message = [] conv.message = []
content = ans["answer"]
if ans.get("start_to_think"):
content = "<think>"
elif ans.get("end_to_think"):
content = "</think>"
if not conv.message or conv.message[-1].get("role", "") != "assistant": if not conv.message or conv.message[-1].get("role", "") != "assistant":
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}) conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
else: else:
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id} if is_final:
if ans.get("answer"):
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
else:
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
else:
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
if conv.reference: if conv.reference:
conv.reference[-1] = reference should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
if should_update_reference:
conv.reference[-1] = reference
return ans return ans
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty." assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import binascii import binascii
import logging import logging
import re import re
@ -23,7 +24,6 @@ from functools import partial
from timeit import default_timer as timer from timeit import default_timer as timer
from langfuse import Langfuse from langfuse import Langfuse
from peewee import fn from peewee import fn
from agentic_reasoning import DeepResearcher
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from common.constants import LLMType, ParserType, StatusEnum from common.constants import LLMType, ParserType, StatusEnum
from api.db.db_models import DB, Dialog from api.db.db_models import DB, Dialog
@ -36,7 +36,7 @@ from common.metadata_utils import apply_meta_data_filter
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from common.time_utils import current_timestamp, datetime_format from common.time_utils import current_timestamp, datetime_format
from graphrag.general.mind_map_extractor import MindMapExtractor from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.resume import forbidden_select_fields4resume from rag.advanced_rag import DeepResearcher
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
if attachments and msg: if attachments and msg:
msg[-1]["content"] += attachments msg[-1]["content"] += attachments
if stream: if stream:
last_ans = "" stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
delta_ans = "" async for kind, value, state in _stream_with_think_delta(stream_iter):
answer = "" if kind == "marker":
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
answer = ans yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
delta_ans = ""
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else: else:
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
@ -279,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
async def async_chat(dialog, messages, stream=True, **kwargs): async def async_chat(dialog, messages, stream=True, **kwargs):
logging.debug("Begin async_chat")
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
async for ans in async_chat_solo(dialog, messages, stream): async for ans in async_chat_solo(dialog, messages, stream):
@ -301,10 +296,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id) langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
if langfuse_keys: if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
if langfuse.auth_check(): try:
langfuse_tracer = langfuse if langfuse.auth_check():
trace_id = langfuse_tracer.create_trace_id() langfuse_tracer = langfuse
trace_context = {"trace_id": trace_id} trace_id = langfuse_tracer.create_trace_id()
trace_context = {"trace_id": trace_id}
except Exception:
# Skip langfuse tracing if connection fails
pass
check_langfuse_tracer_ts = timer() check_langfuse_tracer_ts = timer()
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
@ -324,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
prompt_config = dialog.prompt_config prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
logging.debug(f"field_map retrieved: {field_map}")
# try to use sql if field mapping is good to go # try to use sql if field mapping is good to go
if field_map: if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1])) logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans: # For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
yield ans yield ans
return return
else:
logging.debug("SQL failed or returned no results, falling back to vector search")
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
if p["key"] == "knowledge": if p["key"] == "knowledge":
@ -367,10 +373,11 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
knowledges = [] knowledges = []
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]: if attachments is not None and "knowledge" in param_keys:
logging.debug("Proceeding with retrieval")
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = [] knowledges = []
if prompt_config.get("reasoning", False): if prompt_config.get("reasoning", False) or kwargs.get("reasoning"):
reasoner = DeepResearcher( reasoner = DeepResearcher(
chat_mdl, chat_mdl,
prompt_config, prompt_config,
@ -386,16 +393,28 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
doc_ids=attachments, doc_ids=attachments,
), ),
) )
queue = asyncio.Queue()
async def callback(msg:str):
nonlocal queue
await queue.put(msg + "<br/>")
await callback("<START_DEEP_RESEARCH>")
task = asyncio.create_task(reasoner.research(kbinfos, questions[-1], questions[-1], callback=callback))
while True:
msg = await queue.get()
if msg.find("<START_DEEP_RESEARCH>") == 0:
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
elif msg.find("<END_DEEP_RESEARCH>") == 0:
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
break
else:
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
await task
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
elif stream:
yield think
else: else:
if embd_mdl: if embd_mdl:
kbinfos = retriever.retrieval( kbinfos = await retriever.retrieval(
" ".join(questions), " ".join(questions),
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,
@ -411,7 +430,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
rank_feature=label_question(" ".join(questions), kbs), rank_feature=label_question(" ".join(questions), kbs),
) )
if prompt_config.get("toc_enhance"): if prompt_config.get("toc_enhance"):
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids) kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
@ -421,21 +440,19 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"]) kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"): if prompt_config.get("use_kg"):
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT)) LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck) kbinfos["chunks"].insert(0, ck)
knowledges = kb_prompt(kbinfos, max_tokens) knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer() retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"] empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)} "audio_binary": tts(tts_mdl, empty_res), "final": True}
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
@ -538,21 +555,22 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
) )
if stream: if stream:
last_ans = "" stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
answer = "" last_state = None
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): async for kind, value, state in _stream_with_think_delta(stream_iter):
if thought: last_state = state
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) if kind == "marker":
answer = ans flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
delta_ans = ans[len(last_ans):] yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} full_answer = last_state.full_text if last_state else ""
delta_ans = answer[len(last_ans):] if full_answer:
if delta_ans: final = decorate_answer(thought + full_answer)
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} final["final"] = True
yield decorate_answer(thought + answer) final["audio_binary"] = None
final["answer"] = ""
yield final
else: else:
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
@ -565,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """ logging.debug(f"use_sql: Question: {question}")
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that:
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
2. Write only the SQL, no explanations or additional text.
"""
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows: # Determine which document engine we're using
doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es"
# Construct the full table name
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
base_table = index_name(tenant_id)
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
# Infinity: append kb_id to table name
table_name = f"{base_table}_{kb_ids[0]}"
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
else:
# Elasticsearch/OpenSearch: use base index name
table_name = base_table
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
# Generate engine-specific SQL prompts
if doc_engine == "infinity":
# Build Infinity prompts with JSON extraction context
json_field_names = list(field_map.keys())
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
RULES:
1. Use EXACT field names (case-sensitive) from the list below
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
4. Add AS alias for extracted field names
5. DO NOT select 'content' field
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
- Question asks to "show me" or "display" specific columns
- Question mentions "not null" or "excluding null"
- Add NULL check for count specific column
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
7. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Fields (EXACT case): {}
{} {}
Please write the SQL, only SQL, without any other explanations or text. Question: {}
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
table_name,
", ".join(json_field_names),
"\n".join([f" - {field}" for field in json_field_names]),
question
)
else:
# Build ES/OS prompts with direct field access
sys_prompt = """You are a Database Administrator. Write SQL queries.
RULES:
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
2. Quote field names starting with digit: "123_field"
3. Add IS NOT NULL in WHERE clause when:
- Question asks to "show me" or "display" specific columns
4. Include doc_id/docnm in non-aggregate statement
5. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Available fields:
{}
Question: {}
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
table_name,
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
question
)
tried_times = 0 tried_times = 0
async def get_table(): async def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times nonlocal sys_prompt, user_prompt, question, tried_times
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL) logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") # Remove think blocks if present (format: </think>...)
sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
sql = re.sub(r" +", " ", sql) # Remove markdown code blocks (```sql ... ```)
sql = re.sub(r"([;]|```).*", "", sql) sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"&", "and", sql) sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
if sql[: len("select ")] != "select ": # Remove trailing semicolon that ES SQL parser doesn't like
return None, None sql = sql.rstrip().rstrip(';').strip()
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
if kb_ids: # Add kb_id filter for ES/OS only (Infinity already has it in table name)
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" if doc_engine != "infinity" and kb_ids:
if "where" not in sql.lower(): # Build kb_filter: single KB or multiple KBs with OR
if len(kb_ids) == 1:
kb_filter = f"kb_id = '{kb_ids[0]}'"
else:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where " not in sql.lower():
o = sql.lower().split("order by") o = sql.lower().split("order by")
if len(o) > 1: if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1] sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else: else:
sql += f" WHERE {kb_filter}" sql += f" WHERE {kb_filter}"
else: elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
sql += f" AND {kb_filter}" sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
logging.debug(f"{question} get SQL(refined): {sql}") logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1 tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
tbl = settings.retriever.sql_retrieval(sql, format="json")
if tbl is None:
logging.debug("use_sql: SQL retrieval returned None")
return None, sql
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
return tbl, sql
try: try:
tbl, sql = await get_table() tbl, sql = await get_table()
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
except Exception as e: except Exception as e:
user_prompt = """ logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
# Build retry prompt with error information
if doc_engine == "infinity":
# Build Infinity error retry prompt
json_field_names = list(field_map.keys())
user_prompt = """
Table name: {};
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
{}
Question: {}
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
The SQL error you provided last time is as follows:
{}
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
else:
# Build ES/OS error retry prompt
user_prompt = """
Table name: {}; Table name: {};
Table of database fields are as follows: Table of database fields are as follows (use the field names directly in SQL):
{} {}
Question are as follows: Question are as follows:
{} {}
Please write the SQL, only SQL, without any other explanations or text. Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows: The SQL error you provided last time is as follows:
{} {}
Please correct the error and write SQL again, only SQL, without any other explanations or text. Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e) """.format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
try: try:
tbl, sql = await get_table() tbl, sql = await get_table()
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
except Exception: except Exception:
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
return return
if len(tbl["rows"]) == 0: if len(tbl["rows"]) == 0:
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
return None return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
logging.debug(f"use_sql: column_idx={column_idx}")
logging.debug(f"use_sql: field_map={field_map}")
# Helper function to map column names to display names
def map_column_name(col_name):
if col_name.lower() == "count(star)":
return "COUNT(*)"
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
# Pattern: anything AS alias_name
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
if as_match:
alias = as_match.group(1).strip('"\'')
# Use the alias for display name lookup
if alias in field_map:
display = field_map[alias]
return re.sub(r"(/.*|[^]+)", "", display)
# If alias not in field_map, try to match case-insensitively
for field_key, display_value in field_map.items():
if field_key.lower() == alias.lower():
return re.sub(r"(/.*|[^]+)", "", display_value)
# Return alias as-is if no mapping found
return alias
# Try direct mapping first (for simple column names)
if col_name in field_map:
display = field_map[col_name]
# Clean up any suffix patterns
return re.sub(r"(/.*|[^]+)", "", display)
# Try case-insensitive match for simple column names
col_lower = col_name.lower()
for field_key, display_value in field_map.items():
if field_key.lower() == col_lower:
return re.sub(r"(/.*|[^]+)", "", display_value)
# For aggregate expressions or complex expressions without AS alias,
# try to replace field names with display names
result = col_name
for field_name, display_name in field_map.items():
# Replace field_name with display_name in the expression
result = result.replace(field_name, display_name)
# Clean up any suffix patterns
result = re.sub(r"(/.*|[^]+)", "", result)
return result
# compose Markdown table # compose Markdown table
columns = ( columns = (
"|" + "|".join( "|" + "|".join(
[re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ( [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
"|Source|" if docid_idx and docid_idx else "|") "|Source|" if docid_idx and doc_name_idx else "|")
) )
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]] # Build rows ensuring column names match values - create a dict for each row
rows = [r for r in rows if re.sub(r"[ |]+", "", r)] # keyed by column name to handle any SQL column order
rows = []
for row_idx, r in enumerate(tbl["rows"]):
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
if row_idx == 0:
logging.debug(f"use_sql: First row data: {row_dict}")
row_values = []
for col_idx in column_idx:
col_name = tbl["columns"][col_idx]["name"]
value = row_dict.get(col_name, " ")
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
# Add Source column with citation marker if Source column exists
if docid_idx and doc_name_idx:
row_values.append(f" ##{row_idx}$$")
row_str = "|" + "|".join(row_values) + "|"
if re.sub(r"[ |]+", "", row_str):
rows.append(row_str)
if quota: if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = "\n".join(rows)
else: else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = "\n".join(rows)
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx: if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql) logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
# to provide source chunks, but keep the original table format answer
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
# Keep original table format as answer
answer = "\n".join([columns, line, rows])
# Now fetch doc_id, docnm_kwd to provide source chunks
# Extract WHERE clause from the original SQL
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
if where_match:
where_clause = where_match.group(1).strip()
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
# Add LIMIT to avoid fetching too many chunks
if "limit" not in chunks_sql.lower():
chunks_sql += " limit 20"
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
try:
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
# Build chunks reference - use case-insensitive matching
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
if chunks_did_idx is not None and chunks_dn_idx is not None:
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
# Build doc_aggs
doc_aggs = {}
for r in chunks_tbl["rows"]:
doc_id = r[chunks_did_idx]
doc_name = r[chunks_dn_idx]
if doc_id not in doc_aggs:
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
doc_aggs[doc_id]["count"] += 1
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
except Exception as e:
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
# Fallback: return answer without chunks
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
# Fallback to table format for other cases
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt} return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0] docid_idx = list(docid_idx)[0]
@ -680,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text.
if r[docid_idx] not in doc_aggs: if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0} doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1 doc_aggs[r[docid_idx]]["count"] += 1
return {
result = {
"answer": "\n".join([columns, line, rows]), "answer": "\n".join([columns, line, rows]),
"reference": { "reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
@ -688,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text.
}, },
"prompt": sys_prompt, "prompt": sys_prompt,
} }
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
return result
def clean_tts_text(text: str) -> str: def clean_tts_text(text: str) -> str:
if not text: if not text:
@ -733,6 +948,84 @@ def tts(tts_mdl, text):
return None return None
return binascii.hexlify(bin).decode("utf-8") return binascii.hexlify(bin).decode("utf-8")
class _ThinkStreamState:
def __init__(self) -> None:
self.full_text = ""
self.last_idx = 0
self.endswith_think = False
self.last_full = ""
self.last_model_full = ""
self.in_think = False
self.buffer = ""
def _next_think_delta(state: _ThinkStreamState) -> str:
full_text = state.full_text
if full_text == state.last_full:
return ""
state.last_full = full_text
delta_ans = full_text[state.last_idx:]
if delta_ans.find("<think>") == 0:
state.last_idx += len("<think>")
return "<think>"
if delta_ans.find("<think>") > 0:
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
state.last_idx += delta_ans.find("<think>")
return delta_text
if delta_ans.endswith("</think>"):
state.endswith_think = True
elif state.endswith_think:
state.endswith_think = False
return "</think>"
state.last_idx = len(full_text)
if full_text.endswith("</think>"):
state.last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
state = _ThinkStreamState()
async for chunk in stream_iter:
if not chunk:
continue
if chunk.startswith(state.last_model_full):
new_part = chunk[len(state.last_model_full):]
state.last_model_full = chunk
else:
new_part = chunk
state.last_model_full += chunk
if not new_part:
continue
state.full_text += new_part
delta = _next_think_delta(state)
if not delta:
continue
if delta in ("<think>", "</think>"):
if delta == "<think>" and state.in_think:
continue
if delta == "</think>" and not state.in_think:
continue
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
state.in_think = delta == "<think>"
yield ("marker", delta, state)
continue
state.buffer += delta
if num_tokens_from_string(state.buffer) < min_tokens:
continue
yield ("text", state.buffer, state)
state.buffer = ""
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
if state.endswith_think:
yield ("marker", "</think>", state)
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", []) doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None rerank_mdl = None
@ -758,7 +1051,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
kbinfos = retriever.retrieval( kbinfos = await retriever.retrieval(
question=question, question=question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
tenant_ids=tenant_ids, tenant_ids=tenant_ids,
@ -798,11 +1091,20 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
refs["chunks"] = chunks_format(refs) refs["chunks"] = chunks_format(refs)
return {"answer": answer, "reference": refs} return {"answer": answer, "reference": refs}
answer = "" stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}): last_state = None
answer = ans async for kind, value, state in _stream_with_think_delta(stream_iter):
yield {"answer": answer, "reference": {}} last_state = state
yield decorate_answer(answer) if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "final": False, **flags}
continue
yield {"answer": value, "reference": {}, "final": False}
full_answer = last_state.full_text if last_state else ""
final = decorate_answer(full_answer)
final["final"] = True
final["answer"] = ""
yield final
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
@ -825,7 +1127,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
question=question, question=question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
tenant_ids=tenant_ids, tenant_ids=tenant_ids,

View File

@ -340,14 +340,35 @@ class DocumentService(CommonService):
def remove_document(cls, doc, tenant_id): def remove_document(cls, doc, tenant_id):
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
cls.clear_chunk_num(doc.id) cls.clear_chunk_num(doc.id)
# Delete tasks first
try: try:
TaskService.filter_delete([Task.doc_id == doc.id]) TaskService.filter_delete([Task.doc_id == doc.id])
except Exception as e:
logging.warning(f"Failed to delete tasks for document {doc.id}: {e}")
# Delete chunk images (non-critical, log and continue)
try:
cls.delete_chunk_images(doc, tenant_id) cls.delete_chunk_images(doc, tenant_id)
except Exception as e:
logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}")
# Delete thumbnail (non-critical, log and continue)
try:
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) except Exception as e:
logging.warning(f"Failed to delete thumbnail for document {doc.id}: {e}")
# Delete chunks from doc store - this is critical, log errors
try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
except Exception as e:
logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}")
# Cleanup knowledge graph references (non-critical, log and continue)
try:
graph_source = settings.docStoreConn.get_fields( graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
) )
@ -360,8 +381,9 @@ class DocumentService(CommonService):
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
except Exception: except Exception as e:
pass logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
return cls.delete_by_id(doc.id) return cls.delete_by_id(doc.id)
@classmethod @classmethod
@ -423,6 +445,7 @@ class DocumentService(CommonService):
.where( .where(
cls.model.status == StatusEnum.VALID.value, cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value), ~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) | (((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap (cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts()) return list(docs.dicts())
@ -753,10 +776,25 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_metadata_summary(cls, kb_id): def get_metadata_summary(cls, kb_id, document_ids=None):
def _meta_value_type(value):
if value is None:
return None
if isinstance(value, list):
return "list"
if isinstance(value, bool):
return "string"
if isinstance(value, (int, float)):
return "number"
return "string"
fields = [cls.model.id, cls.model.meta_fields] fields = [cls.model.id, cls.model.meta_fields]
summary = {} summary = {}
for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id): type_counter = {}
query = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
if document_ids:
query = query.where(cls.model.id.in_(document_ids))
for r in query:
meta_fields = r.meta_fields or {} meta_fields = r.meta_fields or {}
if isinstance(meta_fields, str): if isinstance(meta_fields, str):
try: try:
@ -766,6 +804,11 @@ class DocumentService(CommonService):
if not isinstance(meta_fields, dict): if not isinstance(meta_fields, dict):
continue continue
for k, v in meta_fields.items(): for k, v in meta_fields.items():
value_type = _meta_value_type(v)
if value_type:
if k not in type_counter:
type_counter[k] = {}
type_counter[k][value_type] = type_counter[k].get(value_type, 0) + 1
values = v if isinstance(v, list) else [v] values = v if isinstance(v, list) else [v]
for vv in values: for vv in values:
if not vv: if not vv:
@ -774,11 +817,19 @@ class DocumentService(CommonService):
if k not in summary: if k not in summary:
summary[k] = {} summary[k] = {}
summary[k][sv] = summary[k].get(sv, 0) + 1 summary[k][sv] = summary[k].get(sv, 0) + 1
return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()} result = {}
for k, v in summary.items():
values = sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True)
type_counts = type_counter.get(k, {})
value_type = "string"
if type_counts:
value_type = max(type_counts.items(), key=lambda item: item[1])[0]
result[k] = {"type": value_type, "values": values}
return result
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None): def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None, adds=None):
updates = updates or [] updates = updates or []
deletes = deletes or [] deletes = deletes or []
if not doc_ids: if not doc_ids:
@ -801,11 +852,20 @@ class DocumentService(CommonService):
changed = False changed = False
for upd in updates: for upd in updates:
key = upd.get("key") key = upd.get("key")
if not key or key not in meta: if not key:
continue continue
if key not in meta:
meta[key] = upd.get("value")
new_value = upd.get("value") new_value = upd.get("value")
match_provided = "match" in upd match_provided = "match" in upd
if key not in meta:
if match_provided:
continue
meta[key] = dedupe_list(new_value) if isinstance(new_value, list) else new_value
changed = True
continue
if isinstance(meta[key], list): if isinstance(meta[key], list):
if not match_provided: if not match_provided:
if isinstance(new_value, list): if isinstance(new_value, list):
@ -865,7 +925,7 @@ class DocumentService(CommonService):
updated_docs = 0 updated_docs = 0
with DB.atomic(): with DB.atomic():
rows = cls.model.select(cls.model.id, cls.model.meta_fields).where( rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
(cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id) cls.model.id.in_(doc_ids)
) )
for r in rows: for r in rows:
meta = _normalize_meta(r.meta_fields or {}) meta = _normalize_meta(r.meta_fields or {})
@ -914,6 +974,8 @@ class DocumentService(CommonService):
bad = 0 bad = 0
e, doc = DocumentService.get_by_id(d["id"]) e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value status = doc.run # TaskStatus.RUNNING.value
if status == TaskStatus.CANCEL.value:
continue
doc_progress = doc.progress if doc and doc.progress else 0.0 doc_progress = doc.progress if doc and doc.progress else 0.0
special_task_running = False special_task_running = False
priority = 0 priority = 0
@ -957,7 +1019,16 @@ class DocumentService(CommonService):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else: else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
cls.update_by_id(d["id"], info) info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(
cls.model.update(info)
.where(
(cls.model.id == d["id"])
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
)
.execute()
)
except Exception as e: except Exception as e:
if str(e).find("'0'") < 0: if str(e).find("'0'") < 0:
logging.exception("fetch task exception") logging.exception("fetch task exception")
@ -990,7 +1061,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]: def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2" but progress can vary # cancelled: run == "2"
cancelled = ( cancelled = (
cls.model.select(fn.COUNT(1)) cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)) .where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
@ -1245,7 +1316,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
if try_create_idx: if try_create_idx:
if not settings.docStoreConn.index_exist(idxnm, kb_id): if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0])) settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -225,21 +225,36 @@ class EvaluationService(CommonService):
""" """
success_count = 0 success_count = 0
failure_count = 0 failure_count = 0
case_instances = []
for case_data in cases: if not cases:
success, _ = cls.add_test_case( return success_count, failure_count
dataset_id=dataset_id,
question=case_data.get("question", ""),
reference_answer=case_data.get("reference_answer"),
relevant_doc_ids=case_data.get("relevant_doc_ids"),
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
metadata=case_data.get("metadata")
)
if success: cur_timestamp = current_timestamp()
success_count += 1
else: try:
failure_count += 1 for case_data in cases:
case_id = get_uuid()
case_info = {
"id": case_id,
"dataset_id": dataset_id,
"question": case_data.get("question", ""),
"reference_answer": case_data.get("reference_answer"),
"relevant_doc_ids": case_data.get("relevant_doc_ids"),
"relevant_chunk_ids": case_data.get("relevant_chunk_ids"),
"metadata": case_data.get("metadata"),
"create_time": cur_timestamp
}
case_instances.append(EvaluationCase(**case_info))
EvaluationCase.bulk_create(case_instances, batch_size=300)
success_count = len(case_instances)
failure_count = 0
except Exception as e:
logging.error(f"Error bulk importing test cases: {str(e)}")
failure_count = len(cases)
success_count = 0
return success_count, failure_count return success_count, failure_count

View File

@ -439,6 +439,15 @@ class FileService(CommonService):
err, files = [], [] err, files = [], []
for file in file_objs: for file in file_objs:
doc_id = file.id if hasattr(file, "id") else get_uuid()
e, doc = DocumentService.get_by_id(doc_id)
if e:
blob = file.read()
settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id)
doc.size = len(blob)
doc = doc.to_dict()
DocumentService.update_by_id(doc["id"], doc)
continue
try: try:
DocumentService.check_doc_health(kb.tenant_id, file.filename) DocumentService.check_doc_health(kb.tenant_id, file.filename)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id) filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
@ -455,7 +464,6 @@ class FileService(CommonService):
blob = read_potential_broken_pdf(blob) blob = read_potential_broken_pdf(blob)
settings.STORAGE_IMPL.put(kb.id, location, blob) settings.STORAGE_IMPL.put(kb.id, location, blob)
doc_id = get_uuid()
img = thumbnail_img(filename, blob) img = thumbnail_img(filename, blob)
thumbnail_location = "" thumbnail_location = ""

View File

@ -397,7 +397,7 @@ class KnowledgebaseService(CommonService):
if dataset_name == "": if dataset_name == "":
return False, get_data_error_result(message="Dataset name can't be empty.") return False, get_data_error_result(message="Dataset name can't be empty.")
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT: if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}") return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
# Deduplicate name within tenant # Deduplicate name within tenant
dataset_name = duplicate_name( dataset_name = duplicate_name(

View File

@ -441,3 +441,46 @@ class LLMBundle(LLM4Tenant):
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end() generation.end()
return return
async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield txt
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return

View File

@ -167,4 +167,4 @@ class MemoryService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def delete_memory(cls, memory_id: str): def delete_memory(cls, memory_id: str):
return cls.model.delete().where(cls.model.id == memory_id).execute() return cls.delete_by_id(memory_id)

View File

@ -0,0 +1,44 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from common.time_utils import current_timestamp, datetime_format
from api.db.db_models import DB
from api.db.db_models import SystemSettings
from api.db.services.common_service import CommonService
class SystemSettingsService(CommonService):
model = SystemSettings
@classmethod
@DB.connection_context()
def get_by_name(cls, name):
objs = cls.model.select().where(cls.model.name.startswith(name))
return objs
@classmethod
@DB.connection_context()
def update_by_name(cls, name, obj):
obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now())
cls.model.update(obj).where(cls.model.name.startswith(name)).execute()
return SystemSettings(**obj)
@classmethod
@DB.connection_context()
def get_record_count(cls):
count = cls.model.select().count()
return count

View File

@ -121,13 +121,6 @@ class TaskService(CommonService):
.where(cls.model.id == task_id) .where(cls.model.id == task_id)
) )
docs = list(docs.dicts()) docs = list(docs.dicts())
# Assuming docs = list(docs.dicts())
if docs:
kb_config = docs[0]['kb_parser_config'] # Dict from Knowledgebase.parser_config
mineru_method = kb_config.get('mineru_parse_method', 'auto')
mineru_formula = kb_config.get('mineru_formula_enable', True)
mineru_table = kb_config.get('mineru_table_enable', True)
print(mineru_method, mineru_formula, mineru_table)
if not docs: if not docs:
return None return None
@ -179,6 +172,40 @@ class TaskService(CommonService):
return None return None
return tasks return tasks
@classmethod
@DB.connection_context()
def get_tasks_progress_by_doc_ids(cls, doc_ids: list[str]):
"""Retrieve all tasks associated with specific documents.
This method fetches all processing tasks for given document ids, ordered by
creation time. It includes task progress and chunk information.
Args:
doc_ids (str): The unique identifier of the document.
Returns:
list[dict]: List of task dictionaries containing task details.
Returns None if no tasks are found.
"""
fields = [
cls.model.id,
cls.model.doc_id,
cls.model.from_page,
cls.model.progress,
cls.model.progress_msg,
cls.model.digest,
cls.model.chunk_ids,
cls.model.create_time
]
tasks = (
cls.model.select(*fields).order_by(cls.model.create_time.desc())
.where(cls.model.doc_id.in_(doc_ids))
)
tasks = list(tasks.dicts())
if not tasks:
return None
return tasks
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str): def update_chunk_ids(cls, id: str, chunk_ids: str):
@ -495,6 +522,7 @@ def cancel_all_task_of(doc_id):
def has_canceled(task_id): def has_canceled(task_id):
try: try:
if REDIS_CONN.get(f"{task_id}-cancel"): if REDIS_CONN.get(f"{task_id}-cancel"):
logging.info(f"Task: {task_id} has been canceled")
return True return True
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)

View File

@ -19,7 +19,7 @@ import logging
from peewee import IntegrityError from peewee import IntegrityError
from langfuse import Langfuse from langfuse import Langfuse
from common import settings from common import settings
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.langfuse_service import TenantLangfuseService
@ -60,10 +60,8 @@ class TenantLLMService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_my_llms(cls, tenant_id): def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
cls.model.used_tokens, cls.model.status] objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs) return list(objs)
@ -90,6 +88,7 @@ class TenantLLMService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_model_config(cls, tenant_id, llm_type, llm_name=None): def get_model_config(cls, tenant_id, llm_type, llm_name=None):
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
e, tenant = TenantService.get_by_id(tenant_id) e, tenant = TenantService.get_by_id(tenant_id)
if not e: if not e:
raise LookupError("Tenant not found") raise LookupError("Tenant not found")
@ -119,9 +118,9 @@ class TenantLLMService(CommonService):
model_config = cls.get_api_key(tenant_id, mdlnm) model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config: if model_config:
model_config = model_config.to_dict() model_config = model_config.to_dict()
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''): elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
embedding_cfg = settings.EMBEDDING_CFG embedding_cfg = settings.EMBEDDING_CFG
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]} model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
else: else:
raise LookupError(f"Model({mdlnm}@{fid}) not authorized") raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
@ -140,33 +139,27 @@ class TenantLLMService(CommonService):
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: if model_config["llm_factory"] not in EmbeddingModel:
return None return None
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
base_url=model_config["api_base"])
elif llm_type == LLMType.RERANK: elif llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel: if model_config["llm_factory"] not in RerankModel:
return None return None
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
base_url=model_config["api_base"])
elif llm_type == LLMType.IMAGE2TEXT.value: elif llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: if model_config["llm_factory"] not in CvModel:
return None return None
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.CHAT.value: elif llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel: if model_config["llm_factory"] not in ChatModel:
return None return None
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
base_url=model_config["api_base"], **kwargs)
elif llm_type == LLMType.SPEECH2TEXT: elif llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel: if model_config["llm_factory"] not in Seq2txtModel:
return None return None
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
model_name=model_config["llm_name"], lang=lang,
base_url=model_config["api_base"])
elif llm_type == LLMType.TTS: elif llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel: if model_config["llm_factory"] not in TTSModel:
return None return None
@ -216,14 +209,11 @@ class TenantLLMService(CommonService):
try: try:
num = ( num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
cls.model.llm_factory == llm_factory if llm_factory else True)
.execute() .execute()
) )
except Exception: except Exception:
logging.exception( logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
tenant_id, llm_name)
return 0 return 0
return num return num
@ -231,9 +221,7 @@ class TenantLLMService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_openai_models(cls): def get_openai_models(cls):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs) return list(objs)
@classmethod @classmethod
@ -298,6 +286,68 @@ class TenantLLMService(CommonService):
idx += 1 idx += 1
continue continue
@classmethod
def _collect_paddleocr_env_config(cls) -> dict | None:
cfg = PADDLEOCR_DEFAULT_CONFIG
found = False
for key in PADDLEOCR_ENV_KEYS:
val = os.environ.get(key)
if val:
found = True
cfg[key] = val
return cfg if found else None
@classmethod
@DB.connection_context()
def ensure_paddleocr_from_env(cls, tenant_id: str) -> str | None:
"""
Ensure a PaddleOCR model exists for the tenant if env variables are present.
Return the existing or newly created llm_name, or None if env not set.
"""
cfg = cls._collect_paddleocr_env_config()
if not cfg:
return None
saved_paddleocr_models = cls.query(tenant_id=tenant_id, llm_factory="PaddleOCR", model_type=LLMType.OCR.value)
def _parse_api_key(raw: str) -> dict:
try:
return json.loads(raw or "{}")
except Exception:
return {}
for item in saved_paddleocr_models:
api_cfg = _parse_api_key(item.api_key)
normalized = {k: api_cfg.get(k, PADDLEOCR_DEFAULT_CONFIG.get(k)) for k in PADDLEOCR_ENV_KEYS}
if normalized == cfg:
return item.llm_name
used_names = {item.llm_name for item in saved_paddleocr_models}
idx = 1
base_name = "paddleocr-from-env"
while True:
candidate = f"{base_name}-{idx}"
if candidate in used_names:
idx += 1
continue
try:
cls.save(
tenant_id=tenant_id,
llm_factory="PaddleOCR",
llm_name=candidate,
model_type=LLMType.OCR.value,
api_key=json.dumps(cfg),
api_base="",
max_tokens=0,
)
return candidate
except IntegrityError:
logging.warning("PaddleOCR env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
used_names.add(candidate)
idx += 1
continue
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def delete_by_tenant_id(cls, tenant_id): def delete_by_tenant_id(cls, tenant_id):
@ -306,6 +356,7 @@ class TenantLLMService(CommonService):
@staticmethod @staticmethod
def llm_id2llm_type(llm_id: str) -> str | None: def llm_id2llm_type(llm_id: str) -> str | None:
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories: for llm_factory in llm_factories:
@ -340,9 +391,12 @@ class LLM4Tenant:
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
self.langfuse = None self.langfuse = None
if langfuse_keys: if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
host=langfuse_keys.host) try:
if langfuse.auth_check(): if langfuse.auth_check():
self.langfuse = langfuse self.langfuse = langfuse
trace_id = self.langfuse.create_trace_id() trace_id = self.langfuse.create_trace_id()
self.trace_context = {"trace_id": trace_id} self.trace_context = {"trace_id": trace_id}
except Exception:
# Skip langfuse tracing if connection fails
pass

View File

@ -164,7 +164,7 @@ class UserService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_all_users(cls): def get_all_users(cls):
users = cls.model.select() users = cls.model.select().order_by(cls.model.email)
return list(users) return list(users)

View File

@ -18,8 +18,8 @@
# from beartype.claw import beartype_all # <-- you didn't sign up for this # from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
from common.log_utils import init_root_logger import time
from plugin import GlobalPluginManager start_ts = time.time()
import logging import logging
import os import os
@ -40,6 +40,8 @@ from api.db.init_data import init_web_data, init_superuser
from common.versions import get_ragflow_version from common.versions import get_ragflow_version
from common.config_utils import show_configs from common.config_utils import show_configs
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
from common.log_utils import init_root_logger
from plugin import GlobalPluginManager
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event() stop_event = threading.Event()
@ -145,7 +147,7 @@ if __name__ == '__main__':
# start http server # start http server
try: try:
logging.info("RAGFlow HTTP server start...") logging.info(f"RAGFlow server is ready after {time.time() - start_ts}s initialization.")
app.run(host=settings.HOST_IP, port=settings.HOST_PORT) app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()

View File

@ -29,8 +29,15 @@ import requests
from quart import ( from quart import (
Response, Response,
jsonify, jsonify,
request request,
has_app_context,
) )
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
try:
from quart.exceptions import BadRequest as QuartBadRequest
except ImportError: # pragma: no cover - optional dependency
QuartBadRequest = None
from peewee import OperationalError from peewee import OperationalError
@ -42,41 +49,45 @@ from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout from common.connection_utils import timeout
from common.constants import RetCode from common.constants import RetCode
from common import settings from common import settings
from common.misc_utils import thread_pool_exec
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
def _safe_jsonify(payload: dict):
if has_app_context():
return jsonify(payload)
return payload
async def _coerce_request_data() -> dict: async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data.""" """Fetch JSON body with sane defaults; fallback to form data."""
if hasattr(request, "_cached_payload"):
return request._cached_payload
payload: Any = None payload: Any = None
last_error: Exception | None = None
try: body_bytes = await request.get_data()
payload = await request.get_json(force=True, silent=True) has_body = bool(body_bytes)
except Exception as e: content_type = (request.content_type or "").lower()
last_error = e is_json = content_type.startswith("application/json")
payload = None
if payload is None: if not has_body:
try: payload = {}
form = await request.form elif is_json:
payload = form.to_dict() payload = await request.get_json(force=False, silent=False)
except Exception as e: if isinstance(payload, dict):
last_error = e payload = payload or {}
payload = None elif isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
else:
raise TypeError("JSON payload must be an object.")
else:
form = await request.form
payload = form.to_dict() if form else None
if payload is None:
raise TypeError("Request body is not a valid form payload.")
if payload is None: request._cached_payload = payload
if last_error is not None: return payload
raise last_error
raise ValueError("No JSON body or form data found in request.")
if isinstance(payload, dict):
return payload or {}
if isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
async def get_request_json(): async def get_request_json():
return await _coerce_request_data() return await _coerce_request_data()
@ -115,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
continue continue
else: else:
response[key] = value response[key] = value
return jsonify(response) return _safe_jsonify(response)
def server_error_response(e): def server_error_response(e):
@ -124,16 +135,12 @@ def server_error_response(e):
try: try:
msg = repr(e).lower() msg = repr(e).lower()
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg): if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized")
resp.status_code = RetCode.UNAUTHORIZED
return resp
except Exception as ex: except Exception as ex:
logging.warning(f"error checking authorization: {ex}") logging.warning(f"error checking authorization: {ex}")
if len(e.args) > 1:
try:
serialized_data = serialize_for_json(e.args[1])
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
except Exception:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
if repr(e).find("index_not_found_exception") >= 0: if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
@ -168,7 +175,17 @@ def validate_request(*args, **kwargs):
def wrapper(func): def wrapper(func):
@wraps(func) @wraps(func)
async def decorated_function(*_args, **_kwargs): async def decorated_function(*_args, **_kwargs):
errs = process_args(await _coerce_request_data()) exception_types = (AttributeError, TypeError, WerkzeugBadRequest)
if QuartBadRequest is not None:
exception_types = exception_types + (QuartBadRequest,)
if args or kwargs:
try:
input_arguments = await _coerce_request_data()
except exception_types:
input_arguments = {}
else:
input_arguments = await _coerce_request_data()
errs = process_args(input_arguments)
if errs: if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@ -215,7 +232,7 @@ def active_required(func):
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None): def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
response = {"code": code, "message": message, "data": data} response = {"code": code, "message": message, "data": data}
return jsonify(response) return _safe_jsonify(response)
def apikey_required(func): def apikey_required(func):
@ -236,16 +253,16 @@ def apikey_required(func):
def build_error_result(code=RetCode.FORBIDDEN, message="success"): def build_error_result(code=RetCode.FORBIDDEN, message="success"):
response = {"code": code, "message": message} response = {"code": code, "message": message}
response = jsonify(response) response = _safe_jsonify(response)
response.status_code = code if hasattr(response, "status_code"):
response.status_code = code
return response return response
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None): def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
if data is None: if data is None:
return jsonify({"code": code, "message": message}) return _safe_jsonify({"code": code, "message": message})
else: return _safe_jsonify({"code": code, "message": message, "data": data})
return jsonify({"code": code, "message": message, "data": data})
def token_required(func): def token_required(func):
@ -304,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
else: else:
response["message"] = message or "Error" response["message"] = message or "Error"
return jsonify(response) return _safe_jsonify(response)
def get_error_data_result( def get_error_data_result(
@ -318,7 +335,7 @@ def get_error_data_result(
continue continue
else: else:
response[key] = value response[key] = value
return jsonify(response) return _safe_jsonify(response)
def get_error_argument_result(message="Invalid arguments"): def get_error_argument_result(message="Invalid arguments"):
@ -683,7 +700,7 @@ async def is_strong_enough(chat_model, embedding_model):
nonlocal chat_model, embedding_model nonlocal chat_model, embedding_model
if embedding_model: if embedding_model:
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]), thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]),
timeout=10 timeout=10
) )

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import xxhash
def string_to_bytes(string): def string_to_bytes(string):
return string if isinstance( return string if isinstance(
@ -22,3 +24,6 @@ def string_to_bytes(string):
def bytes_to_string(byte): def bytes_to_string(byte):
return byte.decode(encoding="utf-8") return byte.decode(encoding="utf-8")
# 128 bit = 32 character
def hash128(data: str) -> str:
return xxhash.xxh128(data).hexdigest()

View File

@ -24,7 +24,7 @@ from common.file_utils import get_project_base_directory
def crypt(line): def crypt(line):
""" """
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use. decrypt(crypt(input_string)) == base64(input_string), which frontend and ragflow_cli use.
""" """
file_path = os.path.join(get_project_base_directory(), "conf", "public.pem") file_path = os.path.join(get_project_base_directory(), "conf", "public.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome") rsa_key = RSA.importKey(open(file_path).read(), "Welcome")

View File

@ -82,6 +82,8 @@ async def validate_and_parse_json_request(request: Request, validator: type[Base
2. Extra fields added via `extras` parameter are automatically removed 2. Extra fields added via `extras` parameter are automatically removed
from the final output after validation from the final output after validation
""" """
if request.mimetype != "application/json":
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
try: try:
payload = await request.get_json() or {} payload = await request.get_json() or {}
except UnsupportedMediaType: except UnsupportedMediaType:

View File

@ -86,6 +86,9 @@ CONTENT_TYPE_MAP = {
"ico": "image/x-icon", "ico": "image/x-icon",
"avif": "image/avif", "avif": "image/avif",
"heic": "image/heic", "heic": "image/heic",
# PPTX
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
} }

View File

@ -20,6 +20,7 @@ from strenum import StrEnum
SERVICE_CONF = "service_conf.yaml" SERVICE_CONF = "service_conf.yaml"
RAG_FLOW_SERVICE_NAME = "ragflow" RAG_FLOW_SERVICE_NAME = "ragflow"
class CustomEnum(Enum): class CustomEnum(Enum):
@classmethod @classmethod
def valid(cls, value): def valid(cls, value):
@ -68,13 +69,13 @@ class ActiveEnum(Enum):
class LLMType(StrEnum): class LLMType(StrEnum):
CHAT = 'chat' CHAT = "chat"
EMBEDDING = 'embedding' EMBEDDING = "embedding"
SPEECH2TEXT = 'speech2text' SPEECH2TEXT = "speech2text"
IMAGE2TEXT = 'image2text' IMAGE2TEXT = "image2text"
RERANK = 'rerank' RERANK = "rerank"
TTS = 'tts' TTS = "tts"
OCR = 'ocr' OCR = "ocr"
class TaskStatus(StrEnum): class TaskStatus(StrEnum):
@ -86,8 +87,7 @@ class TaskStatus(StrEnum):
SCHEDULE = "5" SCHEDULE = "5"
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE}
TaskStatus.SCHEDULE}
class ParserType(StrEnum): class ParserType(StrEnum):
@ -130,7 +130,12 @@ class FileSource(StrEnum):
GOOGLE_CLOUD_STORAGE = "google_cloud_storage" GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
AIRTABLE = "airtable" AIRTABLE = "airtable"
ASANA = "asana" ASANA = "asana"
GITHUB = "github"
GITLAB = "gitlab" GITLAB = "gitlab"
IMAP = "imap"
BITBUCKET = "bitbucket"
ZENDESK = "zendesk"
class PipelineTaskType(StrEnum): class PipelineTaskType(StrEnum):
PARSE = "Parse" PARSE = "Parse"
@ -138,17 +143,20 @@ class PipelineTaskType(StrEnum):
RAPTOR = "RAPTOR" RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG" GRAPH_RAG = "GraphRAG"
MINDMAP = "Mindmap" MINDMAP = "Mindmap"
MEMORY = "Memory"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
class MCPServerType(StrEnum): class MCPServerType(StrEnum):
SSE = "sse" SSE = "sse"
STREAMABLE_HTTP = "streamable-http" STREAMABLE_HTTP = "streamable-http"
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class Storage(Enum): class Storage(Enum):
MINIO = 1 MINIO = 1
AZURE_SPN = 2 AZURE_SPN = 2
@ -160,10 +168,10 @@ class Storage(Enum):
class MemoryType(Enum): class MemoryType(Enum):
RAW = 0b0001 # 1 << 0 = 1 (0b00000001) RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010) SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100) EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000) PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
class MemoryStorageType(StrEnum): class MemoryStorageType(StrEnum):
@ -234,3 +242,10 @@ MINERU_DEFAULT_CONFIG = {
"MINERU_SERVER_URL": "", "MINERU_SERVER_URL": "",
"MINERU_DELETE_OUTPUT": 1, "MINERU_DELETE_OUTPUT": 1,
} }
PADDLEOCR_ENV_KEYS = ["PADDLEOCR_API_URL", "PADDLEOCR_ACCESS_TOKEN", "PADDLEOCR_ALGORITHM"]
PADDLEOCR_DEFAULT_CONFIG = {
"PADDLEOCR_API_URL": "",
"PADDLEOCR_ACCESS_TOKEN": None,
"PADDLEOCR_ALGORITHM": "PaddleOCR-VL",
}

View File

@ -34,10 +34,11 @@ from .google_drive.connector import GoogleDriveConnector
from .jira.connector import JiraConnector from .jira.connector import JiraConnector
from .sharepoint_connector import SharePointConnector from .sharepoint_connector import SharePointConnector
from .teams_connector import TeamsConnector from .teams_connector import TeamsConnector
from .webdav_connector import WebDAVConnector
from .moodle_connector import MoodleConnector from .moodle_connector import MoodleConnector
from .airtable_connector import AirtableConnector from .airtable_connector import AirtableConnector
from .asana_connector import AsanaConnector from .asana_connector import AsanaConnector
from .imap_connector import ImapConnector
from .zendesk_connector import ZendeskConnector
from .config import BlobType, DocumentSource from .config import BlobType, DocumentSource
from .models import Document, TextSection, ImageSection, BasicExpertInfo from .models import Document, TextSection, ImageSection, BasicExpertInfo
from .exceptions import ( from .exceptions import (
@ -60,7 +61,6 @@ __all__ = [
"JiraConnector", "JiraConnector",
"SharePointConnector", "SharePointConnector",
"TeamsConnector", "TeamsConnector",
"WebDAVConnector",
"MoodleConnector", "MoodleConnector",
"BlobType", "BlobType",
"DocumentSource", "DocumentSource",
@ -75,4 +75,6 @@ __all__ = [
"UnexpectedValidationError", "UnexpectedValidationError",
"AirtableConnector", "AirtableConnector",
"AsanaConnector", "AsanaConnector",
"ImapConnector",
"ZendeskConnector",
] ]

View File

@ -1,6 +1,6 @@
from datetime import datetime, timezone from datetime import datetime, timezone
import logging import logging
from typing import Any from typing import Any, Generator
import requests import requests
@ -8,8 +8,8 @@ from pyairtable import Api as AirtableApi
from common.data_source.config import AIRTABLE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, DocumentSource from common.data_source.config import AIRTABLE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.interfaces import LoadConnector from common.data_source.interfaces import LoadConnector, PollConnector
from common.data_source.models import Document, GenerateDocumentsOutput from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
from common.data_source.utils import extract_size_bytes, get_file_ext from common.data_source.utils import extract_size_bytes, get_file_ext
class AirtableClientNotSetUpError(PermissionError): class AirtableClientNotSetUpError(PermissionError):
@ -19,7 +19,7 @@ class AirtableClientNotSetUpError(PermissionError):
) )
class AirtableConnector(LoadConnector): class AirtableConnector(LoadConnector, PollConnector):
""" """
Lightweight Airtable connector. Lightweight Airtable connector.
@ -75,7 +75,6 @@ class AirtableConnector(LoadConnector):
batch: list[Document] = [] batch: list[Document] = []
for record in records: for record in records:
print(record)
record_id = record.get("id") record_id = record.get("id")
fields = record.get("fields", {}) fields = record.get("fields", {})
created_time = record.get("createdTime") created_time = record.get("createdTime")
@ -132,6 +131,26 @@ class AirtableConnector(LoadConnector):
if batch: if batch:
yield batch yield batch
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
"""Poll source to get documents"""
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
for batch in self.load_from_state():
filtered: list[Document] = []
for doc in batch:
if not doc.doc_updated_at:
continue
doc_dt = doc.doc_updated_at.astimezone(timezone.utc)
if start_dt <= doc_dt < end_dt:
filtered.append(doc)
if filtered:
yield filtered
if __name__ == "__main__": if __name__ == "__main__":
import os import os

View File

View File

@ -0,0 +1,388 @@
from __future__ import annotations
import copy
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TYPE_CHECKING
from typing_extensions import override
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.config import DocumentSource
from common.data_source.config import REQUEST_TIMEOUT_SECONDS
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError,
)
from common.data_source.interfaces import CheckpointedConnector
from common.data_source.interfaces import CheckpointOutput
from common.data_source.interfaces import IndexingHeartbeatInterface
from common.data_source.interfaces import SecondsSinceUnixEpoch
from common.data_source.interfaces import SlimConnectorWithPermSync
from common.data_source.models import ConnectorCheckpoint
from common.data_source.models import ConnectorFailure
from common.data_source.models import DocumentFailure
from common.data_source.models import SlimDocument
from common.data_source.bitbucket.utils import (
build_auth_client,
list_repositories,
map_pr_to_document,
paginate,
PR_LIST_RESPONSE_FIELDS,
SLIM_PR_LIST_RESPONSE_FIELDS,
)
if TYPE_CHECKING:
import httpx
class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Bitbucket PR indexing.
Fields:
repos_queue: Materialized list of repository slugs to process.
current_repo_index: Index of the repository currently being processed.
next_url: Bitbucket "next" URL for continuing pagination within the current repo.
"""
repos_queue: list[str] = []
current_repo_index: int = 0
next_url: str | None = None
class BitbucketConnector(
CheckpointedConnector[BitbucketConnectorCheckpoint],
SlimConnectorWithPermSync,
):
"""Connector for indexing Bitbucket Cloud pull requests.
Args:
workspace: Bitbucket workspace ID.
repositories: Comma-separated list of repository slugs to index.
projects: Comma-separated list of project keys to index all repositories within.
batch_size: Max number of documents to yield per batch.
"""
def __init__(
self,
workspace: str,
repositories: str | None = None,
projects: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self._repositories = (
[s.strip() for s in repositories.split(",") if s.strip()]
if repositories
else None
)
self._projects: list[str] | None = (
[s.strip() for s in projects.split(",") if s.strip()] if projects else None
)
self.batch_size = batch_size
self.email: str | None = None
self.api_token: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load API token-based credentials.
Expects a dict with keys: `bitbucket_email`, `bitbucket_api_token`.
"""
self.email = credentials.get("bitbucket_email")
self.api_token = credentials.get("bitbucket_api_token")
if not self.email or not self.api_token:
raise ConnectorMissingCredentialError("Bitbucket")
return None
def _client(self) -> httpx.Client:
"""Build an authenticated HTTP client or raise if credentials missing."""
if not self.email or not self.api_token:
raise ConnectorMissingCredentialError("Bitbucket")
return build_auth_client(self.email, self.api_token)
def _iter_pull_requests_for_repo(
self,
client: httpx.Client,
repo_slug: str,
params: dict[str, Any] | None = None,
start_url: str | None = None,
on_page: Callable[[str | None], None] | None = None,
) -> Iterator[dict[str, Any]]:
base = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}/{repo_slug}/pullrequests"
yield from paginate(
client,
base,
params,
start_url=start_url,
on_page=on_page,
)
def _build_params(
self,
fields: str = PR_LIST_RESPONSE_FIELDS,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> dict[str, Any]:
"""Build Bitbucket fetch params.
Always include OPEN, MERGED, and DECLINED PRs. If both ``start`` and
``end`` are provided, apply a single updated_on time window.
"""
def _iso(ts: SecondsSinceUnixEpoch) -> str:
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
def _tc_epoch(
lower_epoch: SecondsSinceUnixEpoch | None,
upper_epoch: SecondsSinceUnixEpoch | None,
) -> str | None:
if lower_epoch is not None and upper_epoch is not None:
lower_iso = _iso(lower_epoch)
upper_iso = _iso(upper_epoch)
return f'(updated_on > "{lower_iso}" AND updated_on <= "{upper_iso}")'
return None
params: dict[str, Any] = {"fields": fields, "pagelen": 50}
time_clause = _tc_epoch(start, end)
q = '(state = "OPEN" OR state = "MERGED" OR state = "DECLINED")'
if time_clause:
q = f"{q} AND {time_clause}"
params["q"] = q
return params
def _iter_target_repositories(self, client: httpx.Client) -> Iterator[str]:
"""Yield repository slugs based on configuration.
Priority:
- repositories list
- projects list (list repos by project key)
- workspace (all repos)
"""
if self._repositories:
for slug in self._repositories:
yield slug
return
if self._projects:
for project_key in self._projects:
for repo in list_repositories(client, self.workspace, project_key):
slug_val = repo.get("slug")
if isinstance(slug_val, str) and slug_val:
yield slug_val
return
for repo in list_repositories(client, self.workspace, None):
slug_val = repo.get("slug")
if isinstance(slug_val, str) and slug_val:
yield slug_val
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: BitbucketConnectorCheckpoint,
) -> CheckpointOutput[BitbucketConnectorCheckpoint]:
"""Resumable PR ingestion across repos and pages within a time window.
Yields Documents (or ConnectorFailure for per-PR mapping failures) and returns
an updated checkpoint that records repo position and next page URL.
"""
new_checkpoint = copy.deepcopy(checkpoint)
with self._client() as client:
# Materialize target repositories once
if not new_checkpoint.repos_queue:
# Preserve explicit order; otherwise ensure deterministic ordering
repos_list = list(self._iter_target_repositories(client))
new_checkpoint.repos_queue = sorted(set(repos_list))
new_checkpoint.current_repo_index = 0
new_checkpoint.next_url = None
repos = new_checkpoint.repos_queue
if not repos or new_checkpoint.current_repo_index >= len(repos):
new_checkpoint.has_more = False
return new_checkpoint
repo_slug = repos[new_checkpoint.current_repo_index]
first_page_params = self._build_params(
fields=PR_LIST_RESPONSE_FIELDS,
start=start,
end=end,
)
def _on_page(next_url: str | None) -> None:
new_checkpoint.next_url = next_url
for pr in self._iter_pull_requests_for_repo(
client,
repo_slug,
params=first_page_params,
start_url=new_checkpoint.next_url,
on_page=_on_page,
):
try:
document = map_pr_to_document(pr, self.workspace, repo_slug)
yield document
except Exception as e:
pr_id = pr.get("id")
pr_link = (
f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}"
if pr_id is not None
else None
)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=(
f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:{pr_id}"
if pr_id is not None
else f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:unknown"
),
document_link=pr_link,
),
failure_message=f"Failed to process Bitbucket PR: {e}",
exception=e,
)
# Advance to next repository (if any) and set has_more accordingly
new_checkpoint.current_repo_index += 1
new_checkpoint.next_url = None
new_checkpoint.has_more = new_checkpoint.current_repo_index < len(repos)
return new_checkpoint
@override
def build_dummy_checkpoint(self) -> BitbucketConnectorCheckpoint:
"""Create an initial checkpoint with work remaining."""
return BitbucketConnectorCheckpoint(has_more=True)
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> BitbucketConnectorCheckpoint:
"""Validate and deserialize a checkpoint instance from JSON."""
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> Iterator[list[SlimDocument]]:
"""Return only document IDs for all existing pull requests."""
batch: list[SlimDocument] = []
params = self._build_params(
fields=SLIM_PR_LIST_RESPONSE_FIELDS,
start=start,
end=end,
)
with self._client() as client:
for slug in self._iter_target_repositories(client):
for pr in self._iter_pull_requests_for_repo(
client, slug, params=params
):
pr_id = pr["id"]
doc_id = f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{slug}:pr:{pr_id}"
batch.append(SlimDocument(id=doc_id))
if len(batch) >= self.batch_size:
yield batch
batch = []
if callback:
if callback.should_stop():
# Note: this is not actually used for permission sync yet, just pruning
raise RuntimeError(
"bitbucket_pr_sync: Stop signal detected"
)
callback.progress("bitbucket_pr_sync", len(batch))
if batch:
yield batch
def validate_connector_settings(self) -> None:
"""Validate Bitbucket credentials and workspace access by probing a lightweight endpoint.
Raises:
CredentialExpiredError: on HTTP 401
InsufficientPermissionsError: on HTTP 403
UnexpectedValidationError: on any other failure
"""
try:
with self._client() as client:
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}"
resp = client.get(
url,
params={"pagelen": 1, "fields": "pagelen"},
timeout=REQUEST_TIMEOUT_SECONDS,
)
if resp.status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Bitbucket credentials (HTTP 401)."
)
if resp.status_code == 403:
raise InsufficientPermissionsError(
"Insufficient permissions to access Bitbucket workspace (HTTP 403)."
)
if resp.status_code < 200 or resp.status_code >= 300:
raise UnexpectedValidationError(
f"Unexpected Bitbucket error (status={resp.status_code})."
)
except Exception as e:
# Network or other unexpected errors
if isinstance(
e,
(
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError,
ConnectorMissingCredentialError,
),
):
raise
raise UnexpectedValidationError(
f"Unexpected error while validating Bitbucket settings: {e}"
)
if __name__ == "__main__":
bitbucket = BitbucketConnector(
workspace="<YOUR_WORKSPACE>"
)
bitbucket.load_credentials({
"bitbucket_email": "<YOUR_EMAIL>",
"bitbucket_api_token": "<YOUR_API_TOKEN>",
})
bitbucket.validate_connector_settings()
print("Credentials validated successfully.")
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
end_time = datetime.now(timezone.utc)
for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync(
start=start_time.timestamp(),
end=end_time.timestamp(),
):
for doc in doc_batch:
print(doc)
bitbucket_checkpoint = bitbucket.build_dummy_checkpoint()
while bitbucket_checkpoint.has_more:
gen = bitbucket.load_from_checkpoint(
start=start_time.timestamp(),
end=end_time.timestamp(),
checkpoint=bitbucket_checkpoint,
)
while True:
try:
doc = next(gen)
print(doc)
except StopIteration as e:
bitbucket_checkpoint = e.value
break

View File

@ -0,0 +1,288 @@
from __future__ import annotations
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
import httpx
from common.data_source.config import REQUEST_TIMEOUT_SECONDS, DocumentSource
from common.data_source.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from common.data_source.utils import sanitize_filename
from common.data_source.models import BasicExpertInfo, Document
from common.data_source.cross_connector_utils.retry_wrapper import retry_builder
# Fields requested from Bitbucket PR list endpoint to ensure rich PR data
PR_LIST_RESPONSE_FIELDS: str = ",".join(
[
"next",
"page",
"pagelen",
"values.author",
"values.close_source_branch",
"values.closed_by",
"values.comment_count",
"values.created_on",
"values.description",
"values.destination",
"values.draft",
"values.id",
"values.links",
"values.merge_commit",
"values.participants",
"values.reason",
"values.rendered",
"values.reviewers",
"values.source",
"values.state",
"values.summary",
"values.task_count",
"values.title",
"values.type",
"values.updated_on",
]
)
# Minimal fields for slim retrieval (IDs only)
SLIM_PR_LIST_RESPONSE_FIELDS: str = ",".join(
[
"next",
"page",
"pagelen",
"values.id",
]
)
# Minimal fields for repository list calls
REPO_LIST_RESPONSE_FIELDS: str = ",".join(
[
"next",
"page",
"pagelen",
"values.slug",
"values.full_name",
"values.project.key",
]
)
class BitbucketRetriableError(Exception):
"""Raised for retriable Bitbucket conditions (429, 5xx)."""
class BitbucketNonRetriableError(Exception):
"""Raised for non-retriable Bitbucket client errors (4xx except 429)."""
@retry_builder(
tries=6,
delay=1,
backoff=2,
max_delay=30,
exceptions=(BitbucketRetriableError, httpx.RequestError),
)
@rate_limit_builder(max_calls=60, period=60)
def bitbucket_get(
client: httpx.Client, url: str, params: dict[str, Any] | None = None
) -> httpx.Response:
"""Perform a GET against Bitbucket with retry and rate limiting.
Retries on 429 and 5xx responses, and on transport errors. Honors
`Retry-After` header for 429 when present by sleeping before retrying.
"""
try:
response = client.get(url, params=params, timeout=REQUEST_TIMEOUT_SECONDS)
except httpx.RequestError:
# Allow retry_builder to handle retries of transport errors
raise
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
status = e.response.status_code if e.response is not None else None
if status == 429:
retry_after = e.response.headers.get("Retry-After") if e.response else None
if retry_after is not None:
try:
time.sleep(int(retry_after))
except (TypeError, ValueError):
pass
raise BitbucketRetriableError("Bitbucket rate limit exceeded (429)") from e
if status is not None and 500 <= status < 600:
raise BitbucketRetriableError(f"Bitbucket server error: {status}") from e
if status is not None and 400 <= status < 500:
raise BitbucketNonRetriableError(f"Bitbucket client error: {status}") from e
# Unknown status, propagate
raise
return response
def build_auth_client(email: str, api_token: str) -> httpx.Client:
"""Create an authenticated httpx client for Bitbucket Cloud API."""
return httpx.Client(auth=(email, api_token), http2=True)
def paginate(
client: httpx.Client,
url: str,
params: dict[str, Any] | None = None,
start_url: str | None = None,
on_page: Callable[[str | None], None] | None = None,
) -> Iterator[dict[str, Any]]:
"""Iterate over paginated Bitbucket API responses yielding individual values.
Args:
client: Authenticated HTTP client.
url: Base collection URL (first page when start_url is None).
params: Query params for the first page.
start_url: If provided, start from this absolute URL (ignores params).
on_page: Optional callback invoked after each page with the next page URL.
"""
next_url = start_url or url
# If resuming from a next URL, do not pass params again
query = params.copy() if params else None
query = None if start_url else query
while next_url:
resp = bitbucket_get(client, next_url, params=query)
data = resp.json()
values = data.get("values", [])
for item in values:
yield item
next_url = data.get("next")
if on_page is not None:
on_page(next_url)
# only include params on first call, next_url will contain all necessary params
query = None
def list_repositories(
client: httpx.Client, workspace: str, project_key: str | None = None
) -> Iterator[dict[str, Any]]:
"""List repositories in a workspace, optionally filtered by project key."""
base_url = f"https://api.bitbucket.org/2.0/repositories/{workspace}"
params: dict[str, Any] = {
"fields": REPO_LIST_RESPONSE_FIELDS,
"pagelen": 100,
# Ensure deterministic ordering
"sort": "full_name",
}
if project_key:
params["q"] = f'project.key="{project_key}"'
yield from paginate(client, base_url, params)
def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Document:
"""Map a Bitbucket pull request JSON to Onyx Document."""
pr_id = pr["id"]
title = pr.get("title") or f"PR {pr_id}"
description = pr.get("description") or ""
state = pr.get("state")
draft = pr.get("draft", False)
author = pr.get("author", {})
reviewers = pr.get("reviewers", [])
participants = pr.get("participants", [])
link = pr.get("links", {}).get("html", {}).get("href") or (
f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}"
)
created_on = pr.get("created_on")
updated_on = pr.get("updated_on")
updated_dt = (
datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if isinstance(updated_on, str)
else None
)
source_branch = pr.get("source", {}).get("branch", {}).get("name", "")
destination_branch = pr.get("destination", {}).get("branch", {}).get("name", "")
approved_by = [
_get_user_name(p.get("user", {})) for p in participants if p.get("approved")
]
primary_owner = None
if author:
primary_owner = BasicExpertInfo(
display_name=_get_user_name(author),
)
# secondary_owners = [
# BasicExpertInfo(display_name=_get_user_name(r)) for r in reviewers
# ] or None
reviewer_names = [_get_user_name(r) for r in reviewers]
# Create a concise summary of key PR info
created_date = created_on.split("T")[0] if created_on else "N/A"
updated_date = updated_on.split("T")[0] if updated_on else "N/A"
content_text = (
"Pull Request Information:\n"
f"- Pull Request ID: {pr_id}\n"
f"- Title: {title}\n"
f"- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n"
)
if state == "DECLINED":
content_text += f"- Reason: {pr.get('reason', 'N/A')}\n"
content_text += (
f"- Author: {_get_user_name(author) if author else 'N/A'}\n"
f"- Reviewers: {', '.join(reviewer_names) if reviewer_names else 'N/A'}\n"
f"- Branch: {source_branch} -> {destination_branch}\n"
f"- Created: {created_date}\n"
f"- Updated: {updated_date}"
)
if description:
content_text += f"\n\nDescription:\n{description}"
metadata: dict[str, str | list[str]] = {
"object_type": "PullRequest",
"workspace": workspace,
"repository": repo_slug,
"pr_key": f"{workspace}/{repo_slug}#{pr_id}",
"id": str(pr_id),
"title": title,
"state": state or "",
"draft": str(bool(draft)),
"link": link,
"author": _get_user_name(author) if author else "",
"reviewers": reviewer_names,
"approved_by": approved_by,
"comment_count": str(pr.get("comment_count", "")),
"task_count": str(pr.get("task_count", "")),
"created_on": created_on or "",
"updated_on": updated_on or "",
"source_branch": source_branch,
"destination_branch": destination_branch,
"closed_by": (
_get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else ""
),
"close_source_branch": str(bool(pr.get("close_source_branch", False))),
}
name = sanitize_filename(title, "md")
return Document(
id=f"{DocumentSource.BITBUCKET.value}:{workspace}:{repo_slug}:pr:{pr_id}",
blob=content_text.encode("utf-8"),
source=DocumentSource.BITBUCKET,
extension=".md",
semantic_identifier=f"#{pr_id}: {name}",
size_bytes=len(content_text.encode("utf-8")),
doc_updated_at=updated_dt,
primary_owners=[primary_owner] if primary_owner else None,
# secondary_owners=secondary_owners,
metadata=metadata,
)
def _get_user_name(user: dict[str, Any]) -> str:
return user.get("display_name") or user.get("nickname") or "unknown"

View File

@ -13,6 +13,9 @@ def get_current_tz_offset() -> int:
return round(time_diff.total_seconds() / 3600) return round(time_diff.total_seconds() / 3600)
# Default request timeout, mostly used by connectors
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
ONE_MINUTE = 60 ONE_MINUTE = 60
ONE_HOUR = 3600 ONE_HOUR = 3600
ONE_DAY = ONE_HOUR * 24 ONE_DAY = ONE_HOUR * 24
@ -57,6 +60,10 @@ class DocumentSource(str, Enum):
ASANA = "asana" ASANA = "asana"
GITHUB = "github" GITHUB = "github"
GITLAB = "gitlab" GITLAB = "gitlab"
IMAP = "imap"
BITBUCKET = "bitbucket"
ZENDESK = "zendesk"
class FileOrigin(str, Enum): class FileOrigin(str, Enum):
"""File origins""" """File origins"""
@ -234,6 +241,8 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback") BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback")
GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None
class HtmlBasedConnectorTransformLinksStrategy(str, Enum): class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
# remove links entirely # remove links entirely
STRIP = "strip" STRIP = "strip"
@ -263,6 +272,14 @@ ASANA_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
) )
IMAP_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("IMAP_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
).split(",")
_USER_NOT_FOUND = "Unknown Confluence User" _USER_NOT_FOUND = "Unknown Confluence User"
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"] _COMMENT_EXPANSION_FIELDS = ["body.storage.value"]

View File

@ -0,0 +1,217 @@
import sys
import time
import logging
from collections.abc import Generator
from datetime import datetime
from typing import Generic
from typing import TypeVar
from common.data_source.interfaces import (
BaseConnector,
CheckpointedConnector,
CheckpointedConnectorWithPermSync,
CheckpointOutput,
LoadConnector,
PollConnector,
)
from common.data_source.models import ConnectorCheckpoint, ConnectorFailure, Document
TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,
specifically for Document outputs.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: CT | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[CT],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, CT | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[CT],
) -> CheckpointOutput[CT]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
class ConnectorRunner(Generic[CT]):
"""
Handles:
- Batching
- Additional exception logging
- Combining different connector types to a single interface
"""
def __init__(
self,
connector: BaseConnector,
batch_size: int,
# cannot be True for non-checkpointed connectors
include_permissions: bool,
time_range: TimeRange | None = None,
):
if not isinstance(connector, CheckpointedConnector) and include_permissions:
raise ValueError(
"include_permissions cannot be True for non-checkpointed connectors"
)
self.connector = connector
self.time_range = time_range
self.batch_size = batch_size
self.include_permissions = include_permissions
self.doc_batch: list[Document] = []
def run(self, checkpoint: CT) -> Generator[
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
None,
None,
]:
"""Adds additional exception logging to the connector."""
try:
if isinstance(self.connector, CheckpointedConnector):
if self.time_range is None:
raise ValueError("time_range is required for CheckpointedConnector")
start = time.monotonic()
if self.include_permissions:
if not isinstance(
self.connector, CheckpointedConnectorWithPermSync
):
raise ValueError(
"Connector does not support permission syncing"
)
load_from_checkpoint = (
self.connector.load_from_checkpoint_with_perm_sync
)
else:
load_from_checkpoint = self.connector.load_from_checkpoint
checkpoint_connector_generator = load_from_checkpoint(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
checkpoint=checkpoint,
)
next_checkpoint: CT | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None and isinstance(document, Document):
self.doc_batch.append(document)
if failure is not None:
yield None, failure, None
if len(self.doc_batch) >= self.batch_size:
yield self.doc_batch, None, None
self.doc_batch = []
# yield remaining documents
if len(self.doc_batch) > 0:
yield self.doc_batch, None, None
self.doc_batch = []
yield None, None, next_checkpoint
logging.debug(
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
)
else:
finished_checkpoint = self.connector.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):
if self.time_range is None:
raise ValueError("time_range is required for PollConnector")
for document_batch in self.connector.poll_source(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
):
yield document_batch, None, None
yield None, None, finished_checkpoint
elif isinstance(self.connector, LoadConnector):
for document_batch in self.connector.load_from_state():
yield document_batch, None, None
yield None, None, finished_checkpoint
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
except Exception:
exc_type, _, exc_traceback = sys.exc_info()
# Traverse the traceback to find the last frame where the exception was raised
tb = exc_traceback
if tb is None:
logging.error("No traceback found for exception")
raise
while tb.tb_next:
tb = tb.tb_next # Move to the next frame in the traceback
# Get the local variables from the frame where the exception occurred
local_vars = tb.tb_frame.f_locals
local_vars_str = "\n".join(
f"{key}: {value}" for key, value in local_vars.items()
)
logging.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str[:1024]}"
)
raise

View File

@ -0,0 +1,126 @@
import time
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
import requests
F = TypeVar("F", bound=Callable[..., Any])
class RateLimitTriedTooManyTimesError(Exception):
pass
class _RateLimitDecorator:
"""Builds a generic wrapper/decorator for calls to external APIs that
prevents making more than `max_calls` requests per `period`
Implementation inspired by the `ratelimit` library:
https://github.com/tomasbasham/ratelimit.
NOTE: is not thread safe.
"""
def __init__(
self,
max_calls: int,
period: float, # in seconds
sleep_time: float = 2, # in seconds
sleep_backoff: float = 2, # applies exponential backoff
max_num_sleep: int = 0,
):
self.max_calls = max_calls
self.period = period
self.sleep_time = sleep_time
self.sleep_backoff = sleep_backoff
self.max_num_sleep = max_num_sleep
self.call_history: list[float] = []
self.curr_calls = 0
def __call__(self, func: F) -> F:
@wraps(func)
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
# cleanup calls which are no longer relevant
self._cleanup()
# check if we've exceeded the rate limit
sleep_cnt = 0
while len(self.call_history) == self.max_calls:
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
logging.warning(
f"Rate limit exceeded for function {func.__name__}. "
f"Waiting {sleep_time} seconds before retrying."
)
time.sleep(sleep_time)
sleep_cnt += 1
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
raise RateLimitTriedTooManyTimesError(
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
)
self._cleanup()
# add the current call to the call history
self.call_history.append(time.monotonic())
return func(*args, **kwargs)
return cast(F, wrapped_func)
def _cleanup(self) -> None:
curr_time = time.monotonic()
time_to_expire_before = curr_time - self.period
self.call_history = [
call_time
for call_time in self.call_history
if call_time > time_to_expire_before
]
rate_limit_builder = _RateLimitDecorator
"""If you want to allow the external service to tell you when you've hit the rate limit,
use the following instead"""
R = TypeVar("R", bound=Callable[..., requests.Response])
def wrap_request_to_handle_ratelimiting(
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
) -> R:
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
for _ in range(max_waits):
response = request_fn(*args, **kwargs)
if response.status_code == 429:
try:
wait_time = int(
response.headers.get("Retry-After", default_wait_time_sec)
)
except ValueError:
wait_time = default_wait_time_sec
time.sleep(wait_time)
continue
return response
raise RateLimitTriedTooManyTimesError(f"Exceeded '{max_waits}' retries")
return cast(R, wrapped_request)
_rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get)
_rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post)
class _RateLimitedRequest:
get = _rate_limited_get
post = _rate_limited_post
rl_requests = _RateLimitedRequest

View File

@ -0,0 +1,88 @@
from collections.abc import Callable
import logging
from logging import Logger
from typing import Any
from typing import cast
from typing import TypeVar
import requests
from retry import retry
from common.data_source.config import REQUEST_TIMEOUT_SECONDS
F = TypeVar("F", bound=Callable[..., Any])
logger = logging.getLogger(__name__)
def retry_builder(
tries: int = 20,
delay: float = 0.1,
max_delay: float | None = 60,
backoff: float = 2,
jitter: tuple[float, float] | float = 1,
exceptions: type[Exception] | tuple[type[Exception], ...] = (Exception,),
) -> Callable[[F], F]:
"""Builds a generic wrapper/decorator for calls to external APIs that
may fail due to rate limiting, flakes, or other reasons. Applies exponential
backoff with jitter to retry the call."""
def retry_with_default(func: F) -> F:
@retry(
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
logger=cast(Logger, logger),
exceptions=exceptions,
)
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
return func(*args, **kwargs)
return cast(F, wrapped_func)
return retry_with_default
def request_with_retries(
method: str,
url: str,
*,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = REQUEST_TIMEOUT_SECONDS,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method=method,
url=url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError:
logging.exception(
"Request failed:\n%s",
{
"method": method,
"url": url,
"data": data,
"headers": headers,
"params": params,
"timeout": timeout,
"stream": stream,
},
)
raise
return response
return _make_request()

View File

View File

@ -0,0 +1,973 @@
import copy
import logging
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
from github import Github, Auth
from github import RateLimitExceededException
from github import Repository
from github.GithubException import GithubException
from github.Issue import Issue
from github.NamedUser import NamedUser
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from pydantic import BaseModel
from typing_extensions import override
from common.data_source.utils import sanitize_filename
from common.data_source.config import DocumentSource, GITHUB_CONNECTOR_BASE_URL
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError,
)
from common.data_source.interfaces import CheckpointedConnectorWithPermSyncGH, CheckpointOutput
from common.data_source.models import (
ConnectorCheckpoint,
ConnectorFailure,
Document,
DocumentFailure,
ExternalAccess,
SecondsSinceUnixEpoch,
)
from common.data_source.connector_runner import ConnectorRunner
from .models import SerializedRepository
from .rate_limit_utils import sleep_after_rate_limit_exception
from .utils import deserialize_repository
from .utils import get_external_access_permission
ITEMS_PER_PAGE = 100
CURSOR_LOG_FREQUENCY = 50
_MAX_NUM_RATE_LIMIT_RETRIES = 5
ONE_DAY = timedelta(days=1)
SLIM_BATCH_SIZE = 100
# Cases
# X (from start) standard run, no fallback to cursor-based pagination
# X (from start) standard run errors, fallback to cursor-based pagination
# X error in the middle of a page
# X no errors: run to completion
# X (from checkpoint) standard run, no fallback to cursor-based pagination
# X (from checkpoint) continue from cursor-based pagination
# - retrying
# - no retrying
# things to check:
# checkpoint state on return
# checkpoint progress (no infinite loop)
class DocMetadata(BaseModel):
repo: str
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
if "_PaginatedList__nextUrl" in pag_list.__dict__:
return "_PaginatedList__nextUrl"
for key in pag_list.__dict__:
if "__nextUrl" in key:
return key
for key in pag_list.__dict__:
if "nextUrl" in key:
return key
return ""
def get_nextUrl(
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str
) -> str | None:
return getattr(pag_list, nextUrl_key) if nextUrl_key else None
def set_nextUrl(
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str
) -> None:
if nextUrl_key:
setattr(pag_list, nextUrl_key, nextUrl)
elif nextUrl:
raise ValueError("Next URL key not found: " + str(pag_list.__dict__))
def _paginate_until_error(
git_objs: Callable[[], PaginatedList[PullRequest | Issue]],
cursor_url: str | None,
prev_num_objs: int,
cursor_url_callback: Callable[[str | None, int], None],
retrying: bool = False,
) -> Generator[PullRequest | Issue, None, None]:
num_objs = prev_num_objs
pag_list = git_objs()
nextUrl_key = get_nextUrl_key(pag_list)
if cursor_url:
set_nextUrl(pag_list, nextUrl_key, cursor_url)
elif retrying:
# if we are retrying, we want to skip the objects retrieved
# over previous calls. Unfortunately, this WILL retrieve all
# pages before the one we are resuming from, so we really
# don't want this case to be hit often
logging.warning(
"Retrying from a previous cursor-based pagination call. "
"This will retrieve all pages before the one we are resuming from, "
"which may take a while and consume many API calls."
)
pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:])
num_objs = 0
try:
# this for loop handles cursor-based pagination
for issue_or_pr in pag_list:
num_objs += 1
yield issue_or_pr
# used to store the current cursor url in the checkpoint. This value
# is updated during iteration over pag_list.
cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs)
if num_objs % CURSOR_LOG_FREQUENCY == 0:
logging.info(
f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}"
)
except Exception as e:
logging.exception(f"Error during cursor-based pagination: {e}")
if num_objs - prev_num_objs > 0:
raise
if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying:
logging.info(
"Assuming that this error is due to cursor "
"expiration because no objects were retrieved. "
"Retrying from the first page."
)
yield from _paginate_until_error(
git_objs, None, prev_num_objs, cursor_url_callback, retrying=True
)
return
# for no cursor url or if we reach this point after a retry, raise the error
raise
def _get_batch_rate_limited(
# We pass in a callable because we want git_objs to produce a fresh
# PaginatedList each time it's called to avoid using the same object for cursor-based pagination
# from a partial offset-based pagination call.
git_objs: Callable[[], PaginatedList],
page_num: int,
cursor_url: str | None,
prev_num_objs: int,
cursor_url_callback: Callable[[str | None, int], None],
github_client: Github,
attempt_num: int = 0,
) -> Generator[PullRequest | Issue, None, None]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
)
try:
if cursor_url:
# when this is set, we are resuming from an earlier
# cursor-based pagination call.
yield from _paginate_until_error(
git_objs, cursor_url, prev_num_objs, cursor_url_callback
)
return
objs = list(git_objs().get_page(page_num))
# fetch all data here to disable lazy loading later
# this is needed to capture the rate limit exception here (if one occurs)
for obj in objs:
if hasattr(obj, "raw_data"):
getattr(obj, "raw_data")
yield from objs
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
yield from _get_batch_rate_limited(
git_objs,
page_num,
cursor_url,
prev_num_objs,
cursor_url_callback,
github_client,
attempt_num + 1,
)
except GithubException as e:
if not (
e.status == 422
and (
"cursor" in (e.message or "")
or "cursor" in (e.data or {}).get("message", "")
)
):
raise
# Fallback to a cursor-based pagination strategy
# This can happen for "large datasets," but there's no documentation
# On the error on the web as far as we can tell.
# Error message:
# "Pagination with the page parameter is not supported for large datasets,
# please use cursor based pagination (after/before)"
yield from _paginate_until_error(
git_objs, cursor_url, prev_num_objs, cursor_url_callback
)
def _get_userinfo(user: NamedUser) -> dict[str, str]:
def _safe_get(attr_name: str) -> str | None:
try:
return cast(str | None, getattr(user, attr_name))
except GithubException:
logging.debug(f"Error getting {attr_name} for user")
return None
return {
k: v
for k, v in {
"login": _safe_get("login"),
"name": _safe_get("name"),
"email": _safe_get("email"),
}.items()
if v is not None
}
def _convert_pr_to_document(
pull_request: PullRequest, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
doc_metadata = DocMetadata(repo=repo_name)
file_content_byte = pull_request.body.encode('utf-8') if pull_request.body else b""
name = sanitize_filename(pull_request.title, "md")
return Document(
id=pull_request.html_url,
blob= file_content_byte,
source=DocumentSource.GITHUB,
external_access=repo_external_access,
semantic_identifier=f"{pull_request.number}:{name}",
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=(
pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None
),
extension=".md",
# this metadata is used in perm sync
size_bytes=len(file_content_byte) if file_content_byte else 0,
primary_owners=[],
doc_metadata=doc_metadata.model_dump(),
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
"object_type": "PullRequest",
"id": pull_request.number,
"merged": pull_request.merged,
"state": pull_request.state,
"user": _get_userinfo(pull_request.user) if pull_request.user else None,
"assignees": [
_get_userinfo(assignee) for assignee in pull_request.assignees
],
"repo": (
pull_request.base.repo.full_name if pull_request.base else None
),
"num_commits": str(pull_request.commits),
"num_files_changed": str(pull_request.changed_files),
"labels": [label.name for label in pull_request.labels],
"created_at": (
pull_request.created_at.replace(tzinfo=timezone.utc)
if pull_request.created_at
else None
),
"updated_at": (
pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None
),
"closed_at": (
pull_request.closed_at.replace(tzinfo=timezone.utc)
if pull_request.closed_at
else None
),
"merged_at": (
pull_request.merged_at.replace(tzinfo=timezone.utc)
if pull_request.merged_at
else None
),
"merged_by": (
_get_userinfo(pull_request.merged_by)
if pull_request.merged_by
else None
),
}.items()
if v is not None
},
)
def _fetch_issue_comments(issue: Issue) -> str:
comments = issue.get_comments()
return "\nComment: ".join(comment.body for comment in comments)
def _convert_issue_to_document(
issue: Issue, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = issue.repository.full_name if issue.repository else ""
doc_metadata = DocMetadata(repo=repo_name)
file_content_byte = issue.body.encode('utf-8') if issue.body else b""
name = sanitize_filename(issue.title, "md")
return Document(
id=issue.html_url,
blob=file_content_byte,
source=DocumentSource.GITHUB,
extension=".md",
external_access=repo_external_access,
semantic_identifier=f"{issue.number}:{name}",
# updated_at is UTC time but is timezone unaware
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
# this metadata is used in perm sync
doc_metadata=doc_metadata.model_dump(),
size_bytes=len(file_content_byte) if file_content_byte else 0,
primary_owners=[_get_userinfo(issue.user) if issue.user else None],
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
"object_type": "Issue",
"id": issue.number,
"state": issue.state,
"user": _get_userinfo(issue.user) if issue.user else None,
"assignees": [_get_userinfo(assignee) for assignee in issue.assignees],
"repo": issue.repository.full_name if issue.repository else None,
"labels": [label.name for label in issue.labels],
"created_at": (
issue.created_at.replace(tzinfo=timezone.utc)
if issue.created_at
else None
),
"updated_at": (
issue.updated_at.replace(tzinfo=timezone.utc)
if issue.updated_at
else None
),
"closed_at": (
issue.closed_at.replace(tzinfo=timezone.utc)
if issue.closed_at
else None
),
"closed_by": (
_get_userinfo(issue.closed_by) if issue.closed_by else None
),
}.items()
if v is not None
},
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
class GithubConnectorCheckpoint(ConnectorCheckpoint):
stage: GithubConnectorStage
curr_page: int
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
# Used for the fallback cursor-based pagination strategy
num_retrieved: int
cursor_url: str | None = None
def reset(self) -> None:
"""
Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None)
"""
self.curr_page = 0
self.num_retrieved = 0
self.cursor_url = None
def make_cursor_url_callback(
checkpoint: GithubConnectorCheckpoint,
) -> Callable[[str | None, int], None]:
def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
# we want to maintain the old cursor url so code after retrieval
# can determine that we are using the fallback cursor-based pagination strategy
if cursor_url:
checkpoint.cursor_url = cursor_url
checkpoint.num_retrieved = num_objs
return cursor_url_callback
class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
token = credentials["github_access_token"]
auth = Auth.Token(token)
if GITHUB_CONNECTOR_BASE_URL:
self.github_client = Github(
auth=auth,
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
else:
self.github_client = Github(
auth=auth,
per_page=ITEMS_PER_PAGE,
)
return None
def get_github_repo(
self, github_client: Github, attempt_num: int = 0
) -> Repository.Repository:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github"
)
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
return self.get_github_repo(github_client, attempt_num + 1)
def get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
repos = []
# Split repo_name by comma and strip whitespace
repo_names = [
name.strip() for name in (cast(str, self.repositories)).split(",")
]
for repo_name in repo_names:
if repo_name: # Skip empty strings
try:
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
repos.append(repo)
except GithubException as e:
logging.warning(
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
)
return repos
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
return self.get_github_repos(github_client, attempt_num + 1)
def get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
# Try to get organization first
try:
org = github_client.get_organization(self.repo_owner)
return list(org.get_repos())
except GithubException:
# If not an org, try as a user
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
sleep_after_rate_limit_exception(github_client)
return self.get_all_repos(github_client, attempt_num + 1)
def _pull_requests_func(
self, repo: Repository.Repository
) -> Callable[[], PaginatedList[PullRequest]]:
return lambda: repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
def _issues_func(
self, repo: Repository.Repository
) -> Callable[[], PaginatedList[Issue]]:
return lambda: repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
def _fetch_from_github(
self,
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
checkpoint = copy.deepcopy(checkpoint)
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self.get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self.get_github_repo(self.github_client)]
else:
# All repositories
repos = self.get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
curr_repo = repos.pop()
checkpoint.cached_repo_ids = [repo.id for repo in repos]
checkpoint.cached_repo = SerializedRepository(
id=curr_repo.id,
headers=curr_repo.raw_headers,
raw_data=curr_repo.raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
if checkpoint.cached_repo is None:
raise ValueError("No repo saved in checkpoint")
# Deserialize the repository from the checkpoint
repo = deserialize_repository(checkpoint.cached_repo, self.github_client)
cursor_url_callback = make_cursor_url_callback(checkpoint)
repo_external_access: ExternalAccess | None = None
if include_permissions:
repo_external_access = get_external_access_permission(
repo, self.github_client
)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logging.info(f"Fetching PRs for repo: {repo.name}")
pr_batch = _get_batch_rate_limited(
self._pull_requests_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
done_with_prs = False
num_prs = 0
pr = None
print("start: ", start)
for pr in pr_batch:
num_prs += 1
print("-"*40)
print("PR name", pr.title)
print("updated at", pr.updated_at)
print("-"*40)
print("\n")
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) <= start
):
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logging.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
# If we reach this point with a cursor url in the checkpoint, we were using
# the fallback cursor-based pagination strategy. That strategy tries to get all
# PRs, so having curosr_url set means we are done with prs. However, we need to
# return AFTER the checkpoint reset to avoid infinite loops.
# if we found any PRs on the page and there are more PRs to get, return the checkpoint.
# In offset mode, while indexing without time constraints, the pr batch
# will be empty when we're done.
used_cursor = checkpoint.cursor_url is not None
if num_prs > 0 and not done_with_prs and not used_cursor:
return checkpoint
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.reset()
if used_cursor:
# save the checkpoint after changing stage; next run will continue from issues
return checkpoint
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logging.info(f"Fetching issues for repo: {repo.name}")
issue_batch = list(
_get_batch_rate_limited(
self._issues_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
)
checkpoint.curr_page += 1
done_with_issues = False
num_issues = 0
for issue in issue_batch:
num_issues += 1
issue = cast(Issue, issue)
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) <= start
):
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logging.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, and we're not done, return the checkpoint.
# don't return if we're using cursor-based pagination to avoid infinite loops
if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url:
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.reset()
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
checkpoint.cached_repo = SerializedRepository(
id=next_id,
headers=next_repo.raw_headers,
raw_data=next_repo.raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.reset()
if checkpoint.cached_repo_ids:
logging.info(
f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})"
)
else:
logging.info("No more repos remaining")
return checkpoint
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
include_permissions: bool = False,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
# add a day for timezone safety
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + ONE_DAY
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
# Could be due to delayed processing on GitHub side
# The non-updated issues since last poll will be shortcut-ed and not embedded
# adjusted_start_datetime = start_datetime - timedelta(hours=3)
adjusted_start_datetime = start_datetime
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
if adjusted_start_datetime < epoch:
adjusted_start_datetime = epoch
return self._fetch_from_github(
checkpoint,
start=adjusted_start_datetime,
end=end_datetime,
include_permissions=include_permissions,
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=False
)
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=True
)
def validate_connector_settings(self) -> None:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
if not self.repo_owner:
raise ConnectorValidationError(
"Invalid connector settings: 'repo_owner' must be provided."
)
try:
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repo_names = [name.strip() for name in self.repositories.split(",")]
if not repo_names:
raise ConnectorValidationError(
"Invalid connector settings: No valid repository names provided."
)
# Validate at least one repository exists and is accessible
valid_repos = False
validation_errors = []
for repo_name in repo_names:
if not repo_name:
continue
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
logging.info(
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
break
except GithubException as e:
validation_errors.append(
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
)
if not valid_repos:
error_msg = (
"None of the specified repositories could be accessed: "
)
error_msg += ", ".join(validation_errors)
raise ConnectorValidationError(error_msg)
else:
# Single repository (backward compatibility)
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repositories}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
org = self.github_client.get_organization(self.repo_owner)
total_count = org.get_repos().totalCount
if total_count == 0:
raise ConnectorValidationError(
f"Found no repos for organization: {self.repo_owner}. "
"Does the credential have the right scopes?"
)
except GithubException as e:
# Check for missing SSO
MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower()
if MISSING_SSO_ERROR_MESSAGE in str(e).lower():
SSO_GUIDE_LINK = (
"https://docs.github.com/en/enterprise-cloud@latest/authentication/"
"authenticating-with-saml-single-sign-on/"
"authorizing-a-personal-access-token-for-use-with-saml-single-sign-on"
)
raise ConnectorValidationError(
f"Your GitHub token is missing authorization to access the "
f"`{self.repo_owner}` organization. Please follow the guide to "
f"authorize your token: {SSO_GUIDE_LINK}"
)
# If not an org, try as a user
user = self.github_client.get_user(self.repo_owner)
# Check if we can access any repos
total_count = user.get_repos().totalCount
if total_count == 0:
raise ConnectorValidationError(
f"Found no repos for user: {self.repo_owner}. "
"Does the credential have the right scopes?"
)
except RateLimitExceededException:
raise UnexpectedValidationError(
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
)
except GithubException as e:
if e.status == 401:
raise CredentialExpiredError(
"GitHub credential appears to be invalid or expired (HTTP 401)."
)
elif e.status == 403:
raise InsufficientPermissionsError(
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repositories:
if "," in self.repositories:
raise ConnectorValidationError(
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
)
def validate_checkpoint_json(
self, checkpoint_json: str
) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0
)
if __name__ == "__main__":
# Initialize the connector
connector = GithubConnector(
repo_owner="EvoAgentX",
repositories="EvoAgentX",
include_issues=True,
include_prs=False,
)
connector.load_credentials(
{"github_access_token": "<Your_GitHub_Access_Token>"}
)
if connector.github_client:
get_external_access_permission(
connector.get_github_repos(connector.github_client).pop(),
connector.github_client,
)
# Create a time range from epoch to now
end_time = datetime.now(timezone.utc)
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
time_range = (start_time, end_time)
# Initialize the runner with a batch size of 10
runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner(
connector, batch_size=10, include_permissions=False, time_range=time_range
)
# Get initial checkpoint
checkpoint = connector.build_dummy_checkpoint()
# Run the connector
while checkpoint.has_more:
for doc_batch, failure, next_checkpoint in runner.run(checkpoint):
if doc_batch:
print(f"Retrieved batch of {len(doc_batch)} documents")
for doc in doc_batch:
print(f"Document: {doc.semantic_identifier}")
if failure:
print(f"Failure: {failure.failure_message}")
if next_checkpoint:
checkpoint = next_checkpoint

View File

@ -0,0 +1,17 @@
from typing import Any
from github import Repository
from github.Requester import Requester
from pydantic import BaseModel
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)

View File

@ -0,0 +1,24 @@
import time
import logging
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from github import Github
def sleep_after_rate_limit_exception(github_client: Github) -> None:
"""
Sleep until the GitHub rate limit resets.
Args:
github_client: The GitHub client that hit the rate limit
"""
sleep_time = github_client.get_rate_limit().core.reset.replace(
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logging.info(
"Ran into Github rate-limit. Sleeping %s seconds.", sleep_time.seconds
)
time.sleep(sleep_time.total_seconds())

View File

@ -0,0 +1,44 @@
import logging
from github import Github
from github.Repository import Repository
from common.data_source.models import ExternalAccess
from .models import SerializedRepository
def get_external_access_permission(
repo: Repository, github_client: Github
) -> ExternalAccess:
"""
Get the external access permission for a repository.
This functionality requires Enterprise Edition.
"""
# RAGFlow doesn't implement the Onyx EE external-permissions system.
# Default to private/unknown permissions.
return ExternalAccess.empty()
def deserialize_repository(
cached_repo: SerializedRepository, github_client: Github
) -> Repository:
"""
Deserialize a SerializedRepository back into a Repository object.
"""
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try to get the requester using getattr to avoid linter errors
requester = getattr(github_client, "_requester", None)
if requester is None:
requester = getattr(github_client, "_Github__requester", None)
if requester is None:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
return cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logging.warning("Failed to deserialize repository: %s. Attempting to re-fetch.", e)
repo_id = cached_repo.id
return github_client.get_repo(repo_id)

View File

@ -8,10 +8,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen
from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.auth import get_google_creds
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
from common.data_source.google_util.resource import get_admin_service, get_gmail_service from common.data_source.google_util.resource import get_admin_service, get_gmail_service
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, clean_string
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc, sanitize_filename
# Constants for Gmail API fields # Constants for Gmail API fields
THREAD_LIST_FIELDS = "nextPageToken, threads(id)" THREAD_LIST_FIELDS = "nextPageToken, threads(id)"

View File

@ -191,43 +191,6 @@ def get_credentials_from_env(email: str, oauth: bool = False, source="drive") ->
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded", DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
} }
def sanitize_filename(name: str) -> str:
"""
Soft sanitize for MinIO/S3:
- Replace only prohibited characters with a space.
- Preserve readability (no ugly underscores).
- Collapse multiple spaces.
"""
if name is None:
return "file.txt"
name = str(name).strip()
# Characters that MUST NOT appear in S3/MinIO object keys
# Replace them with a space (not underscore)
forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
name = re.sub(forbidden, " ", name)
# Replace slashes "/" (S3 interprets as folder) with space
name = name.replace("/", " ")
# Collapse multiple spaces into one
name = re.sub(r"\s+", " ", name)
# Trim both ends
name = name.strip()
# Enforce reasonable max length
if len(name) > 200:
base, ext = os.path.splitext(name)
name = base[:180].rstrip() + ext
# Ensure there is an extension (your original logic)
if not os.path.splitext(name)[1]:
name += ".txt"
return name
def clean_string(text: str | None) -> str | None: def clean_string(text: str | None) -> str | None:
""" """

View File

@ -0,0 +1,724 @@
import copy
import email
from email.header import decode_header
import imaplib
import logging
import os
import re
from datetime import datetime, timedelta
from datetime import timezone
from email.message import Message
from email.utils import collapse_rfc2231_value, parseaddr
from enum import Enum
from typing import Any
from typing import cast
import uuid
import bs4
from pydantic import BaseModel
from common.data_source.config import IMAP_CONNECTOR_SIZE_THRESHOLD, DocumentSource
from common.data_source.interfaces import CheckpointOutput, CheckpointedConnectorWithPermSync, CredentialsConnector, CredentialsProviderInterface
from common.data_source.models import BasicExpertInfo, ConnectorCheckpoint, Document, ExternalAccess, SecondsSinceUnixEpoch
_DEFAULT_IMAP_PORT_NUMBER = int(os.environ.get("IMAP_PORT", 993))
_IMAP_OKAY_STATUS = "OK"
_PAGE_SIZE = 100
_USERNAME_KEY = "imap_username"
_PASSWORD_KEY = "imap_password"
class Header(str, Enum):
SUBJECT_HEADER = "subject"
FROM_HEADER = "from"
TO_HEADER = "to"
CC_HEADER = "cc"
DELIVERED_TO_HEADER = (
"Delivered-To" # Used in mailing lists instead of the "to" header.
)
DATE_HEADER = "date"
MESSAGE_ID_HEADER = "Message-ID"
class EmailHeaders(BaseModel):
"""
Model for email headers extracted from IMAP messages.
"""
id: str
subject: str
sender: str
recipients: str | None
cc: str | None
date: datetime
@classmethod
def from_email_msg(cls, email_msg: Message) -> "EmailHeaders":
def _decode(header: str, default: str | None = None) -> str | None:
value = email_msg.get(header, default)
if not value:
return None
decoded_fragments = decode_header(value)
decoded_strings: list[str] = []
for decoded_value, encoding in decoded_fragments:
if isinstance(decoded_value, bytes):
try:
decoded_strings.append(
decoded_value.decode(encoding or "utf-8", errors="replace")
)
except LookupError:
decoded_strings.append(
decoded_value.decode("utf-8", errors="replace")
)
elif isinstance(decoded_value, str):
decoded_strings.append(decoded_value)
else:
decoded_strings.append(str(decoded_value))
return "".join(decoded_strings)
def _parse_date(date_str: str | None) -> datetime | None:
if not date_str:
return None
try:
return email.utils.parsedate_to_datetime(date_str)
except (TypeError, ValueError):
return None
message_id = _decode(header=Header.MESSAGE_ID_HEADER)
if not message_id:
message_id = f"<generated-{uuid.uuid4()}@imap.local>"
# It's possible for the subject line to not exist or be an empty string.
subject = _decode(header=Header.SUBJECT_HEADER) or "Unknown Subject"
from_ = _decode(header=Header.FROM_HEADER)
to = _decode(header=Header.TO_HEADER)
if not to:
to = _decode(header=Header.DELIVERED_TO_HEADER)
cc = _decode(header=Header.CC_HEADER)
date_str = _decode(header=Header.DATE_HEADER)
date = _parse_date(date_str=date_str)
if not date:
date = datetime.now(tz=timezone.utc)
# If any of the above are `None`, model validation will fail.
# Therefore, no guards (i.e.: `if <header> is None: raise RuntimeError(..)`) were written.
return cls.model_validate(
{
"id": message_id,
"subject": subject,
"sender": from_,
"recipients": to,
"cc": cc,
"date": date,
}
)
class CurrentMailbox(BaseModel):
mailbox: str
todo_email_ids: list[str]
# An email has a list of mailboxes.
# Each mailbox has a list of email-ids inside of it.
#
# Usage:
# To use this checkpointer, first fetch all the mailboxes.
# Then, pop a mailbox and fetch all of its email-ids.
# Then, pop each email-id and fetch its content (and parse it, etc..).
# When you have popped all email-ids for this mailbox, pop the next mailbox and repeat the above process until you're done.
#
# For initial checkpointing, set both fields to `None`.
class ImapCheckpoint(ConnectorCheckpoint):
todo_mailboxes: list[str] | None = None
current_mailbox: CurrentMailbox | None = None
class LoginState(str, Enum):
LoggedIn = "logged_in"
LoggedOut = "logged_out"
class ImapConnector(
CredentialsConnector,
CheckpointedConnectorWithPermSync,
):
def __init__(
self,
host: str,
port: int = _DEFAULT_IMAP_PORT_NUMBER,
mailboxes: list[str] | None = None,
) -> None:
self._host = host
self._port = port
self._mailboxes = mailboxes
self._credentials: dict[str, Any] | None = None
@property
def credentials(self) -> dict[str, Any]:
if not self._credentials:
raise RuntimeError(
"Credentials have not been initialized; call `set_credentials_provider` first"
)
return self._credentials
def _get_mail_client(self) -> imaplib.IMAP4_SSL:
"""
Returns a new `imaplib.IMAP4_SSL` instance.
The `imaplib.IMAP4_SSL` object is supposed to be an "ephemeral" object; it's not something that you can login,
logout, then log back into again. I.e., the following will fail:
```py
mail_client.login(..)
mail_client.logout();
mail_client.login(..)
```
Therefore, you need a fresh, new instance in order to operate with IMAP. This function gives one to you.
# Notes
This function will throw an error if the credentials have not yet been set.
"""
def get_or_raise(name: str) -> str:
value = self.credentials.get(name)
if not value:
raise RuntimeError(f"Credential item {name=} was not found")
if not isinstance(value, str):
raise RuntimeError(
f"Credential item {name=} must be of type str, instead received {type(name)=}"
)
return value
username = get_or_raise(_USERNAME_KEY)
password = get_or_raise(_PASSWORD_KEY)
mail_client = imaplib.IMAP4_SSL(host=self._host, port=self._port)
status, _data = mail_client.login(user=username, password=password)
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to log into imap server; {status=}")
return mail_client
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
include_perm_sync: bool,
) -> CheckpointOutput[ImapCheckpoint]:
checkpoint = cast(ImapCheckpoint, copy.deepcopy(checkpoint))
checkpoint.has_more = True
mail_client = self._get_mail_client()
if checkpoint.todo_mailboxes is None:
# This is the dummy checkpoint.
# Fill it with mailboxes first.
if self._mailboxes:
checkpoint.todo_mailboxes = _sanitize_mailbox_names(self._mailboxes)
else:
fetched_mailboxes = _fetch_all_mailboxes_for_email_account(
mail_client=mail_client
)
if not fetched_mailboxes:
raise RuntimeError(
"Failed to find any mailboxes for this email account"
)
checkpoint.todo_mailboxes = _sanitize_mailbox_names(fetched_mailboxes)
return checkpoint
if (
not checkpoint.current_mailbox
or not checkpoint.current_mailbox.todo_email_ids
):
if not checkpoint.todo_mailboxes:
checkpoint.has_more = False
return checkpoint
mailbox = checkpoint.todo_mailboxes.pop()
email_ids = _fetch_email_ids_in_mailbox(
mail_client=mail_client,
mailbox=mailbox,
start=start,
end=end,
)
checkpoint.current_mailbox = CurrentMailbox(
mailbox=mailbox,
todo_email_ids=email_ids,
)
_select_mailbox(
mail_client=mail_client, mailbox=checkpoint.current_mailbox.mailbox
)
current_todos = cast(
list, copy.deepcopy(checkpoint.current_mailbox.todo_email_ids[:_PAGE_SIZE])
)
checkpoint.current_mailbox.todo_email_ids = (
checkpoint.current_mailbox.todo_email_ids[_PAGE_SIZE:]
)
for email_id in current_todos:
email_msg = _fetch_email(mail_client=mail_client, email_id=email_id)
if not email_msg:
logging.warning(f"Failed to fetch message {email_id=}; skipping")
continue
email_headers = EmailHeaders.from_email_msg(email_msg=email_msg)
msg_dt = email_headers.date
if msg_dt.tzinfo is None:
msg_dt = msg_dt.replace(tzinfo=timezone.utc)
else:
msg_dt = msg_dt.astimezone(timezone.utc)
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
if not (start_dt < msg_dt <= end_dt):
continue
email_doc = _convert_email_headers_and_body_into_document(
email_msg=email_msg,
email_headers=email_headers,
include_perm_sync=include_perm_sync,
)
yield email_doc
attachments = extract_attachments(email_msg)
for att in attachments:
yield attachment_to_document(email_doc, att, email_headers)
return checkpoint
# impls for BaseConnector
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._credentials = credentials
return None
def validate_connector_settings(self) -> None:
self._get_mail_client()
# impls for CredentialsConnector
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
self._credentials = credentials_provider.get_credentials()
# impls for CheckpointedConnector
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
) -> CheckpointOutput[ImapCheckpoint]:
return self._load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint, include_perm_sync=False
)
def build_dummy_checkpoint(self) -> ImapCheckpoint:
return ImapCheckpoint(has_more=True)
def validate_checkpoint_json(self, checkpoint_json: str) -> ImapCheckpoint:
return ImapCheckpoint.model_validate_json(json_data=checkpoint_json)
# impls for CheckpointedConnectorWithPermSync
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
) -> CheckpointOutput[ImapCheckpoint]:
return self._load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint, include_perm_sync=True
)
def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> list[str]:
status, mailboxes_data = mail_client.list('""', "*")
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to fetch mailboxes; {status=}")
mailboxes = []
for mailboxes_raw in mailboxes_data:
if isinstance(mailboxes_raw, bytes):
mailboxes_str = mailboxes_raw.decode()
elif isinstance(mailboxes_raw, str):
mailboxes_str = mailboxes_raw
else:
logging.warning(
f"Expected the mailbox data to be of type str, instead got {type(mailboxes_raw)=} {mailboxes_raw}; skipping"
)
continue
# The mailbox LIST response output can be found here:
# https://www.rfc-editor.org/rfc/rfc3501.html#section-7.2.2
#
# The general format is:
# `(<name-attributes>) <hierarchy-delimiter> <mailbox-name>`
#
# The below regex matches on that pattern; from there, we select the 3rd match (index 2), which is the mailbox-name.
match = re.match(r'\([^)]*\)\s+"([^"]+)"\s+"?(.+?)"?$', mailboxes_str)
if not match:
logging.warning(
f"Invalid mailbox-data formatting structure: {mailboxes_str=}; skipping"
)
continue
mailbox = match.group(2)
mailboxes.append(mailbox)
if not mailboxes:
logging.warning(
"No mailboxes parsed from LIST response; falling back to INBOX"
)
return ["INBOX"]
return mailboxes
def _select_mailbox(mail_client: imaplib.IMAP4_SSL, mailbox: str) -> bool:
try:
status, _ = mail_client.select(mailbox=mailbox, readonly=True)
if status != _IMAP_OKAY_STATUS:
return False
return True
except Exception:
return False
def _fetch_email_ids_in_mailbox(
mail_client: imaplib.IMAP4_SSL,
mailbox: str,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[str]:
if not _select_mailbox(mail_client, mailbox):
logging.warning(f"Skip mailbox: {mailbox}")
return []
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end, tz=timezone.utc) + timedelta(days=1)
start_str = start_dt.strftime("%d-%b-%Y")
end_str = end_dt.strftime("%d-%b-%Y")
search_criteria = f'(SINCE "{start_str}" BEFORE "{end_str}")'
status, email_ids_byte_array = mail_client.search(None, search_criteria)
if status != _IMAP_OKAY_STATUS or not email_ids_byte_array:
raise RuntimeError(f"Failed to fetch email ids; {status=}")
email_ids: bytes = email_ids_byte_array[0]
return [email_id.decode() for email_id in email_ids.split()]
def _fetch_email(mail_client: imaplib.IMAP4_SSL, email_id: str) -> Message | None:
status, msg_data = mail_client.fetch(message_set=email_id, message_parts="(RFC822)")
if status != _IMAP_OKAY_STATUS or not msg_data:
return None
data = msg_data[0]
if not isinstance(data, tuple):
raise RuntimeError(
f"Message data should be a tuple; instead got a {type(data)=} {data=}"
)
_, raw_email = data
return email.message_from_bytes(raw_email)
def _convert_email_headers_and_body_into_document(
email_msg: Message,
email_headers: EmailHeaders,
include_perm_sync: bool,
) -> Document:
sender_name, sender_addr = _parse_singular_addr(raw_header=email_headers.sender)
to_addrs = (
_parse_addrs(email_headers.recipients)
if email_headers.recipients
else []
)
cc_addrs = (
_parse_addrs(email_headers.cc)
if email_headers.cc
else []
)
all_participants = to_addrs + cc_addrs
expert_info_map = {
recipient_addr: BasicExpertInfo(
display_name=recipient_name, email=recipient_addr
)
for recipient_name, recipient_addr in all_participants
}
if sender_addr not in expert_info_map:
expert_info_map[sender_addr] = BasicExpertInfo(
display_name=sender_name, email=sender_addr
)
email_body = _parse_email_body(email_msg=email_msg, email_headers=email_headers)
primary_owners = list(expert_info_map.values())
external_access = (
ExternalAccess(
external_user_emails=set(expert_info_map.keys()),
external_user_group_ids=set(),
is_public=False,
)
if include_perm_sync
else None
)
return Document(
id=email_headers.id,
title=email_headers.subject,
blob=email_body,
size_bytes=len(email_body),
semantic_identifier=email_headers.subject,
metadata={},
extension='.txt',
doc_updated_at=email_headers.date,
source=DocumentSource.IMAP,
primary_owners=primary_owners,
external_access=external_access,
)
def extract_attachments(email_msg: Message, max_bytes: int = IMAP_CONNECTOR_SIZE_THRESHOLD):
attachments = []
if not email_msg.is_multipart():
return attachments
for part in email_msg.walk():
if part.get_content_maintype() == "multipart":
continue
disposition = (part.get("Content-Disposition") or "").lower()
filename = part.get_filename()
if not (
disposition.startswith("attachment")
or (disposition.startswith("inline") and filename)
):
continue
payload = part.get_payload(decode=True)
if not payload:
continue
if len(payload) > max_bytes:
continue
attachments.append({
"filename": filename or "attachment.bin",
"content_type": part.get_content_type(),
"content_bytes": payload,
"size_bytes": len(payload),
})
return attachments
def decode_mime_filename(raw: str | None) -> str | None:
if not raw:
return None
try:
raw = collapse_rfc2231_value(raw)
except Exception:
pass
parts = decode_header(raw)
decoded = []
for value, encoding in parts:
if isinstance(value, bytes):
decoded.append(value.decode(encoding or "utf-8", errors="replace"))
else:
decoded.append(value)
return "".join(decoded)
def attachment_to_document(
parent_doc: Document,
att: dict,
email_headers: EmailHeaders,
):
raw_filename = att["filename"]
filename = decode_mime_filename(raw_filename) or "attachment.bin"
ext = "." + filename.split(".")[-1] if "." in filename else ""
return Document(
id=f"{parent_doc.id}#att:{filename}",
source=DocumentSource.IMAP,
semantic_identifier=filename,
extension=ext,
blob=att["content_bytes"],
size_bytes=att["size_bytes"],
doc_updated_at=email_headers.date,
primary_owners=parent_doc.primary_owners,
metadata={
"parent_email_id": parent_doc.id,
"parent_subject": email_headers.subject,
"attachment_filename": filename,
"attachment_content_type": att["content_type"],
},
)
def _parse_email_body(
email_msg: Message,
email_headers: EmailHeaders,
) -> str:
body = None
for part in email_msg.walk():
if part.is_multipart():
# Multipart parts are *containers* for other parts, not the actual content itself.
# Therefore, we skip until we find the individual parts instead.
continue
charset = part.get_content_charset() or "utf-8"
try:
raw_payload = part.get_payload(decode=True)
if not isinstance(raw_payload, bytes):
logging.warning(
"Payload section from email was expected to be an array of bytes, instead got "
f"{type(raw_payload)=}, {raw_payload=}"
)
continue
body = raw_payload.decode(charset)
break
except (UnicodeDecodeError, LookupError) as e:
logging.warning(f"Could not decode part with charset {charset}. Error: {e}")
continue
if not body:
logging.warning(
f"Email with {email_headers.id=} has an empty body; returning an empty string"
)
return ""
soup = bs4.BeautifulSoup(markup=body, features="html.parser")
return " ".join(str_section for str_section in soup.stripped_strings)
def _sanitize_mailbox_names(mailboxes: list[str]) -> list[str]:
"""
Mailboxes with special characters in them must be enclosed by double-quotes, as per the IMAP protocol.
Just to be safe, we wrap *all* mailboxes with double-quotes.
"""
return [f'"{mailbox}"' for mailbox in mailboxes if mailbox]
def _parse_addrs(raw_header: str) -> list[tuple[str, str]]:
addrs = raw_header.split(",")
name_addr_pairs = [parseaddr(addr=addr) for addr in addrs if addr]
return [(name, addr) for name, addr in name_addr_pairs if addr]
def _parse_singular_addr(raw_header: str) -> tuple[str, str]:
addrs = _parse_addrs(raw_header=raw_header)
if not addrs:
return ("Unknown", "unknown@example.com")
elif len(addrs) >= 2:
raise RuntimeError(
f"Expected a singular address, but instead got multiple; {raw_header=} {addrs=}"
)
return addrs[0]
if __name__ == "__main__":
import time
from types import TracebackType
from common.data_source.utils import load_all_docs_from_checkpoint_connector
class OnyxStaticCredentialsProvider(
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "OnyxStaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False
# from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
# from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
host = os.environ.get("IMAP_HOST")
mailboxes_str = os.environ.get("IMAP_MAILBOXES","INBOX")
username = os.environ.get("IMAP_USERNAME")
password = os.environ.get("IMAP_PASSWORD")
mailboxes = (
[mailbox.strip() for mailbox in mailboxes_str.split(",")]
if mailboxes_str
else []
)
if not host:
raise RuntimeError("`IMAP_HOST` must be set")
imap_connector = ImapConnector(
host=host,
mailboxes=mailboxes,
)
imap_connector.set_credentials_provider(
OnyxStaticCredentialsProvider(
tenant_id=None,
connector_name=DocumentSource.IMAP,
credential_json={
_USERNAME_KEY: username,
_PASSWORD_KEY: password,
},
)
)
END = time.time()
START = END - 1 * 24 * 60 * 60
for doc in load_all_docs_from_checkpoint_connector(
connector=imap_connector,
start=START,
end=END,
):
print(doc.id,doc.extension)

View File

@ -237,16 +237,13 @@ class BaseConnector(abc.ABC, Generic[CT]):
def validate_perm_sync(self) -> None: def validate_perm_sync(self) -> None:
""" """
Don't override this; add a function to perm_sync_valid.py in the ee package Permission-sync validation hook.
to do permission sync validation
RAGFlow doesn't ship the Onyx EE permission-sync validation package.
Connectors that support permission sync should override
`validate_connector_settings()` as needed.
""" """
""" return None
validate_connector_settings_fn = fetch_ee_implementation_or_noop(
"onyx.connectors.perm_sync_valid",
"validate_perm_sync",
noop_return_value=None,
)
validate_connector_settings_fn(self)"""
def set_allow_images(self, value: bool) -> None: def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading """Implement if the underlying connector wants to skip/allow image downloading
@ -345,6 +342,17 @@ class CheckpointOutputWrapper(Generic[CT]):
yield None, None, self.next_checkpoint yield None, None, self.next_checkpoint
class CheckpointedConnectorWithPermSyncGH(CheckpointedConnector[CT]):
@abc.abstractmethod
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CT,
) -> CheckpointOutput[CT]:
raise NotImplementedError
# Slim connectors retrieve just the ids of documents # Slim connectors retrieve just the ids of documents
class SlimConnector(BaseConnector): class SlimConnector(BaseConnector):
@abc.abstractmethod @abc.abstractmethod

View File

@ -94,8 +94,10 @@ class Document(BaseModel):
blob: bytes blob: bytes
doc_updated_at: datetime doc_updated_at: datetime
size_bytes: int size_bytes: int
externale_access: Optional[ExternalAccess] = None
primary_owners: Optional[list] = None primary_owners: Optional[list] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None
doc_metadata: Optional[dict[str, Any]] = None
class BasicExpertInfo(BaseModel): class BasicExpertInfo(BaseModel):

View File

@ -1149,3 +1149,137 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
next_ind += 1 next_ind += 1
del future_to_index[future] del future_to_index[future]
def sanitize_filename(name: str, extension: str = "txt") -> str:
"""
Soft sanitize for MinIO/S3:
- Replace only prohibited characters with a space.
- Preserve readability (no ugly underscores).
- Collapse multiple spaces.
"""
if name is None:
return f"file.{extension}"
name = str(name).strip()
# Characters that MUST NOT appear in S3/MinIO object keys
# Replace them with a space (not underscore)
forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
name = re.sub(forbidden, " ", name)
# Replace slashes "/" (S3 interprets as folder) with space
name = name.replace("/", " ")
# Collapse multiple spaces into one
name = re.sub(r"\s+", " ", name)
# Trim both ends
name = name.strip()
# Enforce reasonable max length
if len(name) > 200:
base, ext = os.path.splitext(name)
name = base[:180].rstrip() + ext
if not os.path.splitext(name)[1]:
name += f".{extension}"
return name
F = TypeVar("F", bound=Callable[..., Any])
class _RateLimitDecorator:
"""Builds a generic wrapper/decorator for calls to external APIs that
prevents making more than `max_calls` requests per `period`
Implementation inspired by the `ratelimit` library:
https://github.com/tomasbasham/ratelimit.
NOTE: is not thread safe.
"""
def __init__(
self,
max_calls: int,
period: float, # in seconds
sleep_time: float = 2, # in seconds
sleep_backoff: float = 2, # applies exponential backoff
max_num_sleep: int = 0,
):
self.max_calls = max_calls
self.period = period
self.sleep_time = sleep_time
self.sleep_backoff = sleep_backoff
self.max_num_sleep = max_num_sleep
self.call_history: list[float] = []
self.curr_calls = 0
def __call__(self, func: F) -> F:
@wraps(func)
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
# cleanup calls which are no longer relevant
self._cleanup()
# check if we've exceeded the rate limit
sleep_cnt = 0
while len(self.call_history) == self.max_calls:
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
logging.warning(
f"Rate limit exceeded for function {func.__name__}. "
f"Waiting {sleep_time} seconds before retrying."
)
time.sleep(sleep_time)
sleep_cnt += 1
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
raise RateLimitTriedTooManyTimesError(
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
)
self._cleanup()
# add the current call to the call history
self.call_history.append(time.monotonic())
return func(*args, **kwargs)
return cast(F, wrapped_func)
def _cleanup(self) -> None:
curr_time = time.monotonic()
time_to_expire_before = curr_time - self.period
self.call_history = [
call_time
for call_time in self.call_history
if call_time > time_to_expire_before
]
rate_limit_builder = _RateLimitDecorator
def retry_builder(
tries: int = 20,
delay: float = 0.1,
max_delay: float | None = 60,
backoff: float = 2,
jitter: tuple[float, float] | float = 1,
exceptions: type[Exception] | tuple[type[Exception], ...] = (Exception,),
) -> Callable[[F], F]:
"""Builds a generic wrapper/decorator for calls to external APIs that
may fail due to rate limiting, flakes, or other reasons. Applies exponential
backoff with jitter to retry the call."""
def retry_with_default(func: F) -> F:
@retry(
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
logger=logging.getLogger(__name__),
exceptions=exceptions,
)
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
return func(*args, **kwargs)
return cast(F, wrapped_func)
return retry_with_default

View File

@ -82,10 +82,6 @@ class WebDAVConnector(LoadConnector, PollConnector):
base_url=self.base_url, base_url=self.base_url,
auth=(username, password) auth=(username, password)
) )
# Test connection
self.client.exists(self.remote_path)
except Exception as e: except Exception as e:
logging.error(f"Failed to connect to WebDAV server: {e}") logging.error(f"Failed to connect to WebDAV server: {e}")
raise ConnectorMissingCredentialError( raise ConnectorMissingCredentialError(
@ -308,60 +304,79 @@ class WebDAVConnector(LoadConnector, PollConnector):
yield batch yield batch
def validate_connector_settings(self) -> None: def validate_connector_settings(self) -> None:
"""Validate WebDAV connector settings """Validate WebDAV connector settings.
Raises: Validation should exercise the same code-paths used by the connector
ConnectorMissingCredentialError: If credentials are not loaded (directory listing / PROPFIND), avoiding exists() which may probe with
ConnectorValidationError: If settings are invalid methods that differ across servers.
""" """
if self.client is None: if self.client is None:
raise ConnectorMissingCredentialError( raise ConnectorMissingCredentialError("WebDAV credentials not loaded.")
"WebDAV credentials not loaded."
)
if not self.base_url: if not self.base_url:
raise ConnectorValidationError( raise ConnectorValidationError("No base URL was provided in connector settings.")
"No base URL was provided in connector settings."
) # Normalize directory path: for collections, many servers behave better with trailing '/'
test_path = self.remote_path or "/"
if not test_path.startswith("/"):
test_path = f"/{test_path}"
if test_path != "/" and not test_path.endswith("/"):
test_path = f"{test_path}/"
try: try:
if not self.client.exists(self.remote_path): # Use the same behavior as real sync: list directory with details (PROPFIND)
raise ConnectorValidationError( self.client.ls(test_path, detail=True)
f"Remote path '{self.remote_path}' does not exist on WebDAV server."
)
except Exception as e: except Exception as e:
error_message = str(e) # Prefer structured status codes if present on the exception/response
status = None
for attr in ("status_code", "code"):
v = getattr(e, attr, None)
if isinstance(v, int):
status = v
break
if status is None:
resp = getattr(e, "response", None)
v = getattr(resp, "status_code", None)
if isinstance(v, int):
status = v
if "401" in error_message or "unauthorized" in error_message.lower(): # If we can classify by status code, do it
raise CredentialExpiredError( if status == 401:
"WebDAV credentials appear invalid or expired." raise CredentialExpiredError("WebDAV credentials appear invalid or expired.")
) if status == 403:
if "403" in error_message or "forbidden" in error_message.lower():
raise InsufficientPermissionsError( raise InsufficientPermissionsError(
f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server." f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server."
) )
if status == 404:
if "404" in error_message or "not found" in error_message.lower():
raise ConnectorValidationError( raise ConnectorValidationError(
f"Remote path '{self.remote_path}' does not exist on WebDAV server." f"Remote path '{self.remote_path}' does not exist on WebDAV server."
) )
# Fallback: avoid brittle substring matching that caused false positives.
# Provide the original exception for diagnosis.
raise ConnectorValidationError( raise ConnectorValidationError(
f"Unexpected WebDAV client error: {e}" f"WebDAV validation failed for path '{test_path}': {repr(e)}"
) )
if __name__ == "__main__": if __name__ == "__main__":
credentials_dict = { credentials_dict = {
"username": os.environ.get("WEBDAV_USERNAME"), "username": os.environ.get("WEBDAV_USERNAME"),
"password": os.environ.get("WEBDAV_PASSWORD"), "password": os.environ.get("WEBDAV_PASSWORD"),
} }
credentials_dict = {
"username": "user",
"password": "pass",
}
connector = WebDAVConnector( connector = WebDAVConnector(
base_url=os.environ.get("WEBDAV_URL") or "https://webdav.example.com", base_url="http://172.17.0.1:8080/",
remote_path=os.environ.get("WEBDAV_PATH") or "/", remote_path="/",
) )
try: try:

View File

@ -0,0 +1,667 @@
import copy
import logging
import time
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
import requests
from pydantic import BaseModel
from requests.exceptions import HTTPError
from typing_extensions import override
from common.data_source.config import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS, DocumentSource
from common.data_source.exceptions import ConnectorValidationError, CredentialExpiredError, InsufficientPermissionsError
from common.data_source.html_utils import parse_html_page_basic
from common.data_source.interfaces import CheckpointOutput, CheckpointOutputWrapper, CheckpointedConnector, IndexingHeartbeatInterface, SlimConnectorWithPermSync
from common.data_source.models import BasicExpertInfo, ConnectorCheckpoint, ConnectorFailure, Document, DocumentFailure, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument
from common.data_source.utils import retry_builder, time_str_to_utc,rate_limit_builder
MAX_PAGE_SIZE = 30 # Zendesk API maximum
MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
_SLIM_BATCH_SIZE = 1000
class ZendeskCredentialsNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__(
"Zendesk Credentials are not set up, was load_credentials called?"
)
class ZendeskClient:
def __init__(
self,
subdomain: str,
email: str,
token: str,
calls_per_minute: int | None = None,
):
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
self.auth = (f"{email}/token", token)
self.make_request = request_with_rate_limit(self, calls_per_minute)
def request_with_rate_limit(
client: ZendeskClient, max_calls_per_minute: int | None = None
) -> Callable[[str, dict[str, Any]], dict[str, Any]]:
@retry_builder()
@(
rate_limit_builder(max_calls=max_calls_per_minute, period=60)
if max_calls_per_minute
else lambda x: x
)
def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
response = requests.get(
f"{client.base_url}/{endpoint}", auth=client.auth, params=params
)
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
if retry_after is not None:
# Sleep for the duration indicated by the Retry-After header
time.sleep(int(retry_after))
elif (
response.status_code == 403
and response.json().get("error") == "SupportProductInactive"
):
return response.json()
response.raise_for_status()
return response.json()
return make_request
class ZendeskPageResponse(BaseModel):
data: list[dict[str, Any]]
meta: dict[str, Any]
has_more: bool
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
content_tags: dict[str, str] = {}
params = {"page[size]": MAX_PAGE_SIZE}
try:
while True:
data = client.make_request("guide/content_tags", params)
for tag in data.get("records", []):
content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
return content_tags
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def _get_articles(
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
) -> Iterator[dict[str, Any]]:
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
if start_time is not None:
params["start_time"] = start_time
while True:
data = client.make_request("help_center/articles", params)
for article in data["articles"]:
yield article
if not data.get("meta", {}).get("has_more"):
break
params["page[after]"] = data["meta"]["after_cursor"]
def _get_article_page(
client: ZendeskClient,
start_time: int | None = None,
after_cursor: str | None = None,
page_size: int = MAX_PAGE_SIZE,
) -> ZendeskPageResponse:
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
if start_time is not None:
params["start_time"] = start_time
if after_cursor is not None:
params["page[after]"] = after_cursor
data = client.make_request("help_center/articles", params)
return ZendeskPageResponse(
data=data["articles"],
meta=data["meta"],
has_more=bool(data["meta"].get("has_more", False)),
)
def _get_tickets(
client: ZendeskClient, start_time: int | None = None
) -> Iterator[dict[str, Any]]:
params = {"start_time": start_time or 0}
while True:
data = client.make_request("incremental/tickets.json", params)
for ticket in data["tickets"]:
yield ticket
if not data.get("end_of_stream", False):
params["start_time"] = data["end_time"]
else:
break
# TODO: maybe these don't need to be their own functions?
def _get_tickets_page(
client: ZendeskClient, start_time: int | None = None
) -> ZendeskPageResponse:
params = {"start_time": start_time or 0}
# NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
# in my local testing with very few tickets. We'll look into it if this becomes an
# issue in larger deployments
data = client.make_request("incremental/tickets.json", params)
if data.get("error") == "SupportProductInactive":
raise ValueError(
"Zendesk Support Product is not active for this account, No tickets to index"
)
return ZendeskPageResponse(
data=data["tickets"],
meta={"end_time": data["end_time"]},
has_more=not bool(data.get("end_of_stream", False)),
)
def _fetch_author(
client: ZendeskClient, author_id: str | int
) -> BasicExpertInfo | None:
# Skip fetching if author_id is invalid
# cast to str to avoid issues with zendesk changing their types
if not author_id or str(author_id) == "-1":
return None
try:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
except requests.exceptions.HTTPError:
# Handle any API errors gracefully
return None
def _article_to_document(
article: dict[str, Any],
content_tags: dict[str, str],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
author_id = article.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
updated_at = article.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
text = parse_html_page_basic(article.get("body") or "")
blob = text.encode("utf-8", errors="replace")
# Build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.get("label_names", []) if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.get("content_tag_ids", [])
if tag_id in content_tags
],
}
# Remove empty values
metadata = {k: v for k, v in metadata.items() if v}
return new_author_mapping, Document(
id=f"article:{article['id']}",
source=DocumentSource.ZENDESK,
semantic_identifier=article["title"],
extension=".txt",
blob=blob,
size_bytes=len(blob),
doc_updated_at=update_time,
primary_owners=[author] if author else None,
metadata=metadata,
)
def _get_comment_text(
comment: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, str]:
author_id = comment.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
return new_author_mapping, comment_text
def _ticket_to_document(
ticket: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
submitter_id = ticket.get("submitter")
if not submitter_id:
submitter = None
else:
submitter = (
author_map.get(submitter_id)
if submitter_id in author_map
else _fetch_author(client, submitter_id)
)
new_author_mapping = (
{submitter_id: submitter} if submitter_id and submitter else None
)
updated_at = ticket.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
metadata: dict[str, str | list[str]] = {}
if status := ticket.get("status"):
metadata["status"] = status
if priority := ticket.get("priority"):
metadata["priority"] = priority
if tags := ticket.get("tags"):
metadata["tags"] = tags
if ticket_type := ticket.get("type"):
metadata["ticket_type"] = ticket_type
# Fetch comments for the ticket
comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
comments = comments_data.get("comments", [])
comment_texts = []
for comment in comments:
new_author_mapping, comment_text = _get_comment_text(
comment, author_map, client
)
if new_author_mapping:
author_map.update(new_author_mapping)
comment_texts.append(comment_text)
comments_text = "\n\n".join(comment_texts)
subject = ticket.get("subject")
full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
blob = full_text.encode("utf-8", errors="replace")
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
blob=blob,
extension=".txt",
size_bytes=len(blob),
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=[submitter] if submitter else None,
metadata=metadata,
)
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
# We use cursor-based paginated retrieval for articles
after_cursor_articles: str | None
# We use timestamp-based paginated retrieval for tickets
next_start_time_tickets: int | None
cached_author_map: dict[str, BasicExpertInfo] | None
cached_content_tags: dict[str, str] | None
class ZendeskConnector(
SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint]
):
def __init__(
self,
content_type: str = "articles",
calls_per_minute: int | None = None,
) -> None:
self.content_type = content_type
self.subdomain = ""
# Fetch all tags ahead of time
self.content_tags: dict[str, str] = {}
self.calls_per_minute = calls_per_minute
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
subdomain = (
credentials["zendesk_subdomain"]
.replace("https://", "")
.split(".zendesk.com")[0]
)
self.subdomain = subdomain
self.client = ZendeskClient(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
calls_per_minute=self.calls_per_minute,
)
return None
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
if checkpoint.cached_content_tags is None:
checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
return checkpoint # save the content tags to the checkpoint
self.content_tags = checkpoint.cached_content_tags
if self.content_type == "articles":
checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
return checkpoint
elif self.content_type == "tickets":
checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
return checkpoint
else:
raise ValueError(f"Unsupported content_type: {self.content_type}")
def _retrieve_articles(
self,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
checkpoint = copy.deepcopy(checkpoint)
# This one is built on the fly as there may be more many more authors than tags
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
after_cursor = checkpoint.after_cursor_articles
doc_batch: list[Document] = []
response = _get_article_page(
self.client,
start_time=int(start) if start else None,
after_cursor=after_cursor,
)
articles = response.data
has_more = response.has_more
after_cursor = response.meta.get("after_cursor")
for article in articles:
if (
article.get("body") is None
or article.get("draft")
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.get("label_names", [])
)
):
continue
try:
new_author_map, document = _article_to_document(
article, self.content_tags, author_map, self.client
)
except Exception as e:
logging.error(f"Error processing article {article['id']}: {e}")
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{article.get('id')}",
document_link=article.get("html_url", ""),
),
failure_message=str(e),
exception=e,
)
continue
if new_author_map:
author_map.update(new_author_map)
updated_at = document.doc_updated_at
updated_ts = updated_at.timestamp() if updated_at else None
if updated_ts is not None:
if start is not None and updated_ts <= start:
continue
if end is not None and updated_ts > end:
continue
doc_batch.append(document)
if not has_more:
yield from doc_batch
checkpoint.has_more = False
return checkpoint
# Sometimes no documents are retrieved, but the cursor
# is still updated so the connector makes progress.
yield from doc_batch
checkpoint.after_cursor_articles = after_cursor
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
checkpoint.has_more = bool(
end is None
or last_doc_updated_at is None
or last_doc_updated_at.timestamp() <= end
)
checkpoint.cached_author_map = (
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
)
return checkpoint
def _retrieve_tickets(
self,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
checkpoint = copy.deepcopy(checkpoint)
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
doc_batch: list[Document] = []
next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
tickets = ticket_response.data
has_more = ticket_response.has_more
next_start_time = ticket_response.meta["end_time"]
for ticket in tickets:
if ticket.get("status") == "deleted":
continue
try:
new_author_map, document = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
)
except Exception as e:
logging.error(f"Error processing ticket {ticket['id']}: {e}")
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{ticket.get('id')}",
document_link=ticket.get("url", ""),
),
failure_message=str(e),
exception=e,
)
continue
if new_author_map:
author_map.update(new_author_map)
updated_at = document.doc_updated_at
updated_ts = updated_at.timestamp() if updated_at else None
if updated_ts is not None:
if start is not None and updated_ts <= start:
continue
if end is not None and updated_ts > end:
continue
doc_batch.append(document)
if not has_more:
yield from doc_batch
checkpoint.has_more = False
return checkpoint
yield from doc_batch
checkpoint.next_start_time_tickets = next_start_time
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
checkpoint.has_more = bool(
end is None
or last_doc_updated_at is None
or last_doc_updated_at.timestamp() <= end
)
checkpoint.cached_author_map = (
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
)
return checkpoint
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
if self.content_type == "articles":
articles = _get_articles(
self.client, start_time=int(start) if start else None
)
for article in articles:
slim_doc_batch.append(
SlimDocument(
id=f"article:{article['id']}",
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
elif self.content_type == "tickets":
tickets = _get_tickets(
self.client, start_time=int(start) if start else None
)
for ticket in tickets:
slim_doc_batch.append(
SlimDocument(
id=f"zendesk_ticket_{ticket['id']}",
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
else:
raise ValueError(f"Unsupported content_type: {self.content_type}")
if slim_doc_batch:
yield slim_doc_batch
@override
def validate_connector_settings(self) -> None:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
try:
_get_article_page(self.client, start_time=0)
except HTTPError as e:
# Check for HTTP status codes
if e.response.status_code == 401:
raise CredentialExpiredError(
"Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
) from e
elif e.response.status_code == 403:
raise InsufficientPermissionsError(
"Your Zendesk token does not have sufficient permissions (HTTP 403)."
) from e
elif e.response.status_code == 404:
raise ConnectorValidationError(
"Zendesk resource not found (HTTP 404)."
) from e
else:
raise ConnectorValidationError(
f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
) from e
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> ZendeskConnectorCheckpoint:
return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
return ZendeskConnectorCheckpoint(
after_cursor_articles=None,
next_start_time_tickets=None,
cached_author_map=None,
cached_content_tags=None,
has_more=True,
)
if __name__ == "__main__":
import os
connector = ZendeskConnector(content_type="articles")
connector.load_credentials(
{
"zendesk_subdomain": os.environ["ZENDESK_SUBDOMAIN"],
"zendesk_email": os.environ["ZENDESK_EMAIL"],
"zendesk_token": os.environ["ZENDESK_TOKEN"],
}
)
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
checkpoint = connector.build_dummy_checkpoint()
while checkpoint.has_more:
gen = connector.load_from_checkpoint(
one_day_ago, current, checkpoint
)
wrapper = CheckpointOutputWrapper()
any_doc = False
for document, failure, next_checkpoint in wrapper(gen):
if document:
print("got document:", document.id)
any_doc = True
checkpoint = next_checkpoint
if any_doc:
break

Some files were not shown because too many files have changed in this diff Show More