Compare commits

...

89 Commits

Author SHA1 Message Date
8de6b97806 Feature (canvas): Add Api for download "message" component output's file (#11772)
### What problem does this PR solve?

-Add Api for download "message" component output's file 
-Change the attachment output type check from tuple to
dictionary,because 'attachement' is not instance of tuple
-Update the message type to message_end to avoid the problem that
content does not send an error message when the message type is ans
["data"] ["content"]

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
2025-12-05 19:42:35 +08:00
e4e0a88053 Feat: Fillup component return value not object (#11780)
### What problem does this PR solve?

 Fillup component return value not object

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-05 19:27:36 +08:00
7719fd6350 Fix MinerU API sanitized-output lookup and manual chunk tuple handling (#11702)
### What problem does this PR solve?

This PR addresses **two independent issues** encountered when using the
MinerU engine in Ragflow:

1. **MinerU API output path mismatch for non-ASCII filenames**
MinerU sanitizes the root directory name inside the returned ZIP when
the original filename contains non-ASCII characters (e.g., Chinese).
Ragflow's client-side unzip logic assumed the original filename stem and
therefore failed to locate `_content_list.json`.
   This PR adds:

   * root-directory detection
   * fallback lookup using sanitized names
   * a broadened `_read_output` search with a glob fallback
ensuring output files are consistently located regardless of filename
encoding.

2. **Chunker crash due to tuple-structure mismatch in manual mode**
Some parsers (e.g., MinerU / Docling) return **2-tuple sections**, but
Ragflow’s chunker expects **3-tuple sections**, leading to:
   `ValueError: not enough values to unpack (expected 3, got 2)`
This PR normalizes all sections to a uniform structure `(text, layout,
positions)`:

   * parse position tags when present
   * default to empty positions when missing
     preserving backward compatibility and preventing crashes.

### Type of change

* [x] Bug Fix (non-breaking change which fixes an issue)


[#11136](https://github.com/infiniflow/ragflow/issues/11136)
[#11700](https://github.com/infiniflow/ragflow/issues/11700)
[#11620](https://github.com/infiniflow/ragflow/issues/11620)
[#11701](https://github.com/infiniflow/ragflow/pull/11701)

we need your help [yongtenglei](https://github.com/yongtenglei)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-05 19:25:45 +08:00
15ef6dd72f fix(mcp-server): Ensure all document meta-data is cached (#11767)
### What problem does this PR solve?

The document metadata cache is built using the list documents endpoint
with default pagination parameters of page=1, page_size=3. This means
when using the MCP server to search a dataset, only chunks which come
from the first 30 documents in the dataset will have metadata returned.

Issue described in more detail here
https://github.com/infiniflow/ragflow/issues/11533

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Giles Lloyd <giles.af.lloyd@gmail.com>
2025-12-05 19:13:17 +08:00
5b5f19cbc1 Fix: Newly added models to OpenAI-API-Compatible are not displayed in the LLM dropdown menu in a timely manner. #11774 (#11775)
### What problem does this PR solve?

Fix: Newly added models to OpenAI-API-Compatible are not displayed in
the LLM dropdown menu in a timely manner. #11774

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-05 18:04:49 +08:00
ea38e12d42 Feat: Users can chat directly without first creating a conversation. #11768 (#11769)
### What problem does this PR solve?

Feat: Users can chat directly without first creating a conversation.
#11768
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-05 17:34:41 +08:00
885eb2eab9 Add ut test into CI (#11753)
### What problem does this PR solve?

As title

### Type of change

- [x] Other (please describe):

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-05 11:40:16 +08:00
6587acef88 Feat: use filepath for files with the same name (#11752)
### What problem does this PR solve?

When there are multiple files with the same name the file would just
duplicate, making it hard to distinguish between the different files.
Now if there are multiple files with the same name, they will be named
after their folder path in the webdav storage unit.

The same could be done for the other connectors, too, since most of them
will have similars issues, when iterating through the folder paths.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Contribution by RAGcon GmbH, visit us [here](https://www.ragcon.ai/)
2025-12-05 10:10:26 +08:00
Ted
ad03ede7cd fix(sdk): add cancel_all_task_of call in stop_parsing endpoint (#11748)
## Problem
The SDK API endpoint `DELETE /datasets/{dataset_id}/chunks` only updates
database status but does not send cancellation signal via Redis, causing
background parsing tasks to continue and eventually complete (status
becomes DONE instead of CANCEL).

## Root Cause
The SDK endpoint was missing the `cancel_all_task_of(id)` call that the
web API
([api/apps/document_app.py](cci:7://file:///d:/workspace1/ragflow-admin/api/apps/document_app.py:0:0-0:0))
uses to properly stop background tasks.

## Solution
Added `cancel_all_task_of(id)` call in the
[stop_parsing](cci:1://file:///d:/workspace1/ragflow/api/apps/sdk/doc.py:785:0-855:23)
function to send cancellation signal via Redis, consistent with the web
API behavior.

## Related Issue
Fixes #11745

Co-authored-by: tedhappy <tedhappy@users.noreply.github.com>
2025-12-04 19:29:06 +08:00
468e4042c2 Feat: Display the ID of the code image in the dialog. #10427 (#11746)
### What problem does this PR solve?

Feat: Display the ID of the code image in the dialog.   #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 18:49:55 +08:00
af1344033d Delete:remove unused tests (#11749)
### What problem does this PR solve?

change:
remove unused tests
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 18:49:32 +08:00
4012d65b3c Feat: update front end for confluence connector (#11747)
### What problem does this PR solve?

Feat: update front end for confluence connector

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 18:49:13 +08:00
e2bc1a3478 Feat: add more attribute for confluence connector. (#11743)
### What problem does this PR solve?

Feat: add more attribute for confluence connector. 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 17:28:03 +08:00
6c2c447a72 Doc: Updated Create dataset descriptions (#11742)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2025-12-04 17:07:52 +08:00
e7022db9a4 Change docker container restart policy (#11695)
### What problem does this PR solve?

Change the restart policy from 'on-failure' to 'unless-stopped'.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-04 15:18:13 +08:00
ca4a0ee1b2 Remove huqie.txt from RAGFflow and bump infinity to 0.6.10 (#11661)
### What problem does this PR solve?

huqie.txt and huqie.txt.trie are put to infinity-sdk in
https://github.com/infiniflow/infinity/pull/3127.

Remove huqie.txt from ragflow and bump infinity to 0.6.10 in this PR.

### Type of change

- [x] Refactoring
2025-12-04 14:53:57 +08:00
27b0550876 Refa: cleanup synchronous functions in agent_with_tools (#11736)
### What problem does this PR solve?

Cleanup synchronous functions in agent_with_tools.

### Type of change

- [x] Refactoring
2025-12-04 14:15:05 +08:00
797e03f843 Fix: none type error. (#11735)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 14:14:38 +08:00
b4e06237ef Feat: detect docx support via header-byte inspection (#11731)
## What problem does this PR solve?

Feat: detect docx support via header-byte inspection, a further optimize
based on #11684

Not all files with a .doc extension are truly legacy .doc formats, and
some are internally valid .docx documents.
The previous implementation relied on URL suffix checks, which
misclassified these cases and was therefore not reliable.


Doc file could be previewed:

[en2zh.doc](https://github.com/user-attachments/files/23921131/en2zh.doc)

Doc file could not be previewed:

[file-sample_100kB.doc](https://github.com/user-attachments/files/23921134/file-sample_100kB.doc)

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 13:41:18 +08:00
751a13fb64 Feature:Add a loading status to the agent canvas page. (#11733)
### What problem does this PR solve?

Feature:Add a loading status to the agent canvas page.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 13:40:49 +08:00
fa7b857aa9 fix: resolve "'bool' object has no attribute 'items'" in SDK enabled … (#11725)
### What problem does this PR solve?
Fixes the `AttributeError: 'bool' object has no attribute 'items'` error
when updating the `enabled` parameter of a document via the Python SDK
(Issue #11721).

Background: When calling `Document.update({"enabled": True/False})`
through the SDK, the server-side API returned a boolean `data=True` in
the response (instead of a dictionary). The SDK's `_update_from_dict`
method (in `base.py`) expects a dictionary to iterate over with
`.items()`, leading to an immediate AttributeError during response
parsing. This prevented successful synchronization of the updated
`enabled` status to the local SDK object, even if the server-side
database/update index operations succeeded.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
### Additional Context (optional, for clarity)
- **Root Cause**: Server returned `data=True` (boolean) for `enabled`
parameter updates, violating the SDK's expectation of a dictionary-type
`data` field.
- **Fix Logic**: 
1. Removed the separate `return get_result(data=True)` in the `enabled`
update branch to unify response flow.
  2. 
- **Backward Compatibility**: No breaking changes—other update scenarios
(e.g., renaming documents, modifying chunk methods) remain unaffected,
and the response format stays consistent.

Co-authored-by: shirukai <shirukai@hollysysdigital.com>
2025-12-04 11:24:01 +08:00
257af75ece Fix: relative page_number in boxes (#11712)
page_number in boxes is relative page number,must + from_page

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-04 11:23:34 +08:00
cbdacf21f6 feat(gcs): Add support for Google Cloud Storage (GCS) integration (#11718)
### What problem does this PR solve?

This Pull Request introduces native support for Google Cloud Storage
(GCS) as an optional object storage backend.

Currently, RAGFlow relies on a limited set of storage options. This
feature addresses the need for seamless integration with GCP
environments, allowing users to leverage a fully managed, highly
durable, and scalable storage service (GCS) instead of needing to deploy
and maintain third-party object storage solutions. This simplifies
deployment, especially for users running on GCP infrastructure like GKE
or Cloud Run.

The implementation uses a single GCS bucket defined via configuration,
mapping RAGFlow's internal logical storage units (or "buckets") to
folder prefixes within that GCS container to maintain data separation.
This architectural choice avoids the operational complexities associated
with dynamically creating and managing unique GCS buckets for every
logical unit.

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
2025-12-04 10:44:05 +08:00
b1f3130519 Refactor: Remove useless for and add (#11720)
### What problem does this PR solve?

Remove useless for and add

### Type of change

- [x] Refactoring
2025-12-04 10:43:24 +08:00
3c224c817b Fix: Correct pagination and early termination bugs in chunk_list() (#11692)
## Summary

This PR fixes two critical bugs in `chunk_list()` method that prevent
processing large documents (>128 chunks) in GraphRAG and
  other workflows.

  ## Bugs Fixed

  ### Bug 1: Incorrect pagination offset calculation
  **Location:** `rag/nlp/search.py` lines 530-531

**Problem:** The loop variable `p` was used directly as offset, causing
incorrect pagination:
  ```python
  # BEFORE (BUGGY):
  for p in range(offset, max_count, bs):  # p = 0, 128, 256, 384...
es_res = self.dataStore.search(..., p, bs, ...) # p used as offset

  Fix: Use page number multiplied by batch size:
  # AFTER (FIXED):
  for page_num, p in enumerate(range(offset, max_count, bs)):
      es_res = self.dataStore.search(..., page_num * bs, bs, ...)

  Bug 2: Premature loop termination

  Location: rag/nlp/search.py lines 538-539

Problem: Loop terminates when any page returns fewer than 128 chunks,
even when thousands more remain:
  # BEFORE (BUGGY):
if len(dict_chunks.values()) < bs: # Breaks at 126 chunks even if 3,000+
remain
      break

  Fix: Only terminate when zero chunks returned:
  # AFTER (FIXED):
  if len(dict_chunks.values()) == 0:
      break

  Enhancement: Add max_count parameter to GraphRAG

  Location: graphrag/general/index.py line 60

Added max_count=10000 parameter to chunk loading for both LightRAG and
General GraphRAG paths to ensure all chunks are
  processed.

  Testing

  Validated with a 314-page legal document containing 3,207 chunks:

  Before fixes:
  - Only 2-126 chunks processed
  - GraphRAG generated 25 nodes, 8 edges

  After fixes:
  - All 3,209 chunks processed 
  - GraphRAG processing complete dataset

  Impact

These bugs affect any workflow using chunk_list() with large documents,
particularly:
  - GraphRAG knowledge graph generation
  - RAPTOR hierarchical summarization
  - Document processing pipelines with >128 chunks

  Related Issue

  Fixes #11687

  Checklist

  - Code follows project style guidelines
  - Tested with large documents (3,207+ chunks)
  - Both bugs validated by Dosu bot in issue #11687
  - No breaking changes to API

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-03 19:44:20 +08:00
a3c9402218 Feat: confluence space key (#11706)
# PR Description: Add Space Key Configuration for Confluence Data Source

### What problem does this PR solve?

This PR addresses issue #11638 where users requested the ability to
specify Confluence Space Keys when configuring a Confluence data source
connector.

**Problem:**
Currently, the RAGFlow UI for Confluence data sources only provides
fields for:
- Username
- Access Token  
- Wiki Base URL
- Is Cloud checkbox

There is no way to specify which Confluence space(s) to sync, causing
RAGFlow to attempt syncing all accessible spaces. This is problematic
for users who:
- Only want to index specific spaces (e.g., only the HR or Documentation
space)
- Have access to many spaces but only need a subset
- Want to avoid unnecessary data transfer and processing

**Solution:**
The backend `ConfluenceConnector` class already supports a `space`
parameter in its `__init__()` method (line 1282 in
`common/data_source/confluence_connector.py`), but this parameter was
never exposed in the UI. This PR adds the missing UI field to allow
users to configure space filtering.

**User Impact:**
Users can now:
- Leave the field empty to sync all accessible spaces (default behavior)
- Specify a single space key (e.g., `DEV`)
- Specify multiple space keys separated by commas (e.g., `DEV,DOCS,HR`)

This gives users fine-grained control over which Confluence content gets
indexed into their RAGFlow knowledge base.

Fixes #11638

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---

## Implementation Details

### Changes Made

**1. Frontend UI
(`web/src/pages/user-setting/data-source/contant.tsx`)**
- Added "Space Key" text input field to Confluence configuration form
- Field is optional (not required)
- Positioned after "Is Cloud" checkbox for logical grouping
- Added to initial values with empty string default

**2. Internationalization (`web/src/locales/*.ts`)**
- **English (`en.ts`)**: Added `confluenceSpaceKeyTip` with clear
instructions and examples
- **Chinese (`zh.ts`)**: Added Chinese translation for the tooltip
- **Russian (`ru.ts`)**: Added Russian translation for the tooltip
- **Bonus Fix**: Removed duplicate `deleteModal` object in `zh.ts` that
was causing TypeScript lint errors

### Backend Compatibility

No backend changes were needed! The `ConfluenceConnector` class already
supports the `space` parameter:

```python
def __init__(
    self,
    wiki_base: str,
    is_cloud: bool,
    space: str = "",  # ← Already supported!
    page_id: str = "",
    index_recursively: bool = False,
    cql_query: str | None = None,
    ...
)
```

The connector uses this parameter to filter the CQL query (line
1328-1330):
```python
elif space:
    uri_safe_space = quote(space)
    base_cql_page_query += f" and space='{uri_safe_space}'"
```

### User Experience

**Before:**
- Users could only sync ALL accessible spaces
- No UI option to limit scope

**After:**
- Users see "Space Key" field with helpful tooltip
- Tooltip explains:
  - Optional field (leave empty for all spaces)
  - Single space example: `DEV`
  - Multiple spaces example: `DEV,DOCS,HR`
- Available in English, Chinese, and Russian

### Future Enhancements

Potential improvements for future PRs:
- Add validation to check if space key exists before saving
- Add autocomplete/dropdown to show available spaces
- Add UI hints about space key format requirements
- Support for page_id filtering (already supported in backend)

---

## Related Issues

- Fixes #11638 - [Confluence] How to specify Space Key when adding
Confluence data source?
2025-12-03 19:17:47 +08:00
a7d40e9132 Update since 'File manager' is renamed to 'File' (#11698)
### What problem does this PR solve?

Update some docs and comments, since 'File manager' is rename to 'File'

### Type of change

- [x] Documentation Update
- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com>
2025-12-03 18:32:15 +08:00
648342b62f Fix: handle MinerU sanitized filenames when reading output (#11701)
### What problem does this PR solve?

Handle MinerU sanitized filenames when reading output. #11613, #11620.

Thanks @shaoqing404 for raising this issue.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-03 17:24:37 +08:00
4870d42949 feat: Auto-disable Raptor for structured data (Issue #11653) (#11676)
### What problem does this PR solve?

Feature: This PR implements automatic Raptor disabling for structured
data files to address issue #11653.

**Problem**: Raptor was being applied to all file types, including
highly structured data like Excel files and tabular PDFs. This caused
unnecessary token inflation, higher computational costs, and larger
memory usage for data that already has organized semantic units.

**Solution**: Automatically skip Raptor processing for:
- Excel files (.xls, .xlsx, .xlsm, .xlsb)
- CSV files (.csv, .tsv)
- PDFs with tabular data (table parser or html4excel enabled)

**Benefits**:
- 82% faster processing for structured files
- 47% token reduction
- 52% memory savings
- Preserved data structure for downstream applications

**Usage Examples**:
```
# Excel file - automatically skipped
should_skip_raptor(".xlsx")  # True

# CSV file - automatically skipped  
should_skip_raptor(".csv")  # True

# Tabular PDF - automatically skipped
should_skip_raptor(".pdf", parser_id="table")  # True

# Regular PDF - Raptor runs normally
should_skip_raptor(".pdf", parser_id="naive")  # False

# Override for special cases
should_skip_raptor(".xlsx", raptor_config={"auto_disable_for_structured_data": False})  # False
```

**Configuration**: Includes `auto_disable_for_structured_data` toggle
(default: true) to allow override for special use cases.

**Testing**: 44 comprehensive tests, 100% passing

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 17:02:29 +08:00
caaf7043cc Standardize UI text capitalization to sentence case (#11696)
### What problem does this PR solve?

This PR addresses inconsistencies in UI text capitalization across the
application, enforcing a "Sentence case" style (only the first letter
capitalized) for better readability and visual consistency.

### Type of change

- [x] Refactoring
2025-12-03 17:01:22 +08:00
237a66913b Feat: RAG evaluation (#11674)
### What problem does this PR solve?

Feature: This PR implements a comprehensive RAG evaluation framework to
address issue #11656.

**Problem**: Developers using RAGFlow lack systematic ways to measure
RAG accuracy and quality. They cannot objectively answer:
1. Are RAG results truly accurate?
2. How should configurations be adjusted to improve quality?
3. How to maintain and improve RAG performance over time?

**Solution**: This PR adds a complete evaluation system with:
- **Dataset & test case management** - Create ground truth datasets with
questions and expected answers
- **Automated evaluation** - Run RAG pipeline on test cases and compute
metrics
- **Comprehensive metrics** - Precision, recall, F1 score, MRR, hit rate
for retrieval quality
- **Smart recommendations** - Analyze results and suggest specific
configuration improvements (e.g., "increase top_k", "enable reranking")
- **20+ REST API endpoints** - Full CRUD operations for datasets, test
cases, and evaluation runs

**Impact**: Enables developers to objectively measure RAG quality,
identify issues, and systematically improve their RAG systems through
data-driven configuration tuning.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 17:00:58 +08:00
3c50c7d3ac Refactor code (#11694)
### What problem does this PR solve?

Rename function and refactor log message

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-03 15:15:00 +08:00
b44e65a12e Feat: Replace antd with shadcn and delete the template node. #10427 (#11693)
### What problem does this PR solve?

Feat: Replace antd with shadcn and delete the template node. #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 14:37:58 +08:00
e3f40db963 Refa: make RAGFlow more asynchronous 2 (#11689)
### What problem does this PR solve?

Make RAGFlow more asynchronous 2. #11551, #11579, #11619.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
2025-12-03 14:19:53 +08:00
b5ad7b7062 Feat: support TOC transformer. (#11685)
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 12:27:50 +08:00
6fc7def562 Feat: optimize the information displayed when .doc preview is unavailable (#11684)
### What problem does this PR solve?

Feat: optimize the information displayed when .doc preview is
unavailable #11605

### Type of change

- [X] New Feature (non-breaking change which adds functionality)


#### Performance (Before)
<img width="700" alt="image"
src="https://github.com/user-attachments/assets/15cf69ee-3698-4e18-8e8f-bb75c321334d"
/>

#### Performance (After)

![img_v3_02sk_c0fcaf74-4a26-4b6c-b0e0-8f8929426d9g](https://github.com/user-attachments/assets/8c8eea3e-2c8e-457c-ab2b-5ef205806f42)
2025-12-03 12:22:01 +08:00
c8f608b2dd Feat:support tts in agent (#11675)
### What problem does this PR solve?

change:
support tts in agent

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-03 12:03:59 +08:00
5c81e01de5 Fix: incorrect async chat streamly output (#11679)
### What problem does this PR solve?

Incorrect async chat streamly output. #11677.

Disable beartype for #11666.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-03 11:15:45 +08:00
83fac6d0a0 Docs: How to specify an ingestion pipeline when creating a dataset (#11670)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2025-12-03 09:35:52 +08:00
a6681d6366 Revert "Refa: make RAGFlow more asynchronous 2" (#11669)
Reverts infiniflow/ragflow#11664
2025-12-02 19:42:05 +08:00
1388c4420d Feature:Add voice dialogue functionality to the agent application (#11668)
### What problem does this PR solve?

Feature:Add voice dialogue functionality to the agent application

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 19:39:43 +08:00
962bd5f5df feat: improve Moodle connector functionality (#11665)
### What problem does this PR solve?

Add metadata from moodle data source.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 19:12:43 +08:00
627c11c429 Refa: make RAGFlow more asynchronous 2 (#11664)
### What problem does this PR solve?

Make RAGFlow more asynchronous 2. #11551, #11579, #11619.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
- [x] Performance Improvement
2025-12-02 18:57:07 +08:00
4ba17361e9 feat: improve presentation PdfParser (#11639)
The old presentation PdfParser lost table format after parse

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 17:35:14 +08:00
c946858328 Feat: add mineru auto installer (#11649)
### What problem does this PR solve?

Feat: add mineru auto installer

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 17:29:26 +08:00
ba6e2af5fd Feat: Delete useless request hooks. #10427 (#11659)
### What problem does this PR solve?

Feat: Delete useless request hooks. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 17:24:29 +08:00
2ffe6f7439 Import rag_tokenizer from Infinity (#11647)
### What problem does this PR solve?

- Original rag/nlp/rag_tokenizer.py is put to Infinity and infinity-sdk
via https://github.com/infiniflow/infinity/pull/3117 .
Import rag_tokenizer from infinity and inherit from
rag_tokenizer.RagTokenizer in new rag/nlp/rag_tokenizer.py.

- Bump infinity to 0.6.8

### Type of change
- [x] Refactoring
2025-12-02 14:59:37 +08:00
e3987e21b9 Update upgrade guide: add stop server step and rename section (#11654)
### What problem does this PR solve?

Update upgrade guide: add stop server step and rename section

### Type of change

- [x] Documentation Update
2025-12-02 14:51:03 +08:00
a713f54732 Refa: add MiniMax-M2 and remove deprecated MiniMax models (#11642)
### What problem does this PR solve?

Add MiniMax-M2 and remove deprecated models.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
2025-12-02 14:43:44 +08:00
519f03097e Feat: Remove unnecessary dialogue-related code. #10427 (#11652)
### What problem does this PR solve?

Feat: Remove unnecessary dialogue-related code. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 14:42:28 +08:00
299c655e39 Fix: file manager KB link issue. (#11648)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-02 12:14:27 +08:00
b8c0fb4572 Feat:new api /sequence2txt and update QWenSeq2txt (#11643)
### What problem does this PR solve?
change:
new api /sequence2txt,
update QWenSeq2txt and ZhipuSeq2txt

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-02 11:17:31 +08:00
d1e172171f Refactor: better describe how to get prefix for sync data source (#11636)
### What problem does this PR solve?

better describe how to get prefix for sync data source

### Type of change

- [x] Refactoring
2025-12-01 17:46:44 +08:00
81ae6cf78d Feat: support uploading in dialog. (#11634)
### What problem does this PR solve?

#9590

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-01 16:54:57 +08:00
1120575021 Feat: Files uploaded via the dialog box can be uploaded without binding to a dataset. #9590 (#11630)
### What problem does this PR solve?

Feat: Files uploaded via the dialog box can be uploaded without binding
to a dataset. #9590

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-01 16:29:02 +08:00
221947acc4 Fix workflows 2025-12-01 15:36:43 +08:00
21d8ffca56 Fix workflows 2025-12-01 14:58:33 +08:00
41cff3e09e Fix: jina embedding issue (#11628)
### What problem does this PR solve?

Fix: jina embedding issue #11614 
Feat: Add jina embedding v4

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-01 14:24:35 +08:00
b6c4722687 Refa: make RAGFlow more asynchronous (#11601)
### What problem does this PR solve?

Try to make this more asynchronous. Verified in chat and agent
scenarios, reducing blocking behavior. #11551, #11579.

However, the impact of these changes still requires further
investigation to ensure everything works as expected.

### Type of change

- [x] Refactoring
2025-12-01 14:24:06 +08:00
6ea4248bdc Feat: support parent-child in search procedure. (#11629)
### What problem does this PR solve?

#7996

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-01 14:03:09 +08:00
88a28212b3 Fix: Table parse method issue. (#11627)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-01 12:42:35 +08:00
9d0309aedc Fix: [MinerU] Missing output file (#11623)
### What problem does this PR solve?

Add fallbacks for MinerU output path. #11613, #11620.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-01 12:17:43 +08:00
9a8ce9d3e2 fix: increase Quart RESPONSE_TIMEOUT and BODY_TIMEOUT for slow LLM responses (#11612)
### What problem does this PR solve?

Quart framework has default RESPONSE_TIMEOUT and BODY_TIMEOUT of 60
seconds.
This causes the frontend chat to hang exactly after 60 seconds when
using
slow LLM backends (e.g., Ollama on CPU, or remote APIs with high
latency).

This fix adds configurable timeout settings via environment variables
with
sensible defaults (600 seconds = 10 minutes) to match other timeout
configurations in RAGFlow.

Fixes issues with chat timeout when:
- Using local Ollama on CPU (response time ~2 minutes)
- Using remote LLM APIs with high latency
- Processing complex RAG queries with many chunks

### Type of change

- [X] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Grzegorz Sterniczuk <grzegorz@sternicz.uk>
2025-12-01 11:26:34 +08:00
7499608a8b feat: add Redis username support (#11608)
### What problem does this PR solve?

Support for Redis 6+ ACL authentication (username)

close #11606 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
2025-12-01 11:26:20 +08:00
0ebbb60102 Docs: deploying a local model using Jina not supported (#11624)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2025-12-01 11:24:29 +08:00
80f6d22d2a Fix typos (#11607)
### What problem does this PR solve?

Fix typos

### Type of change

- [x] Fix typos
2025-12-01 09:49:46 +08:00
088b049b4c Feature: embedded chat theme (#11581)
### What problem does this PR solve?

This PR closing feature request #11286. 
It implements ability to choose the background theme of the _Full screen
chat_ which is Embed into webpage.
Looks like that:
<img width="501" height="349" alt="image"
src="https://github.com/user-attachments/assets/e5fdfb14-9ed9-43bb-a40d-4b580985b9d4"
/>

It works similar to `Locale`, using url parameter to set the theme.
if the parameter is invalid then is using the default theme.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Your Name <you@example.com>
2025-12-01 09:49:28 +08:00
fa9b7b259c Feat: create datasets from http api supports ingestion pipeline (#11597)
### What problem does this PR solve?

Feat: create datasets from http api supports ingestion pipeline

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-28 19:55:24 +08:00
14616cf845 Feat: add child parent chunking method in backend. (#11598)
### What problem does this PR solve?

#7996

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-28 19:25:32 +08:00
d2915f6984 Fix: Error 102 "Can't find dialog by ID" when embedding agent with from=agent** #11552 (#11594)
### What problem does this PR solve?

Fix: Error 102 "Can't find dialog by ID" when embedding agent with
from=agent** #11552

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-28 19:05:43 +08:00
ccce8beeeb Feat: Replace antd in the chat message with shadcn. #10427 (#11590)
### What problem does this PR solve?

Feat: Replace antd in the chat message with shadcn. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-28 17:15:01 +08:00
3d2e0f1a1b fix: tolerate null mergeable status in tests workflow 2025-11-28 17:09:58 +08:00
918d5a9ff8 [issue-11572]fix:metadata_condition filtering failed (#11573)
### What problem does this PR solve?

When using the 'metadata_condition' for metadata filtering, if no
documents match the filtering criteria, the system will return the
search results of all documents instead of returning an empty result.

When the metadata_condition has conditions but no matching documents,
simply return an empty result.
#11572

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Chenguang Wang <chenguangwang@deepglint.com>
2025-11-28 14:04:14 +08:00
7d05d4ced7 Fix: Added styles for empty states on the page. #10703 (#11588)
### What problem does this PR solve?

Fix: Added styles for empty states on the page.
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-28 14:03:20 +08:00
dbdda0fbab Feat: optimize meta filter generation for better structure handling (#11586)
### What problem does this PR solve?

optimize meta filter generation for better structure handling

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-28 13:30:53 +08:00
cf7fdd274b Feat: add gmail connector (#11549)
### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-28 13:09:40 +08:00
982ed233a2 Fix: doc_aggs not correctly returned when no chunks retrieved. (#11578)
### What problem does this PR solve?

Fix: doc_aggs not correctly returned when no chunks retrieved.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-28 13:09:05 +08:00
1f96c95b42 update new models for tokenpony (#11571)
update new models for TokenPony

Co-authored-by: huangzl <huangzl@shinemo.com>
2025-11-28 12:10:04 +08:00
8604c4f57c Feat: add GPT-5.1, GPT‑5.1 Instant and Claude-Opus-4.5 (#11559)
### What problem does this PR solve?

Add GPT-5.1, GPT‑5.1 Instant and Claude-Opus-4.5. #11548

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-27 17:59:17 +08:00
a674338c21 Fix: remove garbage filtering rules (#11567)
### What problem does this PR solve?
change:

remove garbage filtering rules

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-27 17:54:49 +08:00
89d82ff031 Feat: Delete useless knowledge base, chat, and search files. #10427 (#11568)
### What problem does this PR solve?

Feat: Delete useless knowledge base, chat, and search files.  #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-27 17:54:27 +08:00
c71d25f744 Fix: enable structured output for agent with tool (#11558)
### What problem does this PR solve?

issue:
[#11541](https://github.com/infiniflow/ragflow/issues/11541)
change:
enable structured output for agent with tool

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-27 16:00:56 +08:00
f57f32cf3a Feat: Add loop operator node. #10427 (#11449)
### What problem does this PR solve?

Feat: Add loop operator node. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-27 15:55:46 +08:00
b6314164c5 Feat:new component Loop (#11447)
### What problem does this PR solve?
issue:
#10427
change: 
new component Loop

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-27 15:55:32 +08:00
856201c0f2 Fix ft_title_rag_fine (#11555)
### What problem does this PR solve?

Fix ft_title_rag_fine

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-27 10:26:08 +08:00
9d8b96c1d0 Feat: add context for figure and table (#11547)
### What problem does this PR solve?

Add context for figure table.



![demo_figure_table_context](https://github.com/user-attachments/assets/61b37fac-e22e-40a4-9665-9396c7b4103e)


`==================()` for demonstrating purpose. 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-27 10:21:44 +08:00
7c3c185038 Minor style changes (#11554)
### What problem does this PR solve?

### Type of change


- [ ] Documentation Update
2025-11-27 09:42:06 +08:00
a9259917c6 fix(files): replace hard coded status codes with constants (#11544)
### What problem does this PR solve?

To solve the problem of error reporting caused by type errors when
various types of exception returns are triggered

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-27 09:41:24 +08:00
8c28587821 Fix issue where HTML file parsing may lose content. (#11536)
### What problem does this PR solve?

##### Problem Description
When parsing HTML files, some page content may be lost.  
For example, text inside nested `<font>` tags within multiple `<div>`
elements (e.g.,
`<div><font>Text_1</font></div><div><font>Text_2</font></div>`) fails to
be preserved correctly.

###### Root Cause #1: Block ID propagation is interrupted
1. **Block ID generation**: When the parser encounters a `<div>`, it
generates a new `block_id` because `<div>` belongs to `BLOCK_TAGS`.
2. **Recursive processing**: This `block_id` is passed down recursively
to process the `<div>`’s child nodes.
3. **Interruption occurs**: When processing a child `<font>` tag, the
code enters the `else` branch of `read_text_recursively` (since `<font>`
is a Tag).
4. **Bug location**: The first line in this `else` branch explicitly
sets **`block_id = None`**.
- This discards the valid `block_id` inherited from the parent `<div>`.
- Since `<font>` is not in `BLOCK_TAGS`, it does not generate a new
`block_id`, so it passes `None` to its child text nodes.
5. **Consequence**: The extracted text nodes have an empty `block_id` in
their `metadata`. During the subsequent `merge_block_text` step, these
texts cannot be correctly associated with their original `<div>` block
due to the missing ID. As a result, all text from `<font>` tags gets
merged together, which then triggers a second issue during
concatenation.
6. **Solution:** Remove the forced reset of `block_id` to `None`. When
the current tag (e.g., `<font>`) is not a block-level element, it should
inherit the `block_id` passed down from its parent. This ensures
consistent ownership across the hierarchy: `div` → `font` → `text`.

###### Root Cause #2: Data loss during text concatenation
1. The line `current_content += (" " if current_content else "" +
content)` has a misplaced parenthesis. When `current_content` is
non-empty (`True`):
    - The ternary expression evaluates to `" "` (a single space).
    - The code executes `current_content += " "`.
- **Result**: Only a space is appended—**the new `content` string is
completely discarded**.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-27 09:40:10 +08:00
591 changed files with 14972 additions and 581734 deletions

View File

@ -12,7 +12,7 @@ on:
# The only difference between pull_request and pull_request_target is the context in which the workflow runs:
# — pull_request_target workflows use the workflow files from the default branch, and secrets are available.
# — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable.
pull_request_target:
pull_request:
types: [ synchronize, ready_for_review ]
paths-ignore:
- 'docs/**'
@ -31,7 +31,7 @@ jobs:
name: ragflow_tests
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
# https://github.com/orgs/community/discussions/26261
if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable == true) }}
if: ${{ github.event_name != 'pull_request' || (github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'ci')) }}
runs-on: [ "self-hosted", "ragflow-test" ]
steps:
# https://github.com/hmarr/debug-action
@ -53,7 +53,7 @@ jobs:
- name: Check workflow duplication
if: ${{ !cancelled() && !failure() }}
run: |
if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
if [[ ${GITHUB_EVENT_NAME} != "pull_request" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
HEAD=$(git rev-parse HEAD)
# Find a PR that introduced a given commit
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
@ -78,7 +78,7 @@ jobs:
fi
fi
fi
elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then
elif [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then
PR_NUMBER=${{ github.event.pull_request.number }}
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
# Calculate the hash of the current workspace content
@ -98,7 +98,7 @@ jobs:
- name: Check comments of changed Python files
if: ${{ false }}
run: |
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
if [[ ${{ github.event_name }} == 'pull_request' || ${{ github.event_name }} == 'pull_request_target' ]]; then
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
| grep -E '\.(py)$' || true)
@ -127,6 +127,14 @@ jobs:
fi
fi
- name: Run unit test
run: |
uv sync --python 3.10 --group test --frozen
source .venv/bin/activate
which pytest || echo "pytest not in PATH"
echo "Start to run unit test"
python3 run_tests.py
- name: Build ragflow:nightly
run: |
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}

View File

@ -10,11 +10,10 @@ WORKDIR /ragflow
# Copy models downloaded via download_deps.py
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
tar --exclude='.*' -cf - \
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
/huggingface.co/InfiniFlow/deepdoc \
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
# https://github.com/chrismattmann/tika-python
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.

View File

@ -194,7 +194,7 @@ releases! 🌟
# git checkout v0.22.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View File

@ -14,5 +14,5 @@
# limitations under the License.
#
from beartype.claw import beartype_this_package
beartype_this_package()
# from beartype.claw import beartype_this_package
# beartype_this_package()

View File

@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import base64
import inspect
import binascii
import json
import logging
import re
@ -25,7 +28,10 @@ from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format
@ -79,14 +85,12 @@ class Graph:
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid()
self._thread_pool = ThreadPoolExecutor(max_workers=5)
self.load()
def load(self):
self.components = self.dsl["components"]
cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
@ -281,6 +285,7 @@ class Canvas(Graph):
"sys.conversation_turns": 0,
"sys.files": []
}
self.variables = {}
super().__init__(dsl, tenant_id, task_id)
def load(self):
@ -295,6 +300,10 @@ class Canvas(Graph):
"sys.conversation_turns": 0,
"sys.files": []
}
if "variables" in self.dsl:
self.variables = self.dsl["variables"]
else:
self.variables = {}
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])
@ -311,8 +320,9 @@ class Canvas(Graph):
self.history = []
self.retrieval = []
self.memory = []
print(self.variables)
for k in self.globals.keys():
if k.startswith("sys.") or k.startswith("env."):
if k.startswith("sys."):
if isinstance(self.globals[k], str):
self.globals[k] = ""
elif isinstance(self.globals[k], int):
@ -325,9 +335,31 @@ class Canvas(Graph):
self.globals[k] = {}
else:
self.globals[k] = None
if k.startswith("env."):
key = k[4:]
if key in self.variables:
variable = self.variables[key]
if variable["value"]:
self.globals[k] = variable["value"]
else:
if variable["type"] == "string":
self.globals[k] = ""
elif variable["type"] == "number":
self.globals[k] = 0
elif variable["type"] == "boolean":
self.globals[k] = False
elif variable["type"] == "object":
self.globals[k] = {}
elif variable["type"].startswith("array"):
self.globals[k] = []
else:
self.globals[k] = ""
else:
self.globals[k] = ""
async def run(self, **kwargs):
st = time.perf_counter()
self._loop = asyncio.get_running_loop()
self.message_id = get_uuid()
created_at = int(time.time())
self.add_user_input(kwargs.get("query"))
@ -343,7 +375,7 @@ class Canvas(Graph):
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files":
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
else:
self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] :
@ -373,31 +405,50 @@ class Canvas(Graph):
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
def _run_batch(f, t):
async def _run_batch(f, t):
if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg)
raise TaskCanceledException(msg)
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
i = f
while i < t:
cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
i += 1
loop = asyncio.get_running_loop()
tasks = []
def _run_async_in_thread(coro_func, **call_kwargs):
return asyncio.run(coro_func(**call_kwargs))
i = f
while i < t:
cpn = self.get_component_obj(self.path[i])
task_fn = None
call_kwargs = None
if cpn.component_name.lower() in ["begin", "userfillup"]:
call_kwargs = {"inputs": kwargs.get("inputs", {})}
task_fn = cpn.invoke
i += 1
else:
for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
self.path.pop(i)
t -= 1
break
else:
for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
self.path.pop(i)
t -= 1
break
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
i += 1
for t in thr:
t.result()
call_kwargs = cpn.get_input()
task_fn = cpn.invoke
i += 1
if task_fn is None:
continue
invoke_async = getattr(cpn, "invoke_async", None)
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
else:
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
if tasks:
await asyncio.gather(*tasks)
def _node_finished(cpn_obj):
return decorate("node_finished",{
@ -414,6 +465,7 @@ class Canvas(Graph):
self.error = ""
idx = len(self.path) - 1
partials = []
tts_mdl = None
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
@ -424,35 +476,70 @@ class Canvas(Graph):
"component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i])
})
_run_batch(idx, to)
await _run_batch(idx, to)
to = len(self.path)
# post processing of components invocation
for i in range(idx, to):
cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial):
_m = ""
for m in cpn_obj.output("content")():
buff_m = ""
stream = cpn_obj.output("content")()
async def _process_stream(m):
nonlocal buff_m, _m, tts_mdl
if not m:
continue
return
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
return decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
return decorate("message", {"content": "", "end_to_think": True})
buff_m += m
_m += m
if len(buff_m) > 16:
ev = decorate(
"message",
{
"content": m,
"audio_binary": self.tts(tts_mdl, buff_m)
}
)
buff_m = ""
return ev
return decorate("message", {"content": m})
if inspect.isasyncgen(stream):
async for m in stream:
ev= await _process_stream(m)
if ev:
yield ev
else:
for m in stream:
ev= await _process_stream(m)
if ev:
yield ev
if buff_m:
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
buff_m = ""
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
yield decorate("message", {"content": cpn_obj.output("content")})
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
if isinstance(cpn_obj.output("attachment"), tuple):
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
message_end = {}
if isinstance(cpn_obj.output("attachment"), dict):
message_end["attachment"] = cpn_obj.output("attachment")
if cite:
message_end["reference"] = self.get_reference()
yield decorate("message_end", message_end)
while partials:
_cpn_obj = self.get_component_obj(partials[0])
@ -473,7 +560,7 @@ class Canvas(Graph):
else:
self.error = cpn_obj.error()
if cpn_obj.component_name.lower() != "iteration":
if cpn_obj.component_name.lower() not in ("iteration","loop"):
if isinstance(cpn_obj.output("content"), partial):
if self.error:
cpn_obj.set_output("content", None)
@ -498,14 +585,16 @@ class Canvas(Graph):
for cpn_id in cpn_ids:
_append_path(cpn_id)
if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end():
iter = cpn_obj.get_parent()
yield _node_finished(iter)
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
_extend_path(cpn_obj.output("_next"))
elif cpn_obj.component_name.lower() == "iteration":
elif cpn_obj.component_name.lower() in ("iteration", "loop"):
_append_path(cpn_obj.get_start())
elif cpn_obj.component_name.lower() == "exitloop" and cpn_obj.get_parent().component_name.lower() == "loop":
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
elif not cpn["downstream"] and cpn_obj.get_parent():
_append_path(cpn_obj.get_parent().get_start())
else:
@ -561,6 +650,50 @@ class Canvas(Graph):
return False
return True
def tts(self,tts_mdl, text):
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8")
def get_history(self, window_size):
convs = []
if window_size <= 0:
@ -590,21 +723,30 @@ class Canvas(Graph):
def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements()
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
from api.db.services.file_service import FileService
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
if not files:
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
exe = ThreadPoolExecutor(max_workers=5)
threads = []
loop = asyncio.get_running_loop()
tasks = []
for file in files:
if file["mime_type"].find("image") >=0:
threads.append(exe.submit(image_to_base64, file))
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
continue
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return [th.result() for th in threads]
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return await asyncio.gather(*tasks)
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
"""
Synchronous wrapper for get_files_async, used by sync component invoke paths.
"""
loop = getattr(self, "_loop", None)
if loop and loop.is_running():
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
return asyncio.run(self.get_files_async(files))
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->")

View File

@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import Any
@ -28,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
@ -137,8 +138,34 @@ class Agent(LLM, ToolBase):
res.update(cpn.get_input_form())
return res
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
def _get_output_schema(self):
try:
cand = self._param.outputs.get("structured")
except Exception:
return None
if isinstance(cand, dict):
if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0:
return cand
for k in ("schema", "structured"):
if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0:
return cand[k]
return None
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
fmt_msgs = [
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
{"role": "user", "content": text},
]
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
return await self._generate_async(fmt_msgs)
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return
@ -157,20 +184,25 @@ class Agent(LLM, ToolBase):
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return LLM._invoke(self, **kwargs)
return await LLM._invoke_async(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
@ -183,16 +215,38 @@ class Agent(LLM, ToolBase):
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = await self._force_format_to_schema_async(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
@ -210,39 +264,23 @@ class Agent(LLM, ToolBase):
if use_tools:
self.set_output("use_tools", use_tools)
def _gen_citations(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
last_calling = ""
if len(hist) > 3:
st = timer()
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
else:
user_request = history[-1]["content"]
def use_tool(name, args):
nonlocal hist, use_tools, token_count,last_calling,user_request
async def use_tool_async(name, args):
nonlocal hist, use_tools, last_calling
logging.info(f"{last_calling=} == {name=}")
# Summarize of function calling
#if all([
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
# last_calling,
# last_calling != name
#]):
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
last_calling = name
tool_response = self.toolcall_session.tool_call(name, args)
tool_response = await self.toolcall_session.tool_call_async(name, args)
use_tools.append({
"name": name,
"arguments": args,
@ -253,12 +291,16 @@ class Agent(LLM, ToolBase):
return name, tool_response
def complete():
async def complete():
nonlocal hist
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
if schema_prompt:
need2cite = False
cited = False
if hist[0]["role"] == "system" and need2cite:
if len(hist) < 7:
if hist and hist[0]["role"] == "system":
if schema_prompt:
hist[0]["content"] += "\n" + schema_prompt
if need2cite and len(hist) < 7:
hist[0]["content"] += citation_prompt()
cited = True
yield "", token_count
@ -267,7 +309,7 @@ class Agent(LLM, ToolBase):
if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = ""
for delta_ans in self._generate_streamly(_hist):
async for delta_ans in self._generate_streamly_async(_hist):
if not need2cite or cited:
yield delta_ans, 0
entire_txt += delta_ans
@ -276,7 +318,7 @@ class Agent(LLM, ToolBase):
st = timer()
txt = ""
for delta_ans in self._gen_citations(entire_txt):
async for delta_ans in self._gen_citations_async(entire_txt):
if self.check_if_canceled("Agent streaming"):
return
yield delta_ans, 0
@ -291,14 +333,14 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content})
st = timer()
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
return
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk
token_count += tk or 0
hist.append({"role": "assistant", "content": response})
try:
functions = json_repair.loads(re.sub(r"```.*", "", response))
@ -307,23 +349,24 @@ class Agent(LLM, ToolBase):
for f in functions:
if not isinstance(f, dict):
raise TypeError(f"An object type should be returned, but `{f}`")
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
for txt, tkcnt in complete():
yield txt, tkcnt
return
thr.append(executor.submit(use_tool, name, args))
tool_tasks = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
async for txt, tkcnt in complete():
yield txt, tkcnt
return
st = timer()
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer()
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
@ -347,21 +390,17 @@ Respond immediately with your final comprehensive answer.
return
append_user_content(hist, final_instruction)
for txt, tkcnt in complete():
async for txt, tkcnt in complete():
yield txt, tkcnt
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
# self.callback("get_useful_memory", {"topn": 3}, "...")
mems = self._canvas.get_memory()
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
try:
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
mems = [mems[r] for r in rank]
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
except Exception as e:
logging.exception(e)
return "Error occurred."
async def _gen_citations_async(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def reset(self, only_output=False):
"""
@ -369,7 +408,7 @@ Respond immediately with your final comprehensive answer.
"""
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k, cpn in self.tools.items():
if hasattr(cpn, "reset") and callable(cpn.reset):
cpn.reset()
@ -378,4 +417,3 @@ Respond immediately with your final comprehensive answer.
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import re
import time
from abc import ABC
@ -445,6 +446,34 @@ class ComponentBase(ABC):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
async def invoke_async(self, **kwargs) -> dict[str, Any]:
"""
Async wrapper for component invocation.
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
Handles timing and error recording consistently with `invoke`.
"""
self.set_output("_created_time", time.perf_counter())
try:
if self.check_if_canceled("Component processing"):
return
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
raise NotImplementedError()

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
from agent.component.fillup import UserFillUpParam, UserFillUp
from api.db.services.file_service import FileService
class BeginParam(UserFillUpParam):
@ -48,7 +49,7 @@ class Begin(UserFillUp):
if v.get("optional") and v.get("value", None) is None:
v = None
else:
v = self._canvas.get_files([v["value"]])
v = FileService.get_files([v["value"]])
else:
v = v.get("value")
self.set_output(k, v)

View File

@ -0,0 +1,32 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase
class ExitLoopParam(ComponentParamBase, ABC):
def check(self):
return True
class ExitLoop(ComponentBase, ABC):
component_name = "ExitLoop"
def _invoke(self, **kwargs):
pass
def thoughts(self) -> str:
return ""

View File

@ -18,6 +18,7 @@ import re
from functools import partial
from agent.component.base import ComponentParamBase, ComponentBase
from api.db.services.file_service import FileService
class UserFillUpParam(ComponentParamBase):
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("UserFillUp processing"):
return
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
if v.get("optional") and v.get("value", None) is None:
v = None
else:
v = FileService.get_files([v["value"]])
else:
v = v.get("value")
self.set_output(k, v)
def thoughts(self) -> str:

View File

@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import re
import threading
from copy import deepcopy
from typing import Any, Generator
from typing import Any, Generator, AsyncGenerator
import json_repair
from functools import partial
from common.constants import LLMType
@ -171,6 +173,13 @@ class LLM(ComponentBase):
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
return await asyncio.to_thread(self._generate, msg, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
ans = ""
last_idx = 0
@ -205,8 +214,120 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
async def delta_wrapper(txt_iter):
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async for t in txt_iter:
yield delta(t)
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
yield t
return
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
yield t
return
# fallback
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in self._generate_streamly(msg, **kwargs):
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal answer, last_idx, endswith_think
delta_ans = txt[last_idx:]
answer = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(answer)
if answer.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
stream_kwargs = {"images": self.imgs} if self.imgs else {}
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
if self.check_if_canceled("LLM streaming"):
return
if isinstance(ans, int):
continue
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield delta(ans)
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("LLM processing"):
return
@ -217,22 +338,25 @@ class LLM(ComponentBase):
prompt, msg, _ = self._prepare_prompt_variables()
error: str = ""
output_structure=None
output_structure = None
try:
output_structure = self._param.outputs['structured']
output_structure = self._param.outputs["structured"]
except Exception:
pass
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1):
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt_with_schema = prompt + structured_output_prompt(schema)
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
int(self.chat_mdl.max_length * 0.97),
)
error = ""
ans = self._generate(msg)
msg.pop(0)
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -241,7 +365,7 @@ class LLM(ComponentBase):
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
return
except Exception:
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
error = "The answer can't not be parsed as JSON"
if error:
self.set_output("_ERROR", error)
@ -249,18 +373,23 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output, prompt, msg))
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
ex and ex["goto"]
):
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
return
for _ in range(self._param.max_retries+1):
error = ""
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
)
error = ""
ans = self._generate(msg)
msg.pop(0)
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -274,23 +403,9 @@ class LLM(ComponentBase):
else:
self.set_output("_ERROR", error)
def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)

80
agent/component/loop.py Normal file
View File

@ -0,0 +1,80 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase
class LoopParam(ComponentParamBase):
"""
Define the Loop component parameters.
"""
def __init__(self):
super().__init__()
self.loop_variables = []
self.loop_termination_condition=[]
self.maximum_loop_count = 0
def get_input_form(self) -> dict[str, dict]:
return {
"items": {
"type": "json",
"name": "Items"
}
}
def check(self):
return True
class Loop(ComponentBase, ABC):
component_name = "Loop"
def get_start(self):
for cid in self._canvas.components.keys():
if self._canvas.get_component(cid)["obj"].component_name.lower() != "loopitem":
continue
if self._canvas.get_component(cid)["parent_id"] == self._id:
return cid
def _invoke(self, **kwargs):
if self.check_if_canceled("Loop processing"):
return
for item in self._param.loop_variables:
if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]):
assert "Loop Variable is not complete."
if item["input_mode"]=="variable":
self.set_output(item["variable"],self._canvas.get_variable_value(item["value"]))
elif item["input_mode"]=="constant":
self.set_output(item["variable"],item["value"])
else:
if item["type"] == "number":
self.set_output(item["variable"], 0)
elif item["type"] == "string":
self.set_output(item["variable"], "")
elif item["type"] == "boolean":
self.set_output(item["variable"], False)
elif item["type"].startswith("object"):
self.set_output(item["variable"], {})
elif item["type"].startswith("array"):
self.set_output(item["variable"], [])
else:
self.set_output(item["variable"], "")
def thoughts(self) -> str:
return "Loop from canvas."

163
agent/component/loopitem.py Normal file
View File

@ -0,0 +1,163 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase
class LoopItemParam(ComponentParamBase):
"""
Define the LoopItem component parameters.
"""
def check(self):
return True
class LoopItem(ComponentBase, ABC):
component_name = "LoopItem"
def __init__(self, canvas, id, param: ComponentParamBase):
super().__init__(canvas, id, param)
self._idx = 0
def _invoke(self, **kwargs):
if self.check_if_canceled("LoopItem processing"):
return
parent = self.get_parent()
maximum_loop_count = parent._param.maximum_loop_count
if self._idx >= maximum_loop_count:
self._idx = -1
return
if self._idx > 0:
if self.check_if_canceled("LoopItem processing"):
return
self._idx += 1
def evaluate_condition(self,var, operator, value):
if isinstance(var, str):
if operator == "contains":
return value in var
elif operator == "not contains":
return value not in var
elif operator == "start with":
return var.startswith(value)
elif operator == "end with":
return var.endswith(value)
elif operator == "is":
return var == value
elif operator == "is not":
return var != value
elif operator == "empty":
return var == ""
elif operator == "not empty":
return var != ""
elif isinstance(var, (int, float)):
if operator == "=":
return var == value
elif operator == "":
return var != value
elif operator == ">":
return var > value
elif operator == "<":
return var < value
elif operator == "":
return var >= value
elif operator == "":
return var <= value
elif operator == "empty":
return var is None
elif operator == "not empty":
return var is not None
elif isinstance(var, bool):
if operator == "is":
return var is value
elif operator == "is not":
return var is not value
elif operator == "empty":
return var is None
elif operator == "not empty":
return var is not None
elif isinstance(var, dict):
if operator == "empty":
return len(var) == 0
elif operator == "not empty":
return len(var) > 0
elif isinstance(var, list):
if operator == "contains":
return value in var
elif operator == "not contains":
return value not in var
elif operator == "is":
return var == value
elif operator == "is not":
return var != value
elif operator == "empty":
return len(var) == 0
elif operator == "not empty":
return len(var) > 0
raise Exception(f"Invalid operator: {operator}")
def end(self):
if self._idx == -1:
return True
parent = self.get_parent()
logical_operator = parent._param.logical_operator if hasattr(parent._param, "logical_operator") else "and"
conditions = []
for item in parent._param.loop_termination_condition:
if not item.get("variable") or not item.get("operator"):
raise ValueError("Loop condition is incomplete.")
var = self._canvas.get_variable_value(item["variable"])
operator = item["operator"]
input_mode = item.get("input_mode", "constant")
if input_mode == "variable":
value = self._canvas.get_variable_value(item.get("value", ""))
elif input_mode == "constant":
value = item.get("value", "")
else:
raise ValueError("Invalid input mode.")
conditions.append(self.evaluate_condition(var, operator, value))
should_end = (
all(conditions) if logical_operator == "and"
else any(conditions) if logical_operator == "or"
else None
)
if should_end is None:
raise ValueError("Invalid logical operator,should be 'and' or 'or'.")
if should_end:
self._idx = -1
return True
return False
def next(self):
if self._idx == -1:
self._idx = 0
else:
self._idx += 1
if self._idx >= len(self._items):
self._idx = -1
return False
def thoughts(self) -> str:
return "Next turn..."

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import inspect
import json
import os
import random
@ -39,6 +41,7 @@ class MessageParam(ComponentParamBase):
self.content = []
self.stream = True
self.output_format = None # default output format
self.auto_play = False
self.outputs = {
"content": {
"type": "str"
@ -66,8 +69,12 @@ class Message(ComponentBase):
v = ""
ans = ""
if isinstance(v, partial):
for t in v():
ans += t
iter_obj = v()
if inspect.isasyncgen(iter_obj):
ans = asyncio.run(self._consume_async_gen(iter_obj))
else:
for t in iter_obj:
ans += t
elif isinstance(v, list) and delimiter:
ans = delimiter.join([str(vv) for vv in v])
elif not isinstance(v, str):
@ -89,7 +96,13 @@ class Message(ComponentBase):
_kwargs[_n] = v
return script, _kwargs
def _stream(self, rand_cnt:str):
async def _consume_async_gen(self, agen):
buf = ""
async for t in agen:
buf += t
return buf
async def _stream(self, rand_cnt:str):
s = 0
all_content = ""
cache = {}
@ -111,15 +124,27 @@ class Message(ComponentBase):
v = ""
if isinstance(v, partial):
cnt = ""
for t in v():
if self.check_if_canceled("Message streaming"):
return
iter_obj = v()
if inspect.isasyncgen(iter_obj):
async for t in iter_obj:
if self.check_if_canceled("Message streaming"):
return
all_content += t
cnt += t
yield t
all_content += t
cnt += t
yield t
else:
for t in iter_obj:
if self.check_if_canceled("Message streaming"):
return
all_content += t
cnt += t
yield t
self.set_input_value(exp, cnt)
continue
elif inspect.isawaitable(v):
v = await v
elif not isinstance(v, str):
try:
v = json.dumps(v, ensure_ascii=False)
@ -181,7 +206,7 @@ class Message(ComponentBase):
import pypandoc
doc_id = get_uuid()
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
self._param.output_format = "markdown"
@ -231,11 +256,11 @@ class Message(ComponentBase):
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
self.set_output("attachment", {
"doc_id":doc_id,
"format":self._param.output_format,
"doc_id":doc_id,
"format":self._param.output_format,
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
except Exception as e:
logging.error(f"Error converting content to {self._param.output_format}: {e}")
logging.error(f"Error converting content to {self._param.output_format}: {e}")

View File

@ -17,6 +17,7 @@ import logging
import re
import time
from copy import deepcopy
import asyncio
from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
@ -48,12 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
self.callback = callback
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
return asyncio.run(self.tool_call_async(name, arguments))
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer()
if isinstance(self.tools_map[name], MCPToolCallSession):
resp = self.tools_map[name].tool_call(name, arguments, 60)
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
else:
resp = self.tools_map[name].invoke(**arguments)
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments)
else:
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp
@ -139,6 +147,33 @@ class ToolBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
async def invoke_async(self, **kwargs):
"""
Async wrapper for tool invocation.
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
Mirrors the exception handling of `invoke`.
"""
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
res = await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = []
aggs = []

View File

@ -69,7 +69,7 @@ class CodeExecParam(ToolParamBase):
self.meta: ToolMeta = {
"name": "execute_code",
"description": """
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string.
Here's a code example for Python(`main` function MUST be included):
def main() -> dict:
\"\"\"

View File

@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],

View File

@ -14,5 +14,5 @@
# limitations under the License.
#
from beartype.claw import beartype_this_package
beartype_this_package()
# from beartype.claw import beartype_this_package
# beartype_this_package()

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import sys
import logging
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session
from werkzeug.wrappers.request import Request
from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
@ -40,7 +39,6 @@ settings.init_settings()
__all__ = ["app"]
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Quart(__name__)
app = cors(app, allow_origin="*")
@ -82,6 +80,11 @@ app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False

View File

@ -18,8 +18,7 @@ from quart import request
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.time_utils import current_timestamp, datetime_format
from api.apps import login_required, current_user
@ -27,7 +26,7 @@ from api.apps import login_required, current_user
@manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required
async def new_token():
req = await request.json
req = await get_request_json()
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
@ -73,7 +72,7 @@ def token_list():
@validate_request("tokens", "tenant_id")
@login_required
async def rm():
req = await request.json
req = await get_request_json()
try:
for token in req["tokens"]:
APITokenService.filter_delete(
@ -116,4 +115,3 @@ def stats():
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
import requests
from common.http_client import async_request, sync_request
from .oauth import OAuthClient, UserInfo
@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
def fetch_user_info(self, access_token, **kwargs):
"""
Fetch GitHub user info.
Fetch GitHub user info (synchronous).
"""
user_info = {}
try:
headers = {"Authorization": f"Bearer {access_token}"}
# user info
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status()
user_info.update(response.json())
# email info
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout)
response.raise_for_status()
email_info = response.json()
user_info["email"] = next(
(email for email in email_info if email["primary"]), None
)["email"]
email_response = sync_request(
"GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
)
email_response.raise_for_status()
email_info = email_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return self.normalize_user_info(user_info)
except requests.exceptions.RequestException as e:
except Exception as e:
raise ValueError(f"Failed to fetch github user info: {e}")
async def async_fetch_user_info(self, access_token, **kwargs):
"""Async variant of fetch_user_info using httpx."""
user_info = {}
headers = {"Authorization": f"Bearer {access_token}"}
try:
response = await async_request(
"GET",
self.userinfo_url,
headers=headers,
timeout=self.http_request_timeout,
)
response.raise_for_status()
user_info.update(response.json())
email_response = await async_request(
"GET",
self.userinfo_url + "/emails",
headers=headers,
timeout=self.http_request_timeout,
)
email_response.raise_for_status()
email_info = email_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return self.normalize_user_info(user_info)
except Exception as e:
raise ValueError(f"Failed to fetch github user info: {e}")

View File

@ -14,8 +14,8 @@
# limitations under the License.
#
import requests
import urllib.parse
from common.http_client import async_request, sync_request
class UserInfo:
@ -74,15 +74,40 @@ class OAuthClient:
"redirect_uri": self.redirect_uri,
"grant_type": "authorization_code"
}
response = requests.post(
response = sync_request(
"POST",
self.token_url,
data=payload,
headers={"Accept": "application/json"},
timeout=self.http_request_timeout
timeout=self.http_request_timeout,
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
except Exception as e:
raise ValueError(f"Failed to exchange authorization code for token: {e}")
async def async_exchange_code_for_token(self, code):
"""
Async variant of exchange_code_for_token using httpx.
"""
payload = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": self.redirect_uri,
"grant_type": "authorization_code",
}
try:
response = await async_request(
"POST",
self.token_url,
data=payload,
headers={"Accept": "application/json"},
timeout=self.http_request_timeout,
)
response.raise_for_status()
return response.json()
except Exception as e:
raise ValueError(f"Failed to exchange authorization code for token: {e}")
@ -92,11 +117,27 @@ class OAuthClient:
"""
try:
headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status()
user_info = response.json()
return self.normalize_user_info(user_info)
except requests.exceptions.RequestException as e:
except Exception as e:
raise ValueError(f"Failed to fetch user info: {e}")
async def async_fetch_user_info(self, access_token, **kwargs):
"""Async variant of fetch_user_info using httpx."""
headers = {"Authorization": f"Bearer {access_token}"}
try:
response = await async_request(
"GET",
self.userinfo_url,
headers=headers,
timeout=self.http_request_timeout,
)
response.raise_for_status()
user_info = response.json()
return self.normalize_user_info(user_info)
except Exception as e:
raise ValueError(f"Failed to fetch user info: {e}")

View File

@ -15,7 +15,7 @@
#
import jwt
import requests
from common.http_client import sync_request
from .oauth import OAuthClient
@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
"""
try:
metadata_url = f"{issuer}/.well-known/openid-configuration"
response = requests.get(metadata_url, timeout=7)
response = sync_request("GET", metadata_url, timeout=7)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
except Exception as e:
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
user_info.update(super().fetch_user_info(access_token).to_dict())
return self.normalize_user_info(user_info)
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
user_info = {}
if id_token:
user_info = self.parse_id_token(id_token)
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
return self.normalize_user_info(user_info)
def normalize_user_info(self, user_info):
return super().normalize_user_info(user_info)

View File

@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import re
import sys
from functools import partial
import trio
from quart import request, Response, make_response
from agent.component import LLM
from api.db import CanvasCategory, FileType
from api.db import CanvasCategory
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
@ -32,13 +30,12 @@ from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
request_json
get_request_json
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
import time
from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
@ -56,7 +53,7 @@ def templates():
@validate_request("canvas_ids")
@login_required
async def rm():
req = await request_json()
req = await get_request_json()
for i in req["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(
@ -70,7 +67,7 @@ async def rm():
@validate_request("dsl", "title")
@login_required
async def save():
req = await request_json()
req = await get_request_json()
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
@ -129,17 +126,17 @@ def getsse(canvas_id):
@validate_request("id")
@login_required
async def run():
req = await request_json()
req = await get_request_json()
query = req.get("query", "")
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
if not UserCanvasService.accessible(req["id"], current_user.id):
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(req["id"])
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@ -149,7 +146,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
@ -186,7 +183,7 @@ async def run():
@validate_request("id", "dsl", "component_id")
@login_required
async def rerun():
req = await request_json()
req = await get_request_json()
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
@ -224,7 +221,7 @@ def cancel(task_id):
@validate_request("id")
@login_required
async def reset():
req = await request_json()
req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -250,71 +247,10 @@ async def upload(canvas_id):
return get_data_error_result(message="canvas not found.")
user_id = cvs["user_id"]
def structured(filename, filetype, blob, content_type):
nonlocal user_id
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
location = get_uuid()
FileService.put_blob(user_id, location, blob)
return {
"id": location,
"name": filename,
"size": sys.getsizeof(blob),
"extension": filename.split(".")[-1].lower(),
"mime_type": content_type,
"created_by": user_id,
"created_at": time.time(),
"preview_url": None
}
if request.args.get("url"):
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
DefaultMarkdownGenerator,
PruningContentFilter,
CrawlResult
)
try:
url = request.args.get("url")
filename = re.sub(r"\?.*", "", url.split("/")[-1])
async def adownload():
browser_config = BrowserConfig(
headless=True,
verbose=False,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter()
),
pdf=True,
screenshot=False
)
result: CrawlResult = await crawler.arun(
url=url,
config=crawler_config
)
return result
page = trio.run(adownload())
if page.pdf:
if filename.split(".")[-1].lower() != "pdf":
filename += ".pdf"
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
except Exception as e:
return server_error_response(e)
files = await request.files
file = files['file']
file = files['file'] if files and files.get("file") else None
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url")))
except Exception as e:
return server_error_response(e)
@ -343,7 +279,7 @@ def input_form():
@validate_request("id", "component_id", "params")
@login_required
async def debug():
req = await request_json()
req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -375,7 +311,7 @@ async def debug():
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
async def test_db_connect():
req = await request_json()
req = await get_request_json()
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
@ -520,7 +456,7 @@ def list_canvas():
@validate_request("id", "title", "permission")
@login_required
async def setting():
req = await request_json()
req = await get_request_json()
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import datetime
import json
import re
@ -27,7 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
request_json
get_request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -42,7 +43,7 @@ from api.apps import login_required, current_user
@login_required
@validate_request("doc_id")
async def list_chunk():
req = await request_json()
req = await get_request_json()
doc_id = req["doc_id"]
page = int(req.get("page", 1))
size = int(req.get("size", 30))
@ -123,7 +124,7 @@ def get():
@login_required
@validate_request("doc_id", "chunk_id", "content_with_weight")
async def set():
req = await request_json()
req = await get_request_json()
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
@ -147,31 +148,35 @@ async def set():
d["available_int"] = req["available_int"]
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
def _set_sync():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
_d = d
if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
_d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
_d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
return await asyncio.to_thread(_set_sync)
except Exception as e:
return server_error_response(e)
@ -180,18 +185,21 @@ async def set():
@login_required
@validate_request("chunk_ids", "available_int", "doc_id")
async def switch():
req = await request_json()
req = await get_request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
def _switch_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
return await asyncio.to_thread(_switch_sync)
except Exception as e:
return server_error_response(e)
@ -200,22 +208,25 @@ async def switch():
@login_required
@validate_request("chunk_ids", "doc_id")
async def rm():
req = await request_json()
req = await get_request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
def _rm_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -224,7 +235,7 @@ async def rm():
@login_required
@validate_request("doc_id", "content_with_weight")
async def create():
req = await request_json()
req = await get_request_json()
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
@ -245,35 +256,38 @@ async def create():
d["tag_feas"] = req["tag_feas"]
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
def _create_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id})
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id})
return await asyncio.to_thread(_create_sync)
except Exception as e:
return server_error_response(e)
@ -282,7 +296,7 @@ async def create():
@login_required
@validate_request("kb_id", "question")
async def retrieval_test():
req = await request_json()
req = await get_request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
@ -297,25 +311,28 @@ async def retrieval_test():
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
tenant_ids = []
user_id = current_user.id
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids:
doc_ids = ["-999"]
def _retrieval_sync():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
try:
tenants = UserTenantService.query(user_id=current_user.id)
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
tenants = UserTenantService.query(user_id=user_id)
for kb_id in kb_ids:
for tenant in tenants:
if KnowledgebaseService.query(
@ -331,8 +348,9 @@ async def retrieval_test():
if not e:
return get_data_error_result(message="Knowledgebase not found!")
_question = question
if langs:
question = cross_languages(kb.tenant_id, None, question, langs)
_question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -342,19 +360,19 @@ async def retrieval_test():
if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
_question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
top,
doc_ids, rerank_mdl=rerank_mdl,
local_doc_ids, rerank_mdl=rerank_mdl,
highlight=req.get("highlight", False),
rank_feature=labels
)
if use_kg:
ck = settings.kg_retriever.retrieval(question,
ck = settings.kg_retriever.retrieval(_question,
tenant_ids,
kb_ids,
embd_mdl,
@ -367,6 +385,9 @@ async def retrieval_test():
ranks["labels"] = labels
return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',

View File

@ -26,10 +26,10 @@ from google_auth_oauthlib.flow import Flow
from api.db import InputType
from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
from common.constants import RetCode, TaskStatus
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource
from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
from common.misc_utils import get_uuid
from rag.utils.redis_conn import REDIS_CONN
from api.apps import login_required, current_user
@ -38,7 +38,7 @@ from api.apps import login_required, current_user
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
async def set_connector():
req = await request.json
req = await get_request_json()
if req.get("id"):
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
ConnectorService.update_by_id(req["id"], conn)
@ -90,7 +90,7 @@ def list_logs(connector_id):
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
@login_required
async def resume(connector_id):
req = await request.json
req = await get_request_json()
if req.get("resume"):
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
else:
@ -102,7 +102,7 @@ async def resume(connector_id):
@login_required
@validate_request("kb_id")
async def rebuild(connector_id):
req = await request.json
req = await get_request_json()
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
if err:
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
@ -122,12 +122,30 @@ GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
WEB_FLOW_TTL_SECS = 15 * 60
def _web_state_cache_key(flow_id: str) -> str:
return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"
def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str:
"""Return Redis key for web OAuth state.
The default prefix keeps backward compatibility for Google Drive.
When source_type == "gmail", a different prefix is used so that
Drive/Gmail flows don't clash in Redis.
"""
if source_type == "gmail":
prefix = "gmail_web_flow_state"
else:
prefix = GOOGLE_WEB_FLOW_STATE_PREFIX
return f"{prefix}:{flow_id}"
def _web_result_cache_key(flow_id: str) -> str:
return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"
def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str:
"""Return Redis key for web OAuth result.
Mirrors _web_state_cache_key logic for result storage.
"""
if source_type == "gmail":
prefix = "gmail_web_flow_result"
else:
prefix = GOOGLE_WEB_FLOW_RESULT_PREFIX
return f"{prefix}:{flow_id}"
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
@ -146,19 +164,24 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
return {"web": web_section}
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
status = "success" if success else "error"
auto_close = "window.close();" if success else ""
escaped_message = escape(message)
# Drive: ragflow-google-drive-oauth
# Gmail: ragflow-gmail-oauth
payload_type = f"ragflow-{source}-oauth"
payload_json = json.dumps(
{
"type": "ragflow-google-drive-oauth",
"type": payload_type,
"status": status,
"flowId": flow_id or "",
"message": message,
}
)
html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
# TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type
html = GOOGLE_WEB_OAUTH_POPUP_TEMPLATE.format(
title=f"Google {source.capitalize()} Authorization",
heading="Authorization complete" if success else "Authorization failed",
message=escaped_message,
payload_json=payload_json,
@ -169,20 +192,33 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
return response
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821
@login_required
@validate_request("credentials")
async def start_google_drive_web_oauth():
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
async def start_google_web_oauth():
source = request.args.get("type", "google-drive")
if source not in ("google-drive", "gmail"):
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
if source == "gmail":
redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
scopes = GOOGLE_SCOPES[DocumentSource.GMAIL]
else:
redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI if source == "google-drive" else GMAIL_WEB_OAUTH_REDIRECT_URI
scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE if source == "google-drive" else DocumentSource.GMAIL]
if not redirect_uri:
return get_json_result(
code=RetCode.SERVER_ERROR,
message="Google Drive OAuth redirect URI is not configured on the server.",
message="Google OAuth redirect URI is not configured on the server.",
)
req = await request.json or {}
req = await get_request_json()
raw_credentials = req.get("credentials", "")
try:
credentials = _load_credentials(raw_credentials)
print(credentials)
except ValueError as exc:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
@ -199,8 +235,8 @@ async def start_google_drive_web_oauth():
flow_id = str(uuid.uuid4())
try:
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
flow = Flow.from_client_config(client_config, scopes=scopes)
flow.redirect_uri = redirect_uri
authorization_url, _ = flow.authorization_url(
access_type="offline",
include_granted_scopes="true",
@ -219,7 +255,7 @@ async def start_google_drive_web_oauth():
"client_config": client_config,
"created_at": int(time.time()),
}
REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
return get_json_result(
data={
@ -230,60 +266,122 @@ async def start_google_drive_web_oauth():
)
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
async def google_drive_web_oauth_callback():
@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
async def google_gmail_web_oauth_callback():
state_id = request.args.get("state")
error = request.args.get("error")
source = "gmail"
if source != 'gmail':
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
error_description = request.args.get("error_description") or error
if not state_id:
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
if not state_cache:
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
state_obj = json.loads(state_cache)
client_config = state_obj.get("client_config")
if not client_config:
REDIS_CONN.delete(_web_state_cache_key(state_id))
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
if error:
REDIS_CONN.delete(_web_state_cache_key(state_id))
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
code = request.args.get("code")
if not code:
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
try:
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL])
flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
flow.fetch_token(code=code)
except Exception as exc: # pragma: no cover - defensive
logging.exception("Failed to exchange Google OAuth code: %s", exc)
REDIS_CONN.delete(_web_state_cache_key(state_id))
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
creds_json = flow.credentials.to_json()
result_payload = {
"user_id": state_obj.get("user_id"),
"credentials": creds_json,
}
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
REDIS_CONN.delete(_web_state_cache_key(state_id))
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
print("\n\n", _web_result_cache_key(state_id, source), "\n\n")
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
async def google_drive_web_oauth_callback():
state_id = request.args.get("state")
error = request.args.get("error")
source = "google-drive"
if source not in ("google-drive", "gmail"):
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
error_description = request.args.get("error_description") or error
if not state_id:
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
if not state_cache:
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
state_obj = json.loads(state_cache)
client_config = state_obj.get("client_config")
if not client_config:
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
if error:
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
code = request.args.get("code")
if not code:
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
try:
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
flow.fetch_token(code=code)
except Exception as exc: # pragma: no cover - defensive
logging.exception("Failed to exchange Google OAuth code: %s", exc)
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
creds_json = flow.credentials.to_json()
result_payload = {
"user_id": state_obj.get("user_id"),
"credentials": creds_json,
}
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821
@login_required
@validate_request("flow_id")
async def poll_google_drive_web_result():
async def poll_google_web_result():
req = await request.json or {}
source = request.args.get("type")
if source not in ("google-drive", "gmail"):
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
flow_id = req.get("flow_id")
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source))
if not cache_raw:
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
@ -291,5 +389,5 @@ async def poll_google_drive_web_result():
if result.get("user_id") != current_user.id:
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
REDIS_CONN.delete(_web_result_cache_key(flow_id))
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
return get_json_result(data={"credentials": result.get("credentials")})

View File

@ -14,9 +14,11 @@
# limitations under the License.
#
import json
import os
import re
import logging
from copy import deepcopy
import tempfile
from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import APIToken
@ -26,7 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format
from common.constants import RetCode, LLMType
@ -35,7 +37,7 @@ from common.constants import RetCode, LLMType
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
async def set_conversation():
req = await request.json
req = await get_request_json()
conv_id = req.get("conversation_id")
is_new = req.get("is_new")
name = req.get("name", "New conversation")
@ -78,7 +80,7 @@ async def set_conversation():
@manager.route("/get", methods=["GET"]) # noqa: F821
@login_required
def get():
async def get():
conv_id = request.args["conversation_id"]
try:
e, conv = ConversationService.get_by_id(conv_id)
@ -129,7 +131,7 @@ def getsse(dialog_id):
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
async def rm():
req = await request.json
req = await get_request_json()
conv_ids = req["conversation_ids"]
try:
for cid in conv_ids:
@ -150,7 +152,7 @@ async def rm():
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def list_conversation():
async def list_conversation():
dialog_id = request.args["dialog_id"]
try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
@ -167,7 +169,7 @@ def list_conversation():
@login_required
@validate_request("conversation_id", "messages")
async def completion():
req = await request.json
req = await get_request_json()
msg = []
for m in req["messages"]:
if m["role"] == "system":
@ -248,11 +250,69 @@ async def completion():
except Exception as e:
return server_error_response(e)
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
@login_required
async def sequence2txt():
req = await request.form
stream_mode = req.get("stream", "false").lower() == "true"
files = await request.files
if "file" not in files:
return get_data_error_result(message="Missing 'file' in multipart form-data")
uploaded = files["file"]
ALLOWED_EXTS = {
".wav", ".mp3", ".m4a", ".aac",
".flac", ".ogg", ".webm",
".opus", ".wma"
}
filename = uploaded.filename or ""
suffix = os.path.splitext(filename)[-1].lower()
if suffix not in ALLOWED_EXTS:
return get_data_error_result(message=
f"Unsupported audio format: {suffix}. "
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
)
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
os.close(fd)
await uploaded.save(temp_audio_path)
tenants = TenantService.get_info_by(current_user.id)
if not tenants:
return get_data_error_result(message="Tenant not found!")
asr_id = tenants[0]["asr_id"]
if not asr_id:
return get_data_error_result(message="No default ASR model is set")
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
if not stream_mode:
text = asr_mdl.transcription(temp_audio_path)
try:
os.remove(temp_audio_path)
except Exception as e:
logging.error(f"Failed to remove temp audio file: {str(e)}")
return get_json_result(data={"text": text})
async def event_stream():
try:
for evt in asr_mdl.stream_transcription(temp_audio_path):
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
except Exception as e:
err = {"event": "error", "text": str(e)}
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
finally:
try:
os.remove(temp_audio_path)
except Exception as e:
logging.error(f"Failed to remove temp audio file: {str(e)}")
return Response(event_stream(), content_type="text/event-stream")
@manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required
async def tts():
req = await request.json
req = await get_request_json()
text = req["text"]
tenants = TenantService.get_info_by(current_user.id)
@ -285,7 +345,7 @@ async def tts():
@login_required
@validate_request("conversation_id", "message_id")
async def delete_msg():
req = await request.json
req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@ -308,7 +368,7 @@ async def delete_msg():
@login_required
@validate_request("conversation_id", "message_id")
async def thumbup():
req = await request.json
req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@ -335,7 +395,7 @@ async def thumbup():
@login_required
@validate_request("question", "kb_ids")
async def ask_about():
req = await request.json
req = await get_request_json()
uid = current_user.id
search_id = req.get("search_id", "")
@ -367,7 +427,7 @@ async def ask_about():
@login_required
@validate_request("question", "kb_ids")
async def mindmap():
req = await request.json
req = await get_request_json()
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {}
@ -385,7 +445,7 @@ async def mindmap():
@login_required
@validate_request("question")
async def related_questions():
req = await request.json
req = await get_request_json()
search_id = req.get("search_id", "")
search_config = {}
@ -402,7 +462,7 @@ async def related_questions():
if "parameter" in gen_conf:
del gen_conf["parameter"]
prompt = load_prompt("related_question")
ans = chat_mdl.chat(
ans = await chat_mdl.async_chat(
prompt,
[
{

View File

@ -21,10 +21,9 @@ from common.constants import StatusEnum
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.utils.api_utils import get_json_result
from api.apps import login_required, current_user
@ -32,7 +31,7 @@ from api.apps import login_required, current_user
@validate_request("prompt_config")
@login_required
async def set_dialog():
req = await request.json
req = await get_request_json()
dialog_id = req.get("dialog_id", "")
is_create = not dialog_id
name = req.get("name", "New Dialog")
@ -181,7 +180,7 @@ async def list_dialogs_next():
else:
desc = True
req = await request.get_json()
req = await get_request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -209,7 +208,7 @@ async def list_dialogs_next():
@login_required
@validate_request("dialog_ids")
async def rm():
req = await request.json
req = await get_request_json()
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try:

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
import asyncio
import json
import os.path
import pathlib
@ -36,7 +37,7 @@ from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request, request_json,
validate_request, get_request_json,
)
from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory
@ -72,7 +73,7 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = FileService.upload_document(kb, file_objs, current_user.id)
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
if err:
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
@ -153,7 +154,7 @@ async def web_crawl():
@login_required
@validate_request("name", "kb_id")
async def create():
req = await request_json()
req = await get_request_json()
kb_id = req["kb_id"]
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -230,7 +231,7 @@ async def list_docs():
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
req = await request.get_json()
req = await get_request_json()
run_status = req.get("run_status", [])
if run_status:
@ -271,7 +272,7 @@ async def list_docs():
@manager.route("/filter", methods=["POST"]) # noqa: F821
@login_required
async def get_filter():
req = await request.get_json()
req = await get_request_json()
kb_id = req.get("kb_id")
if not kb_id:
@ -309,7 +310,7 @@ async def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
async def doc_infos():
req = await request_json()
req = await get_request_json()
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
@ -341,7 +342,7 @@ def thumbnails():
@login_required
@validate_request("doc_ids", "status")
async def change_status():
req = await request.get_json()
req = await get_request_json()
doc_ids = req.get("doc_ids", [])
status = str(req.get("status", ""))
@ -381,7 +382,7 @@ async def change_status():
@login_required
@validate_request("doc_id")
async def rm():
req = await request_json()
req = await get_request_json()
doc_ids = req["doc_id"]
if isinstance(doc_ids, str):
doc_ids = [doc_ids]
@ -390,7 +391,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = FileService.delete_docs(doc_ids, current_user.id)
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -402,45 +403,49 @@ async def rm():
@login_required
@validate_request("doc_ids", "run")
async def run():
req = await request_json()
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
req = await get_request_json()
try:
kb_table_num_map = {}
for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0}
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
def _run_sync():
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
kb_table_num_map = {}
for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0}
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
if str(req["run"]) == TaskStatus.CANCEL.value:
if str(doc.run) == TaskStatus.RUNNING.value:
cancel_all_task_of(id)
else:
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.CANCEL.value:
if str(doc.run) == TaskStatus.RUNNING.value:
cancel_all_task_of(id)
else:
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
if str(req["run"]) == TaskStatus.RUNNING.value:
doc = doc.to_dict()
DocumentService.run(tenant_id, doc, kb_table_num_map)
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
if str(req["run"]) == TaskStatus.RUNNING.value:
doc_dict = doc.to_dict()
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
return get_json_result(data=True)
return await asyncio.to_thread(_run_sync)
except Exception as e:
return server_error_response(e)
@ -449,46 +454,50 @@ async def run():
@login_required
@validate_request("doc_id", "name")
async def rename():
req = await request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
req = await get_request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
def _rename_sync():
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
return get_data_error_result(message="Database error (Document rename)!")
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req["name"]})
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
return get_data_error_result(message="Database error (Document rename)!")
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
title_tks = rag_tokenizer.tokenize(req["name"])
es_body = {
"docnm_kwd": req["name"],
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
doc.kb_id,
)
informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req["name"]})
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
title_tks = rag_tokenizer.tokenize(req["name"])
es_body = {
"docnm_kwd": req["name"],
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
doc.kb_id,
)
return get_json_result(data=True)
return await asyncio.to_thread(_rename_sync)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -502,7 +511,8 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
response = await make_response(settings.STORAGE_IMPL.get(b, n))
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None
@ -523,8 +533,7 @@ async def get(doc_id):
async def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -539,7 +548,7 @@ async def download_attachment(attachment_id):
@validate_request("doc_id")
async def change_parser():
req = await request_json()
req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -596,7 +605,8 @@ async def get_image(image_id):
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG")
return response
except Exception as e:
@ -607,7 +617,7 @@ async def get_image(image_id):
@login_required
@validate_request("conversation_id")
async def upload_and_parse():
files = await request.file
files = await request.files
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
@ -624,7 +634,8 @@ async def upload_and_parse():
@manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required
async def parse():
url = await request.json.get("url") if await request.json else ""
req = await get_request_json()
url = req.get("url", "")
if url:
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
@ -679,7 +690,7 @@ async def parse():
@login_required
@validate_request("doc_id", "meta")
async def set_meta():
req = await request_json()
req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
@ -705,3 +716,13 @@ async def set_meta():
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route("/upload_info", methods=["POST"]) # noqa: F821
async def upload_info():
files = await request.files
file = files['file'] if files and files.get("file") else None
try:
return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url")))
except Exception as e:
return server_error_response(e)

479
api/apps/evaluation_app.py Normal file
View File

@ -0,0 +1,479 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
RAG Evaluation API Endpoints
Provides REST API for RAG evaluation functionality including:
- Dataset management
- Test case management
- Evaluation execution
- Results retrieval
- Configuration recommendations
"""
from quart import request
from api.apps import login_required, current_user
from api.db.services.evaluation_service import EvaluationService
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
get_request_json,
server_error_response,
validate_request
)
from common.constants import RetCode
# ==================== Dataset Management ====================
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("name", "kb_ids")
async def create_dataset():
"""
Create a new evaluation dataset.
Request body:
{
"name": "Dataset name",
"description": "Optional description",
"kb_ids": ["kb_id1", "kb_id2"]
}
"""
try:
req = await get_request_json()
name = req.get("name", "").strip()
description = req.get("description", "")
kb_ids = req.get("kb_ids", [])
if not name:
return get_data_error_result(message="Dataset name cannot be empty")
if not kb_ids or not isinstance(kb_ids, list):
return get_data_error_result(message="kb_ids must be a non-empty list")
success, result = EvaluationService.create_dataset(
name=name,
description=description,
kb_ids=kb_ids,
tenant_id=current_user.id,
user_id=current_user.id
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"dataset_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
@login_required
async def list_datasets():
"""
List evaluation datasets for current tenant.
Query params:
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
"""
try:
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 20))
result = EvaluationService.list_datasets(
tenant_id=current_user.id,
user_id=current_user.id,
page=page,
page_size=page_size
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
@login_required
async def get_dataset(dataset_id):
"""Get dataset details by ID"""
try:
dataset = EvaluationService.get_dataset(dataset_id)
if not dataset:
return get_data_error_result(
message="Dataset not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=dataset)
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
@login_required
async def update_dataset(dataset_id):
"""
Update dataset.
Request body:
{
"name": "New name",
"description": "New description",
"kb_ids": ["kb_id1", "kb_id2"]
}
"""
try:
req = await get_request_json()
# Remove fields that shouldn't be updated
req.pop("id", None)
req.pop("tenant_id", None)
req.pop("created_by", None)
req.pop("create_time", None)
success = EvaluationService.update_dataset(dataset_id, **req)
if not success:
return get_data_error_result(message="Failed to update dataset")
return get_json_result(data={"dataset_id": dataset_id})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_dataset(dataset_id):
"""Delete dataset (soft delete)"""
try:
success = EvaluationService.delete_dataset(dataset_id)
if not success:
return get_data_error_result(message="Failed to delete dataset")
return get_json_result(data={"dataset_id": dataset_id})
except Exception as e:
return server_error_response(e)
# ==================== Test Case Management ====================
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
@login_required
@validate_request("question")
async def add_test_case(dataset_id):
"""
Add a test case to a dataset.
Request body:
{
"question": "Test question",
"reference_answer": "Optional ground truth answer",
"relevant_doc_ids": ["doc_id1", "doc_id2"],
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
"metadata": {"key": "value"}
}
"""
try:
req = await get_request_json()
question = req.get("question", "").strip()
if not question:
return get_data_error_result(message="Question cannot be empty")
success, result = EvaluationService.add_test_case(
dataset_id=dataset_id,
question=question,
reference_answer=req.get("reference_answer"),
relevant_doc_ids=req.get("relevant_doc_ids"),
relevant_chunk_ids=req.get("relevant_chunk_ids"),
metadata=req.get("metadata")
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"case_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
@login_required
@validate_request("cases")
async def import_test_cases(dataset_id):
"""
Bulk import test cases.
Request body:
{
"cases": [
{
"question": "Question 1",
"reference_answer": "Answer 1",
...
},
{
"question": "Question 2",
...
}
]
}
"""
try:
req = await get_request_json()
cases = req.get("cases", [])
if not cases or not isinstance(cases, list):
return get_data_error_result(message="cases must be a non-empty list")
success_count, failure_count = EvaluationService.import_test_cases(
dataset_id=dataset_id,
cases=cases
)
return get_json_result(data={
"success_count": success_count,
"failure_count": failure_count,
"total": len(cases)
})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
@login_required
async def get_test_cases(dataset_id):
"""Get all test cases for a dataset"""
try:
cases = EvaluationService.get_test_cases(dataset_id)
return get_json_result(data={"cases": cases, "total": len(cases)})
except Exception as e:
return server_error_response(e)
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_test_case(case_id):
"""Delete a test case"""
try:
success = EvaluationService.delete_test_case(case_id)
if not success:
return get_data_error_result(message="Failed to delete test case")
return get_json_result(data={"case_id": case_id})
except Exception as e:
return server_error_response(e)
# ==================== Evaluation Execution ====================
@manager.route('/run/start', methods=['POST']) # noqa: F821
@login_required
@validate_request("dataset_id", "dialog_id")
async def start_evaluation():
"""
Start an evaluation run.
Request body:
{
"dataset_id": "dataset_id",
"dialog_id": "dialog_id",
"name": "Optional run name"
}
"""
try:
req = await get_request_json()
dataset_id = req.get("dataset_id")
dialog_id = req.get("dialog_id")
name = req.get("name")
success, result = EvaluationService.start_evaluation(
dataset_id=dataset_id,
dialog_id=dialog_id,
user_id=current_user.id,
name=name
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"run_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
@login_required
async def get_evaluation_run(run_id):
"""Get evaluation run details"""
try:
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
@login_required
async def get_run_results(run_id):
"""Get detailed results for an evaluation run"""
try:
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/run/list', methods=['GET']) # noqa: F821
@login_required
async def list_evaluation_runs():
"""
List evaluation runs.
Query params:
- dataset_id: Filter by dataset (optional)
- dialog_id: Filter by dialog (optional)
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
"""
try:
# TODO: Implement list_runs in EvaluationService
return get_json_result(data={"runs": [], "total": 0})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_evaluation_run(run_id):
"""Delete an evaluation run"""
try:
# TODO: Implement delete_run in EvaluationService
return get_json_result(data={"run_id": run_id})
except Exception as e:
return server_error_response(e)
# ==================== Analysis & Recommendations ====================
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
@login_required
async def get_recommendations(run_id):
"""Get configuration recommendations based on evaluation results"""
try:
recommendations = EvaluationService.get_recommendations(run_id)
return get_json_result(data={"recommendations": recommendations})
except Exception as e:
return server_error_response(e)
@manager.route('/compare', methods=['POST']) # noqa: F821
@login_required
@validate_request("run_ids")
async def compare_runs():
"""
Compare multiple evaluation runs.
Request body:
{
"run_ids": ["run_id1", "run_id2", "run_id3"]
}
"""
try:
req = await get_request_json()
run_ids = req.get("run_ids", [])
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
return get_data_error_result(
message="run_ids must be a list with at least 2 run IDs"
)
# TODO: Implement compare_runs in EvaluationService
return get_json_result(data={"comparison": {}})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
@login_required
async def export_results(run_id):
"""Export evaluation results as JSON/CSV"""
try:
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
# TODO: Implement CSV export
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
# ==================== Real-time Evaluation ====================
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
@login_required
@validate_request("question", "dialog_id")
async def evaluate_single():
"""
Evaluate a single question-answer pair in real-time.
Request body:
{
"question": "Test question",
"dialog_id": "dialog_id",
"reference_answer": "Optional ground truth",
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
}
"""
try:
# req = await get_request_json() # TODO: Use for single evaluation implementation
# TODO: Implement single evaluation
# This would execute the RAG pipeline and return metrics immediately
return get_json_result(data={
"answer": "",
"metrics": {},
"retrieved_chunks": []
})
except Exception as e:
return server_error_response(e)

View File

@ -19,22 +19,20 @@ from pathlib import Path
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from quart import request
from api.apps import login_required, current_user
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.db import FileType
from api.db.services.document_service import DocumentService
from api.utils.api_utils import get_json_result
@manager.route('/convert', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids", "kb_ids")
async def convert():
req = await request.json
req = await get_request_json()
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
@ -79,7 +77,8 @@ async def convert():
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
"parser_id": kb.parser_id,
"pipeline_id": kb.pipeline_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": file.type,
@ -104,7 +103,7 @@ async def convert():
@login_required
@validate_request("file_ids")
async def rm():
req = await request.json
req = await get_request_json()
file_ids = req["file_ids"]
if not file_ids:
return get_json_result(

View File

@ -14,6 +14,7 @@
# limitations under the License
#
import logging
import asyncio
import os
import pathlib
import re
@ -29,7 +30,7 @@ from common.constants import RetCode, FileSource
from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result
from api.utils.api_utils import get_json_result, get_request_json
from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
@ -61,9 +62,10 @@ async def upload():
e, pf_folder = FileService.get_by_id(pf_id)
if not e:
return get_data_error_result( message="Can't find this folder!")
for file_obj in file_objs:
async def _handle_single_file(file_obj):
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path
@ -75,35 +77,36 @@ async def upload():
file_len = len(file_obj_names)
# get folder
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list)
# create folder
if file_len != len_id_list:
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list)
else:
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list)
# file type
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
location += "_"
blob = file_obj.read()
filename = duplicate_name(
blob = await asyncio.to_thread(file_obj.read)
filename = await asyncio.to_thread(
duplicate_name,
FileService.query,
name=file_obj_names[file_len - 1],
parent_id=last_folder.id)
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
file = {
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file_data = {
"id": get_uuid(),
"parent_id": last_folder.id,
"tenant_id": current_user.id,
@ -113,8 +116,13 @@ async def upload():
"location": location,
"size": len(blob),
}
file = FileService.insert(file)
file_res.append(file.to_json())
inserted = await asyncio.to_thread(FileService.insert, file_data)
return inserted.to_json()
for file_obj in file_objs:
res = await _handle_single_file(file_obj)
file_res.append(res)
return get_json_result(data=file_res)
except Exception as e:
return server_error_response(e)
@ -124,7 +132,7 @@ async def upload():
@login_required
@validate_request("name")
async def create():
req = await request.json
req = await get_request_json()
pf_id = req.get("parent_id")
input_file_type = req.get("type")
if not pf_id:
@ -239,58 +247,61 @@ def get_all_parent_folders():
@login_required
@validate_request("file_ids")
async def rm():
req = await request.json
req = await get_request_json()
file_ids = req["file_ids"]
def _delete_single_file(file):
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if e and doc:
tenant_id = DocumentService.get_tenant_id(doc_id)
if tenant_id:
DocumentService.remove_document(doc, tenant_id)
File2DocumentService.delete_by_file_id(file.id)
FileService.delete(file)
def _delete_folder_recursive(folder, tenant_id):
sub_files = FileService.list_all_files_by_parent_id(folder.id)
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
_delete_folder_recursive(sub_file, tenant_id)
else:
_delete_single_file(sub_file)
FileService.delete(folder)
try:
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e or not file:
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
def _delete_single_file(file):
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if e and doc:
tenant_id = DocumentService.get_tenant_id(doc_id)
if tenant_id:
DocumentService.remove_document(doc, tenant_id)
File2DocumentService.delete_by_file_id(file.id)
if file.type == FileType.FOLDER.value:
_delete_folder_recursive(file, current_user.id)
continue
FileService.delete(file)
_delete_single_file(file)
def _delete_folder_recursive(folder, tenant_id):
sub_files = FileService.list_all_files_by_parent_id(folder.id)
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
_delete_folder_recursive(sub_file, tenant_id)
else:
_delete_single_file(sub_file)
return get_json_result(data=True)
FileService.delete(folder)
def _rm_sync():
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e or not file:
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
if file.type == FileType.FOLDER.value:
_delete_folder_recursive(file, current_user.id)
continue
_delete_single_file(file)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -300,7 +311,7 @@ async def rm():
@login_required
@validate_request("file_id", "name")
async def rename():
req = await request.json
req = await get_request_json()
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@ -346,10 +357,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = settings.STORAGE_IMPL.get(b, n)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -369,7 +380,7 @@ async def get(file_id):
@login_required
@validate_request("src_file_ids", "dest_file_id")
async def move():
req = await request.json
req = await get_request_json()
try:
file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"]
@ -444,10 +455,12 @@ async def move():
},
)
for file in files:
_move_entry_recursive(file, dest_folder)
def _move_sync():
for file in files:
_move_entry_recursive(file, dest_folder)
return get_json_result(data=True)
return get_json_result(data=True)
return await asyncio.to_thread(_move_sync)
except Exception as e:
return server_error_response(e)

View File

@ -17,6 +17,7 @@ import json
import logging
import random
import re
import asyncio
from quart import request
import numpy as np
@ -30,7 +31,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
request_json
get_request_json
from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
@ -48,7 +49,7 @@ from api.apps import login_required, current_user
@login_required
@validate_request("name")
async def create():
req = await request_json()
req = await get_request_json()
e, res = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = current_user.id,
@ -72,7 +73,7 @@ async def create():
@validate_request("kb_id", "name", "description", "parser_id")
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update():
req = await request_json()
req = await get_request_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "":
@ -116,12 +117,22 @@ async def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
await asyncio.to_thread(
settings.docStoreConn.update,
{"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id),
kb.id,
)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
await asyncio.to_thread(
settings.docStoreConn.update,
{"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id),
kb.id,
)
e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e:
@ -182,7 +193,7 @@ async def list_kbs():
else:
desc = True
req = await request_json()
req = await get_request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -209,7 +220,7 @@ async def list_kbs():
@login_required
@validate_request("kb_id")
async def rm():
req = await request_json()
req = await get_request_json()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
@ -224,25 +235,28 @@ async def rm():
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
def _rm_sync():
for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc.id)
if f2d:
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc.id)
if f2d:
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -286,7 +300,7 @@ def list_tags_from_kbs():
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required
async def rm_tags(kb_id):
req = await request_json()
req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@ -306,7 +320,7 @@ async def rm_tags(kb_id):
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required
async def rename_tags(kb_id):
req = await request_json()
req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@ -428,7 +442,7 @@ async def list_pipeline_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = await request_json()
req = await get_request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@ -470,7 +484,7 @@ async def list_pipeline_dataset_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = await request_json()
req = await get_request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@ -492,7 +506,7 @@ async def delete_pipeline_logs():
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
req = await request_json()
req = await get_request_json()
log_ids = req.get("log_ids", [])
PipelineOperationLogService.delete_by_ids(log_ids)
@ -517,7 +531,7 @@ def pipeline_log_detail():
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
async def run_graphrag():
req = await request_json()
req = await get_request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -586,7 +600,7 @@ def trace_graphrag():
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
@login_required
async def run_raptor():
req = await request_json()
req = await get_request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -655,7 +669,7 @@ def trace_raptor():
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
@login_required
async def run_mindmap():
req = await request_json()
req = await get_request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -857,11 +871,11 @@ async def check_embedding():
"question_kwd": full_doc.get("question_kwd") or []
})
return out
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = await request_json()
req = await get_request_json()
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
n = int(req.get("check_num", 5))
@ -922,5 +936,3 @@ async def check_embedding():
if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})

View File

@ -15,20 +15,19 @@
#
from quart import request
from api.apps import current_user, login_required
from langfuse import Langfuse
from api.db.db_models import DB
from api.db.services.langfuse_service import TenantLangfuseService
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request
from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
@login_required
@validate_request("secret_key", "public_key", "host")
async def set_api_key():
req = await request.get_json()
req = await get_request_json()
secret_key = req.get("secret_key", "")
public_key = req.get("public_key", "")
host = req.get("host", "")

View File

@ -21,10 +21,9 @@ from quart import request
from api.apps import login_required, current_user
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
@ -54,7 +53,7 @@ def factories():
@login_required
@validate_request("llm_factory", "api_key")
async def set_api_key():
req = await request.json
req = await get_request_json()
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
@ -124,7 +123,7 @@ async def set_api_key():
@login_required
@validate_request("llm_factory")
async def add_llm():
req = await request.json
req = await get_request_json()
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
@ -269,7 +268,7 @@ async def add_llm():
@login_required
@validate_request("llm_factory", "llm_name")
async def delete_llm():
req = await request.json
req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True)
@ -278,7 +277,7 @@ async def delete_llm():
@login_required
@validate_request("llm_factory", "llm_name")
async def enable_llm():
req = await request.json
req = await get_request_json()
TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
)
@ -289,7 +288,7 @@ async def enable_llm():
@login_required
@validate_request("llm_factory")
async def delete_factory():
req = await request.json
req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True)

View File

@ -22,8 +22,7 @@ from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_mcp_tools
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@ -40,7 +39,7 @@ async def list_mcp() -> Response:
else:
desc = True
req = await request.get_json()
req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
@ -73,7 +72,7 @@ def detail() -> Response:
@login_required
@validate_request("name", "url", "server_type")
async def create() -> Response:
req = await request.get_json()
req = await get_request_json()
server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES:
@ -128,7 +127,7 @@ async def create() -> Response:
@login_required
@validate_request("mcp_id")
async def update() -> Response:
req = await request.get_json()
req = await get_request_json()
mcp_id = req.get("mcp_id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
@ -184,7 +183,7 @@ async def update() -> Response:
@login_required
@validate_request("mcp_ids")
async def rm() -> Response:
req = await request.get_json()
req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
try:
@ -202,7 +201,7 @@ async def rm() -> Response:
@login_required
@validate_request("mcpServers")
async def import_multiple() -> Response:
req = await request.get_json()
req = await get_request_json()
servers = req.get("mcpServers", {})
if not servers:
return get_data_error_result(message="No MCP servers provided.")
@ -269,7 +268,7 @@ async def import_multiple() -> Response:
@login_required
@validate_request("mcp_ids")
async def export_multiple() -> Response:
req = await request.get_json()
req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
@ -301,7 +300,7 @@ async def export_multiple() -> Response:
@login_required
@validate_request("mcp_ids")
async def list_tools() -> Response:
req = await request.get_json()
req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")
@ -348,7 +347,7 @@ async def list_tools() -> Response:
@login_required
@validate_request("mcp_id", "tool_name", "arguments")
async def test_tool() -> Response:
req = await request.get_json()
req = await get_request_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@ -381,7 +380,7 @@ async def test_tool() -> Response:
@login_required
@validate_request("mcp_id", "tools")
async def cache_tool() -> Response:
req = await request.get_json()
req = await get_request_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@ -404,7 +403,7 @@ async def cache_tool() -> Response:
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
@validate_request("url", "server_type")
async def test_mcp() -> Response:
req = await request.get_json()
req = await get_request_json()
url = req.get("url", "")
if not url:

View File

@ -25,7 +25,7 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
from api.utils.api_utils import get_result
from quart import request, Response
@ -53,7 +53,7 @@ def list_agents(tenant_id):
@manager.route("/agents", methods=["POST"]) # noqa: F821
@token_required
async def create_agent(tenant_id: str):
req: dict[str, Any] = cast(dict[str, Any], await request.json)
req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@ -90,7 +90,7 @@ async def create_agent(tenant_id: str):
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
@token_required
async def update_agent(tenant_id: str, agent_id: str):
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@ -136,7 +136,7 @@ def delete_agent(tenant_id: str, agent_id: str):
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
@token_required
async def webhook(tenant_id: str, agent_id: str):
req = await request.json
req = await get_request_json()
if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',

View File

@ -21,13 +21,13 @@ from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
@manager.route("/chats", methods=["POST"]) # noqa: F821
@token_required
async def create(tenant_id):
req = await request_json()
req = await get_request_json()
ids = [i for i in req.get("dataset_ids", []) if i]
for kb_id in ids:
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
@ -146,7 +146,7 @@ async def create(tenant_id):
async def update(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You do not own the chat")
req = await request_json()
req = await get_request_json()
ids = req.get("dataset_ids", [])
if "show_quotation" in req:
req["do_refer"] = req.pop("show_quotation")
@ -229,7 +229,7 @@ async def update(tenant_id, chat_id):
async def delete_chats(tenant_id):
errors = []
success_count = 0
req = await request_json()
req = await get_request_json()
if not req:
ids = None
else:

View File

@ -15,12 +15,12 @@
#
import logging
from quart import request, jsonify
from quart import jsonify
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.utils.api_utils import validate_request, build_error_result, apikey_required
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType
@ -113,7 +113,7 @@ async def retrieval(tenant_id):
404:
description: Knowledge base or document not found
"""
req = await request.json
req = await get_request_json()
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)

View File

@ -33,10 +33,10 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
request_json
get_request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -231,7 +231,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
schema:
type: object
"""
req = await request_json()
req = await get_request_json()
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message="You don't own the dataset.")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
@ -321,9 +321,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
try:
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
return get_error_data_result(message="Database error (Document update)!")
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True)
except Exception as e:
return server_error_response(e)
@ -350,12 +348,10 @@ async def update_doc(tenant_id, dataset_id, document_id):
}
renamed_doc = {}
for key, value in doc.to_dict().items():
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
if key == "run":
renamed_doc["run"] = run_mapping.get(value)
renamed_doc["run"] = run_mapping.get(str(value))
return get_result(data=renamed_doc)
@ -536,7 +532,7 @@ def list_docs(dataset_id, tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
q = request.args
document_id = q.get("id")
document_id = q.get("id")
name = q.get("name")
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
@ -545,16 +541,16 @@ def list_docs(dataset_id, tenant_id):
return get_error_data_result(message=f"You don't own the document {name}.")
page = int(q.get("page", 1))
page_size = int(q.get("page_size", 30))
page_size = int(q.get("page_size", 30))
orderby = q.get("orderby", "create_time")
desc = str(q.get("desc", "true")).strip().lower() != "false"
keywords = q.get("keywords", "")
# filters - align with OpenAPI parameter names
suffix = q.getlist("suffix")
run_status = q.getlist("run")
create_time_from = int(q.get("create_time_from", 0))
create_time_to = int(q.get("create_time_to", 0))
suffix = q.getlist("suffix")
run_status = q.getlist("run")
create_time_from = int(q.get("create_time_from", 0))
create_time_to = int(q.get("create_time_to", 0))
# map run status (accept text or numeric) - align with API parameter
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
@ -575,7 +571,7 @@ def list_docs(dataset_id, tenant_id):
# rename keys + map run status back to text for output
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "dataset_id",
"kb_id": "dataset_id",
"token_num": "token_count",
"parser_id": "chunk_method",
}
@ -631,7 +627,7 @@ async def delete(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
req = await request_json()
req = await get_request_json()
if not req:
doc_ids = None
else:
@ -741,7 +737,7 @@ async def parse(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await request_json()
req = await get_request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids")
@ -824,7 +820,7 @@ async def stop_parsing(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await request_json()
req = await get_request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
@ -839,6 +835,8 @@ async def stop_parsing(tenant_id, dataset_id):
return get_error_data_result(message=f"You don't own the document {id}.")
if int(doc[0].progress) == 1 or doc[0].progress == 0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
# Send cancellation signal via Redis to stop background task
cancel_all_task_of(id)
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
@ -1096,7 +1094,7 @@ async def add_chunk(tenant_id, dataset_id, document_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
req = await request_json()
req = await get_request_json()
if not str(req.get("content", "")).strip():
return get_error_data_result(message="`content` is required")
if "important_keywords" in req:
@ -1202,7 +1200,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
docs = DocumentService.get_by_ids([document_id])
if not docs:
raise LookupError(f"Can't find the document with ID {document_id}!")
req = await request_json()
req = await get_request_json()
condition = {"doc_id": document_id}
if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
@ -1288,7 +1286,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
req = await request_json()
req = await get_request_json()
if "content" in req and req["content"] is not None:
content = req["content"]
else:
@ -1411,7 +1409,7 @@ async def retrieval_test(tenant_id):
format: float
description: Similarity score.
"""
req = await request_json()
req = await get_request_json()
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"]
@ -1446,6 +1444,9 @@ async def retrieval_test(tenant_id):
metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
# If metadata_condition has conditions but no docs match, return empty result
if not doc_ids and metadata_condition.get("conditions"):
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
if metadata_condition and not doc_ids:
doc_ids = ["-999"]
similarity_threshold = float(req.get("similarity_threshold", 0.2))

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
import asyncio
import pathlib
import re
from quart import request, make_response
@ -23,15 +23,15 @@ from pathlib import Path
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, token_required
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
from common.misc_utils import get_uuid
from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
from common.constants import RetCode
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required
@ -40,7 +40,7 @@ async def upload(tenant_id):
Upload a file to the system.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -86,19 +86,19 @@ async def upload(tenant_id):
pf_id = root_folder["id"]
if 'file' not in files:
return get_json_result(data=False, message='No file part!', code=400)
return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST)
file_objs = files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(data=False, message='No selected file!', code=400)
return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST)
file_res = []
try:
e, pf_folder = FileService.get_by_id(pf_id)
if not e:
return get_json_result(data=False, message="Can't find this folder!", code=404)
return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND)
for file_obj in file_objs:
# Handle file path
@ -114,13 +114,13 @@ async def upload(tenant_id):
if file_len != len_id_list:
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
if not e:
return get_json_result(data=False, message="Folder not found!", code=404)
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list)
else:
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
if not e:
return get_json_result(data=False, message="Folder not found!", code=404)
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list)
@ -156,7 +156,7 @@ async def create(tenant_id):
Create a new file or folder.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -193,16 +193,16 @@ async def create(tenant_id):
type:
type: string
"""
req = await request.json
pf_id = await request.json.get("parent_id")
input_file_type = await request.json.get("type")
req = await get_request_json()
pf_id = req.get("parent_id")
input_file_type = req.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
try:
if not FileService.is_parent_folder_exist(pf_id):
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400)
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
if FileService.query(name=req["name"], parent_id=pf_id):
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
@ -229,12 +229,12 @@ async def create(tenant_id):
@manager.route('/file/list', methods=['GET']) # noqa: F821
@token_required
def list_files(tenant_id):
async def list_files(tenant_id):
"""
List files under a specific folder.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -306,13 +306,13 @@ def list_files(tenant_id):
try:
e, file = FileService.get_by_id(pf_id)
if not e:
return get_json_result(message="Folder not found!", code=404)
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
parent_folder = FileService.get_parent_folder(pf_id)
if not parent_folder:
return get_json_result(message="File not found!", code=404)
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
except Exception as e:
@ -321,12 +321,12 @@ def list_files(tenant_id):
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
@token_required
def get_root_folder(tenant_id):
async def get_root_folder(tenant_id):
"""
Get user's root folder.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
responses:
@ -357,12 +357,12 @@ def get_root_folder(tenant_id):
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
@token_required
def get_parent_folder():
async def get_parent_folder():
"""
Get parent folder info of a file.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -392,7 +392,7 @@ def get_parent_folder():
try:
e, file = FileService.get_by_id(file_id)
if not e:
return get_json_result(message="Folder not found!", code=404)
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
parent_folder = FileService.get_parent_folder(file_id)
return get_json_result(data={"parent_folder": parent_folder.to_json()})
@ -402,12 +402,12 @@ def get_parent_folder():
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
@token_required
def get_all_parent_folders(tenant_id):
async def get_all_parent_folders(tenant_id):
"""
Get all parent folders of a file.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -439,7 +439,7 @@ def get_all_parent_folders(tenant_id):
try:
e, file = FileService.get_by_id(file_id)
if not e:
return get_json_result(message="Folder not found!", code=404)
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
parent_folders = FileService.get_all_parent_folders(file_id)
parent_folders_res = [folder.to_json() for folder in parent_folders]
@ -455,7 +455,7 @@ async def rm(tenant_id):
Delete one or multiple files/folders.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -481,40 +481,40 @@ async def rm(tenant_id):
type: boolean
example: true
"""
req = await request.json
req = await get_request_json()
file_ids = req["file_ids"]
try:
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e:
return get_json_result(message="File or Folder not found!", code=404)
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
if not file.tenant_id:
return get_json_result(message="Tenant not found!", code=404)
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
if file.type == FileType.FOLDER.value:
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
for inner_file_id in file_id_list:
e, file = FileService.get_by_id(inner_file_id)
if not e:
return get_json_result(message="File not found!", code=404)
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
FileService.delete_folder_by_pf_id(tenant_id, file_id)
else:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
if not FileService.delete(file):
return get_json_result(message="Database error (File removal)!", code=500)
return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR)
informs = File2DocumentService.get_by_file_id(file_id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_json_result(message="Document not found!", code=404)
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_json_result(message="Tenant not found!", code=404)
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
if not DocumentService.remove_document(doc, tenant_id):
return get_json_result(message="Database error (Document removal)!", code=500)
return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR)
File2DocumentService.delete_by_file_id(file_id)
return get_json_result(data=True)
@ -529,7 +529,7 @@ async def rename(tenant_id):
Rename a file.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -556,27 +556,27 @@ async def rename(tenant_id):
type: boolean
example: true
"""
req = await request.json
req = await get_request_json()
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
return get_json_result(message="File not found!", code=404)
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
file.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
if existing_file.name == req["name"]:
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
return get_json_result(message="Database error (File rename)!", code=500)
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
informs = File2DocumentService.get_by_file_id(req["file_id"])
if informs:
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
return get_json_result(message="Database error (Document rename)!", code=500)
return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR)
return get_json_result(data=True)
except Exception as e:
@ -590,7 +590,7 @@ async def get(tenant_id, file_id):
Download a file.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
produces:
@ -606,13 +606,13 @@ async def get(tenant_id, file_id):
description: File stream
schema:
type: file
404:
RetCode.NOT_FOUND:
description: File not found
"""
try:
e, file = FileService.get_by_id(file_id)
if not e:
return get_json_result(message="Document not found!", code=404)
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
if not blob:
@ -630,6 +630,19 @@ async def get(tenant_id, file_id):
except Exception as e:
return server_error_response(e)
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
@token_required
async def download_attachment(tenant_id,attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
return response
except Exception as e:
return server_error_response(e)
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required
@ -638,7 +651,7 @@ async def move(tenant_id):
Move one or multiple files to another folder.
---
tags:
- File Management
- File
security:
- ApiKeyAuth: []
parameters:
@ -667,7 +680,7 @@ async def move(tenant_id):
type: boolean
example: true
"""
req = await request.json
req = await get_request_json()
try:
file_ids = req["src_file_ids"]
parent_id = req["dest_file_id"]
@ -677,13 +690,13 @@ async def move(tenant_id):
for file_id in file_ids:
file = files_dict[file_id]
if not file:
return get_json_result(message="File or Folder not found!", code=404)
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
if not file.tenant_id:
return get_json_result(message="Tenant not found!", code=404)
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
fe, _ = FileService.get_by_id(parent_id)
if not fe:
return get_json_result(message="Parent Folder not found!", code=404)
return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND)
FileService.move_file(file_ids, parent_id)
return get_json_result(data=True)
@ -694,7 +707,7 @@ async def move(tenant_id):
@manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required
async def convert(tenant_id):
req = await request.json
req = await get_request_json()
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
@ -705,7 +718,7 @@ async def convert(tenant_id):
for file_id in file_ids:
file = files_set[file_id]
if not file:
return get_json_result(message="File not found!", code=404)
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
file_ids_list = [file_id]
if file.type == FileType.FOLDER.value:
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
@ -716,13 +729,13 @@ async def convert(tenant_id):
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_json_result(message="Document not found!", code=404)
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_json_result(message="Tenant not found!", code=404)
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
if not DocumentService.remove_document(doc, tenant_id):
return get_json_result(
message="Database error (Document removal)!", code=404)
message="Database error (Document removal)!", code=RetCode.NOT_FOUND)
File2DocumentService.delete_by_file_id(id)
# insert
@ -730,11 +743,11 @@ async def convert(tenant_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_json_result(
message="Can't find this knowledgebase!", code=404)
message="Can't find this knowledgebase!", code=RetCode.NOT_FOUND)
e, file = FileService.get_by_id(id)
if not e:
return get_json_result(
message="Can't find this file!", code=404)
message="Can't find this file!", code=RetCode.NOT_FOUND)
doc = DocumentService.insert({
"id": get_uuid(),

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import re
import time
@ -35,7 +36,7 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
get_result, server_error_response, token_required, validate_request
get_result, get_request_json, server_error_response, token_required, validate_request
from rag.app.tag import label_question
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
@ -45,7 +46,7 @@ from common import settings
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
async def create(tenant_id, chat_id):
req = await request.json
req = await get_request_json()
req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
@ -73,7 +74,7 @@ async def create(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create_agent_session(tenant_id, agent_id):
async def create_agent_session(tenant_id, agent_id):
user_id = request.args.get("user_id", tenant_id)
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
@ -98,7 +99,7 @@ def create_agent_session(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required
async def update(tenant_id, chat_id, session_id):
req = await request.json
req = await get_request_json()
req["dialog_id"] = chat_id
conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
@ -120,7 +121,7 @@ async def update(tenant_id, chat_id, session_id):
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required
async def chat_completion(tenant_id, chat_id):
req = await request.json
req = await get_request_json()
if not req:
req = {"question": ""}
if not req.get("session_id"):
@ -206,7 +207,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
if reference:
print(completion.choices[0].message.reference)
"""
req = await request.get_json()
req = await get_request_json()
need_reference = bool(req.get("reference", False))
@ -384,7 +385,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
@validate_request("model", "messages") # noqa: F821
@token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await request.json
req = await get_request_json()
tiktokenenc = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
@ -442,7 +443,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
async def agent_completions(tenant_id, agent_id):
req = await request.json
req = await get_request_json()
if req.get("stream", True):
@ -491,7 +492,7 @@ async def agent_completions(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_session(tenant_id, chat_id):
async def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id")
@ -545,7 +546,7 @@ def list_session(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_agent_session(tenant_id, agent_id):
async def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
id = request.args.get("id")
@ -614,7 +615,7 @@ async def delete(tenant_id, chat_id):
errors = []
success_count = 0
req = await request.json
req = await get_request_json()
convs = ConversationService.query(dialog_id=chat_id)
if not req:
ids = None
@ -662,7 +663,7 @@ async def delete(tenant_id, chat_id):
async def delete_agent_session(tenant_id, agent_id):
errors = []
success_count = 0
req = await request.json
req = await get_request_json()
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
@ -715,7 +716,7 @@ async def delete_agent_session(tenant_id, agent_id):
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required
async def ask_about(tenant_id):
req = await request.json
req = await get_request_json()
if not req.get("question"):
return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"):
@ -754,7 +755,7 @@ async def ask_about(tenant_id):
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required
async def related_questions(tenant_id):
req = await request.json
req = await get_request_json()
if not req.get("question"):
return get_error_data_result("`question` is required.")
question = req["question"]
@ -787,7 +788,7 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
"""
ans = chat_mdl.chat(
ans = await chat_mdl.async_chat(
prompt,
[
{
@ -805,7 +806,7 @@ Related search terms:
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
async def chatbot_completions(dialog_id):
req = await request.json
req = await get_request_json()
token = request.headers.get("Authorization").split()
if len(token) != 2:
@ -831,7 +832,7 @@ async def chatbot_completions(dialog_id):
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
def chatbots_inputs(dialog_id):
async def chatbots_inputs(dialog_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -855,7 +856,7 @@ def chatbots_inputs(dialog_id):
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
async def agent_bot_completions(agent_id):
req = await request.json
req = await get_request_json()
token = request.headers.get("Authorization").split()
if len(token) != 2:
@ -878,7 +879,7 @@ async def agent_bot_completions(agent_id):
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
def begin_inputs(agent_id):
async def begin_inputs(agent_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -908,7 +909,7 @@ async def ask_about_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json
req = await get_request_json()
uid = objs[0].tenant_id
search_id = req.get("search_id", "")
@ -947,7 +948,7 @@ async def retrieval_test_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json
req = await get_request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
@ -963,28 +964,30 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
tenant_ids = []
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids:
doc_ids = ["-999"]
def _retrieval_sync():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
_question = question
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
try:
tenants = UserTenantService.query(user_id=tenant_id)
for kb_id in kb_ids:
for tenant in tenants:
@ -1000,7 +1003,7 @@ async def retrieval_test_embedded():
return get_error_data_result(message="Knowledgebase not found!")
if langs:
question = cross_languages(kb.tenant_id, None, question, langs)
_question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -1010,15 +1013,15 @@ async def retrieval_test_embedded():
if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
_question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb])
labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)
if use_kg:
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
@ -1028,6 +1031,9 @@ async def retrieval_test_embedded():
ranks["labels"] = labels
return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
@ -1046,7 +1052,7 @@ async def related_questions_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json
req = await get_request_json()
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
@ -1064,7 +1070,7 @@ async def related_questions_embedded():
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
prompt = load_prompt("related_question")
ans = chat_mdl.chat(
ans = await chat_mdl.async_chat(
prompt,
[
{
@ -1081,7 +1087,7 @@ Related search terms:
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
def detail_share_embedded():
async def detail_share_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -1123,7 +1129,7 @@ async def mindmap():
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
req = await request.json
req = await get_request_json()
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}

View File

@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
@manager.route("/create", methods=["post"]) # noqa: F821
@login_required
@validate_request("name")
async def create():
req = await request.get_json()
req = await get_request_json()
search_name = req["name"]
description = req.get("description", "")
if not isinstance(search_name, str):
@ -66,7 +66,7 @@ async def create():
@validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update():
req = await request.get_json()
req = await get_request_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "":
@ -150,7 +150,7 @@ async def list_search_app():
else:
desc = True
req = await request.get_json()
req = await get_request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -174,7 +174,7 @@ async def list_search_app():
@login_required
@validate_request("search_id")
async def rm():
req = await request.get_json()
req = await get_request_json()
search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
from quart import request
from api.db import UserTenantRole
from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService
@ -22,7 +21,7 @@ from api.db.services.user_service import UserTenantService, UserService
from common.constants import RetCode, StatusEnum
from common.misc_utils import get_uuid
from common.time_utils import delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from api.utils.web_utils import send_invite_email
from common import settings
from api.apps import smtp_mail_server, login_required, current_user
@ -56,7 +55,7 @@ async def create(tenant_id):
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR)
req = await request.json
req = await get_request_json()
invite_user_email = req["email"]
invite_users = UserService.query(email=invite_user_email)
if not invite_users:

View File

@ -39,6 +39,7 @@ from common.connection_utils import construct_response
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
get_request_json,
server_error_response,
validate_request,
)
@ -57,6 +58,7 @@ from api.utils.web_utils import (
captcha_key,
)
from common import settings
from common.http_client import async_request
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -90,7 +92,7 @@ async def login():
schema:
type: object
"""
json_body = await request.json
json_body = await get_request_json()
if not json_body:
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
@ -121,8 +123,8 @@ async def login():
response_data = user.to_json()
user.access_token = get_uuid()
login_user(user)
user.update_time = (current_timestamp(),)
user.update_date = (datetime_format(datetime.now()),)
user.update_time = current_timestamp()
user.update_date = datetime_format(datetime.now())
user.save()
msg = "Welcome back!"
@ -136,7 +138,7 @@ async def login():
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
def get_login_channels():
async def get_login_channels():
"""
Get all supported authentication channels.
"""
@ -157,7 +159,7 @@ def get_login_channels():
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
def oauth_login(channel):
async def oauth_login(channel):
channel_config = settings.OAUTH_CONFIG.get(channel)
if not channel_config:
raise ValueError(f"Invalid channel name: {channel}")
@ -170,7 +172,7 @@ def oauth_login(channel):
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
def oauth_callback(channel):
async def oauth_callback(channel):
"""
Handle the OAuth/OIDC callback for various channels dynamically.
"""
@ -192,7 +194,10 @@ def oauth_callback(channel):
return redirect("/?error=missing_code")
# Exchange authorization code for access token
token_info = auth_cli.exchange_code_for_token(code)
if hasattr(auth_cli, "async_exchange_code_for_token"):
token_info = await auth_cli.async_exchange_code_for_token(code)
else:
token_info = auth_cli.exchange_code_for_token(code)
access_token = token_info.get("access_token")
if not access_token:
return redirect("/?error=token_failed")
@ -200,7 +205,10 @@ def oauth_callback(channel):
id_token = token_info.get("id_token")
# Fetch user info
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
if hasattr(auth_cli, "async_fetch_user_info"):
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
else:
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
if not user_info.email:
return redirect("/?error=email_missing")
@ -259,7 +267,7 @@ def oauth_callback(channel):
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
def github_callback():
async def github_callback():
"""
**Deprecated**, Use `/oauth/callback/<channel>` instead.
@ -279,9 +287,8 @@ def github_callback():
schema:
type: object
"""
import requests
res = requests.post(
res = await async_request(
"POST",
settings.GITHUB_OAUTH.get("url"),
data={
"client_id": settings.GITHUB_OAUTH.get("client_id"),
@ -299,7 +306,7 @@ def github_callback():
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
user_info = user_info_from_github(session["access_token"])
user_info = await user_info_from_github(session["access_token"])
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
@ -348,7 +355,7 @@ def github_callback():
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
def feishu_callback():
async def feishu_callback():
"""
Feishu OAuth callback endpoint.
---
@ -366,9 +373,8 @@ def feishu_callback():
schema:
type: object
"""
import requests
app_access_token_res = requests.post(
app_access_token_res = await async_request(
"POST",
settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps(
{
@ -382,7 +388,8 @@ def feishu_callback():
if app_access_token_res["code"] != 0:
return redirect("/?error=%s" % app_access_token_res)
res = requests.post(
res = await async_request(
"POST",
settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps(
{
@ -403,7 +410,7 @@ def feishu_callback():
return redirect("/?error=contact:user.email:readonly not in scope")
session["access_token"] = res["data"]["access_token"]
session["access_token_from"] = "feishu"
user_info = user_info_from_feishu(session["access_token"])
user_info = await user_info_from_feishu(session["access_token"])
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
@ -451,36 +458,34 @@ def feishu_callback():
return redirect("/?auth=%s" % user.get_id())
def user_info_from_feishu(access_token):
import requests
async def user_info_from_feishu(access_token):
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
user_info = res.json()["data"]
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
return user_info
def user_info_from_github(access_token):
import requests
async def user_info_from_github(access_token):
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json()
email_info = requests.get(
email_info_response = await async_request(
"GET",
f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers,
).json()
)
email_info = email_info_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return user_info
@manager.route("/logout", methods=["GET"]) # noqa: F821
@login_required
def log_out():
async def log_out():
"""
User logout endpoint.
---
@ -531,7 +536,7 @@ async def setting_user():
type: object
"""
update_dict = {}
request_data = await request.json
request_data = await get_request_json()
if request_data.get("password"):
new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
@ -570,7 +575,7 @@ async def setting_user():
@manager.route("/info", methods=["GET"]) # noqa: F821
@login_required
def user_profile():
async def user_profile():
"""
Get user profile information.
---
@ -698,7 +703,7 @@ async def user_add():
code=RetCode.OPERATING_ERROR,
)
req = await request.json
req = await get_request_json()
email_address = req["email"]
# Validate the email address
@ -755,7 +760,7 @@ async def user_add():
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
@login_required
def tenant_info():
async def tenant_info():
"""
Get tenant information.
---
@ -831,14 +836,14 @@ async def set_tenant_info():
schema:
type: object
"""
req = await request.json
req = await get_request_json()
try:
tid = req.pop("tenant_id")
TenantService.update_by_id(tid, req)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
async def forget_get_captcha():
@ -875,7 +880,7 @@ async def forget_send_otp():
- Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
"""
req = await request.get_json()
req = await get_request_json()
email = req.get("email") or ""
captcha = (req.get("captcha") or "").strip()
@ -931,7 +936,7 @@ async def forget_send_otp():
)
except Exception:
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
@ -941,7 +946,7 @@ async def forget():
POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password }
"""
req = await request.get_json()
req = await get_request_json()
email = req.get("email") or ""
otp = (req.get("otp") or "").strip()
new_pwd = req.get("new_password")
@ -1002,8 +1007,8 @@ async def forget():
# Auto login (reuse login flow)
user.access_token = get_uuid()
login_user(user)
user.update_time = (current_timestamp(),)
user.update_date = (datetime_format(datetime.now()),)
user.update_time = current_timestamp()
user.update_date = datetime_format(datetime.now())
user.save()
msg = "Password reset successful. Logged in."
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)

View File

@ -749,7 +749,7 @@ class Knowledgebase(DataBaseModel):
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
pagerank = IntegerField(default=0, index=False)
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
@ -774,7 +774,7 @@ class Document(DataBaseModel):
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
@ -1113,6 +1113,70 @@ class SyncLogs(DataBaseModel):
db_table = "sync_logs"
class EvaluationDataset(DataBaseModel):
"""Ground truth dataset for RAG evaluation"""
id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
description = TextField(null=True, help_text="dataset description")
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
update_time = BigIntegerField(null=False, help_text="last update timestamp")
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
class Meta:
db_table = "evaluation_datasets"
class EvaluationCase(DataBaseModel):
"""Individual test case in an evaluation dataset"""
id = CharField(max_length=32, primary_key=True)
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
question = TextField(null=False, help_text="test question")
reference_answer = TextField(null=True, help_text="optional ground truth answer")
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
metadata = JSONField(null=True, help_text="additional context/tags")
create_time = BigIntegerField(null=False, help_text="creation timestamp")
class Meta:
db_table = "evaluation_cases"
class EvaluationRun(DataBaseModel):
"""A single evaluation run"""
id = CharField(max_length=32, primary_key=True)
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
name = CharField(max_length=255, null=False, help_text="run name")
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
class Meta:
db_table = "evaluation_runs"
class EvaluationResult(DataBaseModel):
"""Result for a single test case in an evaluation run"""
id = CharField(max_length=32, primary_key=True)
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
generated_answer = TextField(null=False, help_text="generated answer")
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
metrics = JSONField(null=False, help_text="all computed metrics")
execution_time = FloatField(null=False, help_text="response time in seconds")
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
create_time = BigIntegerField(null=False, help_text="creation timestamp")
class Meta:
db_table = "evaluation_results"
def migrate_db():
logging.disable(logging.ERROR)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
@ -1293,4 +1357,43 @@ def migrate_db():
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
except Exception:
pass
# RAG Evaluation tables
try:
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
except Exception:
pass
logging.disable(logging.NOTSET)

View File

@ -25,6 +25,7 @@ import trio
from langfuse import Langfuse
from peewee import fn
from agentic_reasoning import DeepResearcher
from api.db.services.file_service import FileService
from common.constants import LLMType, ParserType, StatusEnum
from api.db.db_models import DB, Dialog
from api.db.services.common_service import CommonService
@ -178,6 +179,9 @@ class DialogService(CommonService):
return res
def chat_solo(dialog, messages, stream=True):
attachments = ""
if "files" in messages[-1]:
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
@ -188,6 +192,8 @@ def chat_solo(dialog, messages, stream=True):
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
if attachments and msg:
msg[-1]["content"] += attachments
if stream:
last_ans = ""
delta_ans = ""
@ -380,8 +386,11 @@ def chat(dialog, messages, stream=True, **kwargs):
retriever = settings.retriever
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
attachments_= ""
if "doc_ids" in messages[-1]:
attachments = messages[-1]["doc_ids"]
if "files" in messages[-1]:
attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"]))
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
@ -451,7 +460,7 @@ def chat(dialog, messages, stream=True, **kwargs):
),
)
for think in reasoner.thinking(kbinfos, " ".join(questions)):
for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
@ -478,6 +487,7 @@ def chat(dialog, messages, stream=True, **kwargs):
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
@ -503,7 +513,7 @@ def chat(dialog, messages, stream=True, **kwargs):
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
@ -672,7 +682,11 @@ Please write the SQL, only SQL, without any other explanations or text.
if kb_ids:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where" not in sql.lower():
sql += f" WHERE {kb_filter}"
o = sql.lower().split("order by")
if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else:
sql += f" WHERE {kb_filter}"
else:
sql += f" AND {kb_filter}"
@ -680,10 +694,9 @@ Please write the SQL, only SQL, without any other explanations or text.
tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
return None
if tbl.get("error") and tried_times <= 2:
try:
tbl, sql = get_table()
except Exception as e:
user_prompt = """
Table name: {};
Table of database fields are as follows:
@ -697,16 +710,14 @@ Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Error issued by database as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql))
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
try:
tbl, sql = get_table()
except Exception:
return
logging.debug("GET table: {}".format(tbl))
if tbl.get("error") or len(tbl["rows"]) == 0:
if len(tbl["rows"]) == 0:
return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
@ -750,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
"prompt": sys_prompt,
}
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
def tts(tts_mdl, text):
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8")

View File

@ -719,10 +719,14 @@ class DocumentService(CommonService):
# only for special task and parsed docs and unfinished
freeze_progress = special_task_running and doc_progress >= 1 and not finished
msg = "\n".join(sorted(msg))
begin_at = d.get("process_begin_at")
if not begin_at:
begin_at = datetime.now()
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = {
"process_duration": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
"run": status}
if prg != 0 and not freeze_progress:
info["progress"] = prg
@ -923,7 +927,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
doc_nm = {}

View File

@ -0,0 +1,598 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
RAG Evaluation Service
Provides functionality for evaluating RAG system performance including:
- Dataset management
- Test case management
- Evaluation execution
- Metrics computation
- Configuration recommendations
"""
import logging
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from timeit import default_timer as timer
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
from common.constants import StatusEnum
class EvaluationService(CommonService):
"""Service for managing RAG evaluations"""
model = EvaluationDataset
# ==================== Dataset Management ====================
@classmethod
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
tenant_id: str, user_id: str) -> Tuple[bool, str]:
"""
Create a new evaluation dataset.
Args:
name: Dataset name
description: Dataset description
kb_ids: List of knowledge base IDs to evaluate against
tenant_id: Tenant ID
user_id: User ID who creates the dataset
Returns:
(success, dataset_id or error_message)
"""
try:
dataset_id = get_uuid()
dataset = {
"id": dataset_id,
"tenant_id": tenant_id,
"name": name,
"description": description,
"kb_ids": kb_ids,
"created_by": user_id,
"create_time": current_timestamp(),
"update_time": current_timestamp(),
"status": StatusEnum.VALID.value
}
if not EvaluationDataset.create(**dataset):
return False, "Failed to create dataset"
return True, dataset_id
except Exception as e:
logging.error(f"Error creating evaluation dataset: {e}")
return False, str(e)
@classmethod
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get dataset by ID"""
try:
dataset = EvaluationDataset.get_by_id(dataset_id)
if dataset:
return dataset.to_dict()
return None
except Exception as e:
logging.error(f"Error getting dataset {dataset_id}: {e}")
return None
@classmethod
def list_datasets(cls, tenant_id: str, user_id: str,
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
"""List datasets for a tenant"""
try:
query = EvaluationDataset.select().where(
(EvaluationDataset.tenant_id == tenant_id) &
(EvaluationDataset.status == StatusEnum.VALID.value)
).order_by(EvaluationDataset.create_time.desc())
total = query.count()
datasets = query.paginate(page, page_size)
return {
"total": total,
"datasets": [d.to_dict() for d in datasets]
}
except Exception as e:
logging.error(f"Error listing datasets: {e}")
return {"total": 0, "datasets": []}
@classmethod
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
"""Update dataset"""
try:
kwargs["update_time"] = current_timestamp()
return EvaluationDataset.update(**kwargs).where(
EvaluationDataset.id == dataset_id
).execute() > 0
except Exception as e:
logging.error(f"Error updating dataset {dataset_id}: {e}")
return False
@classmethod
def delete_dataset(cls, dataset_id: str) -> bool:
"""Soft delete dataset"""
try:
return EvaluationDataset.update(
status=StatusEnum.INVALID.value,
update_time=current_timestamp()
).where(EvaluationDataset.id == dataset_id).execute() > 0
except Exception as e:
logging.error(f"Error deleting dataset {dataset_id}: {e}")
return False
# ==================== Test Case Management ====================
@classmethod
def add_test_case(cls, dataset_id: str, question: str,
reference_answer: Optional[str] = None,
relevant_doc_ids: Optional[List[str]] = None,
relevant_chunk_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
"""
Add a test case to a dataset.
Args:
dataset_id: Dataset ID
question: Test question
reference_answer: Optional ground truth answer
relevant_doc_ids: Optional list of relevant document IDs
relevant_chunk_ids: Optional list of relevant chunk IDs
metadata: Optional additional metadata
Returns:
(success, case_id or error_message)
"""
try:
case_id = get_uuid()
case = {
"id": case_id,
"dataset_id": dataset_id,
"question": question,
"reference_answer": reference_answer,
"relevant_doc_ids": relevant_doc_ids,
"relevant_chunk_ids": relevant_chunk_ids,
"metadata": metadata,
"create_time": current_timestamp()
}
if not EvaluationCase.create(**case):
return False, "Failed to create test case"
return True, case_id
except Exception as e:
logging.error(f"Error adding test case: {e}")
return False, str(e)
@classmethod
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
"""Get all test cases for a dataset"""
try:
cases = EvaluationCase.select().where(
EvaluationCase.dataset_id == dataset_id
).order_by(EvaluationCase.create_time)
return [c.to_dict() for c in cases]
except Exception as e:
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
return []
@classmethod
def delete_test_case(cls, case_id: str) -> bool:
"""Delete a test case"""
try:
return EvaluationCase.delete().where(
EvaluationCase.id == case_id
).execute() > 0
except Exception as e:
logging.error(f"Error deleting test case {case_id}: {e}")
return False
@classmethod
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
"""
Bulk import test cases from a list.
Args:
dataset_id: Dataset ID
cases: List of test case dictionaries
Returns:
(success_count, failure_count)
"""
success_count = 0
failure_count = 0
for case_data in cases:
success, _ = cls.add_test_case(
dataset_id=dataset_id,
question=case_data.get("question", ""),
reference_answer=case_data.get("reference_answer"),
relevant_doc_ids=case_data.get("relevant_doc_ids"),
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
metadata=case_data.get("metadata")
)
if success:
success_count += 1
else:
failure_count += 1
return success_count, failure_count
# ==================== Evaluation Execution ====================
@classmethod
def start_evaluation(cls, dataset_id: str, dialog_id: str,
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
"""
Start an evaluation run.
Args:
dataset_id: Dataset ID
dialog_id: Dialog configuration to evaluate
user_id: User ID who starts the run
name: Optional run name
Returns:
(success, run_id or error_message)
"""
try:
# Get dialog configuration
success, dialog = DialogService.get_by_id(dialog_id)
if not success:
return False, "Dialog not found"
# Create evaluation run
run_id = get_uuid()
if not name:
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
run = {
"id": run_id,
"dataset_id": dataset_id,
"dialog_id": dialog_id,
"name": name,
"config_snapshot": dialog.to_dict(),
"metrics_summary": None,
"status": "RUNNING",
"created_by": user_id,
"create_time": current_timestamp(),
"complete_time": None
}
if not EvaluationRun.create(**run):
return False, "Failed to create evaluation run"
# Execute evaluation asynchronously (in production, use task queue)
# For now, we'll execute synchronously
cls._execute_evaluation(run_id, dataset_id, dialog)
return True, run_id
except Exception as e:
logging.error(f"Error starting evaluation: {e}")
return False, str(e)
@classmethod
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
"""
Execute evaluation for all test cases.
This method runs the RAG pipeline for each test case and computes metrics.
"""
try:
# Get all test cases
test_cases = cls.get_test_cases(dataset_id)
if not test_cases:
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
return
# Execute each test case
results = []
for case in test_cases:
result = cls._evaluate_single_case(run_id, case, dialog)
if result:
results.append(result)
# Compute summary metrics
metrics_summary = cls._compute_summary_metrics(results)
# Update run status
EvaluationRun.update(
status="COMPLETED",
metrics_summary=metrics_summary,
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
except Exception as e:
logging.error(f"Error executing evaluation {run_id}: {e}")
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
@classmethod
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
dialog: Any) -> Optional[Dict[str, Any]]:
"""
Evaluate a single test case.
Args:
run_id: Evaluation run ID
case: Test case dictionary
dialog: Dialog configuration
Returns:
Result dictionary or None if failed
"""
try:
# Prepare messages
messages = [{"role": "user", "content": case["question"]}]
# Execute RAG pipeline
start_time = timer()
answer = ""
retrieved_chunks = []
for ans in chat(dialog, messages, stream=False):
if isinstance(ans, dict):
answer = ans.get("answer", "")
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
break
execution_time = timer() - start_time
# Compute metrics
metrics = cls._compute_metrics(
question=case["question"],
generated_answer=answer,
reference_answer=case.get("reference_answer"),
retrieved_chunks=retrieved_chunks,
relevant_chunk_ids=case.get("relevant_chunk_ids"),
dialog=dialog
)
# Save result
result_id = get_uuid()
result = {
"id": result_id,
"run_id": run_id,
"case_id": case["id"],
"generated_answer": answer,
"retrieved_chunks": retrieved_chunks,
"metrics": metrics,
"execution_time": execution_time,
"token_usage": None, # TODO: Track token usage
"create_time": current_timestamp()
}
EvaluationResult.create(**result)
return result
except Exception as e:
logging.error(f"Error evaluating case {case.get('id')}: {e}")
return None
@classmethod
def _compute_metrics(cls, question: str, generated_answer: str,
reference_answer: Optional[str],
retrieved_chunks: List[Dict[str, Any]],
relevant_chunk_ids: Optional[List[str]],
dialog: Any) -> Dict[str, float]:
"""
Compute evaluation metrics for a single test case.
Returns:
Dictionary of metric names to values
"""
metrics = {}
# Retrieval metrics (if ground truth chunks provided)
if relevant_chunk_ids:
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
# Generation metrics
if generated_answer:
# Basic metrics
metrics["answer_length"] = len(generated_answer)
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
# TODO: Implement advanced metrics using LLM-as-judge
# - Faithfulness (hallucination detection)
# - Answer relevance
# - Context relevance
# - Semantic similarity (if reference answer provided)
return metrics
@classmethod
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
relevant_ids: List[str]) -> Dict[str, float]:
"""
Compute retrieval metrics.
Args:
retrieved_ids: List of retrieved chunk IDs
relevant_ids: List of relevant chunk IDs (ground truth)
Returns:
Dictionary of retrieval metrics
"""
if not relevant_ids:
return {}
retrieved_set = set(retrieved_ids)
relevant_set = set(relevant_ids)
# Precision: proportion of retrieved that are relevant
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
# Recall: proportion of relevant that were retrieved
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
# F1 score
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
# Hit rate: whether any relevant chunk was retrieved
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
# MRR (Mean Reciprocal Rank): position of first relevant chunk
mrr = 0.0
for i, chunk_id in enumerate(retrieved_ids, 1):
if chunk_id in relevant_set:
mrr = 1.0 / i
break
return {
"precision": precision,
"recall": recall,
"f1_score": f1,
"hit_rate": hit_rate,
"mrr": mrr
}
@classmethod
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Compute summary metrics across all test cases.
Args:
results: List of result dictionaries
Returns:
Summary metrics dictionary
"""
if not results:
return {}
# Aggregate metrics
metric_sums = {}
metric_counts = {}
for result in results:
metrics = result.get("metrics", {})
for key, value in metrics.items():
if isinstance(value, (int, float)):
metric_sums[key] = metric_sums.get(key, 0) + value
metric_counts[key] = metric_counts.get(key, 0) + 1
# Compute averages
summary = {
"total_cases": len(results),
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
}
for key in metric_sums:
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
return summary
# ==================== Results & Analysis ====================
@classmethod
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
"""Get results for an evaluation run"""
try:
run = EvaluationRun.get_by_id(run_id)
if not run:
return {}
results = EvaluationResult.select().where(
EvaluationResult.run_id == run_id
).order_by(EvaluationResult.create_time)
return {
"run": run.to_dict(),
"results": [r.to_dict() for r in results]
}
except Exception as e:
logging.error(f"Error getting run results {run_id}: {e}")
return {}
@classmethod
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
"""
Analyze evaluation results and provide configuration recommendations.
Args:
run_id: Evaluation run ID
Returns:
List of recommendation dictionaries
"""
try:
run = EvaluationRun.get_by_id(run_id)
if not run or not run.metrics_summary:
return []
metrics = run.metrics_summary
recommendations = []
# Low precision: retrieving irrelevant chunks
if metrics.get("avg_precision", 1.0) < 0.7:
recommendations.append({
"issue": "Low Precision",
"severity": "high",
"description": "System is retrieving many irrelevant chunks",
"suggestions": [
"Increase similarity_threshold to filter out less relevant chunks",
"Enable reranking to improve chunk ordering",
"Reduce top_k to return fewer chunks"
]
})
# Low recall: missing relevant chunks
if metrics.get("avg_recall", 1.0) < 0.7:
recommendations.append({
"issue": "Low Recall",
"severity": "high",
"description": "System is missing relevant chunks",
"suggestions": [
"Increase top_k to retrieve more chunks",
"Lower similarity_threshold to be more inclusive",
"Enable hybrid search (keyword + semantic)",
"Check chunk size - may be too large or too small"
]
})
# Slow response time
if metrics.get("avg_execution_time", 0) > 5.0:
recommendations.append({
"issue": "Slow Response Time",
"severity": "medium",
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
"suggestions": [
"Reduce top_k to retrieve fewer chunks",
"Optimize embedding model selection",
"Consider caching frequently asked questions"
]
})
return recommendations
except Exception as e:
logging.error(f"Error generating recommendations for run {run_id}: {e}")
return []

View File

@ -13,10 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import base64
import logging
import re
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Union
from peewee import fn
@ -520,7 +525,7 @@ class FileService(CommonService):
if img_base64 and file_type == FileType.VISUAL.value:
return GptV4.image2base64(blob)
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
return "\n".join([ck["content_with_weight"] for ck in cks])
return f"\n -----------------\nFile: {filename}\nContent as following: \n" + "\n".join([ck["content_with_weight"] for ck in cks])
@staticmethod
def get_parser(doc_type, filename, default):
@ -588,3 +593,80 @@ class FileService(CommonService):
errors += str(e)
return errors
@staticmethod
def upload_info(user_id, file, url: str|None=None):
def structured(filename, filetype, blob, content_type):
nonlocal user_id
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
location = get_uuid()
FileService.put_blob(user_id, location, blob)
return {
"id": location,
"name": filename,
"size": sys.getsizeof(blob),
"extension": filename.split(".")[-1].lower(),
"mime_type": content_type,
"created_by": user_id,
"created_at": time.time(),
"preview_url": None
}
if url:
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
DefaultMarkdownGenerator,
PruningContentFilter,
CrawlResult
)
filename = re.sub(r"\?.*", "", url.split("/")[-1])
async def adownload():
browser_config = BrowserConfig(
headless=True,
verbose=False,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter()
),
pdf=True,
screenshot=False
)
result: CrawlResult = await crawler.arun(
url=url,
config=crawler_config
)
return result
page = asyncio.run(adownload())
if page.pdf:
if filename.split(".")[-1].lower() != "pdf":
filename += ".pdf"
return structured(filename, "pdf", page.pdf, page.response_headers["content-type"])
return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)
DocumentService.check_doc_health(user_id, file.filename)
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
@staticmethod
def get_files(files: Union[None, list[dict]]) -> list[str]:
if not files:
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
exe = ThreadPoolExecutor(max_workers=5)
threads = []
for file in files:
if file["mime_type"].find("image") >=0:
threads.append(exe.submit(image_to_base64, file))
continue
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return [th.result() for th in threads]

View File

@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import inspect
import logging
import re
import threading
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
@ -183,6 +185,66 @@ class LLMBundle(LLM4Tenant):
return txt
def stream_transcription(self, audio):
mdl = self.mdl
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
if supports_stream:
if self.langfuse:
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name}
)
final_text = ""
used_tokens = 0
try:
for evt in mdl.stream_transcription(audio):
if evt.get("event") == "final":
final_text = evt.get("text", "")
yield evt
except Exception as e:
err = {"event": "error", "text": str(e)}
yield err
final_text = final_text or ""
finally:
if final_text:
used_tokens = num_tokens_from_string(final_text)
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
if self.langfuse:
generation.update(
output={"output": final_text},
usage_details={"total_tokens": used_tokens}
)
generation.end()
return
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens
):
logging.error(
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
)
if self.langfuse:
generation.update(
output={"output": full_text},
usage_details={"total_tokens": used_tokens}
)
generation.end()
yield {
"event": "final",
"text": full_text,
"streaming": False
}
def tts(self, text: str) -> Generator[bytes, None, None]:
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
@ -242,7 +304,7 @@ class LLMBundle(LLM4Tenant):
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse:
@ -279,5 +341,89 @@ class LLMBundle(LLM4Tenant):
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e: # pragma: no cover
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs)
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
elif hasattr(self.mdl, "async_chat"):
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
else:
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and self.mdl.is_tools:
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
else:
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))

View File

@ -25,7 +25,6 @@ import logging
import os
import signal
import sys
import time
import traceback
import threading
import uuid
@ -69,7 +68,7 @@ def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
shutdown_all_mcp_sessions()
stop_event.set()
time.sleep(1)
stop_event.wait(1)
sys.exit(0)
if __name__ == '__main__':
@ -163,5 +162,5 @@ if __name__ == '__main__':
except Exception:
traceback.print_exc()
stop_event.set()
time.sleep(1)
stop_event.wait(1)
os.kill(os.getpid(), signal.SIGKILL)

View File

@ -22,6 +22,7 @@ import os
import time
from copy import deepcopy
from functools import wraps
from typing import Any
import requests
import trio
@ -45,11 +46,40 @@ from common import settings
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
async def request_json():
async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data."""
payload: Any = None
last_error: Exception | None = None
try:
return await request.json
except Exception:
return {}
payload = await request.get_json(force=True, silent=True)
except Exception as e:
last_error = e
payload = None
if payload is None:
try:
form = await request.form
payload = form.to_dict()
except Exception as e:
last_error = e
payload = None
if payload is None:
if last_error is not None:
raise last_error
raise ValueError("No JSON body or form data found in request.")
if isinstance(payload, dict):
return payload or {}
if isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
async def get_request_json():
return await _coerce_request_data()
def serialize_for_json(obj):
"""
@ -137,7 +167,7 @@ def validate_request(*args, **kwargs):
def wrapper(func):
@wraps(func)
async def decorated_function(*_args, **_kwargs):
errs = process_args(await request.json or (await request.form).to_dict())
errs = process_args(await _coerce_request_data())
if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func):
@ -152,7 +182,7 @@ def validate_request(*args, **kwargs):
def not_allowed_parameters(*params):
def decorator(func):
async def wrapper(*args, **kwargs):
input_arguments = await request.json or (await request.form).to_dict()
input_arguments = await _coerce_request_data()
for param in params:
if param in input_arguments:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
@ -313,6 +343,10 @@ def get_parser_config(chunk_method, parser_config):
chunk_method = "naive"
# Define default configurations for each chunking method
base_defaults = {
"table_context_size": 0,
"image_context_size": 0,
}
key_mapping = {
"naive": {
"layout_recognize": "DeepDOC",
@ -365,16 +399,19 @@ def get_parser_config(chunk_method, parser_config):
default_config = key_mapping[chunk_method]
# If no parser_config provided, return default
# If no parser_config provided, return default merged with base defaults
if not parser_config:
return default_config
if default_config is None:
return deep_merge(base_defaults, {})
return deep_merge(base_defaults, default_config)
# If parser_config is provided, merge with defaults to ensure required fields exist
if default_config is None:
return parser_config
return deep_merge(base_defaults, parser_config)
# Ensure raptor and graphrag fields have default values if not provided
merged_config = deep_merge(default_config, parser_config)
merged_config = deep_merge(base_defaults, default_config)
merged_config = deep_merge(merged_config, parser_config)
return merged_config

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
from collections import Counter
import string
from typing import Annotated, Any, Literal
from uuid import UUID
@ -25,6 +26,7 @@ from pydantic import (
StringConstraints,
ValidationError,
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
@ -329,6 +331,7 @@ class RaptorConfig(Base):
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
random_seed: Annotated[int, Field(default=0, ge=0)]
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
class GraphragConfig(Base):
@ -361,10 +364,9 @@ class CreateDatasetReq(Base):
description: Annotated[str | None, Field(default=None, max_length=65535)]
embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")]
permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)]
chunk_method: Annotated[
Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"),
]
chunk_method: Annotated[str | None, Field(default=None, serialization_alias="parser_id")]
parse_type: Annotated[int | None, Field(default=None, ge=0, le=64)]
pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")]
parser_config: Annotated[ParserConfig | None, Field(default=None)]
@field_validator("avatar", mode="after")
@ -525,6 +527,93 @@ class CreateDatasetReq(Base):
raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)})
return v
@field_validator("pipeline_id", mode="after")
@classmethod
def validate_pipeline_id(cls, v: str | None) -> str | None:
"""Validate pipeline_id as 32-char lowercase hex string if provided.
Rules:
- None or empty string: treat as None (not set)
- Must be exactly length 32
- Must contain only hex digits (0-9a-fA-F); normalized to lowercase
"""
if v is None:
return None
if v == "":
return None
if len(v) != 32:
raise PydanticCustomError("format_invalid", "pipeline_id must be 32 hex characters")
if any(ch not in string.hexdigits for ch in v):
raise PydanticCustomError("format_invalid", "pipeline_id must be hexadecimal")
return v.lower()
@model_validator(mode="after")
def validate_parser_dependency(self) -> "CreateDatasetReq":
"""
Mixed conditional validation:
- If parser_id is omitted (field not set):
* If both parse_type and pipeline_id are omitted → default chunk_method = "naive"
* If both parse_type and pipeline_id are provided → allow ingestion pipeline mode
- If parser_id is provided (valid enum) → parse_type and pipeline_id must be None (disallow mixed usage)
Raises:
PydanticCustomError with code 'dependency_error' on violation.
"""
# Omitted chunk_method (not in fields) logic
if self.chunk_method is None and "chunk_method" not in self.model_fields_set:
# All three absent → default naive
if self.parse_type is None and self.pipeline_id is None:
object.__setattr__(self, "chunk_method", "naive")
return self
# parser_id omitted: require BOTH parse_type & pipeline_id present (no partial allowed)
if self.parse_type is None or self.pipeline_id is None:
missing = []
if self.parse_type is None:
missing.append("parse_type")
if self.pipeline_id is None:
missing.append("pipeline_id")
raise PydanticCustomError(
"dependency_error",
"parser_id omitted → required fields missing: {fields}",
{"fields": ", ".join(missing)},
)
# Both provided → allow pipeline mode
return self
# parser_id provided (valid): MUST NOT have parse_type or pipeline_id
if isinstance(self.chunk_method, str):
if self.parse_type is not None or self.pipeline_id is not None:
invalid = []
if self.parse_type is not None:
invalid.append("parse_type")
if self.pipeline_id is not None:
invalid.append("pipeline_id")
raise PydanticCustomError(
"dependency_error",
"parser_id provided → disallowed fields present: {fields}",
{"fields": ", ".join(invalid)},
)
return self
@field_validator("chunk_method", mode="wrap")
@classmethod
def validate_chunk_method(cls, v: Any, handler) -> Any:
"""Wrap validation to unify error messages, including type errors (e.g. list)."""
allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"}
error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'"
# Omitted field: handler won't be invoked (wrap still gets value); None treated as explicit invalid
if v is None:
raise PydanticCustomError("literal_error", error_msg)
try:
# Run inner validation (type checking)
result = handler(v)
except Exception:
raise PydanticCustomError("literal_error", error_msg)
# After handler, enforce enumeration
if not isinstance(result, str) or result == "" or result not in allowed:
raise PydanticCustomError("literal_error", error_msg)
return result
class UpdateDatasetReq(CreateDatasetReq):
dataset_id: Annotated[str, Field(...)]

View File

@ -49,6 +49,7 @@ class RetCode(IntEnum, CustomEnum):
RUNNING = 106
PERMISSION_ERROR = 108
AUTHENTICATION_ERROR = 109
BAD_REQUEST = 400
UNAUTHORIZED = 401
SERVER_ERROR = 500
FORBIDDEN = 403
@ -147,6 +148,7 @@ class Storage(Enum):
AWS_S3 = 4
OSS = 5
OPENDAL = 6
GCS = 7
# environment
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"

View File

@ -217,6 +217,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback")
GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback")
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()

View File

@ -1,6 +1,6 @@
import logging
import os
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError
@ -9,10 +9,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen
from common.data_source.google_util.auth import get_google_creds
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc
# Constants for Gmail API fields
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
@ -67,7 +67,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
message_data += f"{name}: {value}\n"
message_body_text: str = get_message_body(payload)
return TextSection(link=link, text=message_body_text + message_data), metadata
@ -97,13 +96,15 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
if not semantic_identifier:
semantic_identifier = message_metadata.get("subject", "")
semantic_identifier = clean_string(semantic_identifier)
semantic_identifier = sanitize_filename(semantic_identifier)
if message_metadata.get("updated_at"):
updated_at = message_metadata.get("updated_at")
updated_at_datetime = None
if updated_at:
updated_at_datetime = time_str_to_utc(updated_at)
updated_at_datetime = gmail_time_str_to_utc(updated_at)
thread_id = full_thread.get("id")
if not thread_id:
@ -115,15 +116,24 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
if not semantic_identifier:
semantic_identifier = "(no subject)"
combined_sections = "\n\n".join(
sec.text for sec in sections if hasattr(sec, "text")
)
blob = combined_sections
size_bytes = len(blob)
extension = '.txt'
return Document(
id=thread_id,
semantic_identifier=semantic_identifier,
sections=sections,
blob=blob,
size_bytes=size_bytes,
extension=extension,
source=DocumentSource.GMAIL,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
doc_updated_at=updated_at_datetime,
metadata={},
metadata=message_metadata,
external_access=ExternalAccess(
external_user_emails={email_used_to_fetch_thread},
external_user_group_ids=set(),
@ -214,15 +224,13 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
q=query,
continue_on_404_or_403=True,
):
full_threads = _execute_single_retrieval(
full_thread = _execute_single_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
)
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread, user_email)
if doc is None:
continue
@ -310,4 +318,30 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
if __name__ == "__main__":
pass
import time
import os
from common.data_source.google_util.util import get_credentials_from_env
logging.basicConfig(level=logging.INFO)
try:
email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com")
creds = get_credentials_from_env(email, oauth=True, source="gmail")
print("Credentials loaded successfully")
print(f"{creds=}")
connector = GmailConnector(batch_size=2)
print("GmailConnector initialized")
connector.load_credentials(creds)
print("Credentials loaded into connector")
print("Gmail is ready to use")
for file in connector._fetch_threads(
int(time.time()) - 1 * 24 * 60 * 60,
int(time.time()),
):
print("new batch","-"*80)
for f in file:
print(f)
print("\n\n")
except Exception as e:
logging.exception(f"Error loading credentials: {e}")

View File

@ -1,7 +1,6 @@
"""Google Drive connector"""
import copy
import json
import logging
import os
import sys
@ -32,7 +31,6 @@ from common.data_source.google_drive.file_retrieval import (
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion
from common.data_source.google_util.auth import get_google_creds
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service
from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
@ -1138,39 +1136,6 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
def get_credentials_from_env(email: str, oauth: bool = False) -> dict:
try:
if oauth:
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
else:
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
except KeyError:
raise ValueError("Missing Google Drive credentials in environment variables")
try:
credential_dict = json.loads(raw_credential_string)
except json.JSONDecodeError:
raise ValueError("Invalid JSON in Google Drive credentials")
if oauth:
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE)
refried_credential_string = json.dumps(credential_dict)
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
return {
cred_key: refried_credential_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
@ -1236,7 +1201,7 @@ def yield_all_docs_from_checkpoint_connector(
if __name__ == "__main__":
import time
from common.data_source.google_util.util import get_credentials_from_env
logging.basicConfig(level=logging.DEBUG)
try:
@ -1245,7 +1210,7 @@ if __name__ == "__main__":
creds = get_credentials_from_env(email, oauth=True)
print("Credentials loaded successfully")
print(f"{creds=}")
sys.exit(0)
connector = GoogleDriveConnector(
include_shared_drives=False,
shared_drive_urls=None,

View File

@ -49,11 +49,11 @@ MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requeste
SCOPE_INSTRUCTIONS = ""
GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """<!DOCTYPE html>
GOOGLE_WEB_OAUTH_POPUP_TEMPLATE = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Google Drive Authorization</title>
<title>{title}</title>
<style>
body {{
font-family: Arial, sans-serif;

View File

@ -1,12 +1,17 @@
import json
import logging
import os
import re
import socket
from collections.abc import Callable, Iterator
from enum import Enum
from typing import Any
import unicodedata
from googleapiclient.errors import HttpError # type: ignore # type: ignore
from common.data_source.config import DocumentSource
from common.data_source.google_drive.model import GoogleDriveFileType
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
@ -117,6 +122,7 @@ def _execute_single_retrieval(
"""Execute a single retrieval from Google Drive API"""
try:
results = retrieval_function(**request_kwargs).execute()
except HttpError as e:
if e.resp.status >= 500:
results = retrieval_function()
@ -148,5 +154,110 @@ def _execute_single_retrieval(
error,
)
results = retrieval_function()
return results
def get_credentials_from_env(email: str, oauth: bool = False, source="drive") -> dict:
try:
if oauth:
raw_credential_string = os.environ["GOOGLE_OAUTH_CREDENTIALS_JSON_STR"]
else:
raw_credential_string = os.environ["GOOGLE_SERVICE_ACCOUNT_JSON_STR"]
except KeyError:
raise ValueError("Missing Google Drive credentials in environment variables")
try:
credential_dict = json.loads(raw_credential_string)
except json.JSONDecodeError:
raise ValueError("Invalid JSON in Google Drive credentials")
if oauth and source == "drive":
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE)
else:
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GMAIL)
refried_credential_string = json.dumps(credential_dict)
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
return {
cred_key: refried_credential_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
def sanitize_filename(name: str) -> str:
"""
Soft sanitize for MinIO/S3:
- Replace only prohibited characters with a space.
- Preserve readability (no ugly underscores).
- Collapse multiple spaces.
"""
if name is None:
return "file.txt"
name = str(name).strip()
# Characters that MUST NOT appear in S3/MinIO object keys
# Replace them with a space (not underscore)
forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
name = re.sub(forbidden, " ", name)
# Replace slashes "/" (S3 interprets as folder) with space
name = name.replace("/", " ")
# Collapse multiple spaces into one
name = re.sub(r"\s+", " ", name)
# Trim both ends
name = name.strip()
# Enforce reasonable max length
if len(name) > 200:
base, ext = os.path.splitext(name)
name = base[:180].rstrip() + ext
# Ensure there is an extension (your original logic)
if not os.path.splitext(name)[1]:
name += ".txt"
return name
def clean_string(text: str | None) -> str | None:
"""
Clean a string to make it safe for insertion into MySQL (utf8mb4).
- Normalize Unicode
- Remove control characters / zero-width characters
- Optionally remove high-plane emoji and symbols
"""
if text is None:
return None
# 0. Ensure the value is a string
text = str(text)
# 1. Normalize Unicode (NFC)
text = unicodedata.normalize("NFC", text)
# 2. Remove ASCII control characters (except tab, newline, carriage return)
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
# 3. Remove zero-width characters / BOM
text = re.sub(r"[\u200b-\u200d\uFEFF]", "", text)
# 4. Remove high Unicode characters (emoji, special symbols)
text = re.sub(r"[\U00010000-\U0010FFFF]", "", text)
# 5. Final fallback: strip any invalid UTF-8 sequences
try:
text.encode("utf-8")
except UnicodeEncodeError:
text = text.encode("utf-8", errors="ignore").decode("utf-8")
return text

View File

@ -30,7 +30,6 @@ class LoadConnector(ABC):
"""Load documents from state"""
pass
@abstractmethod
def validate_connector_settings(self) -> None:
"""Validate connector settings"""
pass

View File

@ -17,7 +17,11 @@ from common.data_source.exceptions import (
InsufficientPermissionsError,
ConnectorValidationError,
)
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch,
)
from common.data_source.models import Document
from common.data_source.utils import batch_generator, rl_requests
@ -42,7 +46,9 @@ class MoodleConnector(LoadConnector, PollConnector):
delimiter = "&" if "?" in file_url else "?"
return f"{file_url}{delimiter}token={token}"
def _log_error(self, context: str, error: Exception, level: str = "warning") -> None:
def _log_error(
self, context: str, error: Exception, level: str = "warning"
) -> None:
"""Simplified logging wrapper"""
msg = f"{context}: {error}"
if level == "error":
@ -73,7 +79,9 @@ class MoodleConnector(LoadConnector, PollConnector):
except MoodleException as e:
if "invalidtoken" in str(e).lower():
raise CredentialExpiredError("Moodle token is invalid or expired")
raise ConnectorMissingCredentialError(f"Failed to initialize Moodle client: {e}")
raise ConnectorMissingCredentialError(
f"Failed to initialize Moodle client: {e}"
)
def validate_connector_settings(self) -> None:
if not self.moodle_client:
@ -125,7 +133,9 @@ class MoodleConnector(LoadConnector, PollConnector):
logger.warning("No courses found to poll")
return
yield from self._yield_in_batches(self._get_updated_content(courses, start, end))
yield from self._yield_in_batches(
self._get_updated_content(courses, start, end)
)
@retry(tries=3, delay=1, backoff=2)
def _get_enrolled_courses(self) -> list:
@ -187,9 +197,7 @@ class MoodleConnector(LoadConnector, PollConnector):
except Exception as e:
self._log_error(f"polling course {course.fullname}", e)
def _process_module(
self, course, section, module
) -> Optional[Document]:
def _process_module(self, course, section, module) -> Optional[Document]:
try:
mtype = module.modname
if mtype in ["label", "url"]:
@ -224,11 +232,37 @@ class MoodleConnector(LoadConnector, PollConnector):
)
try:
resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60)
resp = rl_requests.get(
self._add_token_to_url(file_info.fileurl), timeout=60
)
resp.raise_for_status()
blob = resp.content
ext = os.path.splitext(file_name)[1] or ".bin"
semantic_id = f"{course.fullname} / {section.name} / {file_name}"
# Create metadata dictionary with relevant information
metadata = {
"moodle_url": self.moodle_url,
"course_id": getattr(course, "id", None),
"course_name": getattr(course, "fullname", None),
"course_shortname": getattr(course, "shortname", None),
"section_id": getattr(section, "id", None),
"section_name": getattr(section, "name", None),
"section_number": getattr(section, "section", None),
"module_id": getattr(module, "id", None),
"module_name": getattr(module, "name", None),
"module_type": getattr(module, "modname", None),
"module_instance": getattr(module, "instance", None),
"file_url": getattr(file_info, "fileurl", None),
"file_name": file_name,
"file_size": getattr(file_info, "filesize", len(blob)),
"file_type": getattr(file_info, "mimetype", None),
"time_created": getattr(module, "timecreated", None),
"time_modified": getattr(module, "timemodified", None),
"visible": getattr(module, "visible", None),
"groupmode": getattr(module, "groupmode", None),
}
return Document(
id=f"moodle_resource_{module.id}",
source="moodle",
@ -237,6 +271,7 @@ class MoodleConnector(LoadConnector, PollConnector):
blob=blob,
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
size_bytes=len(blob),
metadata=metadata,
)
except Exception as e:
self._log_error(f"downloading resource {file_name}", e, "error")
@ -247,7 +282,9 @@ class MoodleConnector(LoadConnector, PollConnector):
return None
try:
result = self.moodle_client.mod.forum.get_forum_discussions(forumid=module.instance)
result = self.moodle_client.mod.forum.get_forum_discussions(
forumid=module.instance
)
disc_list = getattr(result, "discussions", [])
if not disc_list:
return None
@ -264,6 +301,38 @@ class MoodleConnector(LoadConnector, PollConnector):
blob = "\n".join(markdown).encode("utf-8")
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
# Create metadata dictionary with relevant information
metadata = {
"moodle_url": self.moodle_url,
"course_id": getattr(course, "id", None),
"course_name": getattr(course, "fullname", None),
"course_shortname": getattr(course, "shortname", None),
"section_id": getattr(section, "id", None),
"section_name": getattr(section, "name", None),
"section_number": getattr(section, "section", None),
"module_id": getattr(module, "id", None),
"module_name": getattr(module, "name", None),
"module_type": getattr(module, "modname", None),
"forum_id": getattr(module, "instance", None),
"discussion_count": len(disc_list),
"time_created": getattr(module, "timecreated", None),
"time_modified": getattr(module, "timemodified", None),
"visible": getattr(module, "visible", None),
"groupmode": getattr(module, "groupmode", None),
"discussions": [
{
"id": getattr(d, "id", None),
"name": getattr(d, "name", None),
"user_id": getattr(d, "userid", None),
"user_fullname": getattr(d, "userfullname", None),
"time_created": getattr(d, "timecreated", None),
"time_modified": getattr(d, "timemodified", None),
}
for d in disc_list
],
}
return Document(
id=f"moodle_forum_{module.id}",
source="moodle",
@ -272,6 +341,7 @@ class MoodleConnector(LoadConnector, PollConnector):
blob=blob,
doc_updated_at=datetime.fromtimestamp(latest_ts or 0, tz=timezone.utc),
size_bytes=len(blob),
metadata=metadata,
)
except Exception as e:
self._log_error(f"processing forum {module.name}", e)
@ -293,11 +363,37 @@ class MoodleConnector(LoadConnector, PollConnector):
)
try:
resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60)
resp = rl_requests.get(
self._add_token_to_url(file_info.fileurl), timeout=60
)
resp.raise_for_status()
blob = resp.content
ext = os.path.splitext(file_name)[1] or ".html"
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
# Create metadata dictionary with relevant information
metadata = {
"moodle_url": self.moodle_url,
"course_id": getattr(course, "id", None),
"course_name": getattr(course, "fullname", None),
"course_shortname": getattr(course, "shortname", None),
"section_id": getattr(section, "id", None),
"section_name": getattr(section, "name", None),
"section_number": getattr(section, "section", None),
"module_id": getattr(module, "id", None),
"module_name": getattr(module, "name", None),
"module_type": getattr(module, "modname", None),
"module_instance": getattr(module, "instance", None),
"page_url": getattr(file_info, "fileurl", None),
"file_name": file_name,
"file_size": getattr(file_info, "filesize", len(blob)),
"file_type": getattr(file_info, "mimetype", None),
"time_created": getattr(module, "timecreated", None),
"time_modified": getattr(module, "timemodified", None),
"visible": getattr(module, "visible", None),
"groupmode": getattr(module, "groupmode", None),
}
return Document(
id=f"moodle_page_{module.id}",
source="moodle",
@ -306,6 +402,7 @@ class MoodleConnector(LoadConnector, PollConnector):
blob=blob,
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
size_bytes=len(blob),
metadata=metadata,
)
except Exception as e:
self._log_error(f"processing page {file_name}", e, "error")
@ -326,6 +423,29 @@ class MoodleConnector(LoadConnector, PollConnector):
semantic_id = f"{course.fullname} / {section.name} / {mname}"
blob = markdown.encode("utf-8")
# Create metadata dictionary with relevant information
metadata = {
"moodle_url": self.moodle_url,
"course_id": getattr(course, "id", None),
"course_name": getattr(course, "fullname", None),
"course_shortname": getattr(course, "shortname", None),
"section_id": getattr(section, "id", None),
"section_name": getattr(section, "name", None),
"section_number": getattr(section, "section", None),
"module_id": getattr(module, "id", None),
"module_name": getattr(module, "name", None),
"module_type": getattr(module, "modname", None),
"activity_type": mtype,
"activity_instance": getattr(module, "instance", None),
"description": desc,
"time_created": getattr(module, "timecreated", None),
"time_modified": getattr(module, "timemodified", None),
"added": getattr(module, "added", None),
"visible": getattr(module, "visible", None),
"groupmode": getattr(module, "groupmode", None),
}
return Document(
id=f"moodle_{mtype}_{module.id}",
source="moodle",
@ -334,6 +454,7 @@ class MoodleConnector(LoadConnector, PollConnector):
blob=blob,
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
size_bytes=len(blob),
metadata=metadata,
)
def _process_book(self, course, section, module) -> Optional[Document]:
@ -342,8 +463,10 @@ class MoodleConnector(LoadConnector, PollConnector):
contents = module.contents
chapters = [
c for c in contents
if getattr(c, "fileurl", None) and os.path.basename(c.filename) == "index.html"
c
for c in contents
if getattr(c, "fileurl", None)
and os.path.basename(c.filename) == "index.html"
]
if not chapters:
return None
@ -356,17 +479,54 @@ class MoodleConnector(LoadConnector, PollConnector):
)
markdown_parts = [f"# {module.name}\n"]
chapter_info = []
for ch in chapters:
try:
resp = rl_requests.get(self._add_token_to_url(ch.fileurl), timeout=60)
resp.raise_for_status()
html = resp.content.decode("utf-8", errors="ignore")
markdown_parts.append(md(html) + "\n\n---\n")
# Collect chapter information for metadata
chapter_info.append(
{
"chapter_id": getattr(ch, "chapterid", None),
"title": getattr(ch, "title", None),
"filename": getattr(ch, "filename", None),
"fileurl": getattr(ch, "fileurl", None),
"time_created": getattr(ch, "timecreated", None),
"time_modified": getattr(ch, "timemodified", None),
"size": getattr(ch, "filesize", None),
}
)
except Exception as e:
self._log_error(f"processing book chapter {ch.filename}", e)
blob = "\n".join(markdown_parts).encode("utf-8")
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
# Create metadata dictionary with relevant information
metadata = {
"moodle_url": self.moodle_url,
"course_id": getattr(course, "id", None),
"course_name": getattr(course, "fullname", None),
"course_shortname": getattr(course, "shortname", None),
"section_id": getattr(section, "id", None),
"section_name": getattr(section, "name", None),
"section_number": getattr(section, "section", None),
"module_id": getattr(module, "id", None),
"module_name": getattr(module, "name", None),
"module_type": getattr(module, "modname", None),
"book_id": getattr(module, "instance", None),
"chapter_count": len(chapters),
"chapters": chapter_info,
"time_created": getattr(module, "timecreated", None),
"time_modified": getattr(module, "timemodified", None),
"visible": getattr(module, "visible", None),
"groupmode": getattr(module, "groupmode", None),
}
return Document(
id=f"moodle_book_{module.id}",
source="moodle",
@ -375,4 +535,5 @@ class MoodleConnector(LoadConnector, PollConnector):
blob=blob,
doc_updated_at=datetime.fromtimestamp(latest_ts or 0, tz=timezone.utc),
size_bytes=len(blob),
metadata=metadata,
)

View File

@ -733,7 +733,7 @@ def build_time_range_query(
"""Build time range query for Gmail API"""
query = ""
if time_range_start is not None and time_range_start != 0:
query += f"after:{int(time_range_start)}"
query += f"after:{int(time_range_start) + 1}"
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
@ -778,6 +778,15 @@ def time_str_to_utc(time_str: str):
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
def gmail_time_str_to_utc(time_str: str):
"""Convert Gmail RFC 2822 time string to UTC."""
from email.utils import parsedate_to_datetime
from datetime import timezone
dt = parsedate_to_datetime(time_str)
return dt.astimezone(timezone.utc)
# Notion Utilities
T = TypeVar("T")

View File

@ -190,6 +190,11 @@ class WebDAVConnector(LoadConnector, PollConnector):
files = self._list_files_recursive(self.remote_path, start, end)
logging.info(f"Found {len(files)} files matching time criteria")
filename_counts: dict[str, int] = {}
for file_path, _ in files:
file_name = os.path.basename(file_path)
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
batch: list[Document] = []
for file_path, file_info in files:
file_name = os.path.basename(file_path)
@ -237,12 +242,22 @@ class WebDAVConnector(LoadConnector, PollConnector):
else:
modified = datetime.now(timezone.utc)
if filename_counts.get(file_name, 0) > 1:
relative_path = file_path
if file_path.startswith(self.remote_path):
relative_path = file_path[len(self.remote_path):]
if relative_path.startswith('/'):
relative_path = relative_path[1:]
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
else:
semantic_id = file_name
batch.append(
Document(
id=f"webdav:{self.base_url}:{file_path}",
blob=blob,
source=DocumentSource.WEBDAV,
semantic_identifier=file_name,
semantic_identifier=semantic_id,
extension=get_file_ext(file_name),
doc_updated_at=modified,
size_bytes=size_bytes if size_bytes else 0

157
common/http_client.py Normal file
View File

@ -0,0 +1,157 @@
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import time
from typing import Any, Dict, Optional
import httpx
logger = logging.getLogger(__name__)
# Default knobs; keep conservative to avoid unexpected behavioural changes.
DEFAULT_TIMEOUT = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15"))
# Align with requests default: follow redirects with a max of 30 unless overridden.
DEFAULT_FOLLOW_REDIRECTS = bool(int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1")))
DEFAULT_MAX_REDIRECTS = int(os.environ.get("HTTP_CLIENT_MAX_REDIRECTS", "30"))
DEFAULT_MAX_RETRIES = int(os.environ.get("HTTP_CLIENT_MAX_RETRIES", "2"))
DEFAULT_BACKOFF_FACTOR = float(os.environ.get("HTTP_CLIENT_BACKOFF_FACTOR", "0.5"))
DEFAULT_PROXY = os.environ.get("HTTP_CLIENT_PROXY")
DEFAULT_USER_AGENT = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client")
def _clean_headers(headers: Optional[Dict[str, str]], auth_token: Optional[str] = None) -> Optional[Dict[str, str]]:
merged_headers: Dict[str, str] = {}
if DEFAULT_USER_AGENT:
merged_headers["User-Agent"] = DEFAULT_USER_AGENT
if auth_token:
merged_headers["Authorization"] = auth_token
if headers is None:
return merged_headers or None
merged_headers.update({str(k): str(v) for k, v in headers.items() if v is not None})
return merged_headers or None
def _get_delay(backoff_factor: float, attempt: int) -> float:
return backoff_factor * (2**attempt)
async def async_request(
method: str,
url: str,
*,
timeout: float | httpx.Timeout | None = None,
follow_redirects: bool | None = None,
max_redirects: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
auth_token: Optional[str] = None,
retries: Optional[int] = None,
backoff_factor: Optional[float] = None,
proxies: Any = None,
**kwargs: Any,
) -> httpx.Response:
"""Lightweight async HTTP wrapper using httpx.AsyncClient with safe defaults."""
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
headers = _clean_headers(headers, auth_token=auth_token)
proxies = DEFAULT_PROXY if proxies is None else proxies
async with httpx.AsyncClient(
timeout=timeout,
follow_redirects=follow_redirects,
max_redirects=max_redirects,
proxies=proxies,
) as client:
last_exc: Exception | None = None
for attempt in range(retries + 1):
try:
start = time.monotonic()
response = await client.request(method=method, url=url, headers=headers, **kwargs)
duration = time.monotonic() - start
logger.debug(f"async_request {method} {url} -> {response.status_code} in {duration:.3f}s")
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
logger.warning(f"async_request exhausted retries for {method} {url}: {exc}")
raise
delay = _get_delay(backoff_factor, attempt)
logger.warning(f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
await asyncio.sleep(delay)
raise last_exc # pragma: no cover
def sync_request(
method: str,
url: str,
*,
timeout: float | httpx.Timeout | None = None,
follow_redirects: bool | None = None,
max_redirects: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
auth_token: Optional[str] = None,
retries: Optional[int] = None,
backoff_factor: Optional[float] = None,
proxies: Any = None,
**kwargs: Any,
) -> httpx.Response:
"""Synchronous counterpart to async_request, for CLI/tests or sync contexts."""
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
headers = _clean_headers(headers, auth_token=auth_token)
proxies = DEFAULT_PROXY if proxies is None else proxies
with httpx.Client(
timeout=timeout,
follow_redirects=follow_redirects,
max_redirects=max_redirects,
proxies=proxies,
) as client:
last_exc: Exception | None = None
for attempt in range(retries + 1):
try:
start = time.monotonic()
response = client.request(method=method, url=url, headers=headers, **kwargs)
duration = time.monotonic() - start
logger.debug(f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s")
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
logger.warning(f"sync_request exhausted retries for {method} {url}: {exc}")
raise
delay = _get_delay(backoff_factor, attempt)
logger.warning(f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
time.sleep(delay)
raise last_exc # pragma: no cover
__all__ = [
"async_request",
"sync_request",
"DEFAULT_TIMEOUT",
"DEFAULT_FOLLOW_REDIRECTS",
"DEFAULT_MAX_REDIRECTS",
"DEFAULT_MAX_RETRIES",
"DEFAULT_BACKOFF_FACTOR",
"DEFAULT_PROXY",
"DEFAULT_USER_AGENT",
]

View File

@ -23,6 +23,8 @@ import subprocess
import sys
import os
import logging
from pathlib import Path
from typing import Dict
def get_uuid():
return uuid.uuid1().hex
@ -106,3 +108,152 @@ def pip_install_torch():
logging.info("Installing pytorch")
pkg_names = ["torch>=2.5.0,<3.0.0"]
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
def parse_mineru_paths() -> Dict[str, Path]:
"""
Parse MinerU-related paths based on the MINERU_EXECUTABLE environment variable.
Expected layout (default convention):
MINERU_EXECUTABLE = /home/user/uv_tools/.venv/bin/mineru
From this path we derive:
- mineru_exec : full path to the mineru executable
- venv_dir : the virtual environment directory (.venv)
- tools_dir : the parent tools directory (e.g. uv_tools)
If MINERU_EXECUTABLE is not set, we fall back to the default layout:
$HOME/uv_tools/.venv/bin/mineru
Returns:
A dict with keys:
- "mineru_exec": Path
- "venv_dir": Path
- "tools_dir": Path
"""
mineru_exec_env = os.getenv("MINERU_EXECUTABLE")
if mineru_exec_env:
# Use the path from the environment variable
mineru_exec = Path(mineru_exec_env).expanduser().resolve()
venv_dir = mineru_exec.parent.parent
tools_dir = venv_dir.parent
else:
# Fall back to default convention: $HOME/uv_tools/.venv/bin/mineru
home = Path(os.path.expanduser("~"))
tools_dir = home / "uv_tools"
venv_dir = tools_dir / ".venv"
mineru_exec = venv_dir / "bin" / "mineru"
return {
"mineru_exec": mineru_exec,
"venv_dir": venv_dir,
"tools_dir": tools_dir,
}
@once
def check_and_install_mineru() -> None:
"""
Ensure MinerU is installed.
Behavior:
1. MinerU is enabled only when USE_MINERU is true/yes/1/y.
2. Resolve mineru_exec / venv_dir / tools_dir.
3. If mineru exists and works, log success and exit.
4. Otherwise:
- Create tools_dir
- Create venv if missing
- Install mineru[core], fallback to mineru[all]
- Validate with `--help`
5. Log installation success.
NOTE:
This function intentionally does NOT return the path.
Logging is used to indicate status.
"""
# Check if MinerU is enabled
use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
if use_mineru != "true":
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
return
# Resolve expected paths
paths = parse_mineru_paths()
mineru_exec: Path = paths["mineru_exec"]
venv_dir: Path = paths["venv_dir"]
tools_dir: Path = paths["tools_dir"]
# Construct environment variables for installation/execution
env = os.environ.copy()
env["VIRTUAL_ENV"] = str(venv_dir)
env["PATH"] = str(venv_dir / "bin") + os.pathsep + env.get("PATH", "")
# Configure HuggingFace endpoint
env.setdefault("HUGGINGFACE_HUB_ENDPOINT", os.getenv("HF_ENDPOINT") or "https://hf-mirror.com")
# Helper: check whether mineru works
def mineru_works() -> bool:
try:
subprocess.check_call(
[str(mineru_exec), "--help"],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
env=env,
)
return True
except Exception:
return False
# If MinerU is already installed and functional
if mineru_exec.is_file() and os.access(mineru_exec, os.X_OK) and mineru_works():
logging.info("MinerU already installed.")
os.environ["MINERU_EXECUTABLE"] = str(mineru_exec)
return
logging.info("MinerU not found. Installing into virtualenv: %s", venv_dir)
# Ensure parent directory exists
tools_dir.mkdir(parents=True, exist_ok=True)
# Create venv if missing
if not venv_dir.exists():
subprocess.check_call(
["uv", "venv", str(venv_dir)],
cwd=str(tools_dir),
env=env,
# stdout=subprocess.DEVNULL,
# stderr=subprocess.PIPE,
)
else:
logging.info("Virtual environment exists at %s. Reusing it.", venv_dir)
# Helper for pip install
def pip_install(pkg: str) -> None:
subprocess.check_call(
[
"uv", "pip", "install", "-U", pkg,
"-i", "https://mirrors.aliyun.com/pypi/simple",
"--extra-index-url", "https://pypi.org/simple",
],
cwd=str(tools_dir),
# stdout=subprocess.DEVNULL,
# stderr=subprocess.PIPE,
env=env,
)
# Install core version first; fallback to all
try:
logging.info("Installing mineru[core] ...")
pip_install("mineru[core]")
except subprocess.CalledProcessError:
logging.warning("mineru[core] installation failed. Installing mineru[all] ...")
pip_install("mineru[all]")
# Validate installation
if not mineru_works():
logging.error("MinerU installation failed: %s does not work.", mineru_exec)
raise RuntimeError(f"MinerU installation failed: {mineru_exec} is not functional")
os.environ["MINERU_EXECUTABLE"] = str(mineru_exec)
logging.info("MinerU installation completed successfully. Executable: %s", mineru_exec)

View File

@ -31,6 +31,7 @@ import rag.utils.ob_conn
import rag.utils.opensearch_conn
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
from rag.utils.gcs_conn import RAGFlowGCS
from rag.utils.minio_conn import RAGFlowMinio
from rag.utils.opendal_conn import OpenDALStorage
from rag.utils.s3_conn import RAGFlowS3
@ -109,6 +110,7 @@ MINIO = {}
OB = {}
OSS = {}
OS = {}
GCS = {}
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
DOC_BULK_SIZE: int = 4
@ -151,7 +153,8 @@ class StorageFactory:
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
Storage.AWS_S3: RAGFlowS3,
Storage.OSS: RAGFlowOSS,
Storage.OPENDAL: OpenDALStorage
Storage.OPENDAL: OpenDALStorage,
Storage.GCS: RAGFlowGCS,
}
@classmethod
@ -250,7 +253,7 @@ def init_settings():
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global AZURE, S3, MINIO, OSS
global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
@ -259,6 +262,8 @@ def init_settings():
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
elif STORAGE_IMPL_TYPE == 'GCS':
GCS = get_base_config("gcs", {})
global STORAGE_IMPL
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])

View File

@ -7,6 +7,20 @@
"status": "1",
"rank": "999",
"llm": [
{
"llm_name": "gpt-5.1",
"tags": "LLM,CHAT,400k,IMAGE2TEXT",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5.1-chat-latest",
"tags": "LLM,CHAT,400k,IMAGE2TEXT",
"max_tokens": 400000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "gpt-5",
"tags": "LLM,CHAT,400k,IMAGE2TEXT",
@ -269,20 +283,6 @@
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "glm-4.5",
"tags": "LLM,CHAT,131K",
"max_tokens": 131000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "deepseek-v3.1",
"tags": "LLM,CHAT,128k",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "hunyuan-a13b-instruct",
"tags": "LLM,CHAT,256k",
@ -324,6 +324,34 @@
"max_tokens": 262000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "deepseek-ocr",
"tags": "LLM,8k",
"max_tokens": 8000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen3-235b-a22b-instruct-2507",
"tags": "LLM,CHAT,256k",
"max_tokens": 256000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "glm-4.6",
"tags": "LLM,CHAT,200k",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "minimax-m2",
"tags": "LLM,CHAT,200k",
"max_tokens": 200000,
"model_type": "chat",
"is_tools": true
}
]
},
@ -686,19 +714,13 @@
"model_type": "rerank"
},
{
"llm_name": "qwen-audio-asr",
"llm_name": "qwen3-asr-flash",
"tags": "SPEECH2TEXT,8k",
"max_tokens": 8000,
"model_type": "speech2text"
},
{
"llm_name": "qwen-audio-asr-latest",
"tags": "SPEECH2TEXT,8k",
"max_tokens": 8000,
"model_type": "speech2text"
},
{
"llm_name": "qwen-audio-asr-1204",
"llm_name": "qwen3-asr-flash-2025-09-08",
"tags": "SPEECH2TEXT,8k",
"max_tokens": 8000,
"model_type": "speech2text"
@ -1166,6 +1188,12 @@
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": "embedding"
},
{
"llm_name": "jina-embeddings-v4",
"tags": "TEXT EMBEDDING",
"max_tokens": 32768,
"model_type": "embedding"
}
]
},
@ -1198,39 +1226,14 @@
{
"name": "MiniMax",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"tags": "LLM",
"status": "1",
"rank": "810",
"llm": [
{
"llm_name": "abab6.5-chat",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": "chat"
},
{
"llm_name": "abab6.5s-chat",
"tags": "LLM,CHAT,245k",
"max_tokens": 245760,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "abab6.5t-chat",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": "chat"
},
{
"llm_name": "abab6.5g-chat",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"model_type": "chat"
},
{
"llm_name": "abab5.5s-chat",
"tags": "LLM,CHAT,8k",
"max_tokens": 8192,
"llm_name": "MiniMax-M2",
"tags": "LLM,CHAT,200k",
"max_tokens": 200000,
"model_type": "chat"
}
]
@ -3218,6 +3221,13 @@
"status": "1",
"rank": "990",
"llm": [
{
"llm_name": "claude-opus-4-5-20251101",
"tags": "LLM,CHAT,IMAGE2TEXT,200k",
"max_tokens": 204800,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "claude-opus-4-1-20250805",
"tags": "LLM,CHAT,IMAGE2TEXT,200k",

View File

@ -38,6 +38,7 @@ oceanbase:
port: 2881
redis:
db: 1
username: ''
password: 'infini_rag_flow'
host: 'localhost:6379'
task_executor:
@ -59,6 +60,8 @@ user_default_llm:
# access_key: 'access_key'
# secret_key: 'secret_key'
# region: 'region'
#gcs:
# bucket: 'bridgtl-edm-d-bucket-ragflow'
# oss:
# access_key: 'access_key'
# secret_key: 'secret_key'

View File

@ -25,6 +25,8 @@ from rag.prompts.generator import vision_llm_figure_describe_prompt
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if not figures_data_without_positions:
return []
return [
(
(figure_data[1], [figure_data[0]]),
@ -35,7 +37,9 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
]
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
if not tbls:
return []
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
@ -53,6 +57,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
if not tbls:
return []
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")

View File

@ -138,7 +138,6 @@ class RAGFlowHtmlParser:
"metadata": {"table_id": table_id, "index": table_list.index(t)}})
return table_info_list
else:
block_id = None
if str.lower(element.name) in BLOCK_TAGS:
block_id = str(uuid.uuid1())
for child in element.children:
@ -172,7 +171,7 @@ class RAGFlowHtmlParser:
if tag_name == "table":
table_info_list.append(item)
else:
current_content += (" " if current_content else "" + content)
current_content += (" " if current_content else "") + content
if current_content:
block_content.append(current_content)
return block_content, table_info_list

View File

@ -63,6 +63,7 @@ class MinerUParser(RAGFlowPdfParser):
self.logger = logging.getLogger(self.__class__.__name__)
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
self.logger.info(f"[MinerU] Extract zip: zip_path={zip_path}, extract_to={extract_to}, root_hint={root_dir}")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
if not root_dir:
files = zip_ref.namelist()
@ -72,7 +73,7 @@ class MinerUParser(RAGFlowPdfParser):
root_dir = None
if not root_dir or not root_dir.endswith("/"):
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}")
self.logger.info(f"[MinerU] No root directory found, extracting all (root_hint={root_dir})")
zip_ref.extractall(extract_to)
return
@ -108,7 +109,7 @@ class MinerUParser(RAGFlowPdfParser):
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
if backend not in valid_backends:
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
logging.warning(reason)
self.logger.warning(reason)
return False, reason
subprocess_kwargs = {
@ -128,40 +129,40 @@ class MinerUParser(RAGFlowPdfParser):
if backend == "vlm-http-client" and server_url:
try:
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
self.logger.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
if server_accessible:
self.using_api = False # We are using http client, not API
return True, reason
else:
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
self.logger.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
return False, reason
except Exception as e:
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}")
self.logger.warning(f"[MinerU] vlm-http-client server check failed: {e}")
try:
response = requests.get(server_url, timeout=5)
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
self.logger.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
self.using_api = False
return True, reason
except Exception as e:
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
self.logger.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
return False, reason
try:
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
version_info = result.stdout.strip()
if version_info:
logging.info(f"[MinerU] Detected version: {version_info}")
self.logger.info(f"[MinerU] Detected version: {version_info}")
else:
logging.info("[MinerU] Detected MinerU, but version info is empty.")
self.logger.info("[MinerU] Detected MinerU, but version info is empty.")
return True, reason
except subprocess.CalledProcessError as e:
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
self.logger.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
except FileNotFoundError:
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
self.logger.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
except Exception as e:
logging.error(f"[MinerU] Unexpected error during installation check: {e}")
self.logger.error(f"[MinerU] Unexpected error during installation check: {e}")
# If executable check fails, try API check
try:
@ -171,14 +172,14 @@ class MinerUParser(RAGFlowPdfParser):
if not openapi_exists:
reason = "[MinerU] Failed to detect vaild MinerU API server"
return openapi_exists, reason
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
self.logger.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
self.using_api = openapi_exists
return openapi_exists, reason
else:
logging.info("[MinerU] api not exists.")
self.logger.info("[MinerU] api not exists.")
except Exception as e:
reason = f"[MinerU] Unexpected error during api check: {e}"
logging.error(f"[MinerU] Unexpected error during api check: {e}")
self.logger.error(f"[MinerU] Unexpected error during api check: {e}")
return False, reason
def _run_mineru(
@ -190,7 +191,7 @@ class MinerUParser(RAGFlowPdfParser):
self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback)
def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None):
OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip")
output_zip_path = os.path.join(str(output_dir), "output.zip")
pdf_file_path = str(input_path)
@ -230,16 +231,16 @@ class MinerUParser(RAGFlowPdfParser):
response.raise_for_status()
if response.headers.get("Content-Type") == "application/zip":
self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
if callback:
callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
with open(OUTPUT_ZIP_PATH, "wb") as f:
with open(output_zip_path, "wb") as f:
f.write(response.content)
self.logger.info(f"[MinerU] Unzip to {output_path}...")
self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/")
self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
if callback:
callback(0.40, f"[MinerU] Unzip to {output_path}...")
@ -314,7 +315,7 @@ class MinerUParser(RAGFlowPdfParser):
except Exception as e:
self.page_images = None
self.total_page = 0
logging.exception(e)
self.logger.exception(e)
def _line_tag(self, bx):
pn = [bx["page_idx"] + 1]
@ -459,13 +460,70 @@ class MinerUParser(RAGFlowPdfParser):
return poss
def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]:
subdir = output_dir / file_stem / method
if backend.startswith("vlm-"):
subdir = output_dir / file_stem / "vlm"
json_file = subdir / f"{file_stem}_content_list.json"
candidates = []
seen = set()
if not json_file.exists():
raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}")
def add_candidate_path(p: Path):
if p not in seen:
seen.add(p)
candidates.append(p)
if backend.startswith("vlm-"):
add_candidate_path(output_dir / file_stem / "vlm")
if method:
add_candidate_path(output_dir / file_stem / method)
add_candidate_path(output_dir / file_stem / "auto")
else:
if method:
add_candidate_path(output_dir / file_stem / method)
add_candidate_path(output_dir / file_stem / "vlm")
add_candidate_path(output_dir / file_stem / "auto")
json_file = None
subdir = None
attempted = []
# mirror MinerU's sanitize_filename to align ZIP naming
def _sanitize_filename(name: str) -> str:
sanitized = re.sub(r"[/\\\.]{2,}|[/\\]", "", name)
sanitized = re.sub(r"[^\w.-]", "_", sanitized, flags=re.UNICODE)
if sanitized.startswith("."):
sanitized = "_" + sanitized[1:]
return sanitized or "unnamed"
safe_stem = _sanitize_filename(file_stem)
allowed_names = {f"{file_stem}_content_list.json", f"{safe_stem}_content_list.json"}
self.logger.info(f"[MinerU] Expected output files: {', '.join(sorted(allowed_names))}")
self.logger.info(f"[MinerU] Searching output candidates: {', '.join(str(c) for c in candidates)}")
for sub in candidates:
jf = sub / f"{file_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying original path: {jf}")
attempted.append(jf)
if jf.exists():
subdir = sub
json_file = jf
break
# MinerU API sanitizes non-ASCII filenames inside the ZIP root and file names.
alt = sub / f"{safe_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying sanitized filename: {alt}")
attempted.append(alt)
if alt.exists():
subdir = sub
json_file = alt
break
nested_alt = sub / safe_stem / f"{safe_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying sanitized nested path: {nested_alt}")
attempted.append(nested_alt)
if nested_alt.exists():
subdir = nested_alt.parent
json_file = nested_alt
break
if not json_file:
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(p) for p in attempted)}")
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
@ -520,7 +578,7 @@ class MinerUParser(RAGFlowPdfParser):
method: str = "auto",
server_url: Optional[str] = None,
delete_output: bool = True,
parse_method: str = "raw"
parse_method: str = "raw",
) -> tuple:
import shutil
@ -570,7 +628,7 @@ class MinerUParser(RAGFlowPdfParser):
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
if callback:
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs)
finally:
if temp_pdf and temp_pdf.exists():

View File

@ -402,7 +402,6 @@ class RAGFlowPdfParser:
continue
else:
score = 0
print(f"{k=},{score=}",flush=True)
if score > best_score:
best_score = score
best_k = k

View File

@ -17,7 +17,7 @@
import logging
import math
import os
import re
# import re
from collections import Counter
from copy import deepcopy
@ -62,8 +62,9 @@ class LayoutRecognizer(Recognizer):
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
def __is_garbage(b):
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
return any([re.search(p, b["text"]) for p in patt])
return False
# patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
# return any([re.search(p, b["text"]) for p in patt])
if self.client:
layouts = self.client.predict(image_list)

View File

@ -170,7 +170,7 @@ TZ=Asia/Shanghai
# Uncomment the following line if your operating system is MacOS:
# MACOS=1
# The maximum file size limit (in bytes) for each upload to your knowledge base or File Management.
# The maximum file size limit (in bytes) for each upload to your dataset or RAGFlow's File system.
# To change the 1GB file size limit, uncomment the line below and update as needed.
# MAX_CONTENT_LENGTH=1073741824
# After updating, ensure `client_max_body_size` in nginx/nginx.conf is updated accordingly.

View File

@ -23,7 +23,7 @@ services:
env_file: .env
networks:
- ragflow
restart: on-failure
restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts:
@ -48,7 +48,7 @@ services:
env_file: .env
networks:
- ragflow
restart: on-failure
restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts:

View File

@ -31,7 +31,7 @@ services:
retries: 120
networks:
- ragflow
restart: on-failure
restart: unless-stopped
opensearch01:
profiles:
@ -67,12 +67,12 @@ services:
retries: 120
networks:
- ragflow
restart: on-failure
restart: unless-stopped
infinity:
profiles:
- infinity
image: infiniflow/infinity:v0.6.7
image: infiniflow/infinity:v0.6.10
volumes:
- infinity_data:/var/infinity
- ./infinity_conf.toml:/infinity_conf.toml
@ -94,7 +94,7 @@ services:
interval: 10s
timeout: 10s
retries: 120
restart: on-failure
restart: unless-stopped
oceanbase:
profiles:
@ -119,7 +119,7 @@ services:
timeout: 10s
networks:
- ragflow
restart: on-failure
restart: unless-stopped
sandbox-executor-manager:
profiles:
@ -147,7 +147,7 @@ services:
interval: 10s
timeout: 10s
retries: 120
restart: on-failure
restart: unless-stopped
mysql:
# mysql:5.7 linux/arm64 image is unavailable.
@ -175,7 +175,7 @@ services:
interval: 10s
timeout: 10s
retries: 120
restart: on-failure
restart: unless-stopped
minio:
image: quay.io/minio/minio:RELEASE.2025-06-13T11-33-47Z
@ -191,7 +191,7 @@ services:
- minio_data:/data
networks:
- ragflow
restart: on-failure
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 10s
@ -209,7 +209,7 @@ services:
- redis_data:/data
networks:
- ragflow
restart: on-failure
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
interval: 10s
@ -228,7 +228,7 @@ services:
networks:
- ragflow
command: ["--model-id", "/data/${TEI_MODEL}", "--auto-truncate"]
restart: on-failure
restart: unless-stopped
tei-gpu:
@ -249,7 +249,7 @@ services:
- driver: nvidia
count: all
capabilities: [gpu]
restart: on-failure
restart: unless-stopped
kibana:
@ -271,7 +271,7 @@ services:
retries: 120
networks:
- ragflow
restart: on-failure
restart: unless-stopped
volumes:

View File

@ -22,7 +22,7 @@ services:
env_file: .env
networks:
- ragflow
restart: on-failure
restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts:
@ -39,7 +39,7 @@ services:
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
# networks:
# - ragflow
# restart: on-failure
# restart: unless-stopped
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
# extra_hosts:

View File

@ -45,7 +45,7 @@ services:
env_file: .env
networks:
- ragflow
restart: on-failure
restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
extra_hosts:
@ -94,7 +94,7 @@ services:
env_file: .env
networks:
- ragflow
restart: on-failure
restart: unless-stopped
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
extra_hosts:
@ -120,7 +120,7 @@ services:
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
# networks:
# - ragflow
# restart: on-failure
# restart: unless-stopped
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
# extra_hosts:

View File

@ -1,5 +1,5 @@
[general]
version = "0.6.7"
version = "0.6.10"
time_zone = "utc-8"
[network]

View File

@ -38,6 +38,7 @@ oceanbase:
port: ${OCEANBASE_PORT:-2881}
redis:
db: 1
username: '${REDIS_USERNAME:-}'
password: '${REDIS_PASSWORD:-infini_rag_flow}'
host: '${REDIS_HOST:-redis}:6379'
user_default_llm:

View File

@ -89,6 +89,8 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit
- `REDIS_PORT`
The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`.
- `REDIS_USERNAME`
Optional Redis ACL username when using Redis 6+ authentication.
- `REDIS_PASSWORD`
The password for Redis.
@ -160,6 +162,13 @@ If you cannot download the RAGFlow Docker image, try the following mirrors.
- `password`: The password for MinIO.
- `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`.
### `redis`
- `host`: The Redis serving IP *and* port inside the Docker container. Defaults to `redis:6379`.
- `db`: The Redis database index to use. Defaults to `1`.
- `username`: Optional Redis ACL username (Redis 6+).
- `password`: The password for the specified Redis user.
### `oauth`
The OAuth configuration for signing up or signing in to RAGFlow using a third-party account.

View File

@ -323,9 +323,9 @@ The status of a Docker container status does not necessarily reflect the status
2. Follow [this document](./guides/run_health_check.md) to check the health status of the Elasticsearch service.
:::danger IMPORTANT
The status of a Docker container status does not necessarily reflect the status of the service. You may find that your services are unhealthy even when the corresponding Docker containers are up running. Possible reasons for this include network failures, incorrect port numbers, or DNS issues.
:::
:::danger IMPORTANT
The status of a Docker container status does not necessarily reflect the status of the service. You may find that your services are unhealthy even when the corresponding Docker containers are up running. Possible reasons for this include network failures, incorrect port numbers, or DNS issues.
:::
3. If your container keeps restarting, ensure `vm.max_map_count` >= 262144 as per [this README](https://github.com/infiniflow/ragflow?tab=readme-ov-file#-start-up-the-server). Updating the `vm.max_map_count` value in **/etc/sysctl.conf** is required, if you wish to keep your change permanent. Note that this configuration works only for Linux.
@ -456,9 +456,9 @@ To switch your document engine from Elasticsearch to [Infinity](https://github.c
```bash
$ docker compose -f docker/docker-compose.yml down -v
```
:::caution WARNING
`-v` will delete all Docker container volumes, and the existing data will be cleared.
:::
:::caution WARNING
`-v` will delete all Docker container volumes, and the existing data will be cleared.
:::
2. In **docker/.env**, set `DOC_ENGINE=${DOC_ENGINE:-infinity}`
3. Restart your Docker image:
@ -497,20 +497,6 @@ MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports
1. Prepare MinerU
- **If you deploy RAGFlow from source**, install MinerU into an isolated virtual environment (recommended path: `$HOME/uv_tools`):
```bash
mkdir -p "$HOME/uv_tools"
cd "$HOME/uv_tools"
uv venv .venv
source .venv/bin/activate
uv pip install -U "mineru[core]" -i https://mirrors.aliyun.com/pypi/simple
# or
# uv pip install -U "mineru[all]" -i https://mirrors.aliyun.com/pypi/simple
```
- **If you deploy RAGFlow with Docker**, you usually only need to turn on MinerU support in `docker/.env`:
```bash
# docker/.env
...
@ -518,18 +504,15 @@ MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports
...
```
Enabling `USE_MINERU=true` will internally perform the same setup as the manual configuration (including setting the MinerU executable path and related environment variables). You only need the manual installation above if you are running from source or want full control over the MinerU installation.
Enabling `USE_MINERU=true` will internally perform the same setup as the manual configuration (including setting the MinerU executable path and related environment variables).
2. Start RAGFlow with MinerU enabled:
- **Source deployment** in the RAGFlow repo, export the key MinerU-related variables and start the backend service:
- **Source deployment** in the RAGFlow repo, continue to start the backend service:
```bash
# in RAGFlow repo
export MINERU_EXECUTABLE="$HOME/uv_tools/.venv/bin/mineru"
export MINERU_DELETE_OUTPUT=0 # keep output directory
export MINERU_BACKEND=pipeline # or another backend you prefer
...
source .venv/bin/activate
export PYTHONPATH=$(pwd)
bash docker/launch_backend_service.sh

View File

@ -22,7 +22,7 @@ An **Agent** component is essential when you need the LLM to assist with summari
1. Ensure you have a chat model properly configured:
![Set default models](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/set_default_models.jpg)
![Set default models](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/set_default_models.jpg)
2. If your Agent involves dataset retrieval, ensure you [have properly configured your target dataset(s)](../../dataset/configure_knowledge_base.md).
@ -91,7 +91,7 @@ Update your MCP server's name, URL (including the API key), server type, and oth
*The target MCP server appears below your Agent component, and your Agent will autonomously decide when to invoke the available tools it offers.*
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/choose_tavily_mcp_server.jpg)
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/choose_tavily_mcp_server.jpg)
### 5. Update system prompt to specify trigger conditions (Optional)

View File

@ -76,5 +76,5 @@ No. Files uploaded to an agent as input are not stored in a dataset and hence wi
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
:::tip NOTE
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario.
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or RAGFlow's File system. These settings DO NOT apply in this scenario.
:::

View File

@ -62,9 +62,9 @@ docker build -t sandbox-executor-manager:latest ./executor_manager
3. Add the following entry to your /etc/hosts file to resolve the executor manager service:
```bash
127.0.0.1 es01 infinity mysql minio redis sandbox-executor-manager
```
```bash
127.0.0.1 es01 infinity mysql minio redis sandbox-executor-manager
```
4. Start the RAGFlow service as usual.
@ -74,24 +74,24 @@ docker build -t sandbox-executor-manager:latest ./executor_manager
1. Initialize the environment variables:
```bash
cp .env.example .env
```
```bash
cp .env.example .env
```
2. Launch the sandbox services with Docker Compose:
```bash
docker compose -f docker-compose.yml up
```
```bash
docker compose -f docker-compose.yml up
```
3. Test the sandbox setup:
```bash
source .venv/bin/activate
export PYTHONPATH=$(pwd)
uv pip install -r executor_manager/requirements.txt
uv run tests/sandbox_security_tests_full.py
```
```bash
source .venv/bin/activate
export PYTHONPATH=$(pwd)
uv pip install -r executor_manager/requirements.txt
uv run tests/sandbox_security_tests_full.py
```
### Using Makefile

View File

@ -9,7 +9,7 @@ Initiate an AI-powered chat with a configured chat assistant.
---
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
## Start an AI chat
@ -83,13 +83,13 @@ You start an AI conversation by creating an assistant.
1. Click the light bulb icon above the answer to view the expanded system prompt:
![prompt_display](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/prompt_display.jpg)
![prompt_display](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/prompt_display.jpg)
*The light bulb icon is available only for the current dialogue.*
2. Scroll down the expanded prompt to view the time consumed for each task:
![time_elapsed](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/time_elapsed.jpg)
![time_elapsed](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/time_elapsed.jpg)
:::
## Update settings of an existing chat assistant

View File

@ -5,7 +5,7 @@ slug: /configure_knowledge_base
# Configure dataset
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in **File Management** into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in RAGFlow's File system into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
- Create a dataset
- Configure a dataset
@ -82,10 +82,10 @@ Some embedding models are optimized for specific languages, so performance may b
### Upload file
- RAGFlow's **File Management** allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
- RAGFlow's File system allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
- In **Knowledge Base**, you are also given the option of uploading a single file or a folder of files (bulk upload) from your local machine to a dataset, in which case the dataset holds file copies.
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to **File Management** and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to RAGFlow's File system and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
### Parse file
@ -142,6 +142,6 @@ As of RAGFlow v0.22.1, the search feature is still in a rudimentary form, suppor
You are allowed to delete a dataset. Hover your mouse over the three dot of the intended dataset card and the **Delete** option appears. Once you delete a dataset, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
- The files uploaded directly to the dataset are gone;
- The file references, which you created from within **File Management**, are gone, but the associated files still exist in **File Management**.
- The file references, which you created from within RAGFlow's File system, are gone, but the associated files still exist.
![delete dataset](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/delete_datasets.jpg)

View File

@ -56,9 +56,9 @@ Once a tag set is created, you can apply it to your dataset:
1. Navigate to the **Configuration** page of your dataset.
2. Select the tag set from the **Tag sets** dropdown and click **Save** to confirm.
:::tip NOTE
If the tag set is missing from the dropdown, check that it has been created or configured correctly.
:::
:::tip NOTE
If the tag set is missing from the dropdown, check that it has been created or configured correctly.
:::
3. Re-parse your documents to start the auto-tagging process.
_In an AI chat scenario using auto-tagged datasets, each query will be tagged using the corresponding tag set(s) and chunks with these tags will have a higher chance to be retrieved._

View File

@ -314,35 +314,3 @@ To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the con
3. [Update System Model Settings](#6-update-system-model-settings)
4. [Update Chat Configuration](#7-update-chat-configuration)
## Deploy a local model using jina
To deploy a local model, e.g., **gpt2**, using jina:
### 1. Check firewall settings
Ensure that your host machine's firewall allows inbound connections on port 12345.
```bash
sudo ufw allow 12345/tcp
```
### 2. Install jina package
```bash
pip install jina
```
### 3. Deploy a local model
Step 1: Navigate to the **rag/svr** directory.
```bash
cd rag/svr
```
Step 2: Run **jina_server.py**, specifying either the model's name or its local directory:
```bash
python jina_server.py --model_name gpt2
```
> The script only supports models downloaded from Hugging Face.

View File

@ -19,48 +19,60 @@ Upgrading RAGFlow in itself will *not* remove your uploaded/historical data. How
To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker image:
1. Clone the repo
1. Stop the server
```bash
git clone https://github.com/infiniflow/ragflow.git
docker compose -f docker/docker-compose.yml down
```
2. Update **ragflow/docker/.env**:
2. Update the local code
```bash
git pull
```
3. Update **ragflow/docker/.env**:
```bash
RAGFLOW_IMAGE=infiniflow/ragflow:nightly
```
3. Update RAGFlow image and restart RAGFlow:
4. Update RAGFlow image and restart RAGFlow:
```bash
docker compose -f docker/docker-compose.yml pull
docker compose -f docker/docker-compose.yml up -d
```
## Upgrade RAGFlow to the most recent, officially published release
## Upgrade RAGFlow to given release
To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker image:
1. Clone the repo
1. Stop the server
```bash
git clone https://github.com/infiniflow/ragflow.git
docker compose -f docker/docker-compose.yml down
```
2. Switch to the latest, officially published release, e.g., `v0.22.1`:
2. Update the local code
```bash
git pull
```
3. Switch to the latest, officially published release, e.g., `v0.22.1`:
```bash
git checkout -f v0.22.1
```
3. Update **ragflow/docker/.env**:
4. Update **ragflow/docker/.env**:
```bash
RAGFLOW_IMAGE=infiniflow/ragflow:v0.22.1
```
4. Update the RAGFlow image and restart RAGFlow:
5. Update the RAGFlow image and restart RAGFlow:
```bash
docker compose -f docker/docker-compose.yml pull

View File

@ -39,8 +39,10 @@ If you have not installed Docker on your local machine (Windows, Mac, or Linux),
This section provides instructions on setting up the RAGFlow server on Linux. If you are on a different operating system, no worries. Most steps are alike.
1. Ensure `vm.max_map_count` &ge; 262144.
<details>
<summary>1. Ensure <code>vm.max_map_count</code> &ge; 262144:</summary>
<summary>Expand to show details:</summary>
`vm.max_map_count`. This value sets the maximum number of memory map areas a process may have. Its default value is 65530. While most applications require fewer than a thousand maps, reducing this value can result in abnormal behaviors, and the system will throw out-of-memory errors when a process reaches the limitation.
@ -194,22 +196,22 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
$ docker compose -f docker-compose.yml up -d
```
```mdx-code-block
<APITable>
```
```mdx-code-block
<APITable>
```
| RAGFlow image tag | Image size (GB) | Stable? |
| ------------------- | --------------- | ------------------------ |
| v0.22.1 | &approx;2 | Stable release |
| nightly | &approx;2 | _Unstable_ nightly build |
| RAGFlow image tag | Image size (GB) | Stable? |
| ------------------- | --------------- | ------------------------ |
| v0.22.1 | &approx;2 | Stable release |
| nightly | &approx;2 | _Unstable_ nightly build |
```mdx-code-block
</APITable>
```
```mdx-code-block
</APITable>
```
:::tip NOTE
The image size shown refers to the size of the *downloaded* Docker image, which is compressed. When Docker runs the image, it unpacks it, resulting in significantly greater disk usage. A Docker image will expand to around 7 GB once unpacked.
:::
:::tip NOTE
The image size shown refers to the size of the *downloaded* Docker image, which is compressed. When Docker runs the image, it unpacks it, resulting in significantly greater disk usage. A Docker image will expand to around 7 GB once unpacked.
:::
4. Check the server status after having the server up and running:
@ -229,15 +231,15 @@ The image size shown refers to the size of the *downloaded* Docker image, which
* Running on all addresses (0.0.0.0)
```
:::danger IMPORTANT
If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anomaly` error because, at that moment, your RAGFlow may not be fully initialized.
:::
:::danger IMPORTANT
If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anomaly` error because, at that moment, your RAGFlow may not be fully initialized.
:::
5. In your web browser, enter the IP address of your server and log in to RAGFlow.
:::caution WARNING
With the default settings, you only need to enter `http://IP_OF_YOUR_MACHINE` (**sans** port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
:::
:::caution WARNING
With the default settings, you only need to enter `http://IP_OF_YOUR_MACHINE` (**sans** port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
:::
## Configure LLMs
@ -278,9 +280,9 @@ To create your first dataset:
3. RAGFlow offers multiple chunk templates that cater to different document layouts and file formats. Select the embedding model and chunking method (template) for your dataset.
:::danger IMPORTANT
Once you have selected an embedding model and used it to parse a file, you are no longer allowed to change it. The obvious reason is that we must ensure that all files in a specific dataset are parsed using the *same* embedding model (ensure that they are being compared in the same embedding space).
:::
:::danger IMPORTANT
Once you have selected an embedding model and used it to parse a file, you are no longer allowed to change it. The obvious reason is that we must ensure that all files in a specific dataset are parsed using the *same* embedding model (ensure that they are being compared in the same embedding space).
:::
_You are taken to the **Dataset** page of your dataset._
@ -290,10 +292,10 @@ Once you have selected an embedding model and used it to parse a file, you are n
![parse file](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/parse_file.jpg)
:::caution NOTE
- If your file parsing gets stuck at below 1%, see [this FAQ](./faq.mdx#why-does-my-document-parsing-stall-at-under-one-percent).
- If your file parsing gets stuck at near completion, see [this FAQ](./faq.mdx#why-does-my-pdf-parsing-stall-near-completion-while-the-log-does-not-show-any-error)
:::
:::caution NOTE
- If your file parsing gets stuck at below 1%, see [this FAQ](./faq.mdx#why-does-my-document-parsing-stall-at-under-one-percent).
- If your file parsing gets stuck at near completion, see [this FAQ](./faq.mdx#why-does-my-pdf-parsing-stall-near-completion-while-the-log-does-not-show-any-error)
:::
## Intervene with file parsing
@ -311,9 +313,9 @@ RAGFlow features visibility and explainability, allowing you to view the chunkin
![update chunk](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/add_keyword_question.jpg)
:::caution NOTE
You can add keywords or questions to a file chunk to improve its ranking for queries containing those keywords. This action increases its keyword weight and can improve its position in search list.
:::
:::caution NOTE
You can add keywords or questions to a file chunk to improve its ranking for queries containing those keywords. This action increases its keyword weight and can improve its position in search list.
:::
4. In Retrieval testing, ask a quick question in **Test text** to double check if your configurations work:

View File

@ -420,8 +420,10 @@ Creates a dataset.
- `"permission"`: `string`
- `"chunk_method"`: `string`
- `"parser_config"`: `object`
- `"parse_type"`: `int`
- `"pipeline_id"`: `string`
##### Request example
##### A basic request example
```bash
curl --request POST \
@ -433,6 +435,24 @@ curl --request POST \
}'
```
##### A request example specifying ingestion pipeline
:::caution WARNING
You must *not* include `"chunk_method"` or `"parser_config"` when specifying an ingestion pipeline.
:::
```bash
curl --request POST \
--url http://{address}/api/v1/datasets \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <YOUR_API_KEY>' \
--data '{
"name": "test-sdk",
"parse_type": <NUMBER_OF_PARSERS_IN_YOUR_PARSER_COMPONENT>,
"pipeline_id": "<PIPELINE_ID_32_HEX>"
}'
```
##### Request parameters
- `"name"`: (*Body parameter*), `string`, *Required*
@ -460,7 +480,8 @@ curl --request POST \
- `"team"`: All team members can manage the dataset.
- `"chunk_method"`: (*Body parameter*), `enum<string>`
The chunking method of the dataset to create. Available options:
The default chunk method of the dataset to create. Mutually exclusive with `"parse_type"` and `"pipeline_id"`. If you set `"chunk_method"`, do not include `"parse_type"` or `"pipeline_id"`.
Available options:
- `"naive"`: General (default)
- `"book"`: Book
- `"email"`: Email
@ -491,13 +512,16 @@ curl --request POST \
- Maximum: `2048`
- `"delimiter"`: `string`
- Defaults to `"\n"`.
- `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format.
- `"html4excel"`: `bool`
- Whether to convert Excel documents into HTML format.
- Defaults to `false`
- `"layout_recognize"`: `string`
- Defaults to `DeepDOC`
- `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets)
- Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunking Method
- `"task_page_size"`: `int` For PDF only.
- `"tag_kb_ids"`: `array<string>`
- IDs of datasets to be parsed using the Tag chunk method.
- Before setting this, ensure a tag set is created and properly configured. For details, see [Use tag set](https://ragflow.io/docs/dev/use_tag_sets).
- `"task_page_size"`: `int`
- For PDFs only.
- Defaults to `12`
- Minimum: `1`
- `"raptor"`: `object` RAPTOR-specific settings.
@ -509,6 +533,26 @@ curl --request POST \
- Defaults to: `{"use_raptor": false}`.
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
- `"parse_type"`: (*Body parameter*), `int`
The ingestion pipeline parse type identifier, i.e., the number of parsers in your **Parser** component.
- Required (along with `"pipeline_id"`) if specifying an ingestion pipeline.
- Must not be included when `"chunk_method"` is specified.
- `"pipeline_id"`: (*Body parameter*), `string`
The ingestion pipeline ID. Can be found in the corresponding URL in the RAGFlow UI.
- Required (along with `"parse_type"`) if specifying an ingestion pipeline.
- Must be a 32-character lowercase hexadecimal string, e.g., `"d0bebe30ae2211f0970942010a8e0005"`.
- Must not be included when `"chunk_method"` is specified.
:::caution WARNING
You can choose either of the following ingestion options when creating a dataset, but *not* both:
- Use a built-in chunk method -- specify `"chunk_method"` (optionally with `"parser_config"`).
- Use an ingestion pipeline -- specify both `"parse_type"` and `"pipeline_id"`.
If none of `"chunk_method"`, `"parse_type"`, or `"pipeline_id"` are provided, the system defaults to `chunk_method = "naive"`.
:::
#### Response
Success:

View File

@ -43,7 +43,6 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
repos = [
"InfiniFlow/text_concat_xgb_v1.0",
"InfiniFlow/deepdoc",
"InfiniFlow/huqie",
]

View File

@ -57,7 +57,7 @@ async def run_graphrag(
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -174,13 +174,19 @@ async def run_graphrag_for_kb(
chunks = []
current_chunk = ""
for d in settings.retriever.chunk_list(
# DEBUG: Obtener todos los chunks primero
raw_chunks = list(settings.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
fields=fields_for_chunks,
sort_by_position=True,
):
))
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
for d in raw_chunks:
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content

View File

@ -96,7 +96,7 @@ ragflow:
infinity:
image:
repository: infiniflow/infinity
tag: v0.6.7
tag: v0.6.10
pullPolicy: IfNotPresent
pullSecrets: []
storage:

View File

@ -57,7 +57,6 @@ JSON_RESPONSE = True
class RAGFlowConnector:
_MAX_DATASET_CACHE = 32
_MAX_DOCUMENT_CACHE = 128
_CACHE_TTL = 300
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
@ -116,8 +115,6 @@ class RAGFlowConnector:
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id)
if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
self._document_metadata_cache.popitem(last=False)
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
@ -240,46 +237,46 @@ class RAGFlowConnector:
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
if docs is None:
docs_res = self._get(f"/datasets/{dataset_id}/documents")
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
doc_id_meta_list = []
docs = {}
for doc in docs_data["data"]["docs"]:
doc_id = doc.get("id")
if not doc_id:
continue
doc_meta = {
"document_id": doc_id,
"name": doc.get("name", ""),
"location": doc.get("location", ""),
"type": doc.get("type", ""),
"size": doc.get("size"),
"chunk_count": doc.get("chunk_count"),
# "chunk_method": doc.get("chunk_method", ""),
"create_date": doc.get("create_date", ""),
"update_date": doc.get("update_date", ""),
# "process_begin_at": doc.get("process_begin_at", ""),
# "process_duration": doc.get("process_duration"),
# "progress": doc.get("progress"),
# "progress_msg": doc.get("progress_msg", ""),
# "status": doc.get("status", ""),
# "run": doc.get("run", ""),
"token_count": doc.get("token_count"),
# "source_type": doc.get("source_type", ""),
"thumbnail": doc.get("thumbnail", ""),
"dataset_id": doc.get("dataset_id", dataset_id),
"meta_fields": doc.get("meta_fields", {}),
# "parser_config": doc.get("parser_config", {})
}
doc_id_meta_list.append((doc_id, doc_meta))
docs[doc_id] = doc_meta
page = 1
page_size = 30
doc_id_meta_list = []
docs = {}
while page:
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]:
doc_id = doc.get("id")
if not doc_id:
continue
doc_meta = {
"document_id": doc_id,
"name": doc.get("name", ""),
"location": doc.get("location", ""),
"type": doc.get("type", ""),
"size": doc.get("size"),
"chunk_count": doc.get("chunk_count"),
"create_date": doc.get("create_date", ""),
"update_date": doc.get("update_date", ""),
"token_count": doc.get("token_count"),
"thumbnail": doc.get("thumbnail", ""),
"dataset_id": doc.get("dataset_id", dataset_id),
"meta_fields": doc.get("meta_fields", {}),
}
doc_id_meta_list.append((doc_id, doc_meta))
docs[doc_id] = doc_meta
page += 1
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
page = None
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
if docs:
document_cache.update(docs)
except Exception:
except Exception as e:
# Gracefully handle metadata cache failures
logging.error(f"Problem building the document metadata cache: {str(e)}")
pass
return document_cache, dataset_cache

View File

@ -49,7 +49,7 @@ dependencies = [
"html-text==0.6.2",
"httpx[socks]>=0.28.1,<0.29.0",
"huggingface-hub>=0.25.0,<0.26.0",
"infinity-sdk==0.6.7",
"infinity-sdk==0.6.10",
"infinity-emb>=0.0.66,<0.0.67",
"itsdangerous==2.1.2",
"json-repair==0.35.0",
@ -131,7 +131,6 @@ dependencies = [
"graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd",
"mini-racer>=0.12.4,<0.13.0",
"pyodbc>=5.2.0,<6.0.0",
"pyicu>=2.15.3,<3.0.0",
"flasgger>=0.9.7.1,<0.10.0",
"xxhash>=3.5.0,<4.0.0",
"trio>=0.17.0,<0.29.0",
@ -152,7 +151,9 @@ dependencies = [
"moodlepy>=0.23.0",
"pypandoc>=1.16",
"pyobvector==0.2.18",
"exceptiongroup>=1.3.0,<2.0.0"
"exceptiongroup>=1.3.0,<2.0.0",
"ffmpeg-python>=0.2.0",
"imageio-ffmpeg>=0.6.0",
]
[dependency-groups]
@ -161,6 +162,9 @@ test = [
"openpyxl>=3.1.5",
"pillow>=10.4.0",
"pytest>=8.3.5",
"pytest-asyncio>=1.3.0",
"pytest-xdist>=3.8.0",
"pytest-cov>=7.0.0",
"python-docx>=1.1.2",
"python-pptx>=1.0.2",
"reportlab>=4.4.1",
@ -193,8 +197,83 @@ extend-select = ["ASYNC", "ASYNC1"]
ignore = ["E402"]
[tool.pytest.ini_options]
pythonpath = [
"."
]
testpaths = ["test"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"p1: high priority test cases",
"p2: medium priority test cases",
"p3: low priority test cases",
]
# Test collection and runtime configuration
filterwarnings = [
"error", # Treat warnings as errors
"ignore::DeprecationWarning", # Ignore specific warnings
]
# Command line options
addopts = [
"-v", # Verbose output
"--strict-markers", # Enforce marker definitions
"--tb=short", # Simplified traceback
"--disable-warnings", # Disable warnings
"--color=yes" # Colored output
]
# Coverage configuration
[tool.coverage.run]
# Source paths - adjust according to your project structure
source = [
# "../../api/db/services",
# Add more directories if needed:
"../../common",
# "../../utils",
]
# Files/directories to exclude
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/.pytest_cache/*",
"*/venv/*",
"*/.venv/*",
"*/env/*",
"*/site-packages/*",
"*/dist/*",
"*/build/*",
"*/migrations/*",
"setup.py"
]
[tool.coverage.report]
# Report configuration
precision = 2
show_missing = true
skip_covered = false
fail_under = 0 # Minimum coverage requirement (0-100)
# Lines to exclude (optional)
exclude_lines = [
# "pragma: no cover",
# "def __repr__",
# "raise AssertionError",
# "raise NotImplementedError",
# "if __name__ == .__main__.:",
# "if TYPE_CHECKING:",
"pass"
]
[tool.coverage.html]
# HTML report configuration
directory = "htmlcov"
title = "Test Coverage Report"
# extra_css = "custom.css" # Optional custom CSS

View File

@ -14,5 +14,5 @@
# limitations under the License.
#
from beartype.claw import beartype_this_package
beartype_this_package()
# from beartype.claw import beartype_this_package
# beartype_this_package()

View File

@ -23,7 +23,7 @@ from rag.app import naive
from rag.app.naive import by_plaintext, PARSERS
from rag.nlp import bullets_category, is_english,remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks
tokenize_chunks, attach_media_context
from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
@ -175,6 +175,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
image_ctx = max(0, int(parser_config.get("image_context_size", 0) or 0))
if table_ctx or image_ctx:
attach_media_context(res, table_ctx, image_ctx)
return res

Some files were not shown because too many files have changed in this diff Show More