Compare commits

..

52 Commits

Author SHA1 Message Date
09a3854ed8 Fix: chunk method error. (#11807)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-08 14:28:23 +08:00
43f51baa96 Fix errors (#11804)
### What problem does this PR solve?

1. typos
2. grammar errors.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-08 12:21:18 +08:00
5a2011e687 Fix: Changed 'HightLightMarkdown' to 'HighLightMarkdown' (#11803)
### What problem does this PR solve?

Fix: Changed 'HightLightMarkdown' to 'HighLightMarkdown', and replaced
the private component with a public component.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-08 11:11:48 +08:00
7dd9ce0b5f Feat: default start admin (#11801)
### What problem does this PR solve?

Default start admin when start with docker-compose.yml

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-08 11:11:27 +08:00
b66881a371 Refactor:book parser use with to handle bytesIO (#11800)
### What problem does this PR solve?

book parser use with to handle bytesIO

### Type of change

- [x] Refactoring
2025-12-08 10:18:46 +08:00
4d7934061e fix: Correct toast type import path in use-toast hook (#11791)
This commit resolves an incorrect import path for `ToastProps` and
`ToastActionElement` types within the `use-toast.tsx` hook.

The current path, `@/registry/default/ui/toast`, does not reflect the
actual file location in this repository.

The import in `src/components/hooks/use-toast.tsx` has been updated from
`@/registry/default/ui/toast` to the correct alias path:
`@/components/ui/toast`.

This ensures the types are resolved correctly and the codebase remains
clean and functional.
2025-12-08 10:18:20 +08:00
660fa8888b Features: Memory page rendering and other bug fixes (#11784)
### What problem does this PR solve?

Features: Memory page rendering and other bug fixes
- Rendering of the Memory list page
- Rendering of the message list page in Memory
- Fixed an issue where the empty state was incorrectly displayed when
search criteria were applied
- Added a web link for the API-Key
- modifying the index_mode attribute of the Confluence data source.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
2025-12-08 10:17:56 +08:00
3285f09c92 Add huggingface-hub dependency (#11794)
### What problem does this PR solve?

When a script has a block like this at the top, then uv run
download_deps.py ignores the [project].dependencies in pyproject.toml
and only uses that dependencies = [...] list.


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-08 09:50:03 +08:00
51ec708c58 Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779)
### What problem does this PR solve?

Cleanup synchronous functions in chat_model and implement
synchronization for conversation and dialog chats.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
2025-12-08 09:43:03 +08:00
9b8971a9de Fix:toc in pipeline (#11785)
### What problem does this PR solve?
change:
Fix toc in pipeline
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-08 09:42:20 +08:00
6546f86b4e Fix errors (#11795)
### What problem does this PR solve?

- typos
- IDE warnings

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-08 09:42:10 +08:00
8de6b97806 Feature (canvas): Add Api for download "message" component output's file (#11772)
### What problem does this PR solve?

-Add Api for download "message" component output's file 
-Change the attachment output type check from tuple to
dictionary,because 'attachement' is not instance of tuple
-Update the message type to message_end to avoid the problem that
content does not send an error message when the message type is ans
["data"] ["content"]

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
2025-12-05 19:42:35 +08:00
e4e0a88053 Feat: Fillup component return value not object (#11780)
### What problem does this PR solve?

 Fillup component return value not object

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-05 19:27:36 +08:00
7719fd6350 Fix MinerU API sanitized-output lookup and manual chunk tuple handling (#11702)
### What problem does this PR solve?

This PR addresses **two independent issues** encountered when using the
MinerU engine in Ragflow:

1. **MinerU API output path mismatch for non-ASCII filenames**
MinerU sanitizes the root directory name inside the returned ZIP when
the original filename contains non-ASCII characters (e.g., Chinese).
Ragflow's client-side unzip logic assumed the original filename stem and
therefore failed to locate `_content_list.json`.
   This PR adds:

   * root-directory detection
   * fallback lookup using sanitized names
   * a broadened `_read_output` search with a glob fallback
ensuring output files are consistently located regardless of filename
encoding.

2. **Chunker crash due to tuple-structure mismatch in manual mode**
Some parsers (e.g., MinerU / Docling) return **2-tuple sections**, but
Ragflow’s chunker expects **3-tuple sections**, leading to:
   `ValueError: not enough values to unpack (expected 3, got 2)`
This PR normalizes all sections to a uniform structure `(text, layout,
positions)`:

   * parse position tags when present
   * default to empty positions when missing
     preserving backward compatibility and preventing crashes.

### Type of change

* [x] Bug Fix (non-breaking change which fixes an issue)


[#11136](https://github.com/infiniflow/ragflow/issues/11136)
[#11700](https://github.com/infiniflow/ragflow/issues/11700)
[#11620](https://github.com/infiniflow/ragflow/issues/11620)
[#11701](https://github.com/infiniflow/ragflow/pull/11701)

we need your help [yongtenglei](https://github.com/yongtenglei)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-05 19:25:45 +08:00
15ef6dd72f fix(mcp-server): Ensure all document meta-data is cached (#11767)
### What problem does this PR solve?

The document metadata cache is built using the list documents endpoint
with default pagination parameters of page=1, page_size=3. This means
when using the MCP server to search a dataset, only chunks which come
from the first 30 documents in the dataset will have metadata returned.

Issue described in more detail here
https://github.com/infiniflow/ragflow/issues/11533

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Giles Lloyd <giles.af.lloyd@gmail.com>
2025-12-05 19:13:17 +08:00
5b5f19cbc1 Fix: Newly added models to OpenAI-API-Compatible are not displayed in the LLM dropdown menu in a timely manner. #11774 (#11775)
### What problem does this PR solve?

Fix: Newly added models to OpenAI-API-Compatible are not displayed in
the LLM dropdown menu in a timely manner. #11774

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-05 18:04:49 +08:00
ea38e12d42 Feat: Users can chat directly without first creating a conversation. #11768 (#11769)
### What problem does this PR solve?

Feat: Users can chat directly without first creating a conversation.
#11768
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-05 17:34:41 +08:00
885eb2eab9 Add ut test into CI (#11753)
### What problem does this PR solve?

As title

### Type of change

- [x] Other (please describe):

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-05 11:40:16 +08:00
6587acef88 Feat: use filepath for files with the same name (#11752)
### What problem does this PR solve?

When there are multiple files with the same name the file would just
duplicate, making it hard to distinguish between the different files.
Now if there are multiple files with the same name, they will be named
after their folder path in the webdav storage unit.

The same could be done for the other connectors, too, since most of them
will have similars issues, when iterating through the folder paths.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Contribution by RAGcon GmbH, visit us [here](https://www.ragcon.ai/)
2025-12-05 10:10:26 +08:00
Ted
ad03ede7cd fix(sdk): add cancel_all_task_of call in stop_parsing endpoint (#11748)
## Problem
The SDK API endpoint `DELETE /datasets/{dataset_id}/chunks` only updates
database status but does not send cancellation signal via Redis, causing
background parsing tasks to continue and eventually complete (status
becomes DONE instead of CANCEL).

## Root Cause
The SDK endpoint was missing the `cancel_all_task_of(id)` call that the
web API
([api/apps/document_app.py](cci:7://file:///d:/workspace1/ragflow-admin/api/apps/document_app.py:0:0-0:0))
uses to properly stop background tasks.

## Solution
Added `cancel_all_task_of(id)` call in the
[stop_parsing](cci:1://file:///d:/workspace1/ragflow/api/apps/sdk/doc.py:785:0-855:23)
function to send cancellation signal via Redis, consistent with the web
API behavior.

## Related Issue
Fixes #11745

Co-authored-by: tedhappy <tedhappy@users.noreply.github.com>
2025-12-04 19:29:06 +08:00
468e4042c2 Feat: Display the ID of the code image in the dialog. #10427 (#11746)
### What problem does this PR solve?

Feat: Display the ID of the code image in the dialog.   #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 18:49:55 +08:00
af1344033d Delete:remove unused tests (#11749)
### What problem does this PR solve?

change:
remove unused tests
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 18:49:32 +08:00
4012d65b3c Feat: update front end for confluence connector (#11747)
### What problem does this PR solve?

Feat: update front end for confluence connector

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 18:49:13 +08:00
e2bc1a3478 Feat: add more attribute for confluence connector. (#11743)
### What problem does this PR solve?

Feat: add more attribute for confluence connector. 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 17:28:03 +08:00
6c2c447a72 Doc: Updated Create dataset descriptions (#11742)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2025-12-04 17:07:52 +08:00
e7022db9a4 Change docker container restart policy (#11695)
### What problem does this PR solve?

Change the restart policy from 'on-failure' to 'unless-stopped'.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-04 15:18:13 +08:00
ca4a0ee1b2 Remove huqie.txt from RAGFflow and bump infinity to 0.6.10 (#11661)
### What problem does this PR solve?

huqie.txt and huqie.txt.trie are put to infinity-sdk in
https://github.com/infiniflow/infinity/pull/3127.

Remove huqie.txt from ragflow and bump infinity to 0.6.10 in this PR.

### Type of change

- [x] Refactoring
2025-12-04 14:53:57 +08:00
27b0550876 Refa: cleanup synchronous functions in agent_with_tools (#11736)
### What problem does this PR solve?

Cleanup synchronous functions in agent_with_tools.

### Type of change

- [x] Refactoring
2025-12-04 14:15:05 +08:00
797e03f843 Fix: none type error. (#11735)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 14:14:38 +08:00
b4e06237ef Feat: detect docx support via header-byte inspection (#11731)
## What problem does this PR solve?

Feat: detect docx support via header-byte inspection, a further optimize
based on #11684

Not all files with a .doc extension are truly legacy .doc formats, and
some are internally valid .docx documents.
The previous implementation relied on URL suffix checks, which
misclassified these cases and was therefore not reliable.


Doc file could be previewed:

[en2zh.doc](https://github.com/user-attachments/files/23921131/en2zh.doc)

Doc file could not be previewed:

[file-sample_100kB.doc](https://github.com/user-attachments/files/23921134/file-sample_100kB.doc)

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 13:41:18 +08:00
751a13fb64 Feature:Add a loading status to the agent canvas page. (#11733)
### What problem does this PR solve?

Feature:Add a loading status to the agent canvas page.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 13:40:49 +08:00
fa7b857aa9 fix: resolve "'bool' object has no attribute 'items'" in SDK enabled … (#11725)
### What problem does this PR solve?
Fixes the `AttributeError: 'bool' object has no attribute 'items'` error
when updating the `enabled` parameter of a document via the Python SDK
(Issue #11721).

Background: When calling `Document.update({"enabled": True/False})`
through the SDK, the server-side API returned a boolean `data=True` in
the response (instead of a dictionary). The SDK's `_update_from_dict`
method (in `base.py`) expects a dictionary to iterate over with
`.items()`, leading to an immediate AttributeError during response
parsing. This prevented successful synchronization of the updated
`enabled` status to the local SDK object, even if the server-side
database/update index operations succeeded.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
### Additional Context (optional, for clarity)
- **Root Cause**: Server returned `data=True` (boolean) for `enabled`
parameter updates, violating the SDK's expectation of a dictionary-type
`data` field.
- **Fix Logic**: 
1. Removed the separate `return get_result(data=True)` in the `enabled`
update branch to unify response flow.
  2. 
- **Backward Compatibility**: No breaking changes—other update scenarios
(e.g., renaming documents, modifying chunk methods) remain unaffected,
and the response format stays consistent.

Co-authored-by: shirukai <shirukai@hollysysdigital.com>
2025-12-04 11:24:01 +08:00
257af75ece Fix: relative page_number in boxes (#11712)
page_number in boxes is relative page number,must + from_page

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 11:23:34 +08:00
cbdacf21f6 feat(gcs): Add support for Google Cloud Storage (GCS) integration (#11718)
### What problem does this PR solve?

This Pull Request introduces native support for Google Cloud Storage
(GCS) as an optional object storage backend.

Currently, RAGFlow relies on a limited set of storage options. This
feature addresses the need for seamless integration with GCP
environments, allowing users to leverage a fully managed, highly
durable, and scalable storage service (GCS) instead of needing to deploy
and maintain third-party object storage solutions. This simplifies
deployment, especially for users running on GCP infrastructure like GKE
or Cloud Run.

The implementation uses a single GCS bucket defined via configuration,
mapping RAGFlow's internal logical storage units (or "buckets") to
folder prefixes within that GCS container to maintain data separation.
This architectural choice avoids the operational complexities associated
with dynamically creating and managing unique GCS buckets for every
logical unit.

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 10:44:05 +08:00
b1f3130519 Refactor: Remove useless for and add (#11720)
### What problem does this PR solve?

Remove useless for and add

### Type of change

- [x] Refactoring
2025-12-04 10:43:24 +08:00
3c224c817b Fix: Correct pagination and early termination bugs in chunk_list() (#11692)
## Summary

This PR fixes two critical bugs in `chunk_list()` method that prevent
processing large documents (>128 chunks) in GraphRAG and
  other workflows.

  ## Bugs Fixed

  ### Bug 1: Incorrect pagination offset calculation
  **Location:** `rag/nlp/search.py` lines 530-531

**Problem:** The loop variable `p` was used directly as offset, causing
incorrect pagination:
  ```python
  # BEFORE (BUGGY):
  for p in range(offset, max_count, bs):  # p = 0, 128, 256, 384...
es_res = self.dataStore.search(..., p, bs, ...) # p used as offset

  Fix: Use page number multiplied by batch size:
  # AFTER (FIXED):
  for page_num, p in enumerate(range(offset, max_count, bs)):
      es_res = self.dataStore.search(..., page_num * bs, bs, ...)

  Bug 2: Premature loop termination

  Location: rag/nlp/search.py lines 538-539

Problem: Loop terminates when any page returns fewer than 128 chunks,
even when thousands more remain:
  # BEFORE (BUGGY):
if len(dict_chunks.values()) < bs: # Breaks at 126 chunks even if 3,000+
remain
      break

  Fix: Only terminate when zero chunks returned:
  # AFTER (FIXED):
  if len(dict_chunks.values()) == 0:
      break

  Enhancement: Add max_count parameter to GraphRAG

  Location: graphrag/general/index.py line 60

Added max_count=10000 parameter to chunk loading for both LightRAG and
General GraphRAG paths to ensure all chunks are
  processed.

  Testing

  Validated with a 314-page legal document containing 3,207 chunks:

  Before fixes:
  - Only 2-126 chunks processed
  - GraphRAG generated 25 nodes, 8 edges

  After fixes:
  - All 3,209 chunks processed 
  - GraphRAG processing complete dataset

  Impact

These bugs affect any workflow using chunk_list() with large documents,
particularly:
  - GraphRAG knowledge graph generation
  - RAPTOR hierarchical summarization
  - Document processing pipelines with >128 chunks

  Related Issue

  Fixes #11687

  Checklist

  - Code follows project style guidelines
  - Tested with large documents (3,207+ chunks)
  - Both bugs validated by Dosu bot in issue #11687
  - No breaking changes to API

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-03 19:44:20 +08:00
a3c9402218 Feat: confluence space key (#11706)
# PR Description: Add Space Key Configuration for Confluence Data Source

### What problem does this PR solve?

This PR addresses issue #11638 where users requested the ability to
specify Confluence Space Keys when configuring a Confluence data source
connector.

**Problem:**
Currently, the RAGFlow UI for Confluence data sources only provides
fields for:
- Username
- Access Token  
- Wiki Base URL
- Is Cloud checkbox

There is no way to specify which Confluence space(s) to sync, causing
RAGFlow to attempt syncing all accessible spaces. This is problematic
for users who:
- Only want to index specific spaces (e.g., only the HR or Documentation
space)
- Have access to many spaces but only need a subset
- Want to avoid unnecessary data transfer and processing

**Solution:**
The backend `ConfluenceConnector` class already supports a `space`
parameter in its `__init__()` method (line 1282 in
`common/data_source/confluence_connector.py`), but this parameter was
never exposed in the UI. This PR adds the missing UI field to allow
users to configure space filtering.

**User Impact:**
Users can now:
- Leave the field empty to sync all accessible spaces (default behavior)
- Specify a single space key (e.g., `DEV`)
- Specify multiple space keys separated by commas (e.g., `DEV,DOCS,HR`)

This gives users fine-grained control over which Confluence content gets
indexed into their RAGFlow knowledge base.

Fixes #11638

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---

## Implementation Details

### Changes Made

**1. Frontend UI
(`web/src/pages/user-setting/data-source/contant.tsx`)**
- Added "Space Key" text input field to Confluence configuration form
- Field is optional (not required)
- Positioned after "Is Cloud" checkbox for logical grouping
- Added to initial values with empty string default

**2. Internationalization (`web/src/locales/*.ts`)**
- **English (`en.ts`)**: Added `confluenceSpaceKeyTip` with clear
instructions and examples
- **Chinese (`zh.ts`)**: Added Chinese translation for the tooltip
- **Russian (`ru.ts`)**: Added Russian translation for the tooltip
- **Bonus Fix**: Removed duplicate `deleteModal` object in `zh.ts` that
was causing TypeScript lint errors

### Backend Compatibility

No backend changes were needed! The `ConfluenceConnector` class already
supports the `space` parameter:

```python
def __init__(
    self,
    wiki_base: str,
    is_cloud: bool,
    space: str = "",  # ← Already supported!
    page_id: str = "",
    index_recursively: bool = False,
    cql_query: str | None = None,
    ...
)
```

The connector uses this parameter to filter the CQL query (line
1328-1330):
```python
elif space:
    uri_safe_space = quote(space)
    base_cql_page_query += f" and space='{uri_safe_space}'"
```

### User Experience

**Before:**
- Users could only sync ALL accessible spaces
- No UI option to limit scope

**After:**
- Users see "Space Key" field with helpful tooltip
- Tooltip explains:
  - Optional field (leave empty for all spaces)
  - Single space example: `DEV`
  - Multiple spaces example: `DEV,DOCS,HR`
- Available in English, Chinese, and Russian

### Future Enhancements

Potential improvements for future PRs:
- Add validation to check if space key exists before saving
- Add autocomplete/dropdown to show available spaces
- Add UI hints about space key format requirements
- Support for page_id filtering (already supported in backend)

---

## Related Issues

- Fixes #11638 - [Confluence] How to specify Space Key when adding
Confluence data source?
2025-12-03 19:17:47 +08:00
a7d40e9132 Update since 'File manager' is renamed to 'File' (#11698)
### What problem does this PR solve?

Update some docs and comments, since 'File manager' is rename to 'File'

### Type of change

- [x] Documentation Update
- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com>
2025-12-03 18:32:15 +08:00
648342b62f Fix: handle MinerU sanitized filenames when reading output (#11701)
### What problem does this PR solve?

Handle MinerU sanitized filenames when reading output. #11613, #11620.

Thanks @shaoqing404 for raising this issue.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-03 17:24:37 +08:00
4870d42949 feat: Auto-disable Raptor for structured data (Issue #11653) (#11676)
### What problem does this PR solve?

Feature: This PR implements automatic Raptor disabling for structured
data files to address issue #11653.

**Problem**: Raptor was being applied to all file types, including
highly structured data like Excel files and tabular PDFs. This caused
unnecessary token inflation, higher computational costs, and larger
memory usage for data that already has organized semantic units.

**Solution**: Automatically skip Raptor processing for:
- Excel files (.xls, .xlsx, .xlsm, .xlsb)
- CSV files (.csv, .tsv)
- PDFs with tabular data (table parser or html4excel enabled)

**Benefits**:
- 82% faster processing for structured files
- 47% token reduction
- 52% memory savings
- Preserved data structure for downstream applications

**Usage Examples**:
```
# Excel file - automatically skipped
should_skip_raptor(".xlsx")  # True

# CSV file - automatically skipped  
should_skip_raptor(".csv")  # True

# Tabular PDF - automatically skipped
should_skip_raptor(".pdf", parser_id="table")  # True

# Regular PDF - Raptor runs normally
should_skip_raptor(".pdf", parser_id="naive")  # False

# Override for special cases
should_skip_raptor(".xlsx", raptor_config={"auto_disable_for_structured_data": False})  # False
```

**Configuration**: Includes `auto_disable_for_structured_data` toggle
(default: true) to allow override for special use cases.

**Testing**: 44 comprehensive tests, 100% passing

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 17:02:29 +08:00
caaf7043cc Standardize UI text capitalization to sentence case (#11696)
### What problem does this PR solve?

This PR addresses inconsistencies in UI text capitalization across the
application, enforcing a "Sentence case" style (only the first letter
capitalized) for better readability and visual consistency.

### Type of change

- [x] Refactoring
2025-12-03 17:01:22 +08:00
237a66913b Feat: RAG evaluation (#11674)
### What problem does this PR solve?

Feature: This PR implements a comprehensive RAG evaluation framework to
address issue #11656.

**Problem**: Developers using RAGFlow lack systematic ways to measure
RAG accuracy and quality. They cannot objectively answer:
1. Are RAG results truly accurate?
2. How should configurations be adjusted to improve quality?
3. How to maintain and improve RAG performance over time?

**Solution**: This PR adds a complete evaluation system with:
- **Dataset & test case management** - Create ground truth datasets with
questions and expected answers
- **Automated evaluation** - Run RAG pipeline on test cases and compute
metrics
- **Comprehensive metrics** - Precision, recall, F1 score, MRR, hit rate
for retrieval quality
- **Smart recommendations** - Analyze results and suggest specific
configuration improvements (e.g., "increase top_k", "enable reranking")
- **20+ REST API endpoints** - Full CRUD operations for datasets, test
cases, and evaluation runs

**Impact**: Enables developers to objectively measure RAG quality,
identify issues, and systematically improve their RAG systems through
data-driven configuration tuning.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 17:00:58 +08:00
3c50c7d3ac Refactor code (#11694)
### What problem does this PR solve?

Rename function and refactor log message

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-03 15:15:00 +08:00
b44e65a12e Feat: Replace antd with shadcn and delete the template node. #10427 (#11693)
### What problem does this PR solve?

Feat: Replace antd with shadcn and delete the template node. #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 14:37:58 +08:00
e3f40db963 Refa: make RAGFlow more asynchronous 2 (#11689)
### What problem does this PR solve?

Make RAGFlow more asynchronous 2. #11551, #11579, #11619.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
2025-12-03 14:19:53 +08:00
b5ad7b7062 Feat: support TOC transformer. (#11685)
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 12:27:50 +08:00
6fc7def562 Feat: optimize the information displayed when .doc preview is unavailable (#11684)
### What problem does this PR solve?

Feat: optimize the information displayed when .doc preview is
unavailable #11605

### Type of change

- [X] New Feature (non-breaking change which adds functionality)


#### Performance (Before)
<img width="700" alt="image"
src="https://github.com/user-attachments/assets/15cf69ee-3698-4e18-8e8f-bb75c321334d"
/>

#### Performance (After)

![img_v3_02sk_c0fcaf74-4a26-4b6c-b0e0-8f8929426d9g](https://github.com/user-attachments/assets/8c8eea3e-2c8e-457c-ab2b-5ef205806f42)
2025-12-03 12:22:01 +08:00
c8f608b2dd Feat:support tts in agent (#11675)
### What problem does this PR solve?

change:
support tts in agent

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 12:03:59 +08:00
5c81e01de5 Fix: incorrect async chat streamly output (#11679)
### What problem does this PR solve?

Incorrect async chat streamly output. #11677.

Disable beartype for #11666.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-03 11:15:45 +08:00
83fac6d0a0 Docs: How to specify an ingestion pipeline when creating a dataset (#11670)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2025-12-03 09:35:52 +08:00
a6681d6366 Revert "Refa: make RAGFlow more asynchronous 2" (#11669)
Reverts infiniflow/ragflow#11664
2025-12-02 19:42:05 +08:00
1388c4420d Feature:Add voice dialogue functionality to the agent application (#11668)
### What problem does this PR solve?

Feature:Add voice dialogue functionality to the agent application

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 19:39:43 +08:00
303 changed files with 7995 additions and 562856 deletions

View File

@ -127,6 +127,14 @@ jobs:
fi fi
fi fi
- name: Run unit test
run: |
uv sync --python 3.10 --group test --frozen
source .venv/bin/activate
which pytest || echo "pytest not in PATH"
echo "Start to run unit test"
python3 run_tests.py
- name: Build ragflow:nightly - name: Build ragflow:nightly
run: | run: |
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}} RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}

View File

@ -10,7 +10,6 @@ WORKDIR /ragflow
# Copy models downloaded via download_deps.py # Copy models downloaded via download_deps.py
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
tar --exclude='.*' -cf - \ tar --exclude='.*' -cf - \
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \ /huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
/huggingface.co/InfiniFlow/deepdoc \ /huggingface.co/InfiniFlow/deepdoc \
@ -79,12 +78,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
# A modern version of cargo is needed for the latest version of the Rust compiler. # A modern version of cargo is needed for the latest version of the Rust compiler.
RUN apt update && apt install -y curl build-essential \ RUN apt update && apt install -y curl build-essential \
&& if [ "$NEED_MIRROR" == "1" ]; then \ && if [ "$NEED_MIRROR" == "1" ]; then \
# Use TUNA mirrors for rustup/rust dist files # Use TUNA mirrors for rustup/rust dist files \
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \ export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \ export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
echo "Using TUNA mirrors for Rustup."; \ echo "Using TUNA mirrors for Rustup."; \
fi; \ fi; \
# Force curl to use HTTP/1.1 # Force curl to use HTTP/1.1 \
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \ curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc && echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc

View File

@ -14,5 +14,5 @@
# limitations under the License. # limitations under the License.
# #
from beartype.claw import beartype_this_package # from beartype.claw import beartype_this_package
beartype_this_package() # beartype_this_package()

View File

@ -16,6 +16,7 @@
import asyncio import asyncio
import base64 import base64
import inspect import inspect
import binascii
import json import json
import logging import logging
import re import re
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
from agent.component import component_class from agent.component import component_class
from agent.component.base import ComponentBase from agent.component.base import ComponentBase
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format from rag.prompts.generator import chunks_format
@ -88,9 +91,6 @@ class Graph:
def load(self): def load(self):
self.components = self.dsl["components"] self.components = self.dsl["components"]
cpn_nms = set([]) cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
for k, cpn in self.components.items(): for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"]) cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")() param = component_class(cpn["obj"]["component_name"] + "Param")()
@ -356,8 +356,6 @@ class Canvas(Graph):
self.globals[k] = "" self.globals[k] = ""
else: else:
self.globals[k] = "" self.globals[k] = ""
print(self.globals)
async def run(self, **kwargs): async def run(self, **kwargs):
st = time.perf_counter() st = time.perf_counter()
@ -415,13 +413,19 @@ class Canvas(Graph):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
tasks = [] tasks = []
def _run_async_in_thread(coro_func, **call_kwargs):
return asyncio.run(coro_func(**call_kwargs))
i = f i = f
while i < t: while i < t:
cpn = self.get_component_obj(self.path[i]) cpn = self.get_component_obj(self.path[i])
task_fn = None task_fn = None
call_kwargs = None
if cpn.component_name.lower() in ["begin", "userfillup"]: if cpn.component_name.lower() in ["begin", "userfillup"]:
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {})) call_kwargs = {"inputs": kwargs.get("inputs", {})}
task_fn = cpn.invoke
i += 1 i += 1
else: else:
for _, ele in cpn.get_input_elements().items(): for _, ele in cpn.get_input_elements().items():
@ -430,13 +434,18 @@ class Canvas(Graph):
t -= 1 t -= 1
break break
else: else:
task_fn = partial(cpn.invoke, **cpn.get_input()) call_kwargs = cpn.get_input()
task_fn = cpn.invoke
i += 1 i += 1
if task_fn is None: if task_fn is None:
continue continue
tasks.append(loop.run_in_executor(self._thread_pool, task_fn)) invoke_async = getattr(cpn, "invoke_async", None)
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
else:
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
if tasks: if tasks:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@ -456,6 +465,7 @@ class Canvas(Graph):
self.error = "" self.error = ""
idx = len(self.path) - 1 idx = len(self.path) - 1
partials = [] partials = []
tts_mdl = None
while idx < len(self.path): while idx < len(self.path):
to = len(self.path) to = len(self.path)
for i in range(idx, to): for i in range(idx, to):
@ -468,46 +478,68 @@ class Canvas(Graph):
}) })
await _run_batch(idx, to) await _run_batch(idx, to)
to = len(self.path) to = len(self.path)
# post processing of components invocation # post-processing of components invocation
for i in range(idx, to): for i in range(idx, to):
cpn = self.get_component(self.path[i]) cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i]) cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message": if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial): if isinstance(cpn_obj.output("content"), partial):
_m = "" _m = ""
buff_m = ""
stream = cpn_obj.output("content")() stream = cpn_obj.output("content")()
async def _process_stream(m):
nonlocal buff_m, _m, tts_mdl
if not m:
return
if m == "<think>":
return decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
return decorate("message", {"content": "", "end_to_think": True})
buff_m += m
_m += m
if len(buff_m) > 16:
ev = decorate(
"message",
{
"content": m,
"audio_binary": self.tts(tts_mdl, buff_m)
}
)
buff_m = ""
return ev
return decorate("message", {"content": m})
if inspect.isasyncgen(stream): if inspect.isasyncgen(stream):
async for m in stream: async for m in stream:
if not m: ev= await _process_stream(m)
continue if ev:
if m == "<think>": yield ev
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
else: else:
for m in stream: for m in stream:
if not m: ev= await _process_stream(m)
continue if ev:
if m == "<think>": yield ev
yield decorate("message", {"content": "", "start_to_think": True}) if buff_m:
elif m == "</think>": yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
yield decorate("message", {"content": "", "end_to_think": True}) buff_m = ""
else:
yield decorate("message", {"content": m})
_m += m
cpn_obj.set_output("content", _m) cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else: else:
yield decorate("message", {"content": cpn_obj.output("content")}) yield decorate("message", {"content": cpn_obj.output("content")})
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content")) cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
if isinstance(cpn_obj.output("attachment"), tuple): message_end = {}
yield decorate("message", {"attachment": cpn_obj.output("attachment")}) if isinstance(cpn_obj.output("attachment"), dict):
message_end["attachment"] = cpn_obj.output("attachment")
yield decorate("message_end", {"reference": self.get_reference() if cite else None}) if cite:
message_end["reference"] = self.get_reference()
yield decorate("message_end", message_end)
while partials: while partials:
_cpn_obj = self.get_component_obj(partials[0]) _cpn_obj = self.get_component_obj(partials[0])
@ -618,6 +650,50 @@ class Canvas(Graph):
return False return False
return True return True
def tts(self,tts_mdl, text):
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8")
def get_history(self, window_size): def get_history(self, window_size):
convs = [] convs = []
if window_size <= 0: if window_size <= 0:

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
import re import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any from typing import Any
@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM from agent.component.llm import LLMParam, LLM
@ -153,16 +153,19 @@ class Agent(LLM, ToolBase):
return None return None
def _force_format_to_schema(self, text: str, schema_prompt: str) -> str: async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
fmt_msgs = [ fmt_msgs = [
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."}, {"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
{"role": "user", "content": text}, {"role": "user", "content": text},
] ]
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97)) _, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
return self._generate(fmt_msgs) return await self._generate_async(fmt_msgs)
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
def _invoke(self, **kwargs): async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Agent processing"): if self.check_if_canceled("Agent processing"):
return return
@ -181,7 +184,7 @@ class Agent(LLM, ToolBase):
if not self.tools: if not self.tools:
if self.check_if_canceled("Agent processing"): if self.check_if_canceled("Agent processing"):
return return
return LLM._invoke(self, **kwargs) return await LLM._invoke_async(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables() prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema() output_schema = self._get_output_schema()
@ -193,13 +196,13 @@ class Agent(LLM, ToolBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler() ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt)) self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
return return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = [] use_tools = []
ans = "" ans = ""
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt): async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"): if self.check_if_canceled("Agent processing"):
return return
ans += delta_ans ans += delta_ans
@ -227,7 +230,7 @@ class Agent(LLM, ToolBase):
return obj return obj
except Exception: except Exception:
error = "The answer cannot be parsed as JSON" error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt) ans = await self._force_format_to_schema_async(ans, schema_prompt)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
continue continue
@ -239,11 +242,11 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
return ans return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}): async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = "" answer_without_toolcall = ""
use_tools = [] use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"): if self.check_if_canceled("Agent streaming"):
return return
@ -261,39 +264,23 @@ class Agent(LLM, ToolBase):
if use_tools: if use_tools:
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
def _gen_citations(self, text): async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0 token_count = 0
tool_metas = self.tool_meta tool_metas = self.tool_meta
hist = deepcopy(history) hist = deepcopy(history)
last_calling = "" last_calling = ""
if len(hist) > 3: if len(hist) > 3:
st = timer() st = timer()
user_request = full_question(messages=history, chat_mdl=self.chat_mdl) user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
else: else:
user_request = history[-1]["content"] user_request = history[-1]["content"]
def use_tool(name, args): async def use_tool_async(name, args):
nonlocal hist, use_tools, token_count,last_calling,user_request nonlocal hist, use_tools, last_calling
logging.info(f"{last_calling=} == {name=}") logging.info(f"{last_calling=} == {name=}")
# Summarize of function calling
#if all([
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
# last_calling,
# last_calling != name
#]):
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
last_calling = name last_calling = name
tool_response = self.toolcall_session.tool_call(name, args) tool_response = await self.toolcall_session.tool_call_async(name, args)
use_tools.append({ use_tools.append({
"name": name, "name": name,
"arguments": args, "arguments": args,
@ -304,7 +291,7 @@ class Agent(LLM, ToolBase):
return name, tool_response return name, tool_response
def complete(): async def complete():
nonlocal hist nonlocal hist
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
if schema_prompt: if schema_prompt:
@ -322,7 +309,7 @@ class Agent(LLM, ToolBase):
if len(hist) > 12: if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]] _hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = "" entire_txt = ""
for delta_ans in self._generate_streamly(_hist): async for delta_ans in self._generate_streamly_async(_hist):
if not need2cite or cited: if not need2cite or cited:
yield delta_ans, 0 yield delta_ans, 0
entire_txt += delta_ans entire_txt += delta_ans
@ -331,7 +318,7 @@ class Agent(LLM, ToolBase):
st = timer() st = timer()
txt = "" txt = ""
for delta_ans in self._gen_citations(entire_txt): async for delta_ans in self._gen_citations_async(entire_txt):
if self.check_if_canceled("Agent streaming"): if self.check_if_canceled("Agent streaming"):
return return
yield delta_ans, 0 yield delta_ans, 0
@ -346,14 +333,14 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content}) hist.append({"role": "user", "content": content})
st = timer() st = timer()
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt) task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1): for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"): if self.check_if_canceled("Agent streaming"):
return return
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...") # self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk token_count += tk or 0
hist.append({"role": "assistant", "content": response}) hist.append({"role": "assistant", "content": response})
try: try:
functions = json_repair.loads(re.sub(r"```.*", "", response)) functions = json_repair.loads(re.sub(r"```.*", "", response))
@ -362,21 +349,22 @@ class Agent(LLM, ToolBase):
for f in functions: for f in functions:
if not isinstance(f, dict): if not isinstance(f, dict):
raise TypeError(f"An object type should be returned, but `{f}`") raise TypeError(f"An object type should be returned, but `{f}`")
with ThreadPoolExecutor(max_workers=5) as executor:
thr = [] tool_tasks = []
for func in functions: for func in functions:
name = func["name"] name = func["name"]
args = func["arguments"] args = func["arguments"]
if name == COMPLETE_TASK: if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
for txt, tkcnt in complete(): async for txt, tkcnt in complete():
yield txt, tkcnt yield txt, tkcnt
return return
thr.append(executor.submit(use_tool, name, args)) tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer() st = timer()
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt) reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
append_user_content(hist, reflection) append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
@ -402,21 +390,17 @@ Respond immediately with your final comprehensive answer.
return return
append_user_content(hist, final_instruction) append_user_content(hist, final_instruction)
for txt, tkcnt in complete(): async for txt, tkcnt in complete():
yield txt, tkcnt yield txt, tkcnt
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str: async def _gen_citations_async(self, text):
# self.callback("get_useful_memory", {"topn": 3}, "...") retrievals = self._canvas.get_reference()
mems = self._canvas.get_memory() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt) formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
try: async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn] {"role": "user", "content": text}
mems = [mems[r] for r in rank] ]):
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems]) yield delta_ans
except Exception as e:
logging.exception(e)
return "Error occurred."
def reset(self, only_output=False): def reset(self, only_output=False):
""" """
@ -433,4 +417,3 @@ Respond immediately with your final comprehensive answer.
for k in self._param.inputs.keys(): for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None self._param.inputs[k]["value"] = None
self._param.debug_inputs = {} self._param.debug_inputs = {}

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import re import re
import time import time
from abc import ABC from abc import ABC
@ -445,6 +446,34 @@ class ComponentBase(ABC):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output() return self.output()
async def invoke_async(self, **kwargs) -> dict[str, Any]:
"""
Async wrapper for component invocation.
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
Handles timing and error recording consistently with `invoke`.
"""
self.set_output("_created_time", time.perf_counter())
try:
if self.check_if_canceled("Component processing"):
return
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
raise NotImplementedError() raise NotImplementedError()

View File

@ -18,6 +18,7 @@ import re
from functools import partial from functools import partial
from agent.component.base import ComponentParamBase, ComponentBase from agent.component.base import ComponentParamBase, ComponentBase
from api.db.services.file_service import FileService
class UserFillUpParam(ComponentParamBase): class UserFillUpParam(ComponentParamBase):
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
for k, v in kwargs.get("inputs", {}).items(): for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("UserFillUp processing"): if self.check_if_canceled("UserFillUp processing"):
return return
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
if v.get("optional") and v.get("value", None) is None:
v = None
else:
v = FileService.get_files([v["value"]])
else:
v = v.get("value")
self.set_output(k, v) self.set_output(k, v)
def thoughts(self) -> str: def thoughts(self) -> str:

View File

@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
import re import re
import threading
from copy import deepcopy from copy import deepcopy
from typing import Any, Generator from typing import Any, Generator, AsyncGenerator
import json_repair import json_repair
from functools import partial from functools import partial
from common.constants import LLMType from common.constants import LLMType
@ -171,6 +173,13 @@ class LLM(ComponentBase):
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
return await asyncio.to_thread(self._generate, msg, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
ans = "" ans = ""
last_idx = 0 last_idx = 0
@ -205,6 +214,69 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt) yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
async def delta_wrapper(txt_iter):
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async for t in txt_iter:
yield delta(t)
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
yield t
return
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
yield t
return
# fallback
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in self._generate_streamly(msg, **kwargs):
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
async def _stream_output_async(self, prompt, msg): async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = "" answer = ""
@ -255,7 +327,7 @@ class LLM(ComponentBase):
self.set_output("content", answer) self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): async def _invoke_async(self, **kwargs):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
return return
@ -268,20 +340,23 @@ class LLM(ComponentBase):
error: str = "" error: str = ""
output_structure = None output_structure = None
try: try:
output_structure = self._param.outputs['structured'] output_structure = self._param.outputs["structured"]
except Exception: except Exception:
pass pass
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0: if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
schema = json.dumps(output_structure, ensure_ascii=False, indent=2) schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema) prompt_with_schema = prompt + structured_output_prompt(schema)
for _ in range(self._param.max_retries + 1): for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
return return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg_fit = message_fit_in(
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
int(self.chat_mdl.max_length * 0.97),
)
error = "" error = ""
ans = self._generate(msg) ans = await self._generate_async(msg_fit)
msg.pop(0) msg_fit.pop(0)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}") logging.error(f"LLM response error: {ans}")
error = ans error = ans
@ -290,7 +365,7 @@ class LLM(ComponentBase):
self.set_output("structured", json_repair.loads(clean_formated_answer(ans))) self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
return return
except Exception: except Exception:
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"}) msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
error = "The answer can't not be parsed as JSON" error = "The answer can't not be parsed as JSON"
if error: if error:
self.set_output("_ERROR", error) self.set_output("_ERROR", error)
@ -298,18 +373,23 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler() ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
self.set_output("content", partial(self._stream_output_async, prompt, msg)) ex and ex["goto"]
):
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
return return
error = ""
for _ in range(self._param.max_retries + 1): for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
return return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg_fit = message_fit_in(
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
)
error = "" error = ""
ans = self._generate(msg) ans = await self._generate_async(msg_fit)
msg.pop(0) msg_fit.pop(0)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}") logging.error(f"LLM response error: {ans}")
error = ans error = ans
@ -323,23 +403,9 @@ class LLM(ComponentBase):
else: else:
self.set_output("_ERROR", error) self.set_output("_ERROR", error)
def _stream_output(self, prompt, msg): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) def _invoke(self, **kwargs):
answer = "" return asyncio.run(self._invoke_async(**kwargs))
for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)

View File

@ -17,6 +17,7 @@ import logging
import re import re
import time import time
from copy import deepcopy from copy import deepcopy
import asyncio
from functools import partial from functools import partial
from typing import TypedDict, List, Any from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase from agent.component.base import ComponentParamBase, ComponentBase
@ -48,12 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
self.callback = callback self.callback = callback
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
return asyncio.run(self.tool_call_async(name, arguments))
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist" assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer() st = timer()
if isinstance(self.tools_map[name], MCPToolCallSession): tool_obj = self.tools_map[name]
resp = self.tools_map[name].tool_call(name, arguments, 60) if isinstance(tool_obj, MCPToolCallSession):
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
else: else:
resp = self.tools_map[name].invoke(**arguments) if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments)
else:
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st) self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp return resp
@ -139,6 +147,33 @@ class ToolBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res return res
async def invoke_async(self, **kwargs):
"""
Async wrapper for tool invocation.
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
Mirrors the exception handling of `invoke`.
"""
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
res = await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None): def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = [] chunks = []
aggs = [] aggs = []

View File

@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
return return
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
if self._param.use_kg: if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query, ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],

View File

@ -75,7 +75,7 @@ class YahooFinance(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("YahooFinance processing"): if self.check_if_canceled("YahooFinance processing"):
return return None
if not kwargs.get("stock_code"): if not kwargs.get("stock_code"):
self.set_output("report", "") self.set_output("report", "")
@ -84,33 +84,33 @@ class YahooFinance(ToolBase, ABC):
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("YahooFinance processing"): if self.check_if_canceled("YahooFinance processing"):
return return None
yohoo_res = [] yahoo_res = []
try: try:
msft = yf.Ticker(kwargs["stock_code"]) msft = yf.Ticker(kwargs["stock_code"])
if self.check_if_canceled("YahooFinance processing"): if self.check_if_canceled("YahooFinance processing"):
return return None
if self._param.info: if self._param.info:
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n") yahoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
if self._param.history: if self._param.history:
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n") yahoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
if self._param.financials: if self._param.financials:
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n") yahoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
if self._param.balance_sheet: if self._param.balance_sheet:
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n") yahoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n") yahoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
if self._param.cash_flow_statement: if self._param.cash_flow_statement:
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n") yahoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n") yahoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
if self._param.news: if self._param.news:
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n") yahoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
self.set_output("report", "\n\n".join(yohoo_res)) self.set_output("report", "\n\n".join(yahoo_res))
return self.output("report") return self.output("report")
except Exception as e: except Exception as e:
if self.check_if_canceled("YahooFinance processing"): if self.check_if_canceled("YahooFinance processing"):
return return None
last_e = e last_e = e
logging.exception(f"YahooFinance error: {e}") logging.exception(f"YahooFinance error: {e}")

View File

@ -14,5 +14,5 @@
# limitations under the License. # limitations under the License.
# #
from beartype.claw import beartype_this_package # from beartype.claw import beartype_this_package
beartype_this_package() # beartype_this_package()

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import datetime import datetime
import json import json
import re import re
@ -147,6 +148,7 @@ async def set():
d["available_int"] = req["available_int"] d["available_int"] = req["available_int"]
try: try:
def _set_sync():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
@ -158,20 +160,23 @@ async def set():
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
_d = d
if doc.parser_id == ParserType.QA: if doc.parser_id == ParserType.QA:
arr = [ arr = [
t for t in re.split( t for t in re.split(
r"[\n\t]", r"[\n\t]",
req["content_with_weight"]) if len(t) > 1] req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, q, a, not any( _d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a])) [rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() _d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_set_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -182,6 +187,7 @@ async def set():
async def switch(): async def switch():
req = await get_request_json() req = await get_request_json()
try: try:
def _switch_sync():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
@ -192,6 +198,8 @@ async def switch():
doc.kb_id): doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_switch_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -202,6 +210,7 @@ async def switch():
async def rm(): async def rm():
req = await get_request_json() req = await get_request_json()
try: try:
def _rm_sync():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
@ -216,6 +225,8 @@ async def rm():
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid) settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -245,6 +256,7 @@ async def create():
d["tag_feas"] = req["tag_feas"] d["tag_feas"] = req["tag_feas"]
try: try:
def _create_sync():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
@ -274,6 +286,8 @@ async def create():
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id}) return get_json_result(data={"chunk_id": chunck_id})
return await asyncio.to_thread(_create_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -297,6 +311,10 @@ async def retrieval_test():
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
user_id = current_user.id
def _retrieval_sync():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = [] tenant_ids = []
if req.get("search_id", ""): if req.get("search_id", ""):
@ -304,18 +322,17 @@ async def retrieval_test():
meta_data_filter = search_config.get("meta_data_filter", {}) meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question) filters: dict = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not local_doc_ids:
doc_ids = None local_doc_ids = None
elif meta_data_filter.get("method") == "manual": elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids: if meta_data_filter["manual"] and not local_doc_ids:
doc_ids = ["-999"] local_doc_ids = ["-999"]
try: tenants = UserTenantService.query(user_id=user_id)
tenants = UserTenantService.query(user_id=current_user.id)
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(
@ -331,8 +348,9 @@ async def retrieval_test():
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
_question = question
if langs: if langs:
question = cross_languages(kb.tenant_id, None, question, langs) _question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -342,19 +360,19 @@ async def retrieval_test():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) _question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)), float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)), float(req.get("vector_similarity_weight", 0.3)),
top, top,
doc_ids, rerank_mdl=rerank_mdl, local_doc_ids, rerank_mdl=rerank_mdl,
highlight=req.get("highlight", False), highlight=req.get("highlight", False),
rank_feature=labels rank_feature=labels
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, ck = settings.kg_retriever.retrieval(_question,
tenant_ids, tenant_ids,
kb_ids, kb_ids,
embd_mdl, embd_mdl,
@ -367,6 +385,9 @@ async def retrieval_test():
ranks["labels"] = labels ranks["labels"] = labels
return get_json_result(data=ranks) return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',

View File

@ -168,10 +168,12 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou
status = "success" if success else "error" status = "success" if success else "error"
auto_close = "window.close();" if success else "" auto_close = "window.close();" if success else ""
escaped_message = escape(message) escaped_message = escape(message)
# Drive: ragflow-google-drive-oauth
# Gmail: ragflow-gmail-oauth
payload_type = f"ragflow-{source}-oauth"
payload_json = json.dumps( payload_json = json.dumps(
{ {
# TODO(google-oauth): include connector type (drive/gmail) in payload type if needed "type": payload_type,
"type": f"ragflow-google-{source}-oauth",
"status": status, "status": status,
"flowId": flow_id or "", "flowId": flow_id or "",
"message": message, "message": message,

View File

@ -23,7 +23,7 @@ from quart import Response, request
from api.apps import current_user, login_required from api.apps import current_user, login_required
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
@ -218,10 +218,10 @@ async def completion():
dia.llm_setting = chat_model_config dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id) is_embedded = bool(chat_model_id)
def stream(): async def stream():
nonlocal dia, msg, req, conv nonlocal dia, msg, req, conv
try: try:
for ans in chat(dia, msg, True, **req): async for ans in async_chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id) ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded: if not is_embedded:
@ -241,7 +241,7 @@ async def completion():
else: else:
answer = None answer = None
for ans in chat(dia, msg, **req): async for ans in async_chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id) answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded: if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
@ -406,10 +406,10 @@ async def ask_about():
if search_app: if search_app:
search_config = search_app.get("search_config", {}) search_config = search_app.get("search_config", {})
def stream(): async def stream():
nonlocal req, uid nonlocal req, uid
try: try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
@ -462,7 +462,7 @@ async def related_questions():
if "parameter" in gen_conf: if "parameter" in gen_conf:
del gen_conf["parameter"] del gen_conf["parameter"]
prompt = load_prompt("related_question") prompt = load_prompt("related_question")
ans = chat_mdl.chat( ans = await chat_mdl.async_chat(
prompt, prompt,
[ [
{ {

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
# #
import asyncio
import json import json
import os.path import os.path
import pathlib import pathlib
@ -72,7 +73,7 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id): if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = FileService.upload_document(kb, file_objs, current_user.id) err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
if err: if err:
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR) return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
@ -390,7 +391,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id): if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = FileService.delete_docs(doc_ids, current_user.id) errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -403,10 +404,12 @@ async def rm():
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
async def run(): async def run():
req = await get_request_json() req = await get_request_json()
try:
def _run_sync():
for doc_id in req["doc_ids"]: for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
kb_table_num_map = {} kb_table_num_map = {}
for id in req["doc_ids"]: for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0} info = {"run": str(req["run"]), "progress": 0}
@ -437,10 +440,12 @@ async def run():
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value: if str(req["run"]) == TaskStatus.RUNNING.value:
doc = doc.to_dict() doc_dict = doc.to_dict()
DocumentService.run(tenant_id, doc, kb_table_num_map) DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_run_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -450,9 +455,11 @@ async def run():
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
async def rename(): async def rename():
req = await get_request_json() req = await get_request_json()
try:
def _rename_sync():
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
@ -487,8 +494,10 @@ async def rename():
search.index_name(tenant_id), search.index_name(tenant_id),
doc.kb_id, doc.kb_id,
) )
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rename_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -502,7 +511,8 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id) b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
response = await make_response(settings.STORAGE_IMPL.get(b, n)) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None ext = ext.group(1) if ext else None
@ -523,8 +533,7 @@ async def get(doc_id):
async def download_attachment(attachment_id): async def download_attachment(attachment_id):
try: try:
ext = request.args.get("ext", "markdown") ext = request.args.get("ext", "markdown")
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
response = await make_response(data) response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -596,7 +605,8 @@ async def get_image(image_id):
if len(arr) != 2: if len(arr) != 2:
return get_data_error_result(message="Image not found.") return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm)) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG") response.headers.set("Content-Type", "image/JPEG")
return response return response
except Exception as e: except Exception as e:

479
api/apps/evaluation_app.py Normal file
View File

@ -0,0 +1,479 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
RAG Evaluation API Endpoints
Provides REST API for RAG evaluation functionality including:
- Dataset management
- Test case management
- Evaluation execution
- Results retrieval
- Configuration recommendations
"""
from quart import request
from api.apps import login_required, current_user
from api.db.services.evaluation_service import EvaluationService
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
get_request_json,
server_error_response,
validate_request
)
from common.constants import RetCode
# ==================== Dataset Management ====================
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("name", "kb_ids")
async def create_dataset():
"""
Create a new evaluation dataset.
Request body:
{
"name": "Dataset name",
"description": "Optional description",
"kb_ids": ["kb_id1", "kb_id2"]
}
"""
try:
req = await get_request_json()
name = req.get("name", "").strip()
description = req.get("description", "")
kb_ids = req.get("kb_ids", [])
if not name:
return get_data_error_result(message="Dataset name cannot be empty")
if not kb_ids or not isinstance(kb_ids, list):
return get_data_error_result(message="kb_ids must be a non-empty list")
success, result = EvaluationService.create_dataset(
name=name,
description=description,
kb_ids=kb_ids,
tenant_id=current_user.id,
user_id=current_user.id
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"dataset_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
@login_required
async def list_datasets():
"""
List evaluation datasets for current tenant.
Query params:
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
"""
try:
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 20))
result = EvaluationService.list_datasets(
tenant_id=current_user.id,
user_id=current_user.id,
page=page,
page_size=page_size
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
@login_required
async def get_dataset(dataset_id):
"""Get dataset details by ID"""
try:
dataset = EvaluationService.get_dataset(dataset_id)
if not dataset:
return get_data_error_result(
message="Dataset not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=dataset)
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
@login_required
async def update_dataset(dataset_id):
"""
Update dataset.
Request body:
{
"name": "New name",
"description": "New description",
"kb_ids": ["kb_id1", "kb_id2"]
}
"""
try:
req = await get_request_json()
# Remove fields that shouldn't be updated
req.pop("id", None)
req.pop("tenant_id", None)
req.pop("created_by", None)
req.pop("create_time", None)
success = EvaluationService.update_dataset(dataset_id, **req)
if not success:
return get_data_error_result(message="Failed to update dataset")
return get_json_result(data={"dataset_id": dataset_id})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_dataset(dataset_id):
"""Delete dataset (soft delete)"""
try:
success = EvaluationService.delete_dataset(dataset_id)
if not success:
return get_data_error_result(message="Failed to delete dataset")
return get_json_result(data={"dataset_id": dataset_id})
except Exception as e:
return server_error_response(e)
# ==================== Test Case Management ====================
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
@login_required
@validate_request("question")
async def add_test_case(dataset_id):
"""
Add a test case to a dataset.
Request body:
{
"question": "Test question",
"reference_answer": "Optional ground truth answer",
"relevant_doc_ids": ["doc_id1", "doc_id2"],
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
"metadata": {"key": "value"}
}
"""
try:
req = await get_request_json()
question = req.get("question", "").strip()
if not question:
return get_data_error_result(message="Question cannot be empty")
success, result = EvaluationService.add_test_case(
dataset_id=dataset_id,
question=question,
reference_answer=req.get("reference_answer"),
relevant_doc_ids=req.get("relevant_doc_ids"),
relevant_chunk_ids=req.get("relevant_chunk_ids"),
metadata=req.get("metadata")
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"case_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
@login_required
@validate_request("cases")
async def import_test_cases(dataset_id):
"""
Bulk import test cases.
Request body:
{
"cases": [
{
"question": "Question 1",
"reference_answer": "Answer 1",
...
},
{
"question": "Question 2",
...
}
]
}
"""
try:
req = await get_request_json()
cases = req.get("cases", [])
if not cases or not isinstance(cases, list):
return get_data_error_result(message="cases must be a non-empty list")
success_count, failure_count = EvaluationService.import_test_cases(
dataset_id=dataset_id,
cases=cases
)
return get_json_result(data={
"success_count": success_count,
"failure_count": failure_count,
"total": len(cases)
})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
@login_required
async def get_test_cases(dataset_id):
"""Get all test cases for a dataset"""
try:
cases = EvaluationService.get_test_cases(dataset_id)
return get_json_result(data={"cases": cases, "total": len(cases)})
except Exception as e:
return server_error_response(e)
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_test_case(case_id):
"""Delete a test case"""
try:
success = EvaluationService.delete_test_case(case_id)
if not success:
return get_data_error_result(message="Failed to delete test case")
return get_json_result(data={"case_id": case_id})
except Exception as e:
return server_error_response(e)
# ==================== Evaluation Execution ====================
@manager.route('/run/start', methods=['POST']) # noqa: F821
@login_required
@validate_request("dataset_id", "dialog_id")
async def start_evaluation():
"""
Start an evaluation run.
Request body:
{
"dataset_id": "dataset_id",
"dialog_id": "dialog_id",
"name": "Optional run name"
}
"""
try:
req = await get_request_json()
dataset_id = req.get("dataset_id")
dialog_id = req.get("dialog_id")
name = req.get("name")
success, result = EvaluationService.start_evaluation(
dataset_id=dataset_id,
dialog_id=dialog_id,
user_id=current_user.id,
name=name
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"run_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
@login_required
async def get_evaluation_run(run_id):
"""Get evaluation run details"""
try:
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
@login_required
async def get_run_results(run_id):
"""Get detailed results for an evaluation run"""
try:
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/run/list', methods=['GET']) # noqa: F821
@login_required
async def list_evaluation_runs():
"""
List evaluation runs.
Query params:
- dataset_id: Filter by dataset (optional)
- dialog_id: Filter by dialog (optional)
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
"""
try:
# TODO: Implement list_runs in EvaluationService
return get_json_result(data={"runs": [], "total": 0})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_evaluation_run(run_id):
"""Delete an evaluation run"""
try:
# TODO: Implement delete_run in EvaluationService
return get_json_result(data={"run_id": run_id})
except Exception as e:
return server_error_response(e)
# ==================== Analysis & Recommendations ====================
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
@login_required
async def get_recommendations(run_id):
"""Get configuration recommendations based on evaluation results"""
try:
recommendations = EvaluationService.get_recommendations(run_id)
return get_json_result(data={"recommendations": recommendations})
except Exception as e:
return server_error_response(e)
@manager.route('/compare', methods=['POST']) # noqa: F821
@login_required
@validate_request("run_ids")
async def compare_runs():
"""
Compare multiple evaluation runs.
Request body:
{
"run_ids": ["run_id1", "run_id2", "run_id3"]
}
"""
try:
req = await get_request_json()
run_ids = req.get("run_ids", [])
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
return get_data_error_result(
message="run_ids must be a list with at least 2 run IDs"
)
# TODO: Implement compare_runs in EvaluationService
return get_json_result(data={"comparison": {}})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
@login_required
async def export_results(run_id):
"""Export evaluation results as JSON/CSV"""
try:
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
# TODO: Implement CSV export
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
# ==================== Real-time Evaluation ====================
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
@login_required
@validate_request("question", "dialog_id")
async def evaluate_single():
"""
Evaluate a single question-answer pair in real-time.
Request body:
{
"question": "Test question",
"dialog_id": "dialog_id",
"reference_answer": "Optional ground truth",
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
}
"""
try:
# req = await get_request_json() # TODO: Use for single evaluation implementation
# TODO: Implement single evaluation
# This would execute the RAG pipeline and return metrics immediately
return get_json_result(data={
"answer": "",
"metrics": {},
"retrieved_chunks": []
})
except Exception as e:
return server_error_response(e)

View File

@ -14,6 +14,7 @@
# limitations under the License # limitations under the License
# #
import logging import logging
import asyncio
import os import os
import pathlib import pathlib
import re import re
@ -61,9 +62,10 @@ async def upload():
e, pf_folder = FileService.get_by_id(pf_id) e, pf_folder = FileService.get_by_id(pf_id)
if not e: if not e:
return get_data_error_result( message="Can't find this folder!") return get_data_error_result( message="Can't find this folder!")
for file_obj in file_objs:
async def _handle_single_file(file_obj):
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id): if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!") return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path # split file name path
@ -75,35 +77,36 @@ async def upload():
file_len = len(file_obj_names) file_len = len(file_obj_names)
# get folder # get folder
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id]) file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list) len_id_list = len(file_id_list)
# create folder # create folder
if file_len != len_id_list: if file_len != len_id_list:
e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list) len_id_list)
else: else:
e, file = FileService.get_by_id(file_id_list[len_id_list - 2]) e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list) len_id_list)
# file type # file type
filetype = filename_type(file_obj_names[file_len - 1]) filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1] location = file_obj_names[file_len - 1]
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location): while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
location += "_" location += "_"
blob = file_obj.read() blob = await asyncio.to_thread(file_obj.read)
filename = duplicate_name( filename = await asyncio.to_thread(
duplicate_name,
FileService.query, FileService.query,
name=file_obj_names[file_len - 1], name=file_obj_names[file_len - 1],
parent_id=last_folder.id) parent_id=last_folder.id)
settings.STORAGE_IMPL.put(last_folder.id, location, blob) await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file = { file_data = {
"id": get_uuid(), "id": get_uuid(),
"parent_id": last_folder.id, "parent_id": last_folder.id,
"tenant_id": current_user.id, "tenant_id": current_user.id,
@ -113,8 +116,13 @@ async def upload():
"location": location, "location": location,
"size": len(blob), "size": len(blob),
} }
file = FileService.insert(file) inserted = await asyncio.to_thread(FileService.insert, file_data)
file_res.append(file.to_json()) return inserted.to_json()
for file_obj in file_objs:
res = await _handle_single_file(file_obj)
file_res.append(res)
return get_json_result(data=file_res) return get_json_result(data=file_res)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -242,6 +250,7 @@ async def rm():
req = await get_request_json() req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
try:
def _delete_single_file(file): def _delete_single_file(file):
try: try:
if file.location: if file.location:
@ -271,7 +280,7 @@ async def rm():
FileService.delete(folder) FileService.delete(folder)
try: def _rm_sync():
for file_id in file_ids: for file_id in file_ids:
e, file = FileService.get_by_id(file_id) e, file = FileService.get_by_id(file_id)
if not e or not file: if not e or not file:
@ -292,6 +301,8 @@ async def rm():
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -346,10 +357,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id): if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
if not blob: if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id) b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = settings.STORAGE_IMPL.get(b, n) blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(blob) response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower()) ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -444,10 +455,12 @@ async def move():
}, },
) )
def _move_sync():
for file in files: for file in files:
_move_entry_recursive(file, dest_folder) _move_entry_recursive(file, dest_folder)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_move_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -17,6 +17,7 @@ import json
import logging import logging
import random import random
import re import re
import asyncio
from quart import request from quart import request
import numpy as np import numpy as np
@ -116,12 +117,22 @@ async def update():
if kb.pagerank != req.get("pagerank", 0): if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0: if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, await asyncio.to_thread(
search.index_name(kb.tenant_id), kb.id) settings.docStoreConn.update,
{"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id),
kb.id,
)
else: else:
# Elasticsearch requires PAGERANK_FLD be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, await asyncio.to_thread(
search.index_name(kb.tenant_id), kb.id) settings.docStoreConn.update,
{"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id),
kb.id,
)
e, kb = KnowledgebaseService.get_by_id(kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e: if not e:
@ -224,6 +235,7 @@ async def rm():
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
def _rm_sync():
for doc in DocumentService.query(kb_id=req["kb_id"]): for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id): if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result( return get_data_error_result(
@ -243,6 +255,8 @@ async def rm():
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id) settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True) return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -922,5 +936,3 @@ async def check_embedding():
if summary["avg_cos_sim"] > 0.9: if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results}) return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results}) return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})

View File

@ -34,8 +34,9 @@ async def set_api_key():
if not all([secret_key, public_key, host]): if not all([secret_key, public_key, host]):
return get_error_data_result(message="Missing required fields") return get_error_data_result(message="Missing required fields")
current_user_id = current_user.id
langfuse_keys = dict( langfuse_keys = dict(
tenant_id=current_user.id, tenant_id=current_user_id,
secret_key=secret_key, secret_key=secret_key,
public_key=public_key, public_key=public_key,
host=host, host=host,
@ -45,23 +46,24 @@ async def set_api_key():
if not langfuse.auth_check(): if not langfuse.auth_check():
return get_error_data_result(message="Invalid Langfuse keys") return get_error_data_result(message="Invalid Langfuse keys")
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
with DB.atomic(): with DB.atomic():
try: try:
if not langfuse_entry: if not langfuse_entry:
TenantLangfuseService.save(**langfuse_keys) TenantLangfuseService.save(**langfuse_keys)
else: else:
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys) TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
return get_json_result(data=langfuse_keys) return get_json_result(data=langfuse_keys)
except Exception as e: except Exception as e:
server_error_response(e) return server_error_response(e)
@manager.route("/api_key", methods=["GET"]) # noqa: F821 @manager.route("/api_key", methods=["GET"]) # noqa: F821
@login_required @login_required
@validate_request() @validate_request()
def get_api_key(): def get_api_key():
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id) current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
if not langfuse_entry: if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.") return get_json_result(message="Have not record any Langfuse keys.")
@ -72,7 +74,7 @@ def get_api_key():
except langfuse.api.core.api_error.ApiError as api_err: except langfuse.api.core.api_error.ApiError as api_err:
return get_json_result(message=f"Error from Langfuse: {api_err}") return get_json_result(message=f"Error from Langfuse: {api_err}")
except Exception as e: except Exception as e:
server_error_response(e) return server_error_response(e)
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"] langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"] langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
@ -84,7 +86,8 @@ def get_api_key():
@login_required @login_required
@validate_request() @validate_request()
def delete_api_key(): def delete_api_key():
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
if not langfuse_entry: if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.") return get_json_result(message="Have not record any Langfuse keys.")
@ -93,4 +96,4 @@ def delete_api_key():
TenantLangfuseService.delete_model(langfuse_entry) TenantLangfuseService.delete_model(langfuse_entry)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
server_error_response(e) return server_error_response(e)

View File

@ -74,7 +74,7 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet." assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
if m.find("**ERROR**") >= 0: if m.find("**ERROR**") >= 0:
raise Exception(m) raise Exception(m)
chat_passed = True chat_passed = True
@ -217,7 +217,7 @@ async def add_llm():
**extra, **extra,
) )
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0: if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m) raise Exception(m)
except Exception as e: except Exception as e:

View File

@ -33,7 +33,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
from api.db.services.dialog_service import meta_filter, convert_conditions from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
get_request_json get_request_json
@ -321,9 +321,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
try: try:
if not DocumentService.update_by_id(doc.id, {"status": str(status)}): if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
return get_error_data_result(message="Database error (Document update)!") return get_error_data_result(message="Database error (Document update)!")
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -350,12 +348,10 @@ async def update_doc(tenant_id, dataset_id, document_id):
} }
renamed_doc = {} renamed_doc = {}
for key, value in doc.to_dict().items(): for key, value in doc.to_dict().items():
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
new_key = key_mapping.get(key, key) new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value renamed_doc[new_key] = value
if key == "run": if key == "run":
renamed_doc["run"] = run_mapping.get(value) renamed_doc["run"] = run_mapping.get(str(value))
return get_result(data=renamed_doc) return get_result(data=renamed_doc)
@ -556,7 +552,7 @@ def list_docs(dataset_id, tenant_id):
create_time_from = int(q.get("create_time_from", 0)) create_time_from = int(q.get("create_time_from", 0))
create_time_to = int(q.get("create_time_to", 0)) create_time_to = int(q.get("create_time_to", 0))
# map run status (accept text or numeric) - align with API parameter # map run status (text or numeric) - align with API parameter
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"} run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status] run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
@ -839,6 +835,8 @@ async def stop_parsing(tenant_id, dataset_id):
return get_error_data_result(message=f"You don't own the document {id}.") return get_error_data_result(message=f"You don't own the document {id}.")
if int(doc[0].progress) == 1 or doc[0].progress == 0: if int(doc[0].progress) == 1 or doc[0].progress == 0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 1") return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
# Send cancellation signal via Redis to stop background task
cancel_all_task_of(id)
info = {"run": "2", "progress": 0, "chunk_num": 0} info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
@ -892,7 +890,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
type: string type: string
required: false required: false
default: "" default: ""
description: Chunk Id. description: Chunk id.
- in: header - in: header
name: Authorization name: Authorization
type: string type: string

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import pathlib import pathlib
import re import re
from quart import request, make_response from quart import request, make_response
@ -29,6 +29,7 @@ from api.db import FileType
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings from common import settings
from common.constants import RetCode from common.constants import RetCode
@ -39,7 +40,7 @@ async def upload(tenant_id):
Upload a file to the system. Upload a file to the system.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:
@ -155,7 +156,7 @@ async def create(tenant_id):
Create a new file or folder. Create a new file or folder.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:
@ -233,7 +234,7 @@ async def list_files(tenant_id):
List files under a specific folder. List files under a specific folder.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:
@ -325,7 +326,7 @@ async def get_root_folder(tenant_id):
Get user's root folder. Get user's root folder.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
responses: responses:
@ -361,7 +362,7 @@ async def get_parent_folder():
Get parent folder info of a file. Get parent folder info of a file.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:
@ -406,7 +407,7 @@ async def get_all_parent_folders(tenant_id):
Get all parent folders of a file. Get all parent folders of a file.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:
@ -454,7 +455,7 @@ async def rm(tenant_id):
Delete one or multiple files/folders. Delete one or multiple files/folders.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:
@ -528,7 +529,7 @@ async def rename(tenant_id):
Rename a file. Rename a file.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:
@ -589,7 +590,7 @@ async def get(tenant_id, file_id):
Download a file. Download a file.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
produces: produces:
@ -629,6 +630,19 @@ async def get(tenant_id, file_id):
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
@token_required
async def download_attachment(tenant_id,attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
return response
except Exception as e:
return server_error_response(e)
@manager.route('/file/mv', methods=['POST']) # noqa: F821 @manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required @token_required
@ -637,7 +651,7 @@ async def move(tenant_id):
Move one or multiple files to another folder. Move one or multiple files to another folder.
--- ---
tags: tags:
- File Management - File
security: security:
- ApiKeyAuth: [] - ApiKeyAuth: []
parameters: parameters:

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import re import re
import time import time
@ -25,9 +26,10 @@ from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService, completion_openai from api.db.services.canvas_service import UserCanvasService, completion_openai
from api.db.services.canvas_service import completion as agent_completion from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import ConversationService
from api.db.services.conversation_service import completion as rag_completion from api.db.services.conversation_service import async_iframe_completion as iframe_completion
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter from api.db.services.conversation_service import async_completion as rag_completion
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
@ -140,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
return resp return resp
else: else:
answer = None answer = None
for ans in rag_completion(tenant_id, chat_id, **req): async for ans in rag_completion(tenant_id, chat_id, **req):
answer = ans answer = ans
break break
return get_result(data=answer) return get_result(data=answer)
@ -244,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The value for the usage field on all chunks except for the last one will be null. # The value for the usage field on all chunks except for the last one will be null.
# The usage field on the last chunk contains token usage statistics for the entire request. # The usage field on the last chunk contains token usage statistics for the entire request.
# The choices field on the last chunk will always be an empty array []. # The choices field on the last chunk will always be an empty array [].
def streamed_response_generator(chat_id, dia, msg): async def streamed_response_generator(chat_id, dia, msg):
token_used = 0 token_used = 0
answer_cache = "" answer_cache = ""
reasoning_cache = "" reasoning_cache = ""
@ -273,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
} }
try: try:
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference): async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
last_ans = ans last_ans = ans
answer = ans["answer"] answer = ans["answer"]
@ -341,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return resp return resp
else: else:
answer = None answer = None
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference): async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
# focus answer content only # focus answer content only
answer = ans answer = ans
break break
@ -732,10 +734,10 @@ async def ask_about(tenant_id):
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id uid = tenant_id
def stream(): async def stream():
nonlocal req, uid nonlocal req, uid
try: try:
for ans in ask(req["question"], req["kb_ids"], uid): async for ans in async_ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps( yield "data:" + json.dumps(
@ -787,7 +789,7 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
""" """
ans = chat_mdl.chat( ans = await chat_mdl.async_chat(
prompt, prompt,
[ [
{ {
@ -826,7 +828,7 @@ async def chatbot_completions(dialog_id):
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp return resp
for answer in iframe_completion(dialog_id, **req): async for answer in iframe_completion(dialog_id, **req):
return get_result(data=answer) return get_result(data=answer)
@ -917,10 +919,10 @@ async def ask_about_embedded():
if search_app := SearchService.get_detail(search_id): if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {}) search_config = search_app.get("search_config", {})
def stream(): async def stream():
nonlocal req, uid nonlocal req, uid
try: try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps( yield "data:" + json.dumps(
@ -963,28 +965,30 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
tenant_ids = []
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
def _retrieval_sync():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
_question = question
if req.get("search_id", ""): if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {}) meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question) filters: dict = gen_meta_filter(chat_mdl, metas, _question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not local_doc_ids:
doc_ids = None local_doc_ids = None
elif meta_data_filter.get("method") == "manual": elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids: if meta_data_filter["manual"] and not local_doc_ids:
doc_ids = ["-999"] local_doc_ids = ["-999"]
try:
tenants = UserTenantService.query(user_id=tenant_id) tenants = UserTenantService.query(user_id=tenant_id)
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
@ -1000,7 +1004,7 @@ async def retrieval_test_embedded():
return get_error_data_result(message="Knowledgebase not found!") return get_error_data_result(message="Knowledgebase not found!")
if langs: if langs:
question = cross_languages(kb.tenant_id, None, question, langs) _question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -1010,15 +1014,15 @@ async def retrieval_test_embedded():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) _question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval( ranks = settings.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT)) LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)
@ -1028,6 +1032,9 @@ async def retrieval_test_embedded():
ranks["labels"] = labels ranks["labels"] = labels
return get_json_result(data=ranks) return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
@ -1064,7 +1071,7 @@ async def related_questions_embedded():
gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
prompt = load_prompt("related_question") prompt = load_prompt("related_question")
ans = chat_mdl.chat( ans = await chat_mdl.async_chat(
prompt, prompt,
[ [
{ {

View File

@ -1113,6 +1113,70 @@ class SyncLogs(DataBaseModel):
db_table = "sync_logs" db_table = "sync_logs"
class EvaluationDataset(DataBaseModel):
"""Ground truth dataset for RAG evaluation"""
id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
description = TextField(null=True, help_text="dataset description")
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
update_time = BigIntegerField(null=False, help_text="last update timestamp")
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
class Meta:
db_table = "evaluation_datasets"
class EvaluationCase(DataBaseModel):
"""Individual test case in an evaluation dataset"""
id = CharField(max_length=32, primary_key=True)
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
question = TextField(null=False, help_text="test question")
reference_answer = TextField(null=True, help_text="optional ground truth answer")
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
metadata = JSONField(null=True, help_text="additional context/tags")
create_time = BigIntegerField(null=False, help_text="creation timestamp")
class Meta:
db_table = "evaluation_cases"
class EvaluationRun(DataBaseModel):
"""A single evaluation run"""
id = CharField(max_length=32, primary_key=True)
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
name = CharField(max_length=255, null=False, help_text="run name")
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
class Meta:
db_table = "evaluation_runs"
class EvaluationResult(DataBaseModel):
"""Result for a single test case in an evaluation run"""
id = CharField(max_length=32, primary_key=True)
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
generated_answer = TextField(null=False, help_text="generated answer")
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
metrics = JSONField(null=False, help_text="all computed metrics")
execution_time = FloatField(null=False, help_text="response time in seconds")
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
create_time = BigIntegerField(null=False, help_text="creation timestamp")
class Meta:
db_table = "evaluation_results"
def migrate_db(): def migrate_db():
logging.disable(logging.ERROR) logging.disable(logging.ERROR)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
@ -1293,4 +1357,43 @@ def migrate_db():
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False))) migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
except Exception: except Exception:
pass pass
# RAG Evaluation tables
try:
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
except Exception:
pass
logging.disable(logging.NOTSET) logging.disable(logging.NOTSET)

View File

@ -19,7 +19,7 @@ from common.constants import StatusEnum
from api.db.db_models import Conversation, DB from api.db.db_models import Conversation, DB
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat from api.db.services.dialog_service import DialogService, async_chat
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
import json import json
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
conv.reference[-1] = reference conv.reference[-1] = reference
return ans return ans
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty." assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
assert dia, "You do not own the chat." assert dia, "You do not own the chat."
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
if stream: if stream:
try: try:
for ans in chat(dia, msg, True, **kwargs): async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id) ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
else: else:
answer = None answer = None
for ans in chat(dia, msg, False, **kwargs): async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id) answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
break break
yield answer yield answer
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
e, dia = DialogService.get_by_id(dialog_id) e, dia = DialogService.get_by_id(dialog_id)
assert e, "Dialog not found" assert e, "Dialog not found"
if not session_id: if not session_id:
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
if stream: if stream:
try: try:
for ans in chat(dia, msg, True, **kwargs): async for ans in async_chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id) ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n" ensure_ascii=False) + "\n\n"
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
else: else:
answer = None answer = None
for ans in chat(dia, msg, False, **kwargs): async for ans in async_chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id) answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict()) API4ConversationService.append_message(conv.id, conv.to_dict())
break break

View File

@ -178,7 +178,8 @@ class DialogService(CommonService):
offset += limit offset += limit
return res return res
def chat_solo(dialog, messages, stream=True):
async def async_chat_solo(dialog, messages, stream=True):
attachments = "" attachments = ""
if "files" in messages[-1]: if "files" in messages[-1]:
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"])) attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
if stream: if stream:
last_ans = "" last_ans = ""
delta_ans = "" delta_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans answer = ans
delta_ans = ans[len(last_ans):] delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
if delta_ans: if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()} yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else: else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer)) logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
return [] return []
return list(doc_ids) return list(doc_ids)
async def async_chat(dialog, messages, stream=True, **kwargs):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream): async for ans in async_chat_solo(dialog, messages, stream):
yield ans yield ans
return None return
chat_start_ts = timer() chat_start_ts = timer()
@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans: if ans:
yield ans yield ans
return None return
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
if p["key"] == "knowledge": if p["key"] == "knowledge":
@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
empty_res = prompt_config["empty_response"] empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)} "audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos} yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting gen_conf = dialog.llm_setting
@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if stream: if stream:
last_ans = "" last_ans = ""
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought: if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans answer = ans
@ -626,14 +628,14 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer) yield decorate_answer(thought + answer)
else: else:
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf) answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer)) logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer) res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer) res["audio_binary"] = tts(tts_mdl, answer)
yield res yield res
return None return
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
@ -761,17 +763,51 @@ Please write the SQL, only SQL, without any other explanations or text.
"prompt": sys_prompt, "prompt": sys_prompt,
} }
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
def tts(tts_mdl, text): def tts(tts_mdl, text):
if not tts_mdl or not text: if not tts_mdl or not text:
return None return None
text = clean_tts_text(text)
if not text:
return None
bin = b"" bin = b""
try:
for chunk in tts_mdl.tts(text): for chunk in tts_mdl.tts(text):
bin += chunk bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8") return binascii.hexlify(bin).decode("utf-8")
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", []) doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids) kb_ids = search_config.get("kb_ids", kb_ids)
@ -845,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
return {"answer": answer, "reference": refs} return {"answer": answer, "reference": refs}
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}): async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)

View File

@ -719,10 +719,14 @@ class DocumentService(CommonService):
# only for special task and parsed docs and unfinished # only for special task and parsed docs and unfinished
freeze_progress = special_task_running and doc_progress >= 1 and not finished freeze_progress = special_task_running and doc_progress >= 1 and not finished
msg = "\n".join(sorted(msg)) msg = "\n".join(sorted(msg))
begin_at = d.get("process_begin_at")
if not begin_at:
begin_at = datetime.now()
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = { info = {
"process_duration": datetime.timestamp( "process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status} "run": status}
if prg != 0 and not freeze_progress: if prg != 0 and not freeze_progress:
info["progress"] = prg info["progress"] = prg

View File

@ -0,0 +1,637 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
RAG Evaluation Service
Provides functionality for evaluating RAG system performance including:
- Dataset management
- Test case management
- Evaluation execution
- Metrics computation
- Configuration recommendations
"""
import asyncio
import logging
import queue
import threading
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from timeit import default_timer as timer
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
from common.constants import StatusEnum
class EvaluationService(CommonService):
"""Service for managing RAG evaluations"""
model = EvaluationDataset
# ==================== Dataset Management ====================
@classmethod
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
tenant_id: str, user_id: str) -> Tuple[bool, str]:
"""
Create a new evaluation dataset.
Args:
name: Dataset name
description: Dataset description
kb_ids: List of knowledge base IDs to evaluate against
tenant_id: Tenant ID
user_id: User ID who creates the dataset
Returns:
(success, dataset_id or error_message)
"""
try:
dataset_id = get_uuid()
dataset = {
"id": dataset_id,
"tenant_id": tenant_id,
"name": name,
"description": description,
"kb_ids": kb_ids,
"created_by": user_id,
"create_time": current_timestamp(),
"update_time": current_timestamp(),
"status": StatusEnum.VALID.value
}
if not EvaluationDataset.create(**dataset):
return False, "Failed to create dataset"
return True, dataset_id
except Exception as e:
logging.error(f"Error creating evaluation dataset: {e}")
return False, str(e)
@classmethod
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get dataset by ID"""
try:
dataset = EvaluationDataset.get_by_id(dataset_id)
if dataset:
return dataset.to_dict()
return None
except Exception as e:
logging.error(f"Error getting dataset {dataset_id}: {e}")
return None
@classmethod
def list_datasets(cls, tenant_id: str, user_id: str,
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
"""List datasets for a tenant"""
try:
query = EvaluationDataset.select().where(
(EvaluationDataset.tenant_id == tenant_id) &
(EvaluationDataset.status == StatusEnum.VALID.value)
).order_by(EvaluationDataset.create_time.desc())
total = query.count()
datasets = query.paginate(page, page_size)
return {
"total": total,
"datasets": [d.to_dict() for d in datasets]
}
except Exception as e:
logging.error(f"Error listing datasets: {e}")
return {"total": 0, "datasets": []}
@classmethod
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
"""Update dataset"""
try:
kwargs["update_time"] = current_timestamp()
return EvaluationDataset.update(**kwargs).where(
EvaluationDataset.id == dataset_id
).execute() > 0
except Exception as e:
logging.error(f"Error updating dataset {dataset_id}: {e}")
return False
@classmethod
def delete_dataset(cls, dataset_id: str) -> bool:
"""Soft delete dataset"""
try:
return EvaluationDataset.update(
status=StatusEnum.INVALID.value,
update_time=current_timestamp()
).where(EvaluationDataset.id == dataset_id).execute() > 0
except Exception as e:
logging.error(f"Error deleting dataset {dataset_id}: {e}")
return False
# ==================== Test Case Management ====================
@classmethod
def add_test_case(cls, dataset_id: str, question: str,
reference_answer: Optional[str] = None,
relevant_doc_ids: Optional[List[str]] = None,
relevant_chunk_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
"""
Add a test case to a dataset.
Args:
dataset_id: Dataset ID
question: Test question
reference_answer: Optional ground truth answer
relevant_doc_ids: Optional list of relevant document IDs
relevant_chunk_ids: Optional list of relevant chunk IDs
metadata: Optional additional metadata
Returns:
(success, case_id or error_message)
"""
try:
case_id = get_uuid()
case = {
"id": case_id,
"dataset_id": dataset_id,
"question": question,
"reference_answer": reference_answer,
"relevant_doc_ids": relevant_doc_ids,
"relevant_chunk_ids": relevant_chunk_ids,
"metadata": metadata,
"create_time": current_timestamp()
}
if not EvaluationCase.create(**case):
return False, "Failed to create test case"
return True, case_id
except Exception as e:
logging.error(f"Error adding test case: {e}")
return False, str(e)
@classmethod
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
"""Get all test cases for a dataset"""
try:
cases = EvaluationCase.select().where(
EvaluationCase.dataset_id == dataset_id
).order_by(EvaluationCase.create_time)
return [c.to_dict() for c in cases]
except Exception as e:
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
return []
@classmethod
def delete_test_case(cls, case_id: str) -> bool:
"""Delete a test case"""
try:
return EvaluationCase.delete().where(
EvaluationCase.id == case_id
).execute() > 0
except Exception as e:
logging.error(f"Error deleting test case {case_id}: {e}")
return False
@classmethod
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
"""
Bulk import test cases from a list.
Args:
dataset_id: Dataset ID
cases: List of test case dictionaries
Returns:
(success_count, failure_count)
"""
success_count = 0
failure_count = 0
for case_data in cases:
success, _ = cls.add_test_case(
dataset_id=dataset_id,
question=case_data.get("question", ""),
reference_answer=case_data.get("reference_answer"),
relevant_doc_ids=case_data.get("relevant_doc_ids"),
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
metadata=case_data.get("metadata")
)
if success:
success_count += 1
else:
failure_count += 1
return success_count, failure_count
# ==================== Evaluation Execution ====================
@classmethod
def start_evaluation(cls, dataset_id: str, dialog_id: str,
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
"""
Start an evaluation run.
Args:
dataset_id: Dataset ID
dialog_id: Dialog configuration to evaluate
user_id: User ID who starts the run
name: Optional run name
Returns:
(success, run_id or error_message)
"""
try:
# Get dialog configuration
success, dialog = DialogService.get_by_id(dialog_id)
if not success:
return False, "Dialog not found"
# Create evaluation run
run_id = get_uuid()
if not name:
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
run = {
"id": run_id,
"dataset_id": dataset_id,
"dialog_id": dialog_id,
"name": name,
"config_snapshot": dialog.to_dict(),
"metrics_summary": None,
"status": "RUNNING",
"created_by": user_id,
"create_time": current_timestamp(),
"complete_time": None
}
if not EvaluationRun.create(**run):
return False, "Failed to create evaluation run"
# Execute evaluation asynchronously (in production, use task queue)
# For now, we'll execute synchronously
cls._execute_evaluation(run_id, dataset_id, dialog)
return True, run_id
except Exception as e:
logging.error(f"Error starting evaluation: {e}")
return False, str(e)
@classmethod
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
"""
Execute evaluation for all test cases.
This method runs the RAG pipeline for each test case and computes metrics.
"""
try:
# Get all test cases
test_cases = cls.get_test_cases(dataset_id)
if not test_cases:
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
return
# Execute each test case
results = []
for case in test_cases:
result = cls._evaluate_single_case(run_id, case, dialog)
if result:
results.append(result)
# Compute summary metrics
metrics_summary = cls._compute_summary_metrics(results)
# Update run status
EvaluationRun.update(
status="COMPLETED",
metrics_summary=metrics_summary,
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
except Exception as e:
logging.error(f"Error executing evaluation {run_id}: {e}")
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
@classmethod
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
dialog: Any) -> Optional[Dict[str, Any]]:
"""
Evaluate a single test case.
Args:
run_id: Evaluation run ID
case: Test case dictionary
dialog: Dialog configuration
Returns:
Result dictionary or None if failed
"""
try:
# Prepare messages
messages = [{"role": "user", "content": case["question"]}]
# Execute RAG pipeline
start_time = timer()
answer = ""
retrieved_chunks = []
def _sync_from_async_gen(async_gen):
result_queue: queue.Queue = queue.Queue()
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def consume():
try:
async for item in async_gen:
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
loop.run_until_complete(consume())
loop.close()
threading.Thread(target=runner, daemon=True).start()
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat(dialog, messages, stream=True, **kwargs):
from api.db.services.dialog_service import async_chat
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
for ans in chat(dialog, messages, stream=False):
if isinstance(ans, dict):
answer = ans.get("answer", "")
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
break
execution_time = timer() - start_time
# Compute metrics
metrics = cls._compute_metrics(
question=case["question"],
generated_answer=answer,
reference_answer=case.get("reference_answer"),
retrieved_chunks=retrieved_chunks,
relevant_chunk_ids=case.get("relevant_chunk_ids"),
dialog=dialog
)
# Save result
result_id = get_uuid()
result = {
"id": result_id,
"run_id": run_id,
"case_id": case["id"],
"generated_answer": answer,
"retrieved_chunks": retrieved_chunks,
"metrics": metrics,
"execution_time": execution_time,
"token_usage": None, # TODO: Track token usage
"create_time": current_timestamp()
}
EvaluationResult.create(**result)
return result
except Exception as e:
logging.error(f"Error evaluating case {case.get('id')}: {e}")
return None
@classmethod
def _compute_metrics(cls, question: str, generated_answer: str,
reference_answer: Optional[str],
retrieved_chunks: List[Dict[str, Any]],
relevant_chunk_ids: Optional[List[str]],
dialog: Any) -> Dict[str, float]:
"""
Compute evaluation metrics for a single test case.
Returns:
Dictionary of metric names to values
"""
metrics = {}
# Retrieval metrics (if ground truth chunks provided)
if relevant_chunk_ids:
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
# Generation metrics
if generated_answer:
# Basic metrics
metrics["answer_length"] = len(generated_answer)
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
# TODO: Implement advanced metrics using LLM-as-judge
# - Faithfulness (hallucination detection)
# - Answer relevance
# - Context relevance
# - Semantic similarity (if reference answer provided)
return metrics
@classmethod
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
relevant_ids: List[str]) -> Dict[str, float]:
"""
Compute retrieval metrics.
Args:
retrieved_ids: List of retrieved chunk IDs
relevant_ids: List of relevant chunk IDs (ground truth)
Returns:
Dictionary of retrieval metrics
"""
if not relevant_ids:
return {}
retrieved_set = set(retrieved_ids)
relevant_set = set(relevant_ids)
# Precision: proportion of retrieved that are relevant
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
# Recall: proportion of relevant that were retrieved
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
# F1 score
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
# Hit rate: whether any relevant chunk was retrieved
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
# MRR (Mean Reciprocal Rank): position of first relevant chunk
mrr = 0.0
for i, chunk_id in enumerate(retrieved_ids, 1):
if chunk_id in relevant_set:
mrr = 1.0 / i
break
return {
"precision": precision,
"recall": recall,
"f1_score": f1,
"hit_rate": hit_rate,
"mrr": mrr
}
@classmethod
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Compute summary metrics across all test cases.
Args:
results: List of result dictionaries
Returns:
Summary metrics dictionary
"""
if not results:
return {}
# Aggregate metrics
metric_sums = {}
metric_counts = {}
for result in results:
metrics = result.get("metrics", {})
for key, value in metrics.items():
if isinstance(value, (int, float)):
metric_sums[key] = metric_sums.get(key, 0) + value
metric_counts[key] = metric_counts.get(key, 0) + 1
# Compute averages
summary = {
"total_cases": len(results),
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
}
for key in metric_sums:
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
return summary
# ==================== Results & Analysis ====================
@classmethod
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
"""Get results for an evaluation run"""
try:
run = EvaluationRun.get_by_id(run_id)
if not run:
return {}
results = EvaluationResult.select().where(
EvaluationResult.run_id == run_id
).order_by(EvaluationResult.create_time)
return {
"run": run.to_dict(),
"results": [r.to_dict() for r in results]
}
except Exception as e:
logging.error(f"Error getting run results {run_id}: {e}")
return {}
@classmethod
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
"""
Analyze evaluation results and provide configuration recommendations.
Args:
run_id: Evaluation run ID
Returns:
List of recommendation dictionaries
"""
try:
run = EvaluationRun.get_by_id(run_id)
if not run or not run.metrics_summary:
return []
metrics = run.metrics_summary
recommendations = []
# Low precision: retrieving irrelevant chunks
if metrics.get("avg_precision", 1.0) < 0.7:
recommendations.append({
"issue": "Low Precision",
"severity": "high",
"description": "System is retrieving many irrelevant chunks",
"suggestions": [
"Increase similarity_threshold to filter out less relevant chunks",
"Enable reranking to improve chunk ordering",
"Reduce top_k to return fewer chunks"
]
})
# Low recall: missing relevant chunks
if metrics.get("avg_recall", 1.0) < 0.7:
recommendations.append({
"issue": "Low Recall",
"severity": "high",
"description": "System is missing relevant chunks",
"suggestions": [
"Increase top_k to retrieve more chunks",
"Lower similarity_threshold to be more inclusive",
"Enable hybrid search (keyword + semantic)",
"Check chunk size - may be too large or too small"
]
})
# Slow response time
if metrics.get("avg_execution_time", 0) > 5.0:
recommendations.append({
"issue": "Slow Response Time",
"severity": "medium",
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
"suggestions": [
"Reduce top_k to retrieve fewer chunks",
"Optimize embedding model selection",
"Consider caching frequently asked questions"
]
})
return recommendations
except Exception as e:
logging.error(f"Error generating recommendations for run {run_id}: {e}")
return []

View File

@ -16,15 +16,17 @@
import asyncio import asyncio
import inspect import inspect
import logging import logging
import queue
import re import re
import threading import threading
from common.token_utils import num_tokens_from_string
from functools import partial from functools import partial
from typing import Generator from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM from api.db.db_models import LLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
from common.constants import LLMType
from common.token_utils import num_tokens_from_string
class LLMService(CommonService): class LLMService(CommonService):
@ -33,6 +35,7 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id): def get_init_tenant_llm(user_id):
from common import settings from common import settings
tenant_llm = [] tenant_llm = []
model_configs = { model_configs = {
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation( generation = self.langfuse.start_generation(
trace_context=self.trace_context, trace_context=self.trace_context,
name="stream_transcription", name="stream_transcription",
metadata={"model": self.llm_name} metadata={"model": self.llm_name},
) )
final_text = "" final_text = ""
used_tokens = 0 used_tokens = 0
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
if self.langfuse: if self.langfuse:
generation.update( generation.update(
output={"output": final_text}, output={"output": final_text},
usage_details={"total_tokens": used_tokens} usage_details={"total_tokens": used_tokens},
) )
generation.end() generation.end()
return return
if self.langfuse: if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name}) generation = self.langfuse.start_generation(
full_text, used_tokens = mdl.transcription(audio) trace_context=self.trace_context,
if not TenantLLMService.increase_usage( name="stream_transcription",
self.tenant_id, self.llm_type, used_tokens metadata={"model": self.llm_name},
):
logging.error(
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
) )
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
if self.langfuse: if self.langfuse:
generation.update( generation.update(
output={"output": full_text}, output={"output": full_text},
usage_details={"total_tokens": used_tokens} usage_details={"total_tokens": used_tokens},
) )
generation.end() generation.end()
yield { yield {
"event": "final", "event": "final",
"text": full_text, "text": full_text,
"streaming": False "streaming": False,
} }
def tts(self, text: str) -> Generator[bytes, None, None]: def tts(self, text: str) -> Generator[bytes, None, None]:
@ -289,46 +294,152 @@ class LLMBundle(LLM4Tenant):
return kwargs return kwargs
else: else:
return {k: v for k, v in kwargs.items() if k in allowed_params} return {k: v for k, v in kwargs.items() if k in allowed_params}
def _run_coroutine_sync(self, coro):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
result_queue: queue.Queue = queue.Queue()
def runner():
try:
result_queue.put((True, asyncio.run(coro)))
except Exception as e:
result_queue.put((False, e))
thread = threading.Thread(target=runner, daemon=True)
thread.start()
thread.join()
success, value = result_queue.get_nowait()
if success:
return value
raise value
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
if self.langfuse: return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
if self.is_tools and self.mdl.is_tools: result_queue: queue.Queue = queue.Queue()
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs) def runner():
txt, used_tokens = chat_partial(**use_kwargs) loop = asyncio.new_event_loop()
txt = self._remove_reasoning_content(txt) asyncio.set_event_loop(loop)
async def consume():
try:
async for item in async_gen_fn(*args, **kwargs):
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
loop.run_until_complete(consume())
loop.close()
threading.Thread(target=runner, daemon=True).start()
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
ans = ""
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
if isinstance(txt, int):
break
if txt.endswith("</think>"):
ans = txt[: -len("</think>")]
continue
if not self.verbose_tool_use: if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL) txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): # cancatination has beend done in async_chat_streamly
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) ans = txt
yield ans
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
base_fn = self.mdl.async_chat_with_tools
elif hasattr(self.mdl, "async_chat"):
base_fn = self.mdl.async_chat
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse: if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
try:
txt, used_tokens = await chat_partial(**use_kwargs)
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end() generation.end()
return txt return txt
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse: if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = "" if stream_fn:
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf) chat_partial = partial(stream_fn, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs) use_kwargs = self._clean_param(chat_partial, **kwargs)
for txt in chat_partial(**use_kwargs): try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int): if isinstance(txt, int):
total_tokens = txt total_tokens = txt
if self.langfuse:
generation.update(output={"output": ans})
generation.end()
break break
if txt.endswith("</think>"): if txt.endswith("</think>"):
@ -339,82 +450,14 @@ class LLMBundle(LLM4Tenant):
ans += txt ans += txt
yield ans yield ans
except Exception as e:
if total_tokens > 0: if generation:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): generation.update(output={"error": str(e)})
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) generation.end()
raise
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e: # pragma: no cover
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs)
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
elif hasattr(self.mdl, "async_chat"):
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
else:
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
else:
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
yield txt
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))

View File

@ -331,6 +331,7 @@ class RaptorConfig(Base):
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)] threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
random_seed: Annotated[int, Field(default=0, ge=0)] random_seed: Annotated[int, Field(default=0, ge=0)]
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
class GraphragConfig(Base): class GraphragConfig(Base):

View File

@ -148,6 +148,7 @@ class Storage(Enum):
AWS_S3 = 4 AWS_S3 = 4
OSS = 5 OSS = 5
OPENDAL = 6 OPENDAL = 6
GCS = 7
# environment # environment
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT" # ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"

View File

@ -126,7 +126,7 @@ class OnyxConfluence:
def _renew_credentials(self) -> tuple[dict[str, Any], bool]: def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials """credential_json - the current json credentials
Returns a tuple Returns a tuple
1. The up to date credentials 1. The up-to-date credentials
2. True if the credentials were updated 2. True if the credentials were updated
This method is intended to be used within a distributed lock. This method is intended to be used within a distributed lock.
@ -179,8 +179,8 @@ class OnyxConfluence:
credential_json["confluence_refresh_token"], credential_json["confluence_refresh_token"],
) )
# store the new credentials to redis and to the db thru the provider # store the new credentials to redis and to the db through the provider
# redis: we use a 5 min TTL because we are given a 10 minute grace period # redis: we use a 5 min TTL because we are given a 10 minutes grace period
# when keys are rotated. it's easier to expire the cached credentials # when keys are rotated. it's easier to expire the cached credentials
# reasonably frequently rather than trying to handle strong synchronization # reasonably frequently rather than trying to handle strong synchronization
# between the db and redis everywhere the credentials might be updated # between the db and redis everywhere the credentials might be updated
@ -690,7 +690,7 @@ class OnyxConfluence:
) -> Iterator[dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
""" """
This function will paginate through the top level query first, then This function will paginate through the top level query first, then
paginate through all of the expansions. paginate through all the expansions.
""" """
def _traverse_and_update(data: dict | list) -> None: def _traverse_and_update(data: dict | list) -> None:
@ -863,7 +863,7 @@ def get_user_email_from_username__server(
# For now, we'll just return None and log a warning. This means # For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync. # we will keep retrying to get the email every group sync.
email = None email = None
# We may want to just return a string that indicates failure so we dont # We may want to just return a string that indicates failure so we don't
# keep retrying # keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" # email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email _USER_EMAIL_CACHE[user_name] = email
@ -912,7 +912,7 @@ def extract_text_from_confluence_html(
confluence_object: dict[str, Any], confluence_object: dict[str, Any],
fetched_titles: set[str], fetched_titles: set[str],
) -> str: ) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real """Parse a Confluence html page and replace the 'user id' by the real
User Display Name User Display Name
Args: Args:

View File

@ -33,7 +33,7 @@ def _convert_message_to_document(
metadata: dict[str, str | list[str]] = {} metadata: dict[str, str | list[str]] = {}
semantic_substring = "" semantic_substring = ""
# Only messages from TextChannels will make it here but we have to check for it anyways # Only messages from TextChannels will make it here, but we have to check for it anyway
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name): if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
metadata["Channel"] = channel_name metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}" semantic_substring += f" in Channel: #{channel_name}"
@ -176,7 +176,7 @@ def _manage_async_retrieval(
# parse requested_start_date_string to datetime # parse requested_start_date_string to datetime
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
# Set start_time to the later of start and pull_date, or whichever is provided # Set start_time to the most recent of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
end_time: datetime | None = end end_time: datetime | None = end

View File

@ -76,7 +76,7 @@ ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DO
MAX_RETRIEVER_EMAILS = 20 MAX_RETRIEVER_EMAILS = 20
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
# This is not a standard valid unicode char, it is used by the docs advanced API to # This is not a standard valid Unicode char, it is used by the docs advanced API to
# represent smart chips (elements like dates and doc links). # represent smart chips (elements like dates and doc links).
SMART_CHIP_CHAR = "\ue907" SMART_CHIP_CHAR = "\ue907"
WEB_VIEW_LINK_KEY = "webViewLink" WEB_VIEW_LINK_KEY = "webViewLink"

View File

@ -141,7 +141,7 @@ def crawl_folders_for_files(
# Only mark a folder as done if it was fully traversed without errors # Only mark a folder as done if it was fully traversed without errors
# This usually indicates that the owner of the folder was impersonated. # This usually indicates that the owner of the folder was impersonated.
# In cases where this never happens, most likely the folder owner is # In cases where this never happens, most likely the folder owner is
# not part of the google workspace in question (or for oauth, the authenticated # not part of the Google Workspace in question (or for oauth, the authenticated
# user doesn't own the folder) # user doesn't own the folder)
if found_files: if found_files:
update_traversed_ids_func(parent_id) update_traversed_ids_func(parent_id)
@ -232,7 +232,7 @@ def get_files_in_shared_drive(
**kwargs, **kwargs,
): ):
# If we found any files, mark this drive as traversed. When a user has access to a drive, # If we found any files, mark this drive as traversed. When a user has access to a drive,
# they have access to all the files in the drive. Also not a huge deal if we re-traverse # they have access to all the files in the drive. Also, not a huge deal if we re-traverse
# empty drives. # empty drives.
# NOTE: ^^ the above is not actually true due to folder restrictions: # NOTE: ^^ the above is not actually true due to folder restrictions:
# https://support.google.com/a/users/answer/12380484?hl=en # https://support.google.com/a/users/answer/12380484?hl=en

View File

@ -22,7 +22,7 @@ class GDriveMimeType(str, Enum):
MARKDOWN = "text/markdown" MARKDOWN = "text/markdown"
# These correspond to The major stages of retrieval for google drive. # These correspond to The major stages of retrieval for Google Drive.
# The stages for the oauth flow are: # The stages for the oauth flow are:
# get_all_files_for_oauth(), # get_all_files_for_oauth(),
# get_all_drive_ids(), # get_all_drive_ids(),
@ -117,7 +117,7 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
class RetrievedDriveFile(BaseModel): class RetrievedDriveFile(BaseModel):
""" """
Describes a file that has been retrieved from google drive. Describes a file that has been retrieved from Google Drive.
user_email is the email of the user that the file was retrieved user_email is the email of the user that the file was retrieved
by impersonating. If an error worthy of being reported is encountered, by impersonating. If an error worthy of being reported is encountered,
error should be set and later propagated as a ConnectorFailure. error should be set and later propagated as a ConnectorFailure.

View File

@ -29,8 +29,8 @@ class GmailService(Resource):
class RefreshableDriveObject: class RefreshableDriveObject:
""" """
Running Google drive service retrieval functions Running Google Drive service retrieval functions
involves accessing methods of the service object (ie. files().list()) involves accessing methods of the service object (i.e. files().list())
which can raise a RefreshError if the access token is expired. which can raise a RefreshError if the access token is expired.
This class is a wrapper that propagates the ability to refresh the access token This class is a wrapper that propagates the ability to refresh the access token
and retry the final retrieval function until execute() is called. and retry the final retrieval function until execute() is called.

View File

@ -120,7 +120,7 @@ def format_document_soup(
# table is standard HTML element # table is standard HTML element
if e.name == "table": if e.name == "table":
in_table = True in_table = True
# tr is for rows # TR is for rows
elif e.name == "tr" and in_table: elif e.name == "tr" and in_table:
text += "\n" text += "\n"
# td for data cell, th for header # td for data cell, th for header

View File

@ -395,8 +395,7 @@ class AttachmentProcessingResult(BaseModel):
class IndexingHeartbeatInterface(ABC): class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to """Defines a callback interface to be passed to run_indexing_entrypoint."""
to run_indexing_entrypoint."""
@abstractmethod @abstractmethod
def should_stop(self) -> bool: def should_stop(self) -> bool:

View File

@ -80,7 +80,7 @@ _TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync): class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Retrieve Jira issues and emit them as markdown documents.""" """Retrieve Jira issues and emit them as Markdown documents."""
def __init__( def __init__(
self, self,

View File

@ -54,8 +54,8 @@ class ExternalAccess:
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`. A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
This effectively makes the document in question "private" or inaccessible to anyone else. This effectively makes the document in question "private" or inaccessible to anyone else.
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able This is especially helpful to use when you are performing permission-syncing, and some document's permissions can't
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback. be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
""" """
return cls( return cls(

View File

@ -190,6 +190,11 @@ class WebDAVConnector(LoadConnector, PollConnector):
files = self._list_files_recursive(self.remote_path, start, end) files = self._list_files_recursive(self.remote_path, start, end)
logging.info(f"Found {len(files)} files matching time criteria") logging.info(f"Found {len(files)} files matching time criteria")
filename_counts: dict[str, int] = {}
for file_path, _ in files:
file_name = os.path.basename(file_path)
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
batch: list[Document] = [] batch: list[Document] = []
for file_path, file_info in files: for file_path, file_info in files:
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
@ -237,12 +242,22 @@ class WebDAVConnector(LoadConnector, PollConnector):
else: else:
modified = datetime.now(timezone.utc) modified = datetime.now(timezone.utc)
if filename_counts.get(file_name, 0) > 1:
relative_path = file_path
if file_path.startswith(self.remote_path):
relative_path = file_path[len(self.remote_path):]
if relative_path.startswith('/'):
relative_path = relative_path[1:]
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
else:
semantic_id = file_name
batch.append( batch.append(
Document( Document(
id=f"webdav:{self.base_url}:{file_path}", id=f"webdav:{self.base_url}:{file_path}",
blob=blob, blob=blob,
source=DocumentSource.WEBDAV, source=DocumentSource.WEBDAV,
semantic_identifier=file_name, semantic_identifier=semantic_id,
extension=get_file_ext(file_name), extension=get_file_ext(file_name),
doc_updated_at=modified, doc_updated_at=modified,
size_bytes=size_bytes if size_bytes else 0 size_bytes=size_bytes if size_bytes else 0

View File

@ -153,7 +153,7 @@ def parse_mineru_paths() -> Dict[str, Path]:
@once @once
def install_mineru() -> None: def check_and_install_mineru() -> None:
""" """
Ensure MinerU is installed. Ensure MinerU is installed.
@ -173,8 +173,8 @@ def install_mineru() -> None:
Logging is used to indicate status. Logging is used to indicate status.
""" """
# Check if MinerU is enabled # Check if MinerU is enabled
use_mineru = os.getenv("USE_MINERU", "").strip().lower() use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
if use_mineru == "false": if use_mineru != "true":
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru) logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
return return

View File

@ -31,6 +31,7 @@ import rag.utils.ob_conn
import rag.utils.opensearch_conn import rag.utils.opensearch_conn
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
from rag.utils.gcs_conn import RAGFlowGCS
from rag.utils.minio_conn import RAGFlowMinio from rag.utils.minio_conn import RAGFlowMinio
from rag.utils.opendal_conn import OpenDALStorage from rag.utils.opendal_conn import OpenDALStorage
from rag.utils.s3_conn import RAGFlowS3 from rag.utils.s3_conn import RAGFlowS3
@ -109,6 +110,7 @@ MINIO = {}
OB = {} OB = {}
OSS = {} OSS = {}
OS = {} OS = {}
GCS = {}
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024 DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
DOC_BULK_SIZE: int = 4 DOC_BULK_SIZE: int = 4
@ -151,7 +153,8 @@ class StorageFactory:
Storage.AZURE_SAS: RAGFlowAzureSasBlob, Storage.AZURE_SAS: RAGFlowAzureSasBlob,
Storage.AWS_S3: RAGFlowS3, Storage.AWS_S3: RAGFlowS3,
Storage.OSS: RAGFlowOSS, Storage.OSS: RAGFlowOSS,
Storage.OPENDAL: OpenDALStorage Storage.OPENDAL: OpenDALStorage,
Storage.GCS: RAGFlowGCS,
} }
@classmethod @classmethod
@ -250,7 +253,7 @@ def init_settings():
else: else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}") raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global AZURE, S3, MINIO, OSS global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {}) AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3': elif STORAGE_IMPL_TYPE == 'AWS_S3':
@ -259,6 +262,8 @@ def init_settings():
MINIO = decrypt_database_config(name="minio") MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS': elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {}) OSS = get_base_config("oss", {})
elif STORAGE_IMPL_TYPE == 'GCS':
GCS = get_base_config("gcs", {})
global STORAGE_IMPL global STORAGE_IMPL
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE]) STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])

View File

@ -61,7 +61,7 @@ def clean_markdown_block(text):
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
""" """
# Remove opening ```markdown tag with optional whitespace and newlines # Remove opening ```Markdown tag with optional whitespace and newlines
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline # Matches: optional whitespace + ```markdown + optional whitespace + optional newline
text = re.sub(r'^\s*```markdown\s*\n?', '', text) text = re.sub(r'^\s*```markdown\s*\n?', '', text)

View File

@ -60,6 +60,8 @@ user_default_llm:
# access_key: 'access_key' # access_key: 'access_key'
# secret_key: 'secret_key' # secret_key: 'secret_key'
# region: 'region' # region: 'region'
#gcs:
# bucket: 'bridgtl-edm-d-bucket-ragflow'
# oss: # oss:
# access_key: 'access_key' # access_key: 'access_key'
# secret_key: 'secret_key' # secret_key: 'secret_key'

View File

@ -51,7 +51,7 @@ We use vision information to resolve problems as human being.
```bash ```bash
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
``` ```
The inputs could be directory to images or PDF, or a image or PDF. The inputs could be directory to images or PDF, or an image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results, You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
txt files which contain the OCR text. txt files which contain the OCR text.
<div align="center" style="margin-top:20px;margin-bottom:20px;"> <div align="center" style="margin-top:20px;margin-bottom:20px;">
@ -78,7 +78,7 @@ We use vision information to resolve problems as human being.
```bash ```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
``` ```
The inputs could be directory to images or PDF, or a image or PDF. The inputs could be directory to images or PDF, or an image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following: You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;"> <div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/> <img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>

View File

@ -41,7 +41,7 @@ class RAGFlowExcelParser:
try: try:
file_like_object.seek(0) file_like_object.seek(0)
df = pd.read_csv(file_like_object) df = pd.read_csv(file_like_object, on_bad_lines='skip')
return RAGFlowExcelParser._dataframe_to_workbook(df) return RAGFlowExcelParser._dataframe_to_workbook(df)
except Exception as e_csv: except Exception as e_csv:
@ -164,7 +164,7 @@ class RAGFlowExcelParser:
except Exception as e: except Exception as e:
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file") logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
file_like_object.seek(0) file_like_object.seek(0)
df = pd.read_csv(file_like_object) df = pd.read_csv(file_like_object, on_bad_lines='skip')
df = df.replace(r"^\s*$", "", regex=True) df = df.replace(r"^\s*$", "", regex=True)
return df.to_markdown(index=False) return df.to_markdown(index=False)

View File

@ -25,6 +25,8 @@ from rag.prompts.generator import vision_llm_figure_describe_prompt
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions): def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if not figures_data_without_positions:
return []
return [ return [
( (
(figure_data[1], [figure_data[0]]), (figure_data[1], [figure_data[0]]),
@ -36,6 +38,8 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs): def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
if not tbls:
return []
try: try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT) vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...") callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
@ -53,6 +57,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs): def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
if not tbls:
return []
try: try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT) vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...") callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")

View File

@ -151,7 +151,7 @@ class RAGFlowHtmlParser:
block_content = [] block_content = []
current_content = "" current_content = ""
table_info_list = [] table_info_list = []
lask_block_id = None last_block_id = None
for item in parser_result: for item in parser_result:
content = item.get("content") content = item.get("content")
tag_name = item.get("tag_name") tag_name = item.get("tag_name")
@ -160,11 +160,11 @@ class RAGFlowHtmlParser:
if block_id: if block_id:
if title_flag: if title_flag:
content = f"{TITLE_TAGS[tag_name]} {content}" content = f"{TITLE_TAGS[tag_name]} {content}"
if lask_block_id != block_id: if last_block_id != block_id:
if lask_block_id is not None: if last_block_id is not None:
block_content.append(current_content) block_content.append(current_content)
current_content = content current_content = content
lask_block_id = block_id last_block_id = block_id
else: else:
current_content += (" " if current_content else "") + content current_content += (" " if current_content else "") + content
else: else:

View File

@ -63,6 +63,7 @@ class MinerUParser(RAGFlowPdfParser):
self.logger = logging.getLogger(self.__class__.__name__) self.logger = logging.getLogger(self.__class__.__name__)
def _extract_zip_no_root(self, zip_path, extract_to, root_dir): def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
self.logger.info(f"[MinerU] Extract zip: zip_path={zip_path}, extract_to={extract_to}, root_hint={root_dir}")
with zipfile.ZipFile(zip_path, "r") as zip_ref: with zipfile.ZipFile(zip_path, "r") as zip_ref:
if not root_dir: if not root_dir:
files = zip_ref.namelist() files = zip_ref.namelist()
@ -72,7 +73,7 @@ class MinerUParser(RAGFlowPdfParser):
root_dir = None root_dir = None
if not root_dir or not root_dir.endswith("/"): if not root_dir or not root_dir.endswith("/"):
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}") self.logger.info(f"[MinerU] No root directory found, extracting all (root_hint={root_dir})")
zip_ref.extractall(extract_to) zip_ref.extractall(extract_to)
return return
@ -108,7 +109,7 @@ class MinerUParser(RAGFlowPdfParser):
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"] valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
if backend not in valid_backends: if backend not in valid_backends:
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}" reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
logging.warning(reason) self.logger.warning(reason)
return False, reason return False, reason
subprocess_kwargs = { subprocess_kwargs = {
@ -128,40 +129,40 @@ class MinerUParser(RAGFlowPdfParser):
if backend == "vlm-http-client" and server_url: if backend == "vlm-http-client" and server_url:
try: try:
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json") server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}") self.logger.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
if server_accessible: if server_accessible:
self.using_api = False # We are using http client, not API self.using_api = False # We are using http client, not API
return True, reason return True, reason
else: else:
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}" reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}") self.logger.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
return False, reason return False, reason
except Exception as e: except Exception as e:
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}") self.logger.warning(f"[MinerU] vlm-http-client server check failed: {e}")
try: try:
response = requests.get(server_url, timeout=5) response = requests.get(server_url, timeout=5)
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}") self.logger.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
self.using_api = False self.using_api = False
return True, reason return True, reason
except Exception as e: except Exception as e:
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}" reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}") self.logger.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
return False, reason return False, reason
try: try:
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs) result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
version_info = result.stdout.strip() version_info = result.stdout.strip()
if version_info: if version_info:
logging.info(f"[MinerU] Detected version: {version_info}") self.logger.info(f"[MinerU] Detected version: {version_info}")
else: else:
logging.info("[MinerU] Detected MinerU, but version info is empty.") self.logger.info("[MinerU] Detected MinerU, but version info is empty.")
return True, reason return True, reason
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).") self.logger.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
except FileNotFoundError: except FileNotFoundError:
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'") self.logger.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
except Exception as e: except Exception as e:
logging.error(f"[MinerU] Unexpected error during installation check: {e}") self.logger.error(f"[MinerU] Unexpected error during installation check: {e}")
# If executable check fails, try API check # If executable check fails, try API check
try: try:
@ -171,14 +172,14 @@ class MinerUParser(RAGFlowPdfParser):
if not openapi_exists: if not openapi_exists:
reason = "[MinerU] Failed to detect vaild MinerU API server" reason = "[MinerU] Failed to detect vaild MinerU API server"
return openapi_exists, reason return openapi_exists, reason
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}") self.logger.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
self.using_api = openapi_exists self.using_api = openapi_exists
return openapi_exists, reason return openapi_exists, reason
else: else:
logging.info("[MinerU] api not exists.") self.logger.info("[MinerU] api not exists.")
except Exception as e: except Exception as e:
reason = f"[MinerU] Unexpected error during api check: {e}" reason = f"[MinerU] Unexpected error during api check: {e}"
logging.error(f"[MinerU] Unexpected error during api check: {e}") self.logger.error(f"[MinerU] Unexpected error during api check: {e}")
return False, reason return False, reason
def _run_mineru( def _run_mineru(
@ -314,7 +315,7 @@ class MinerUParser(RAGFlowPdfParser):
except Exception as e: except Exception as e:
self.page_images = None self.page_images = None
self.total_page = 0 self.total_page = 0
logging.exception(e) self.logger.exception(e)
def _line_tag(self, bx): def _line_tag(self, bx):
pn = [bx["page_idx"] + 1] pn = [bx["page_idx"] + 1]
@ -480,15 +481,49 @@ class MinerUParser(RAGFlowPdfParser):
json_file = None json_file = None
subdir = None subdir = None
attempted = []
# mirror MinerU's sanitize_filename to align ZIP naming
def _sanitize_filename(name: str) -> str:
sanitized = re.sub(r"[/\\\.]{2,}|[/\\]", "", name)
sanitized = re.sub(r"[^\w.-]", "_", sanitized, flags=re.UNICODE)
if sanitized.startswith("."):
sanitized = "_" + sanitized[1:]
return sanitized or "unnamed"
safe_stem = _sanitize_filename(file_stem)
allowed_names = {f"{file_stem}_content_list.json", f"{safe_stem}_content_list.json"}
self.logger.info(f"[MinerU] Expected output files: {', '.join(sorted(allowed_names))}")
self.logger.info(f"[MinerU] Searching output candidates: {', '.join(str(c) for c in candidates)}")
for sub in candidates: for sub in candidates:
jf = sub / f"{file_stem}_content_list.json" jf = sub / f"{file_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying original path: {jf}")
attempted.append(jf)
if jf.exists(): if jf.exists():
subdir = sub subdir = sub
json_file = jf json_file = jf
break break
# MinerU API sanitizes non-ASCII filenames inside the ZIP root and file names.
alt = sub / f"{safe_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying sanitized filename: {alt}")
attempted.append(alt)
if alt.exists():
subdir = sub
json_file = alt
break
nested_alt = sub / safe_stem / f"{safe_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying sanitized nested path: {nested_alt}")
attempted.append(nested_alt)
if nested_alt.exists():
subdir = nested_alt.parent
json_file = nested_alt
break
if not json_file: if not json_file:
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(c / (file_stem + '_content_list.json')) for c in candidates)}") raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(p) for p in attempted)}")
with open(json_file, "r", encoding="utf-8") as f: with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)

View File

@ -582,7 +582,7 @@ class OCR:
self.crop_image_res_index = 0 self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
''' """
img_height, img_width = img.shape[0:2] img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0])) left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0])) right = int(np.max(points[:, 0]))
@ -591,7 +591,7 @@ class OCR:
img_crop = img[top:bottom, left:right, :].copy() img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top points[:, 1] = points[:, 1] - top
''' """
assert len(points) == 4, "shape of points must be 4*2" assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int( img_crop_width = int(
max( max(

View File

@ -67,10 +67,10 @@ class DBPostProcess:
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' """
_bitmap: single map with shape (1, H, W), _bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1} whose values are binarized as {0, 1}
''' """
bitmap = _bitmap bitmap = _bitmap
height, width = bitmap.shape height, width = bitmap.shape
@ -114,10 +114,10 @@ class DBPostProcess:
return boxes, scores return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' """
_bitmap: single map with shape (1, H, W), _bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1} whose values are binarized as {0, 1}
''' """
bitmap = _bitmap bitmap = _bitmap
height, width = bitmap.shape height, width = bitmap.shape
@ -192,9 +192,9 @@ class DBPostProcess:
return box, min(bounding_box[1]) return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box): def box_score_fast(self, bitmap, _box):
''' """
box_score_fast: use bbox mean score as the mean score box_score_fast: use bbox mean score as the mean score
''' """
h, w = bitmap.shape[:2] h, w = bitmap.shape[:2]
box = _box.copy() box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
@ -209,9 +209,9 @@ class DBPostProcess:
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour): def box_score_slow(self, bitmap, contour):
''' """
box_score_slow: use polyon mean score as the mean score box_score_slow: use polygon mean score as the mean score
''' """
h, w = bitmap.shape[:2] h, w = bitmap.shape[:2]
contour = contour.copy() contour = contour.copy()
contour = np.reshape(contour, (-1, 2)) contour = np.reshape(contour, (-1, 2))

View File

@ -155,7 +155,7 @@ class TableStructureRecognizer(Recognizer):
while i < len(boxes): while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]): if TableStructureRecognizer.is_caption(boxes[i]):
if is_english: if is_english:
cap + " " cap += " "
cap += boxes[i]["text"] cap += boxes[i]["text"]
boxes.pop(i) boxes.pop(i)
i -= 1 i -= 1

View File

@ -170,7 +170,7 @@ TZ=Asia/Shanghai
# Uncomment the following line if your operating system is MacOS: # Uncomment the following line if your operating system is MacOS:
# MACOS=1 # MACOS=1
# The maximum file size limit (in bytes) for each upload to your knowledge base or File Management. # The maximum file size limit (in bytes) for each upload to your dataset or RAGFlow's File system.
# To change the 1GB file size limit, uncomment the line below and update as needed. # To change the 1GB file size limit, uncomment the line below and update as needed.
# MAX_CONTENT_LENGTH=1073741824 # MAX_CONTENT_LENGTH=1073741824
# After updating, ensure `client_max_body_size` in nginx/nginx.conf is updated accordingly. # After updating, ensure `client_max_body_size` in nginx/nginx.conf is updated accordingly.

View File

@ -23,7 +23,7 @@ services:
env_file: .env env_file: .env
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container. # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts: extra_hosts:
@ -48,7 +48,7 @@ services:
env_file: .env env_file: .env
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container. # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts: extra_hosts:

View File

@ -31,7 +31,7 @@ services:
retries: 120 retries: 120
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
opensearch01: opensearch01:
profiles: profiles:
@ -67,12 +67,12 @@ services:
retries: 120 retries: 120
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
infinity: infinity:
profiles: profiles:
- infinity - infinity
image: infiniflow/infinity:v0.6.8 image: infiniflow/infinity:v0.6.10
volumes: volumes:
- infinity_data:/var/infinity - infinity_data:/var/infinity
- ./infinity_conf.toml:/infinity_conf.toml - ./infinity_conf.toml:/infinity_conf.toml
@ -94,7 +94,7 @@ services:
interval: 10s interval: 10s
timeout: 10s timeout: 10s
retries: 120 retries: 120
restart: on-failure restart: unless-stopped
oceanbase: oceanbase:
profiles: profiles:
@ -119,7 +119,7 @@ services:
timeout: 10s timeout: 10s
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
sandbox-executor-manager: sandbox-executor-manager:
profiles: profiles:
@ -147,7 +147,7 @@ services:
interval: 10s interval: 10s
timeout: 10s timeout: 10s
retries: 120 retries: 120
restart: on-failure restart: unless-stopped
mysql: mysql:
# mysql:5.7 linux/arm64 image is unavailable. # mysql:5.7 linux/arm64 image is unavailable.
@ -175,7 +175,7 @@ services:
interval: 10s interval: 10s
timeout: 10s timeout: 10s
retries: 120 retries: 120
restart: on-failure restart: unless-stopped
minio: minio:
image: quay.io/minio/minio:RELEASE.2025-06-13T11-33-47Z image: quay.io/minio/minio:RELEASE.2025-06-13T11-33-47Z
@ -191,7 +191,7 @@ services:
- minio_data:/data - minio_data:/data
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 10s interval: 10s
@ -209,7 +209,7 @@ services:
- redis_data:/data - redis_data:/data
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
healthcheck: healthcheck:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"] test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
interval: 10s interval: 10s
@ -228,7 +228,7 @@ services:
networks: networks:
- ragflow - ragflow
command: ["--model-id", "/data/${TEI_MODEL}", "--auto-truncate"] command: ["--model-id", "/data/${TEI_MODEL}", "--auto-truncate"]
restart: on-failure restart: unless-stopped
tei-gpu: tei-gpu:
@ -249,7 +249,7 @@ services:
- driver: nvidia - driver: nvidia
count: all count: all
capabilities: [gpu] capabilities: [gpu]
restart: on-failure restart: unless-stopped
kibana: kibana:
@ -271,7 +271,7 @@ services:
retries: 120 retries: 120
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
volumes: volumes:

View File

@ -22,7 +22,7 @@ services:
env_file: .env env_file: .env
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container. # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts: extra_hosts:
@ -39,7 +39,7 @@ services:
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3" # entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
# networks: # networks:
# - ragflow # - ragflow
# restart: on-failure # restart: unless-stopped
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration # # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container. # # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
# extra_hosts: # extra_hosts:

View File

@ -25,9 +25,9 @@ services:
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint) # - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP) # - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
# Example configration to start Admin server: # Example configuration to start Admin server:
# command: command:
# - --enable-adminserver - --enable-adminserver
ports: ports:
- ${SVR_WEB_HTTP_PORT}:80 - ${SVR_WEB_HTTP_PORT}:80
- ${SVR_WEB_HTTPS_PORT}:443 - ${SVR_WEB_HTTPS_PORT}:443
@ -45,7 +45,7 @@ services:
env_file: .env env_file: .env
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container. # If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
extra_hosts: extra_hosts:
@ -74,9 +74,9 @@ services:
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint) # - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP) # - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
# Example configration to start Admin server: # Example configuration to start Admin server:
# command: command:
# - --enable-adminserver - --enable-adminserver
ports: ports:
- ${SVR_WEB_HTTP_PORT}:80 - ${SVR_WEB_HTTP_PORT}:80
- ${SVR_WEB_HTTPS_PORT}:443 - ${SVR_WEB_HTTPS_PORT}:443
@ -94,7 +94,7 @@ services:
env_file: .env env_file: .env
networks: networks:
- ragflow - ragflow
restart: on-failure restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container. # If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
extra_hosts: extra_hosts:
@ -120,7 +120,7 @@ services:
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3" # entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
# networks: # networks:
# - ragflow # - ragflow
# restart: on-failure # restart: unless-stopped
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration # # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container. # # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
# extra_hosts: # extra_hosts:

View File

@ -1,5 +1,5 @@
[general] [general]
version = "0.6.8" version = "0.6.10"
time_zone = "utc-8" time_zone = "utc-8"
[network] [network]

View File

@ -151,7 +151,7 @@ See [Build a RAGFlow Docker image](./develop/build_docker_image.mdx).
### Cannot access https://huggingface.co ### Cannot access https://huggingface.co
A locally deployed RAGflow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails: A locally deployed RAGFlow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
``` ```
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res' FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'

View File

@ -76,5 +76,5 @@ No. Files uploaded to an agent as input are not stored in a dataset and hence wi
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete. There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
:::tip NOTE :::tip NOTE
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario. The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or RAGFlow's File system. These settings DO NOT apply in this scenario.
::: :::

View File

@ -45,13 +45,13 @@ Click the light bulb icon above the *current* dialogue and scroll down the popup
| Item name | Description | | Item name | Description |
| ----------------- | --------------------------------------------------------------------------------------------- | | ----------------- |-----------------------------------------------------------------------------------------------|
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. | | Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
| Check LLM | Time to validate the specified LLM. | | Check LLM | Time to validate the specified LLM. |
| Create retriever | Time to create a chunk retriever. | | Create retriever | Time to create a chunk retriever. |
| Bind embedding | Time to initialize an embedding model instance. | | Bind embedding | Time to initialize an embedding model instance. |
| Bind LLM | Time to initialize an LLM instance. | | Bind LLM | Time to initialize an LLM instance. |
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. | | Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. | | Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
| Generate keywords | Time to extract keywords from the user query. | | Generate keywords | Time to extract keywords from the user query. |
| Retrieval | Time to retrieve the chunks. | | Retrieval | Time to retrieve the chunks. |

View File

@ -37,7 +37,7 @@ Please note that rerank models are essential in certain scenarios. There is alwa
| Create retriever | Time to create a chunk retriever. | | Create retriever | Time to create a chunk retriever. |
| Bind embedding | Time to initialize an embedding model instance. | | Bind embedding | Time to initialize an embedding model instance. |
| Bind LLM | Time to initialize an LLM instance. | | Bind LLM | Time to initialize an LLM instance. |
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. | | Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. | | Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
| Generate keywords | Time to extract keywords from the user query. | | Generate keywords | Time to extract keywords from the user query. |
| Retrieval | Time to retrieve the chunks. | | Retrieval | Time to retrieve the chunks. |

View File

@ -9,7 +9,7 @@ Initiate an AI-powered chat with a configured chat assistant.
--- ---
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation. Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
## Start an AI chat ## Start an AI chat

View File

@ -5,7 +5,7 @@ slug: /configure_knowledge_base
# Configure dataset # Configure dataset
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in **File Management** into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics: Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in RAGFlow's File system into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
- Create a dataset - Create a dataset
- Configure a dataset - Configure a dataset
@ -82,10 +82,10 @@ Some embedding models are optimized for specific languages, so performance may b
### Upload file ### Upload file
- RAGFlow's **File Management** allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file. - RAGFlow's File system allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
- In **Knowledge Base**, you are also given the option of uploading a single file or a folder of files (bulk upload) from your local machine to a dataset, in which case the dataset holds file copies. - In **Knowledge Base**, you are also given the option of uploading a single file or a folder of files (bulk upload) from your local machine to a dataset, in which case the dataset holds file copies.
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to **File Management** and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset. While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to RAGFlow's File system and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
### Parse file ### Parse file
@ -142,6 +142,6 @@ As of RAGFlow v0.22.1, the search feature is still in a rudimentary form, suppor
You are allowed to delete a dataset. Hover your mouse over the three dot of the intended dataset card and the **Delete** option appears. Once you delete a dataset, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is: You are allowed to delete a dataset. Hover your mouse over the three dot of the intended dataset card and the **Delete** option appears. Once you delete a dataset, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
- The files uploaded directly to the dataset are gone; - The files uploaded directly to the dataset are gone;
- The file references, which you created from within **File Management**, are gone, but the associated files still exist in **File Management**. - The file references, which you created from within RAGFlow's File system, are gone, but the associated files still exist.
![delete dataset](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/delete_datasets.jpg) ![delete dataset](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/delete_datasets.jpg)

View File

@ -8,7 +8,7 @@ slug: /manage_users_and_services
The Admin CLI and Admin Service form a client-server architectural suite for RAGflow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGflow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled. The Admin CLI and Admin Service form a client-server architectural suite for RAGFlow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGFlow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.

View File

@ -305,7 +305,7 @@ With the Ollama service running, open a new terminal and run `./ollama pull <mod
</TabItem> </TabItem>
</Tabs> </Tabs>
### 4. Configure RAGflow ### 4. Configure RAGFlow
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section: To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:

View File

@ -419,17 +419,11 @@ Creates a dataset.
- `"embedding_model"`: `string` - `"embedding_model"`: `string`
- `"permission"`: `string` - `"permission"`: `string`
- `"chunk_method"`: `string` - `"chunk_method"`: `string`
- "parser_config": `object` - `"parser_config"`: `object`
- "parse_type": `int` - `"parse_type"`: `int`
- "pipeline_id": `string` - `"pipeline_id"`: `string`
Note: Choose exactly one ingestion mode when creating a dataset. ##### A basic request example
- Chunking method: provide `"chunk_method"` (optionally with `"parser_config"`).
- Ingestion pipeline: provide both `"parse_type"` and `"pipeline_id"` and do not provide `"chunk_method"`.
These options are mutually exclusive. If all three of `chunk_method`, `parse_type`, and `pipeline_id` are omitted, the system defaults to `chunk_method = "naive"`.
##### Request example
```bash ```bash
curl --request POST \ curl --request POST \
@ -441,9 +435,11 @@ curl --request POST \
}' }'
``` ```
##### Request example (ingestion pipeline) ##### A request example specifying ingestion pipeline
Use this form when specifying an ingestion pipeline (do not include `chunk_method`). :::caution WARNING
You must *not* include `"chunk_method"` or `"parser_config"` when specifying an ingestion pipeline.
:::
```bash ```bash
curl --request POST \ curl --request POST \
@ -452,15 +448,11 @@ curl --request POST \
--header 'Authorization: Bearer <YOUR_API_KEY>' \ --header 'Authorization: Bearer <YOUR_API_KEY>' \
--data '{ --data '{
"name": "test-sdk", "name": "test-sdk",
"parse_type": <NUMBER_OF_FORMATS_IN_PARSE>, "parse_type": <NUMBER_OF_PARSERS_IN_YOUR_PARSER_COMPONENT>,
"pipeline_id": "<PIPELINE_ID_32_HEX>" "pipeline_id": "<PIPELINE_ID_32_HEX>"
}' }'
``` ```
Notes:
- `parse_type` is an integer. Replace `<NUMBER_OF_FORMATS_IN_PARSE>` with your pipeline's parse-type value.
- `pipeline_id` must be a 32-character lowercase hexadecimal string.
##### Request parameters ##### Request parameters
- `"name"`: (*Body parameter*), `string`, *Required* - `"name"`: (*Body parameter*), `string`, *Required*
@ -488,7 +480,8 @@ Notes:
- `"team"`: All team members can manage the dataset. - `"team"`: All team members can manage the dataset.
- `"chunk_method"`: (*Body parameter*), `enum<string>` - `"chunk_method"`: (*Body parameter*), `enum<string>`
The chunking method of the dataset to create. Available options: The default chunk method of the dataset to create. Mutually exclusive with `"parse_type"` and `"pipeline_id"`. If you set `"chunk_method"`, do not include `"parse_type"` or `"pipeline_id"`.
Available options:
- `"naive"`: General (default) - `"naive"`: General (default)
- `"book"`: Book - `"book"`: Book
- `"email"`: Email - `"email"`: Email
@ -501,7 +494,6 @@ Notes:
- `"qa"`: Q&A - `"qa"`: Q&A
- `"table"`: Table - `"table"`: Table
- `"tag"`: Tag - `"tag"`: Tag
- Mutually exclusive with `parse_type` and `pipeline_id`. If you set `chunk_method`, do not include `parse_type` or `pipeline_id`.
- `"parser_config"`: (*Body parameter*), `object` - `"parser_config"`: (*Body parameter*), `object`
The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`: The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
@ -520,13 +512,16 @@ Notes:
- Maximum: `2048` - Maximum: `2048`
- `"delimiter"`: `string` - `"delimiter"`: `string`
- Defaults to `"\n"`. - Defaults to `"\n"`.
- `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format. - `"html4excel"`: `bool`
- Whether to convert Excel documents into HTML format.
- Defaults to `false` - Defaults to `false`
- `"layout_recognize"`: `string` - `"layout_recognize"`: `string`
- Defaults to `DeepDOC` - Defaults to `DeepDOC`
- `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets) - `"tag_kb_ids"`: `array<string>`
- Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunking Method - IDs of datasets to be parsed using the Tag chunk method.
- `"task_page_size"`: `int` For PDF only. - Before setting this, ensure a tag set is created and properly configured. For details, see [Use tag set](https://ragflow.io/docs/dev/use_tag_sets).
- `"task_page_size"`: `int`
- For PDFs only.
- Defaults to `12` - Defaults to `12`
- Minimum: `1` - Minimum: `1`
- `"raptor"`: `object` RAPTOR-specific settings. - `"raptor"`: `object` RAPTOR-specific settings.
@ -538,14 +533,25 @@ Notes:
- Defaults to: `{"use_raptor": false}`. - Defaults to: `{"use_raptor": false}`.
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object. - If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
- "parse_type": (*Body parameter*), `int` - `"parse_type"`: (*Body parameter*), `int`
The ingestion pipeline parse type identifier. Required if and only if you are using an ingestion pipeline (together with `"pipeline_id"`). Must not be provided when `"chunk_method"` is set. The ingestion pipeline parse type identifier, i.e., the number of parsers in your **Parser** component.
- Required (along with `"pipeline_id"`) if specifying an ingestion pipeline.
- Must not be included when `"chunk_method"` is specified.
- "pipeline_id": (*Body parameter*), `string` - `"pipeline_id"`: (*Body parameter*), `string`
The ingestion pipeline ID. Required if and only if you are using an ingestion pipeline (together with `"parse_type"`). The ingestion pipeline ID. Can be found in the corresponding URL in the RAGFlow UI.
- Must not be provided when `"chunk_method"` is set. - Required (along with `"parse_type"`) if specifying an ingestion pipeline.
- Must be a 32-character lowercase hexadecimal string, e.g., `"d0bebe30ae2211f0970942010a8e0005"`.
- Must not be included when `"chunk_method"` is specified.
Note: If none of `chunk_method`, `parse_type`, and `pipeline_id` are provided, the system will default to `chunk_method = "naive"`. :::caution WARNING
You can choose either of the following ingestion options when creating a dataset, but *not* both:
- Use a built-in chunk method -- specify `"chunk_method"` (optionally with `"parser_config"`).
- Use an ingestion pipeline -- specify both `"parse_type"` and `"pipeline_id"`.
If none of `"chunk_method"`, `"parse_type"`, or `"pipeline_id"` are provided, the system defaults to `chunk_method = "naive"`.
:::
#### Response #### Response
@ -4007,7 +4013,7 @@ Failure:
**DELETE** `/api/v1/agents/{agent_id}/sessions` **DELETE** `/api/v1/agents/{agent_id}/sessions`
Deletes sessions of a agent by ID. Deletes sessions of an agent by ID.
#### Request #### Request
@ -4066,7 +4072,7 @@ Failure:
Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results. Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results.
This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find the it in the Request Headers in your browser easily as shown below: This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find it in the Request Headers in your browser easily as shown below:
![Image](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/login_token.jpg) ![Image](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/login_token.jpg)

View File

@ -1740,7 +1740,7 @@ for session in sessions:
Agent.delete_sessions(ids: list[str] = None) Agent.delete_sessions(ids: list[str] = None)
``` ```
Deletes sessions of a agent by ID. Deletes sessions of an agent by ID.
#### Parameters #### Parameters

View File

@ -5,6 +5,7 @@
# requires-python = ">=3.10" # requires-python = ">=3.10"
# dependencies = [ # dependencies = [
# "nltk", # "nltk",
# "huggingface-hub"
# ] # ]
# /// # ///
@ -43,7 +44,6 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
repos = [ repos = [
"InfiniFlow/text_concat_xgb_v1.0", "InfiniFlow/text_concat_xgb_v1.0",
"InfiniFlow/deepdoc", "InfiniFlow/deepdoc",
"InfiniFlow/huqie",
] ]

View File

@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
# #
''' """
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset. The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
''' """
from ragflow_sdk import RAGFlow from ragflow_sdk import RAGFlow
import sys import sys

View File

@ -57,7 +57,7 @@ async def run_graphrag(
start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -174,13 +174,19 @@ async def run_graphrag_for_kb(
chunks = [] chunks = []
current_chunk = "" current_chunk = ""
for d in settings.retriever.chunk_list( # DEBUG: Obtener todos los chunks primero
raw_chunks = list(settings.retriever.chunk_list(
doc_id, doc_id,
tenant_id, tenant_id,
[kb_id], [kb_id],
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
fields=fields_for_chunks, fields=fields_for_chunks,
sort_by_position=True, sort_by_position=True,
): ))
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
for d in raw_chunks:
content = d["content_with_weight"] content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024: if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content current_chunk += content

View File

@ -96,7 +96,7 @@ ragflow:
infinity: infinity:
image: image:
repository: infiniflow/infinity repository: infiniflow/infinity
tag: v0.6.8 tag: v0.6.10
pullPolicy: IfNotPresent pullPolicy: IfNotPresent
pullSecrets: [] pullSecrets: []
storage: storage:

View File

@ -57,7 +57,6 @@ JSON_RESPONSE = True
class RAGFlowConnector: class RAGFlowConnector:
_MAX_DATASET_CACHE = 32 _MAX_DATASET_CACHE = 32
_MAX_DOCUMENT_CACHE = 128
_CACHE_TTL = 300 _CACHE_TTL = 300
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts) _dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
@ -116,8 +115,6 @@ class RAGFlowConnector:
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list): def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id) self._document_metadata_cache.move_to_end(dataset_id)
if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
self._document_metadata_cache.popitem(last=False)
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
@ -240,11 +237,14 @@ class RAGFlowConnector:
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id) docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
if docs is None: if docs is None:
docs_res = self._get(f"/datasets/{dataset_id}/documents") page = 1
docs_data = docs_res.json() page_size = 30
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
doc_id_meta_list = [] doc_id_meta_list = []
docs = {} docs = {}
while page:
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]: for doc in docs_data["data"]["docs"]:
doc_id = doc.get("id") doc_id = doc.get("id")
if not doc_id: if not doc_id:
@ -256,30 +256,27 @@ class RAGFlowConnector:
"type": doc.get("type", ""), "type": doc.get("type", ""),
"size": doc.get("size"), "size": doc.get("size"),
"chunk_count": doc.get("chunk_count"), "chunk_count": doc.get("chunk_count"),
# "chunk_method": doc.get("chunk_method", ""),
"create_date": doc.get("create_date", ""), "create_date": doc.get("create_date", ""),
"update_date": doc.get("update_date", ""), "update_date": doc.get("update_date", ""),
# "process_begin_at": doc.get("process_begin_at", ""),
# "process_duration": doc.get("process_duration"),
# "progress": doc.get("progress"),
# "progress_msg": doc.get("progress_msg", ""),
# "status": doc.get("status", ""),
# "run": doc.get("run", ""),
"token_count": doc.get("token_count"), "token_count": doc.get("token_count"),
# "source_type": doc.get("source_type", ""),
"thumbnail": doc.get("thumbnail", ""), "thumbnail": doc.get("thumbnail", ""),
"dataset_id": doc.get("dataset_id", dataset_id), "dataset_id": doc.get("dataset_id", dataset_id),
"meta_fields": doc.get("meta_fields", {}), "meta_fields": doc.get("meta_fields", {}),
# "parser_config": doc.get("parser_config", {})
} }
doc_id_meta_list.append((doc_id, doc_meta)) doc_id_meta_list.append((doc_id, doc_meta))
docs[doc_id] = doc_meta docs[doc_id] = doc_meta
page += 1
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
page = None
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list) self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
if docs: if docs:
document_cache.update(docs) document_cache.update(docs)
except Exception: except Exception as e:
# Gracefully handle metadata cache failures # Gracefully handle metadata cache failures
logging.error(f"Problem building the document metadata cache: {str(e)}")
pass pass
return document_cache, dataset_cache return document_cache, dataset_cache

View File

@ -92,6 +92,6 @@ def get_metadata(cls) -> LLMToolMetadata:
The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM. The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM.
The fields starts with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation. The fields start with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
Now our tool is ready. You can select it in the `Generate` component and try it out. Now our tool is ready. You can select it in the `Generate` component and try it out.

View File

@ -5,7 +5,7 @@ from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin
class BadCalculatorPlugin(LLMToolPlugin): class BadCalculatorPlugin(LLMToolPlugin):
""" """
A sample LLM tool plugin, will add two numbers with 100. A sample LLM tool plugin, will add two numbers with 100.
It only present for demo purpose. Do not use it in production. It only presents for demo purpose. Do not use it in production.
""" """
_version_ = "1.0.0" _version_ = "1.0.0"

View File

@ -49,7 +49,7 @@ dependencies = [
"html-text==0.6.2", "html-text==0.6.2",
"httpx[socks]>=0.28.1,<0.29.0", "httpx[socks]>=0.28.1,<0.29.0",
"huggingface-hub>=0.25.0,<0.26.0", "huggingface-hub>=0.25.0,<0.26.0",
"infinity-sdk==0.6.8", "infinity-sdk==0.6.10",
"infinity-emb>=0.0.66,<0.0.67", "infinity-emb>=0.0.66,<0.0.67",
"itsdangerous==2.1.2", "itsdangerous==2.1.2",
"json-repair==0.35.0", "json-repair==0.35.0",
@ -131,7 +131,6 @@ dependencies = [
"graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd", "graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd",
"mini-racer>=0.12.4,<0.13.0", "mini-racer>=0.12.4,<0.13.0",
"pyodbc>=5.2.0,<6.0.0", "pyodbc>=5.2.0,<6.0.0",
"pyicu>=2.15.3,<3.0.0",
"flasgger>=0.9.7.1,<0.10.0", "flasgger>=0.9.7.1,<0.10.0",
"xxhash>=3.5.0,<4.0.0", "xxhash>=3.5.0,<4.0.0",
"trio>=0.17.0,<0.29.0", "trio>=0.17.0,<0.29.0",
@ -163,6 +162,9 @@ test = [
"openpyxl>=3.1.5", "openpyxl>=3.1.5",
"pillow>=10.4.0", "pillow>=10.4.0",
"pytest>=8.3.5", "pytest>=8.3.5",
"pytest-asyncio>=1.3.0",
"pytest-xdist>=3.8.0",
"pytest-cov>=7.0.0",
"python-docx>=1.1.2", "python-docx>=1.1.2",
"python-pptx>=1.0.2", "python-pptx>=1.0.2",
"reportlab>=4.4.1", "reportlab>=4.4.1",
@ -195,8 +197,83 @@ extend-select = ["ASYNC", "ASYNC1"]
ignore = ["E402"] ignore = ["E402"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
pythonpath = [
"."
]
testpaths = ["test"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [ markers = [
"p1: high priority test cases", "p1: high priority test cases",
"p2: medium priority test cases", "p2: medium priority test cases",
"p3: low priority test cases", "p3: low priority test cases",
] ]
# Test collection and runtime configuration
filterwarnings = [
"error", # Treat warnings as errors
"ignore::DeprecationWarning", # Ignore specific warnings
]
# Command line options
addopts = [
"-v", # Verbose output
"--strict-markers", # Enforce marker definitions
"--tb=short", # Simplified traceback
"--disable-warnings", # Disable warnings
"--color=yes" # Colored output
]
# Coverage configuration
[tool.coverage.run]
# Source paths - adjust according to your project structure
source = [
# "../../api/db/services",
# Add more directories if needed:
"../../common",
# "../../utils",
]
# Files/directories to exclude
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/.pytest_cache/*",
"*/venv/*",
"*/.venv/*",
"*/env/*",
"*/site-packages/*",
"*/dist/*",
"*/build/*",
"*/migrations/*",
"setup.py"
]
[tool.coverage.report]
# Report configuration
precision = 2
show_missing = true
skip_covered = false
fail_under = 0 # Minimum coverage requirement (0-100)
# Lines to exclude (optional)
exclude_lines = [
# "pragma: no cover",
# "def __repr__",
# "raise AssertionError",
# "raise NotImplementedError",
# "if __name__ == .__main__.:",
# "if TYPE_CHECKING:",
"pass"
]
[tool.coverage.html]
# HTML report configuration
directory = "htmlcov"
title = "Test Coverage Report"
# extra_css = "custom.css" # Optional custom CSS

View File

@ -14,5 +14,5 @@
# limitations under the License. # limitations under the License.
# #
from beartype.claw import beartype_this_package # from beartype.claw import beartype_this_package
beartype_this_package() # beartype_this_package()

View File

@ -143,6 +143,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.doc$", filename, re.IGNORECASE): elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
with BytesIO(binary) as binary:
binary = BytesIO(binary) binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary) doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n') sections = doc_parsed['content'].split('\n')

View File

@ -201,12 +201,23 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.doc$", filename, re.IGNORECASE): elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
try:
from tika import parser as tika_parser
except Exception as e:
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
return []
binary = BytesIO(binary) binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary) doc_parsed = tika_parser.from_buffer(binary)
if doc_parsed.get('content', None) is not None:
sections = doc_parsed['content'].split('\n') sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s] sections = [s for s in sections if s]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else:
callback(0.8, f"tika.parser got empty content from {filename}.")
logging.warning(f"tika.parser got empty content from {filename}.")
return []
else: else:
raise NotImplementedError( raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)") "file type not supported yet(doc, docx, pdf, txt supported)")

View File

@ -219,23 +219,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
) )
def _normalize_section(section): def _normalize_section(section):
# pad section to length 3: (txt, sec_id, poss) # Pad/normalize to (txt, layout, positions)
if len(section) == 1: if not isinstance(section, (list, tuple)):
section = (section, "", [])
elif len(section) == 1:
section = (section[0], "", []) section = (section[0], "", [])
elif len(section) == 2: elif len(section) == 2:
section = (section[0], "", section[1]) section = (section[0], "", section[1])
elif len(section) != 3: else:
raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})") section = (section[0], section[1], section[2])
txt, layoutno, poss = section txt, layoutno, poss = section
if isinstance(poss, str): if isinstance(poss, str):
poss = pdf_parser.extract_positions(poss) poss = pdf_parser.extract_positions(poss)
if poss:
first = poss[0] # tuple: ([pn], x1, x2, y1, y2) first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
pn = first[0] pn = first[0]
if isinstance(pn, list) and pn:
if isinstance(pn, list):
pn = pn[0] # [pn] -> pn pn = pn[0] # [pn] -> pn
poss[0] = (pn, *first[1:]) poss[0] = (pn, *first[1:])
if not poss:
poss = []
return (txt, layoutno, poss) return (txt, layoutno, poss)

View File

@ -86,9 +86,11 @@ class Pdf(PdfParser):
# (A) Add text # (A) Add text
for b in self.boxes: for b in self.boxes:
if not (from_page < b["page_number"] <= to_page + from_page): # b["page_number"] is relative page numbermust + from_page
global_page_num = b["page_number"] + from_page
if not (from_page < global_page_num <= to_page + from_page):
continue continue
page_items[b["page_number"]].append({ page_items[global_page_num].append({
"top": b["top"], "top": b["top"],
"x0": b["x0"], "x0": b["x0"],
"text": b["text"], "text": b["text"],
@ -100,7 +102,6 @@ class Pdf(PdfParser):
if not positions: if not positions:
continue continue
# Handle content type (list vs str)
if isinstance(content, list): if isinstance(content, list):
final_text = "\n".join(content) final_text = "\n".join(content)
elif isinstance(content, str): elif isinstance(content, str):
@ -109,10 +110,11 @@ class Pdf(PdfParser):
final_text = str(content) final_text = str(content)
try: try:
# Parse positions
pn_index = positions[0][0] pn_index = positions[0][0]
if isinstance(pn_index, list): if isinstance(pn_index, list):
pn_index = pn_index[0] pn_index = pn_index[0]
# pn_index in tbls is absolute page number
current_page_num = int(pn_index) + 1 current_page_num = int(pn_index) + 1
except Exception as e: except Exception as e:
print(f"Error parsing position: {e}") print(f"Error parsing position: {e}")

View File

@ -313,7 +313,7 @@ def mdQuestionLevel(s):
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
""" """
Excel and csv(txt) format files are supported. Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column question and answer without header. If the file is in Excel format, there should be 2 column question and answer without header.
And question column is ahead of answer column. And question column is ahead of answer column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed. And it's O.K if it has multiple sheets as long as the columns are rightly composed.

View File

@ -37,7 +37,7 @@ def beAdoc(d, q, a, eng, row_num=-1):
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
""" """
Excel and csv(txt) format files are supported. Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column content and tags without header. If the file is in Excel format, there should be 2 column content and tags without header.
And content column is ahead of tags column. And content column is ahead of tags column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed. And it's O.K if it has multiple sheets as long as the columns are rightly composed.

View File

@ -12,10 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging
import random import random
from copy import deepcopy from copy import deepcopy
import xxhash
from agent.component.llm import LLMParam, LLM from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.prompts.generator import run_toc_from_text
class ExtractorParam(ProcessParamBase, LLMParam): class ExtractorParam(ProcessParamBase, LLMParam):
@ -31,6 +37,39 @@ class ExtractorParam(ProcessParamBase, LLMParam):
class Extractor(ProcessBase, LLM): class Extractor(ProcessBase, LLM):
component_name = "Extractor" component_name = "Extractor"
async def _build_TOC(self, docs):
self.callback(0.2,message="Start to generate table of content ...")
docs = sorted(docs, key=lambda d:(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc = await run_toc_from_text([d["text"] for d in docs], self.chat_mdl)
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
ii += 1
if toc:
d = deepcopy(docs[-1])
d["doc_id"] = self._canvas._doc_id
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
async def _invoke(self, **kwargs): async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks") self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.") self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
@ -45,6 +84,15 @@ class Extractor(ProcessBase, LLM):
chunks_key = k chunks_key = k
if chunks: if chunks:
if self._param.field_name == "toc":
for ck in chunks:
ck["doc_id"] = self._canvas._doc_id
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
toc =await self._build_TOC(chunks)
chunks.append(toc)
self.set_output("chunks", chunks)
return
prog = 0 prog = 0
for i, ck in enumerate(chunks): for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"] args[chunks_key] = ck["text"]

View File

@ -125,7 +125,7 @@ class Splitter(ProcessBase):
{ {
"text": RAGFlowPdfParser.remove_tag(c), "text": RAGFlowPdfParser.remove_tag(c),
"image": img, "image": img,
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)] "positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
} }
for c, img in zip(chunks, images) if c.strip() for c, img in zip(chunks, images) if c.strip()
] ]

View File

@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum):
JiekouAI = "Jiekou.AI" JiekouAI = "Jiekou.AI"
ZHIPU_AI = "ZHIPU-AI" ZHIPU_AI = "ZHIPU-AI"
MiniMax = "MiniMax" MiniMax = "MiniMax"
DeerAPI = "DeerAPI"
GPUStack = "GPUStack"
FACTORY_DEFAULT_BASE_URL = { FACTORY_DEFAULT_BASE_URL = {
@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4", SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1", SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
} }
@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.JiekouAI: "openai/", SupportedLiteLLMProvider.JiekouAI: "openai/",
SupportedLiteLLMProvider.ZHIPU_AI: "openai/", SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
SupportedLiteLLMProvider.MiniMax: "openai/", SupportedLiteLLMProvider.MiniMax: "openai/",
SupportedLiteLLMProvider.DeerAPI: "openai/",
SupportedLiteLLMProvider.GPUStack: "openai/",
} }
ChatModel = globals().get("ChatModel", {}) ChatModel = globals().get("ChatModel", {})

View File

@ -19,7 +19,6 @@ import logging
import os import os
import random import random
import re import re
import threading
import time import time
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
@ -78,11 +77,9 @@ class Base(ABC):
self.toolcall_sessions = {} self.toolcall_sessions = {}
def _get_delay(self): def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150) return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error): def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower() error_str = str(error).lower()
keywords_mapping = [ keywords_mapping = [
@ -139,89 +136,7 @@ class Base(ABC):
return gen_conf return gen_conf
def _bridge_sync_stream(self, gen): async def _async_chat_streamly(self, history, gen_conf, **kwargs):
"""Run a sync generator in a thread and yield asynchronously."""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as exc: # pragma: no cover - defensive
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
final_ans = ""
tol_token = 0
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
if kwargs.get("stop") or "stop" in gen_conf:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
else:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def _async_chat_stream(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False reasoning_start = False
@ -265,13 +180,19 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
for attempt in range(self.max_retries + 1):
try: try:
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs): async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
ans = delta_ans ans = delta_ans
total_tokens += tol total_tokens += tol
yield delta_ans yield ans
except openai.APIError as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
return
yield total_tokens yield total_tokens
@ -307,7 +228,7 @@ class Base(ABC):
logging.error(f"sync base giving up: {msg}") logging.error(f"sync base giving up: {msg}")
return msg return msg
async def _exceptions_async(self, e, attempt) -> str | None: async def _exceptions_async(self, e, attempt):
logging.exception("OpenAI async completion") logging.exception("OpenAI async completion")
error_code = self._classify_error(e) error_code = self._classify_error(e)
if attempt == self.max_retries: if attempt == self.max_retries:
@ -357,61 +278,6 @@ class Base(ABC):
self.toolcall_session = toolcall_session self.toolcall_session = toolcall_session
self.tools = tools self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
@ -466,140 +332,6 @@ class Base(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
tools = self.tools tools = self.tools
@ -715,9 +447,10 @@ class Base(ABC):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0: if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly") logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
final_ans = "" final_ans = ""
tol_token = 0 tol_token = 0
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs): async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"): if delta.startswith("<think>") or delta.endswith("</think>"):
continue continue
final_ans += delta final_ans += delta
@ -754,57 +487,6 @@ class Base(ABC):
return e, 0 return e, 0
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
class GptTurbo(Base): class GptTurbo(Base):
_FACTORY_NAME = "OpenAI" _FACTORY_NAME = "OpenAI"
@ -1504,16 +1186,6 @@ class GoogleChat(Base):
yield total_tokens yield total_tokens
class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
class TokenPonyChat(Base): class TokenPonyChat(Base):
_FACTORY_NAME = "TokenPony" _FACTORY_NAME = "TokenPony"
@ -1523,15 +1195,6 @@ class TokenPonyChat(Base):
super().__init__(key, model_name, base_url, **kwargs) super().__init__(key, model_name, base_url, **kwargs)
class DeerAPIChat(Base):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class LiteLLMBase(ABC): class LiteLLMBase(ABC):
_FACTORY_NAME = [ _FACTORY_NAME = [
"Tongyi-Qianwen", "Tongyi-Qianwen",
@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC):
"Jiekou.AI", "Jiekou.AI",
"ZHIPU-AI", "ZHIPU-AI",
"MiniMax", "MiniMax",
"DeerAPI",
"GPUStack",
] ]
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC):
self.provider_order = json.loads(key).get("provider_order", "") self.provider_order = json.loads(key).get("provider_order", "")
def _get_delay(self): def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150) return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error): def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower() error_str = str(error).lower()
keywords_mapping = [ keywords_mapping = [
@ -1619,78 +1282,17 @@ class LiteLLMBase(ABC):
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
return gen_conf return gen_conf
def _chat(self, history, gen_conf, **kwargs): async def async_chat(self, system, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) hist = list(history) if history else []
if system:
if not hist or hist[0].get("role") != "system":
hist.insert(0, {"role": "system", "content": system})
logging.info("[HISTORY]" + json.dumps(hist, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0: if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False} kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) completion_args = self._construct_completion_args(history=hist, stream=False, tools=False, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def async_chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
try: try:
@ -1790,22 +1392,7 @@ class LiteLLMBase(ABC):
def _should_retry(self, error_code: str) -> bool: def _should_retry(self, error_code: str) -> bool:
return error_code in self._retryable_errors return error_code in self._retryable_errors
def _exceptions(self, e, attempt) -> str | None: async def _exceptions_async(self, e, attempt):
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("LiteLLMBase async completion") logging.exception("LiteLLMBase async completion")
error_code = self._classify_error(e) error_code = self._classify_error(e)
if attempt == self.max_retries: if attempt == self.max_retries:
@ -1854,71 +1441,7 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session self.toolcall_session = toolcall_session
self.tools = tools self.tools = tools
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
completion_args = {
"model": self.model_name,
"messages": history,
"api_key": self.api_key,
"num_retries": self.max_retries,
**kwargs,
}
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
# Ollama deployments commonly sit behind a reverse proxy that enforces
# Bearer auth. Ensure the Authorization header is set when an API key
# is provided, while respecting any user-supplied headers. #11350
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
extra_headers["Authorization"] = f"Bearer {self.api_key}"
if extra_headers:
completion_args["extra_headers"] = extra_headers
return completion_args
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -1926,16 +1449,14 @@ class LiteLLMBase(ABC):
ans = "" ans = ""
tk_count = 0 tk_count = 0
hist = deepcopy(history) hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
history = deepcopy(hist) # deepcopy is required here history = deepcopy(hist)
try: try:
for _ in range(self.max_rounds + 1): for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}") logging.info(f"{self.tools=}")
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
response = litellm.completion( response = await litellm.acompletion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
timeout=self.timeout, timeout=self.timeout,
@ -1961,7 +1482,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name name = tool_call.function.name
try: try:
args = json_repair.loads(tool_call.function.arguments) args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args) tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response) history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response) ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e: except Exception as e:
@ -1972,49 +1493,19 @@ class LiteLLMBase(ABC):
logging.warning(f"Exceed max rounds: {self.max_rounds}") logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf) response, token_count = await self.async_chat("", history, gen_conf)
ans += response ans += response
tk_count += token_count tk_count += token_count
return ans, tk_count return ans, tk_count
except Exception as e: except Exception as e:
e = self._exceptions(e, attempt) e = await self._exceptions_async(e, attempt)
if e: if e:
return e, tk_count return e, tk_count
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs): async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
response = self._chat(history, gen_conf, **kwargs)
return response
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
tools = self.tools tools = self.tools
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
@ -2023,16 +1514,15 @@ class LiteLLMBase(ABC):
total_tokens = 0 total_tokens = 0
hist = deepcopy(history) hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
history = deepcopy(hist) # deepcopy is required here history = deepcopy(hist)
try: try:
for _ in range(self.max_rounds + 1): for _ in range(self.max_rounds + 1):
reasoning_start = False reasoning_start = False
logging.info(f"{tools=}") logging.info(f"{tools=}")
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = await litellm.acompletion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
timeout=self.timeout, timeout=self.timeout,
@ -2041,7 +1531,7 @@ class LiteLLMBase(ABC):
final_tool_calls = {} final_tool_calls = {}
answer = "" answer = ""
for resp in response: async for resp in response:
if not hasattr(resp, "choices") or not resp.choices: if not hasattr(resp, "choices") or not resp.choices:
continue continue
@ -2077,7 +1567,7 @@ class LiteLLMBase(ABC):
if not tol: if not tol:
total_tokens += num_tokens_from_string(delta.content) total_tokens += num_tokens_from_string(delta.content)
else: else:
total_tokens += tol total_tokens = tol
finish_reason = getattr(resp.choices[0], "finish_reason", "") finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length": if finish_reason == "length":
@ -2092,31 +1582,25 @@ class LiteLLMBase(ABC):
try: try:
args = json_repair.loads(tool_call.function.arguments) args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...") yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args) tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response) history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response) yield self._verbose_tool_use(name, args, tool_response)
except Exception as e: except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append( history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
}
)
yield self._verbose_tool_use(name, {}, str(e)) yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}") logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = await litellm.acompletion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
timeout=self.timeout, timeout=self.timeout,
) )
for resp in response: async for resp in response:
if not hasattr(resp, "choices") or not resp.choices: if not hasattr(resp, "choices") or not resp.choices:
continue continue
delta = resp.choices[0].delta delta = resp.choices[0].delta
@ -2126,14 +1610,14 @@ class LiteLLMBase(ABC):
if not tol: if not tol:
total_tokens += num_tokens_from_string(delta.content) total_tokens += num_tokens_from_string(delta.content)
else: else:
total_tokens += tol total_tokens = tol
yield delta.content yield delta.content
yield total_tokens yield total_tokens
return return
except Exception as e: except Exception as e:
e = self._exceptions(e, attempt) e = await self._exceptions_async(e, attempt)
if e: if e:
yield e yield e
yield total_tokens yield total_tokens
@ -2141,53 +1625,71 @@ class LiteLLMBase(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
if system and history and history[0].get("role") != "system": completion_args = {
history.insert(0, {"role": "system", "content": system}) "model": self.model_name,
gen_conf = self._clean_conf(gen_conf) "messages": history,
ans = "" "api_key": self.api_key,
total_tokens = 0 "num_retries": self.max_retries,
try: **kwargs,
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): }
yield delta_ans if stream:
total_tokens += tol completion_args.update(
except openai.APIError as e: {
yield ans + "\n**ERROR**: " + str(e) "stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
yield total_tokens def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
def _calculate_dynamic_ctx(self, history): extra_body = {}
"""Calculate dynamic context window size""" provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
elif self.provider == SupportedLiteLLMProvider.GPUStack:
completion_args.update(
{
"api_base": self.base_url,
}
)
def count_tokens(text): # Ollama deployments commonly sit behind a reverse proxy that enforces
"""Calculate token count for text""" # Bearer auth. Ensure the Authorization header is set when an API key
# Simple calculation: 1 token per ASCII character # is provided, while respecting any user-supplied headers. #11350
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) extra_headers = deepcopy(completion_args.get("extra_headers") or {})
total = 0 if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
for char in text: extra_headers["Authorization"] = f"Bearer {self.api_key}"
if ord(char) < 128: # ASCII characters if extra_headers:
total += 1 completion_args["extra_headers"] = extra_headers
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) return completion_args
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size

View File

@ -537,7 +537,8 @@ class Dealer:
doc["id"] = id doc["id"] = id
if dict_chunks: if dict_chunks:
res.extend(dict_chunks.values()) res.extend(dict_chunks.values())
if len(dict_chunks.values()) < bs: # FIX: Solo terminar si no hay chunks, no si hay menos de bs
if len(dict_chunks.values()) == 0:
break break
return res return res

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import datetime import datetime
import json import json
import logging import logging
@ -342,7 +343,8 @@ def form_history(history, limit=-6):
return context return context
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
tools_desc = tool_schema(tools_description) tools_desc = tool_schema(tools_description)
context = "" context = ""
@ -351,7 +353,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
else: else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER) template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}]) kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
kwd = kwd[0] kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -360,9 +362,17 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
return kwd return kwd
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
chat_async = getattr(chat_mdl, "async_chat", None)
if chat_async and asyncio.iscoroutinefunction(chat_async):
return await chat_async(system, history, **kwargs)
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description: if not tools_description:
return "" return "", 0
desc = tool_schema(tools_description) desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP)) template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`." user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
@ -371,14 +381,18 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
hist[-1]["content"] += user_prompt hist[-1]["content"] += user_prompt
else: else:
hist.append({"role": "user", "content": user_prompt}) hist.append({"role": "user", "content": user_prompt})
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")), json_str = await _chat_async(
hist[1:], stop=["<|stop|>"]) chat_mdl,
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:],
stop=["<|stop|>"],
)
tk_cnt = num_tokens_from_string(json_str) tk_cnt = num_tokens_from_string(json_str)
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL) json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
return json_str, tk_cnt return json_str, tk_cnt
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"] goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT)) template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@ -389,7 +403,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
else: else:
hist.append({"role": "user", "content": user_prompt}) hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length) _, msg = message_fit_in(hist, chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:]) ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """ return """
**Observation** **Observation**
@ -420,12 +434,12 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}): async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY) template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: " user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>") ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)

File diff suppressed because it is too large Load Diff

View File

@ -157,11 +157,30 @@ class Confluence(SyncBase):
from common.data_source.config import DocumentSource from common.data_source.config import DocumentSource
from common.data_source.interfaces import StaticCredentialsProvider from common.data_source.interfaces import StaticCredentialsProvider
index_mode = (self.conf.get("index_mode") or "everything").lower()
if index_mode not in {"everything", "space", "page"}:
index_mode = "everything"
space = ""
page_id = ""
index_recursively = False
if index_mode == "space":
space = (self.conf.get("space") or "").strip()
if not space:
raise ValueError("Space Key is required when indexing a specific Confluence space.")
elif index_mode == "page":
page_id = (self.conf.get("page_id") or "").strip()
if not page_id:
raise ValueError("Page ID is required when indexing a specific Confluence page.")
index_recursively = bool(self.conf.get("index_recursively", False))
self.connector = ConfluenceConnector( self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"], wiki_base=self.conf["wiki_base"],
space=self.conf.get("space", ""),
is_cloud=self.conf.get("is_cloud", True), is_cloud=self.conf.get("is_cloud", True),
# page_id=self.conf.get("page_id", ""), space=space,
page_id=page_id,
index_recursively=index_recursively,
) )
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"]) credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])

View File

@ -29,6 +29,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.utils.base64_image import image2id from rag.utils.base64_image import image2id
from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason
from common.log_utils import init_root_logger from common.log_utils import init_root_logger
from common.config_utils import show_configs from common.config_utils import show_configs
from graphrag.general.index import run_graphrag_for_kb from graphrag.general.index import run_graphrag_for_kb
@ -68,7 +69,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from common import settings from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
from common.misc_utils import install_mineru from common.misc_utils import check_and_install_mineru
BATCH_SIZE = 64 BATCH_SIZE = 64
@ -591,6 +592,7 @@ async def run_dataflow(task: dict):
ck["docnm_kwd"] = task["name"] ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19] ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp() ck["create_timestamp_flt"] = datetime.now().timestamp()
if not ck.get("id"):
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest() ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck: if "questions" in ck:
if "question_tks" not in ck: if "question_tks" not in ck:
@ -853,6 +855,17 @@ async def do_handle_task(task):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration") progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return return
# Check if Raptor should be skipped for structured data
file_type = task.get("type", "")
parser_id = task.get("parser_id", "")
raptor_config = kb_parser_config.get("raptor", {})
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}")
return
# bind LLM for raptor # bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR # run RAPTOR
@ -1101,8 +1114,8 @@ async def main():
show_configs() show_configs()
settings.init_settings() settings.init_settings()
settings.check_and_install_torch() settings.check_and_install_torch()
install_mineru() check_and_install_mineru()
logging.info(f'settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}') logging.info(f'default embedding config: {settings.EMBEDDING_CFG}')
settings.print_rag_settings() settings.print_rag_settings()
if sys.platform != "win32": if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot) signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)

207
rag/utils/gcs_conn.py Normal file
View File

@ -0,0 +1,207 @@
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
import datetime
from io import BytesIO
from google.cloud import storage
from google.api_core.exceptions import NotFound
from common.decorator import singleton
from common import settings
@singleton
class RAGFlowGCS:
def __init__(self):
self.client = None
self.bucket_name = None
self.__open__()
def __open__(self):
try:
if self.client:
self.client = None
except Exception:
pass
try:
self.client = storage.Client()
self.bucket_name = settings.GCS["bucket"]
except Exception:
logging.exception("Fail to connect to GCS")
def _get_blob_path(self, folder, filename):
"""Helper to construct the path: folder/filename"""
if not folder:
return filename
return f"{folder}/{filename}"
def health(self):
folder, fnm, binary = "ragflow-health", "health_check", b"_t@@@1"
try:
bucket_obj = self.client.bucket(self.bucket_name)
if not bucket_obj.exists():
logging.error(f"Health check failed: Main bucket '{self.bucket_name}' does not exist.")
return False
blob_path = self._get_blob_path(folder, fnm)
blob = bucket_obj.blob(blob_path)
blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream')
return True
except Exception as e:
logging.exception(f"Health check failed: {e}")
return False
def put(self, bucket, fnm, binary, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket (to match interface)
for _ in range(3):
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, fnm)
blob = bucket_obj.blob(blob_path)
blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream')
return True
except NotFound:
logging.error(f"Fail to put: Main bucket {self.bucket_name} does not exist.")
return False
except Exception:
logging.exception(f"Fail to put {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return False
def rm(self, bucket, fnm, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, fnm)
blob = bucket_obj.blob(blob_path)
blob.delete()
except NotFound:
pass
except Exception:
logging.exception(f"Fail to remove {bucket}/{fnm}:")
def get(self, bucket, filename, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
for _ in range(1):
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, filename)
blob = bucket_obj.blob(blob_path)
return blob.download_as_bytes()
except NotFound:
logging.warning(f"File not found {bucket}/{filename} in {self.bucket_name}")
return None
except Exception:
logging.exception(f"Fail to get {bucket}/{filename}")
self.__open__()
time.sleep(1)
return None
def obj_exist(self, bucket, filename, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, filename)
blob = bucket_obj.blob(blob_path)
return blob.exists()
except Exception:
logging.exception(f"obj_exist {bucket}/{filename} got exception")
return False
def bucket_exists(self, bucket):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
return bucket_obj.exists()
except Exception:
logging.exception(f"bucket_exist check for {self.bucket_name} got exception")
return False
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
for _ in range(10):
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, fnm)
blob = bucket_obj.blob(blob_path)
expiration = expires
if isinstance(expires, int):
expiration = datetime.timedelta(seconds=expires)
url = blob.generate_signed_url(
version="v4",
expiration=expiration,
method="GET"
)
return url
except Exception:
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return None
def remove_bucket(self, bucket):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
prefix = f"{bucket}/"
blobs = list(self.client.list_blobs(self.bucket_name, prefix=prefix))
if blobs:
bucket_obj.delete_blobs(blobs)
except Exception:
logging.exception(f"Fail to remove virtual bucket (folder) {bucket}")
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
# RENAMED PARAMETERS to match original interface
try:
bucket_obj = self.client.bucket(self.bucket_name)
src_blob_path = self._get_blob_path(src_bucket, src_path)
dest_blob_path = self._get_blob_path(dest_bucket, dest_path)
src_blob = bucket_obj.blob(src_blob_path)
if not src_blob.exists():
logging.error(f"Source object not found: {src_blob_path}")
return False
bucket_obj.copy_blob(src_blob, bucket_obj, dest_blob_path)
return True
except NotFound:
logging.error(f"Copy failed: Main bucket {self.bucket_name} does not exist.")
return False
except Exception:
logging.exception(f"Fail to copy {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
return False
def move(self, src_bucket, src_path, dest_bucket, dest_path):
try:
if self.copy(src_bucket, src_path, dest_bucket, dest_path):
self.rm(src_bucket, src_path)
return True
else:
logging.error(f"Copy failed, move aborted: {src_bucket}/{src_path}")
return False
except Exception:
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
return False

145
rag/utils/raptor_utils.py Normal file
View File

@ -0,0 +1,145 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utility functions for Raptor processing decisions.
"""
import logging
from typing import Optional
# File extensions for structured data types
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
CSV_EXTENSIONS = {".csv", ".tsv"}
STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS
def is_structured_file_type(file_type: Optional[str]) -> bool:
"""
Check if a file type is structured data (Excel, CSV, etc.)
Args:
file_type: File extension (e.g., ".xlsx", ".csv")
Returns:
True if file is structured data type
"""
if not file_type:
return False
# Normalize to lowercase and ensure leading dot
file_type = file_type.lower()
if not file_type.startswith("."):
file_type = f".{file_type}"
return file_type in STRUCTURED_EXTENSIONS
def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) -> bool:
"""
Check if a PDF is being parsed as tabular data.
Args:
parser_id: Parser ID (e.g., "table", "naive")
parser_config: Parser configuration dict
Returns:
True if PDF is being parsed as tabular data
"""
parser_config = parser_config or {}
# If using table parser, it's tabular
if parser_id and parser_id.lower() == "table":
return True
# Check if html4excel is enabled (Excel-like table parsing)
if parser_config.get("html4excel", False):
return True
return False
def should_skip_raptor(
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None,
raptor_config: Optional[dict] = None
) -> bool:
"""
Determine if Raptor should be skipped for a given document.
This function implements the logic to automatically disable Raptor for:
1. Excel files (.xls, .xlsx, .csv, etc.)
2. PDFs with tabular data (using table parser or html4excel)
Args:
file_type: File extension (e.g., ".xlsx", ".pdf")
parser_id: Parser ID being used
parser_config: Parser configuration dict
raptor_config: Raptor configuration dict (can override with auto_disable_for_structured_data)
Returns:
True if Raptor should be skipped, False otherwise
"""
parser_config = parser_config or {}
raptor_config = raptor_config or {}
# Check if auto-disable is explicitly disabled in config
if raptor_config.get("auto_disable_for_structured_data", True) is False:
logging.info("Raptor auto-disable is turned off via configuration")
return False
# Check for Excel/CSV files
if is_structured_file_type(file_type):
logging.info(f"Skipping Raptor for structured file type: {file_type}")
return True
# Check for tabular PDFs
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
logging.info(f"Skipping Raptor for tabular PDF (parser_id={parser_id})")
return True
return False
def get_skip_reason(
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None
) -> str:
"""
Get a human-readable reason why Raptor was skipped.
Args:
file_type: File extension
parser_id: Parser ID being used
parser_config: Parser configuration dict
Returns:
Reason string, or empty string if Raptor should not be skipped
"""
parser_config = parser_config or {}
if is_structured_file_type(file_type):
return f"Structured data file ({file_type}) - Raptor auto-disabled"
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
return f"Tabular PDF (parser={parser_id}) - Raptor auto-disabled"
return ""

275
run_tests.py Executable file
View File

@ -0,0 +1,275 @@
#!/usr/bin/env python3
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import argparse
import subprocess
from pathlib import Path
from typing import List
class Colors:
"""ANSI color codes for terminal output"""
RED = '\033[0;31m'
GREEN = '\033[0;32m'
YELLOW = '\033[1;33m'
BLUE = '\033[0;34m'
NC = '\033[0m' # No Color
class TestRunner:
"""RAGFlow Unit Test Runner"""
def __init__(self):
self.project_root = Path(__file__).parent.resolve()
self.ut_dir = Path(self.project_root / 'test' / 'unit_test')
# Default options
self.coverage = False
self.parallel = False
self.verbose = False
self.markers = ""
# Python interpreter path
self.python = sys.executable
@staticmethod
def print_info(message: str) -> None:
"""Print informational message"""
print(f"{Colors.BLUE}[INFO]{Colors.NC} {message}")
@staticmethod
def print_error(message: str) -> None:
"""Print error message"""
print(f"{Colors.RED}[ERROR]{Colors.NC} {message}")
@staticmethod
def show_usage() -> None:
"""Display usage information"""
usage = """
RAGFlow Unit Test Runner
Usage: python run_tests.py [OPTIONS]
OPTIONS:
-h, --help Show this help message
-c, --coverage Run tests with coverage report
-p, --parallel Run tests in parallel (requires pytest-xdist)
-v, --verbose Verbose output
-t, --test FILE Run specific test file or directory
-m, --markers MARKERS Run tests with specific markers (e.g., "unit", "integration")
EXAMPLES:
# Run all tests
python run_tests.py
# Run with coverage
python run_tests.py --coverage
# Run in parallel
python run_tests.py --parallel
# Run specific test file
python run_tests.py --test services/test_dialog_service.py
# Run only unit tests
python run_tests.py --markers "unit"
# Run tests with coverage and parallel execution
python run_tests.py --coverage --parallel
"""
print(usage)
def build_pytest_command(self) -> List[str]:
"""Build the pytest command arguments"""
cmd = ["pytest", str(self.ut_dir)]
# Add test path
# Add markers
if self.markers:
cmd.extend(["-m", self.markers])
# Add verbose flag
if self.verbose:
cmd.extend(["-vv"])
else:
cmd.append("-v")
# Add coverage
if self.coverage:
# Relative path from test directory to source code
source_path = str(self.project_root / "common")
cmd.extend([
"--cov", source_path,
"--cov-report", "html",
"--cov-report", "term"
])
# Add parallel execution
if self.parallel:
# Try to get number of CPU cores
try:
import multiprocessing
cpu_count = multiprocessing.cpu_count()
cmd.extend(["-n", str(cpu_count)])
except ImportError:
# Fallback to auto if multiprocessing not available
cmd.extend(["-n", "auto"])
# Add default options from pyproject.toml if it exists
pyproject_path = self.project_root / "pyproject.toml"
if pyproject_path.exists():
cmd.extend(["--config-file", str(pyproject_path)])
return cmd
def run_tests(self) -> bool:
"""Execute the pytest command"""
# Change to test directory
os.chdir(self.project_root)
# Build command
cmd = self.build_pytest_command()
# Print test configuration
self.print_info("Running RAGFlow Unit Tests")
self.print_info("=" * 40)
self.print_info(f"Test Directory: {self.ut_dir}")
self.print_info(f"Coverage: {self.coverage}")
self.print_info(f"Parallel: {self.parallel}")
self.print_info(f"Verbose: {self.verbose}")
if self.markers:
self.print_info(f"Markers: {self.markers}")
print(f"\n{Colors.BLUE}[EXECUTING]{Colors.NC} {' '.join(cmd)}\n")
# Run pytest
try:
result = subprocess.run(cmd, check=False)
if result.returncode == 0:
print(f"\n{Colors.GREEN}[SUCCESS]{Colors.NC} All tests passed!")
if self.coverage:
coverage_dir = self.ut_dir / "htmlcov"
if coverage_dir.exists():
index_file = coverage_dir / "index.html"
print(f"\n{Colors.BLUE}[INFO]{Colors.NC} Coverage report generated:")
print(f" {index_file}")
print("\nOpen with:")
print(f" - Windows: start {index_file}")
print(f" - macOS: open {index_file}")
print(f" - Linux: xdg-open {index_file}")
return True
else:
print(f"\n{Colors.RED}[FAILURE]{Colors.NC} Some tests failed!")
return False
except KeyboardInterrupt:
print(f"\n{Colors.YELLOW}[INTERRUPTED]{Colors.NC} Test execution interrupted by user")
return False
except Exception as e:
self.print_error(f"Failed to execute tests: {e}")
return False
def parse_arguments(self) -> bool:
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="RAGFlow Unit Test Runner",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python run_tests.py # Run all tests
python run_tests.py --coverage # Run with coverage
python run_tests.py --parallel # Run in parallel
python run_tests.py --test services/test_dialog_service.py # Run specific test
python run_tests.py --markers "unit" # Run only unit tests
"""
)
parser.add_argument(
"-c", "--coverage",
action="store_true",
help="Run tests with coverage report"
)
parser.add_argument(
"-p", "--parallel",
action="store_true",
help="Run tests in parallel (requires pytest-xdist)"
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Verbose output"
)
parser.add_argument(
"-t", "--test",
type=str,
default="",
help="Run specific test file or directory"
)
parser.add_argument(
"-m", "--markers",
type=str,
default="",
help="Run tests with specific markers (e.g., 'unit', 'integration')"
)
try:
args = parser.parse_args()
# Set options
self.coverage = args.coverage
self.parallel = args.parallel
self.verbose = args.verbose
self.markers = args.markers
return True
except SystemExit:
# argparse already printed help, just exit
return False
except Exception as e:
self.print_error(f"Error parsing arguments: {e}")
return False
def run(self) -> int:
"""Main execution method"""
# Parse command line arguments
if not self.parse_arguments():
return 1
# Run tests
success = self.run_tests()
return 0 if success else 1
def main():
"""Entry point"""
runner = TestRunner()
return runner.run()
if __name__ == "__main__":
sys.exit(main())

View File

@ -122,15 +122,15 @@ async def create_container(name: str, language: SupportLanguage) -> bool:
logger.info(f"Sandbox config:\n\t {create_args}") logger.info(f"Sandbox config:\n\t {create_args}")
try: try:
returncode, _, stderr = await async_run_command(*create_args, timeout=10) return_code, _, stderr = await async_run_command(*create_args, timeout=10)
if returncode != 0: if return_code != 0:
logger.error(f"❌ Container creation failed {name}: {stderr}") logger.error(f"❌ Container creation failed {name}: {stderr}")
return False return False
if language == SupportLanguage.NODEJS: if language == SupportLanguage.NODEJS:
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"] copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
returncode, _, stderr = await async_run_command(*copy_cmd, timeout=10) return_code, _, stderr = await async_run_command(*copy_cmd, timeout=10)
if returncode != 0: if return_code != 0:
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}") logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
return False return False
@ -185,7 +185,7 @@ async def allocate_container_blocking(language: SupportLanguage, timeout=10) ->
async def container_is_running(name: str) -> bool: async def container_is_running(name: str) -> bool:
"""Asynchronously check the container status""" """Asynchronously check the container status"""
try: try:
returncode, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2) return_code, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
return returncode == 0 and stdout.strip() == "true" return return_code == 0 and stdout.strip() == "true"
except Exception: except Exception:
return False return False

View File

@ -0,0 +1,287 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for Raptor utility functions.
"""
import pytest
from rag.utils.raptor_utils import (
is_structured_file_type,
is_tabular_pdf,
should_skip_raptor,
get_skip_reason,
EXCEL_EXTENSIONS,
CSV_EXTENSIONS,
STRUCTURED_EXTENSIONS
)
class TestIsStructuredFileType:
"""Test file type detection for structured data"""
@pytest.mark.parametrize("file_type,expected", [
(".xlsx", True),
(".xls", True),
(".xlsm", True),
(".xlsb", True),
(".csv", True),
(".tsv", True),
("xlsx", True), # Without leading dot
("XLSX", True), # Uppercase
(".pdf", False),
(".docx", False),
(".txt", False),
("", False),
(None, False),
])
def test_file_type_detection(self, file_type, expected):
"""Test detection of various file types"""
assert is_structured_file_type(file_type) == expected
def test_excel_extensions_defined(self):
"""Test that Excel extensions are properly defined"""
assert ".xlsx" in EXCEL_EXTENSIONS
assert ".xls" in EXCEL_EXTENSIONS
assert len(EXCEL_EXTENSIONS) >= 4
def test_csv_extensions_defined(self):
"""Test that CSV extensions are properly defined"""
assert ".csv" in CSV_EXTENSIONS
assert ".tsv" in CSV_EXTENSIONS
def test_structured_extensions_combined(self):
"""Test that structured extensions include both Excel and CSV"""
assert EXCEL_EXTENSIONS.issubset(STRUCTURED_EXTENSIONS)
assert CSV_EXTENSIONS.issubset(STRUCTURED_EXTENSIONS)
class TestIsTabularPDF:
"""Test tabular PDF detection"""
def test_table_parser_detected(self):
"""Test that table parser is detected as tabular"""
assert is_tabular_pdf("table", {}) is True
assert is_tabular_pdf("TABLE", {}) is True
def test_html4excel_detected(self):
"""Test that html4excel config is detected as tabular"""
assert is_tabular_pdf("naive", {"html4excel": True}) is True
assert is_tabular_pdf("", {"html4excel": True}) is True
def test_non_tabular_pdf(self):
"""Test that non-tabular PDFs are not detected"""
assert is_tabular_pdf("naive", {}) is False
assert is_tabular_pdf("naive", {"html4excel": False}) is False
assert is_tabular_pdf("", {}) is False
def test_combined_conditions(self):
"""Test combined table parser and html4excel"""
assert is_tabular_pdf("table", {"html4excel": True}) is True
assert is_tabular_pdf("table", {"html4excel": False}) is True
class TestShouldSkipRaptor:
"""Test Raptor skip logic"""
def test_skip_excel_files(self):
"""Test that Excel files skip Raptor"""
assert should_skip_raptor(".xlsx") is True
assert should_skip_raptor(".xls") is True
assert should_skip_raptor(".xlsm") is True
def test_skip_csv_files(self):
"""Test that CSV files skip Raptor"""
assert should_skip_raptor(".csv") is True
assert should_skip_raptor(".tsv") is True
def test_skip_tabular_pdf_with_table_parser(self):
"""Test that tabular PDFs skip Raptor"""
assert should_skip_raptor(".pdf", parser_id="table") is True
assert should_skip_raptor("pdf", parser_id="TABLE") is True
def test_skip_tabular_pdf_with_html4excel(self):
"""Test that PDFs with html4excel skip Raptor"""
assert should_skip_raptor(".pdf", parser_config={"html4excel": True}) is True
def test_dont_skip_regular_pdf(self):
"""Test that regular PDFs don't skip Raptor"""
assert should_skip_raptor(".pdf", parser_id="naive") is False
assert should_skip_raptor(".pdf", parser_config={}) is False
def test_dont_skip_text_files(self):
"""Test that text files don't skip Raptor"""
assert should_skip_raptor(".txt") is False
assert should_skip_raptor(".docx") is False
assert should_skip_raptor(".md") is False
def test_override_with_config(self):
"""Test that auto-disable can be overridden"""
raptor_config = {"auto_disable_for_structured_data": False}
# Should not skip even for Excel files
assert should_skip_raptor(".xlsx", raptor_config=raptor_config) is False
assert should_skip_raptor(".csv", raptor_config=raptor_config) is False
assert should_skip_raptor(".pdf", parser_id="table", raptor_config=raptor_config) is False
def test_default_auto_disable_enabled(self):
"""Test that auto-disable is enabled by default"""
# Empty raptor_config should default to auto_disable=True
assert should_skip_raptor(".xlsx", raptor_config={}) is True
assert should_skip_raptor(".xlsx", raptor_config=None) is True
def test_explicit_auto_disable_enabled(self):
"""Test explicit auto-disable enabled"""
raptor_config = {"auto_disable_for_structured_data": True}
assert should_skip_raptor(".xlsx", raptor_config=raptor_config) is True
class TestGetSkipReason:
"""Test skip reason generation"""
def test_excel_skip_reason(self):
"""Test skip reason for Excel files"""
reason = get_skip_reason(".xlsx")
assert "Structured data file" in reason
assert ".xlsx" in reason
assert "auto-disabled" in reason.lower()
def test_csv_skip_reason(self):
"""Test skip reason for CSV files"""
reason = get_skip_reason(".csv")
assert "Structured data file" in reason
assert ".csv" in reason
def test_tabular_pdf_skip_reason(self):
"""Test skip reason for tabular PDFs"""
reason = get_skip_reason(".pdf", parser_id="table")
assert "Tabular PDF" in reason
assert "table" in reason.lower()
assert "auto-disabled" in reason.lower()
def test_html4excel_skip_reason(self):
"""Test skip reason for html4excel PDFs"""
reason = get_skip_reason(".pdf", parser_config={"html4excel": True})
assert "Tabular PDF" in reason
def test_no_skip_reason_for_regular_files(self):
"""Test that regular files have no skip reason"""
assert get_skip_reason(".txt") == ""
assert get_skip_reason(".docx") == ""
assert get_skip_reason(".pdf", parser_id="naive") == ""
class TestEdgeCases:
"""Test edge cases and error handling"""
def test_none_values(self):
"""Test handling of None values"""
assert should_skip_raptor(None) is False
assert should_skip_raptor("") is False
assert get_skip_reason(None) == ""
def test_empty_strings(self):
"""Test handling of empty strings"""
assert should_skip_raptor("") is False
assert get_skip_reason("") == ""
def test_case_insensitivity(self):
"""Test case insensitive handling"""
assert is_structured_file_type("XLSX") is True
assert is_structured_file_type("XlSx") is True
assert is_tabular_pdf("TABLE", {}) is True
assert is_tabular_pdf("TaBlE", {}) is True
def test_with_and_without_dot(self):
"""Test file extensions with and without leading dot"""
assert should_skip_raptor(".xlsx") is True
assert should_skip_raptor("xlsx") is True
assert should_skip_raptor(".CSV") is True
assert should_skip_raptor("csv") is True
class TestIntegrationScenarios:
"""Test real-world integration scenarios"""
def test_financial_excel_report(self):
"""Test scenario: Financial quarterly Excel report"""
file_type = ".xlsx"
parser_id = "naive"
parser_config = {}
raptor_config = {"use_raptor": True}
# Should skip Raptor
assert should_skip_raptor(file_type, parser_id, parser_config, raptor_config) is True
reason = get_skip_reason(file_type, parser_id, parser_config)
assert "Structured data file" in reason
def test_scientific_csv_data(self):
"""Test scenario: Scientific experimental CSV results"""
file_type = ".csv"
# Should skip Raptor
assert should_skip_raptor(file_type) is True
reason = get_skip_reason(file_type)
assert ".csv" in reason
def test_legal_contract_with_tables(self):
"""Test scenario: Legal contract PDF with tables"""
file_type = ".pdf"
parser_id = "table"
parser_config = {}
# Should skip Raptor
assert should_skip_raptor(file_type, parser_id, parser_config) is True
reason = get_skip_reason(file_type, parser_id, parser_config)
assert "Tabular PDF" in reason
def test_text_heavy_pdf_document(self):
"""Test scenario: Text-heavy PDF document"""
file_type = ".pdf"
parser_id = "naive"
parser_config = {}
# Should NOT skip Raptor
assert should_skip_raptor(file_type, parser_id, parser_config) is False
reason = get_skip_reason(file_type, parser_id, parser_config)
assert reason == ""
def test_mixed_dataset_processing(self):
"""Test scenario: Mixed dataset with various file types"""
files = [
(".xlsx", "naive", {}, True), # Excel - skip
(".csv", "naive", {}, True), # CSV - skip
(".pdf", "table", {}, True), # Tabular PDF - skip
(".pdf", "naive", {}, False), # Regular PDF - don't skip
(".docx", "naive", {}, False), # Word doc - don't skip
(".txt", "naive", {}, False), # Text file - don't skip
]
for file_type, parser_id, parser_config, expected_skip in files:
result = should_skip_raptor(file_type, parser_id, parser_config)
assert result == expected_skip, f"Failed for {file_type}"
def test_override_for_special_excel(self):
"""Test scenario: Override auto-disable for special Excel processing"""
file_type = ".xlsx"
raptor_config = {"auto_disable_for_structured_data": False}
# Should NOT skip when explicitly disabled
assert should_skip_raptor(file_type, raptor_config=raptor_config) is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])

Some files were not shown because too many files have changed in this diff Show More