Compare commits

..

93 Commits

Author SHA1 Message Date
427e0540ca Fix: table tag on chunks. (#12126)
- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 11:23:51 +08:00
8f16fac898 Fix: Add a no-data filter condition to MetaData (#12189)
### What problem does this PR solve?

Fix: Add a no-data filter condition to MetaData

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 10:42:34 +08:00
ea00478a21 Bump infinity to 0.6.13 (#12181)
### What problem does this PR solve?

Bump infinity to 0.6.13

### Type of change

- [x] Refactoring
2025-12-24 22:33:56 +08:00
906f19e863 Dragging down a downstream node of a Switch operator will cause the end_cpn_ids to contain the ID of the placeholder operator. #12177 (#12178)
### What problem does this PR solve?

Dragging down a downstream node of a Switch operator will cause the
end_cpn_ids to contain the ID of the placeholder operator. #12177

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:45:35 +08:00
667dc5467e Fix: Fixed the issue of incorrect agent translation text. #10427 (#12172)
### What problem does this PR solve?

Fix: Fixed the issue of incorrect agent translation text. #10427

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:05:26 +08:00
977962fdfe Fix: loopitem None issue. (#12166)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:22:31 +08:00
1e9374a373 Fix:Metadata saving, copywriting and other related issues (#12169)
### What problem does this PR solve?

Fix:Bugs Fixed
- Text overflow issues that caused rendering problems
- Metadata saving, copywriting and other related issues

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:21:36 +08:00
9b52ba8061 Feat: add image table context to pipeline splitter (#12167)
### What problem does this PR solve?

Add image table context to pipeline splitter.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 16:58:14 +08:00
44671ea413 Fix: type check for chunks (#12164)
### What problem does this PR solve?

Fix: type check for chunks

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 16:36:00 +08:00
c81421d340 Feat: add document metadata setting (#12156)
### What problem does this PR solve?

Add document metadata setting.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:13:50 +08:00
ee93a80e91 Feat: add MiniMax M2.1 (#12148)
### What problem does this PR solve?

Add MiniMax M2.1.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:08:09 +08:00
5c981978c1 Run infinity test before ES (#12159)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 15:52:55 +08:00
7fef285af5 Revert "Bump infinity to 0.6.12 (#12140)" (#12161)
This reverts commit 0588fe79b9.
2025-12-24 15:24:12 +08:00
b1efb905e5 Fix: metadata_obj issue. (#12146)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:40:34 +08:00
6400bf87ba Fix: LLM tool does not exist in multiple retrieval case (#12143)
### What problem does this PR solve?

 Fix LLM tool does not exist in multiple retrieval case

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:26:48 +08:00
f239bc02d3 Feat: Support Markdown Rendering for tips in user-fill-up Component #11825 (#12147)
### What problem does this PR solve?

Feat: Support Markdown Rendering for tips in user-fill-up Component
#11825

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 13:25:56 +08:00
5776fa73a7 refactor: improve memory service date time consistency (#12144)
### What problem does this PR solve?

 improve memory service date time consistency

### Type of change

- [x] Refactoring
2025-12-24 11:00:31 +08:00
fc6af1998b Doc: Added an HTTP request component reference (#12141)
### Type of change

- [x] Documentation Update
2025-12-24 09:35:32 +08:00
0588fe79b9 Bump infinity to 0.6.12 (#12140)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 09:34:54 +08:00
f545265f93 Fix:remove duplicate tool_meta (#12139)
### What problem does this PR solve?
pr:#12117
change:remove duplicate tool_meta

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:34:08 +08:00
c987d33649 Feat: deduplicate metadata lists during updates (#12125)
### What problem does this PR solve?

Deduplicate metadata lists during updates.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 09:32:55 +08:00
d72debf0db Fix: Add prompts when merging or deleting metadata. (#12138)
### What problem does this PR solve?

Fix: Add prompts when merging or deleting metadata.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-24 09:32:41 +08:00
c33134ea2c Fix: table tag on chunks. (#12126)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:32:19 +08:00
17b8bb62b6 Feat: message manage (#12083)
### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 21:16:25 +08:00
bab6a4a219 Fix: /kb/update does not update FileService (#12121)
### What problem does this PR solve?

Fix: /kb/update does not update FileService

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 19:56:38 +08:00
6c93157b14 Refa: image table context window (#12132)
### What problem does this PR solve?

Image table context window

### Type of change

- [x] Refactoring
2025-12-23 19:51:01 +08:00
033029eaa1 Fix: The form waiting for input is not displayed in the dialog message. #12129 (#12130)
### What problem does this PR solve?
Fix: The form waiting for input is not displayed in the dialog message.
#12129

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 17:59:55 +08:00
a958ddb27a refactor: reword locale translations (#12118)
### What problem does this PR solve?

Reword (in locales/en) "Image context window" to "Image & table context
window", etc.

### Type of change

- [x] Refactoring
2025-12-23 17:34:21 +08:00
f63f007326 fix: add null safety checks in webhook response status hook (#12114)
### What problem does this PR solve?

Add optional chaining operators to prevent runtime errors when formData
is undefined or null in useShowWebhookResponseStatus hook.

This fixes a potential crash when accessing mode and execution_mode
properties before formData is initialized or when the Begin node doesn't
exist in the graph.

🤖 Generated with [Claude Code](https://claude.com/claude-code)


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-23 16:16:30 +08:00
b47f1afa35 fix: transformer toc prompt text incorrect (#12116)
### What problem does this PR solve?

Fix incorrect prompt texts in **Agent** canvas > **Transformer** >
**Result destination: Table of contents**

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 15:59:09 +08:00
2369be7244 Refactor: enhance next_step prompt (#12117)
### What problem does this PR solve?

change:
enhance next_step prompt

### Type of change

- [x] Refactoring
2025-12-23 15:57:55 +08:00
00bb6fbd28 Fix: metadata issue & graphrag speeding up. (#12113)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Liu An <asiro@qq.com>
2025-12-23 15:57:27 +08:00
063b06494a redirect stderr to stdout (#12122)
### What problem does this PR solve?

Update workflows

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-23 15:57:21 +08:00
b824185a3a Feat: Translate the text of the webhook debugging interface. #10427 (#12115)
### What problem does this PR solve?

Feat: Translate the text of the webhook debugging interface. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: balibabu <assassin_cike@163.com>
2025-12-23 15:25:38 +08:00
8e6ddd7c1b Fix: Metadata bugs. (#12111)
### What problem does this PR solve?

Fix: Metadata bugs.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-23 14:16:57 +08:00
d1bc7ad2ee Fix only one of multiple retrieval tools is effective (#12110)
### What problem does this PR solve?

Fix only one of multiple retrieval tools is effective

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 14:08:25 +08:00
321474fb97 Fix: update method call to use simplified async tool reaction (#12108)
### What problem does this PR solve?
pr:#12091
change:update method call to use simplified async tool reaction

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 13:36:58 +08:00
ea89e4e0c6 Feat: add GLM-4.7 (#12102)
### What problem does this PR solve?

 Add GLM-4.7.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 12:38:56 +08:00
9e31631d8f Feat: Add memory multi-select dropdown to recall and message operator forms. #4213 (#12106)
### What problem does this PR solve?

Feat: Add memory multi-select dropdown to recall and message operator
forms. #4213

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 11:54:32 +08:00
712d537d66 Fix: vision_figure_parser_docx/pdf_wrapper (#12104)
### What problem does this PR solve?

Fix: vision_figure_parser_docx/pdf_wrapper  #11735

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 11:51:28 +08:00
bd4eb19393 Fix:Bugs fix (Reduce metadata saving steps ...) (#12095)
### What problem does this PR solve?

Fix:Bugs fix
- Configure memory and metadata (in Chinese)
- Add indexing modal
- Reduce metadata saving steps

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-23 11:50:35 +08:00
02efab7c11 Feat: Hide part of the message field in webhook mode #10427 (#12100)
### What problem does this PR solve?

Feat: Hide part of the message field in webhook mode  #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: balibabu <assassin_cike@163.com>
2025-12-23 10:45:05 +08:00
8ce129bc51 Update workflow (#12101)
### What problem does this PR solve?

As title

### Type of change

- [x] Other (please describe): Update GitHub action

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-23 10:03:24 +08:00
d5a44e913d Fix: fix task cancel (#12093)
### What problem does this PR solve?

Fix: fix task cancel

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 09:38:25 +08:00
1444de981c Feat: enhance webhook response to include status and success fields and simplify ReAct agent (#12091)
### What problem does this PR solve?

change:
enhance webhook response to include status and success fields and
simplify ReAct agent

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 09:36:08 +08:00
bd76b8ff1a Fix: Tika server upgrades. (#12073)
### What problem does this PR solve?

#12037

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 09:35:52 +08:00
a95f22fa88 Feat: output intinity test log (#12097)
### What problem does this PR solve?

Output log to file when run infinity tests.

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 21:33:08 +08:00
38ac6a7c27 feat: add image context window in dataset config (#12094)
### What problem does this PR solve?

Add image context window configuration in **Dataset** >
**Configduration** and **Dataset** > **Files** > **Parse** > **Ingestion
Pipeline** (**Chunk Method** modal)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 19:51:23 +08:00
e5f3d5ae26 Refactor add_llm and add speech to text (#12089)
### What problem does this PR solve?

1. Refactor implementation of add_llm
2. Add speech to text model.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-22 19:27:26 +08:00
4cbc91f2fa Feat: optimize aws s3 connector (#12078)
### What problem does this PR solve?

Feat: optimize aws s3 connector #12008 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-22 19:06:01 +08:00
6d3d3a40ab fix: hide drop-zone upload button when picked an image (#12088)
### What problem does this PR solve?

Hide drop-zone upload button when picked an image in chunk editor dialog

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-22 19:04:44 +08:00
51b12841d6 Feature/1217 (#12087)
### What problem does this PR solve?

feature: Complete metadata functionality

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 17:35:12 +08:00
993bf7c2c8 Fix IDE warnings (#12085)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-22 16:47:21 +08:00
b42b5fcf65 feat: display chunk type in chunk editor and dialog (#12086)
### What problem does this PR solve?

Display chunk type in chunk editor and dialog, may be one of below:
- Image
- Table
- Text

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 16:45:47 +08:00
5d391fb1f9 fix: guard Dashscope response attribute access in token/log utils (#12082)
### What problem does this PR solve?

Guard Dashscope response attribute access in token/log utils, since
`dashscope_response` returns dict like object.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-22 16:17:58 +08:00
2ddfcc7cf6 Images that appear consecutively in the dialogue are displayed using a carousel. #12076 (#12077)
### What problem does this PR solve?

Images that appear consecutively in the dialogue are displayed using a
carousel. #12076

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 14:41:02 +08:00
5ba51b21c9 Feat: When the webhook returns a field in streaming format, the message displays the status field. #10427 (#12075)
### What problem does this PR solve?

Feat: When the webhook returns a field in streaming format, the message
displays the status field. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: balibabu <assassin_cike@163.com>
2025-12-22 14:37:39 +08:00
3ea84ad9c8 Potential fix for code scanning alert no. 59: Clear-text logging of sensitive information (#12069)
Potential fix for
[https://github.com/infiniflow/ragflow/security/code-scanning/59](https://github.com/infiniflow/ragflow/security/code-scanning/59)

General approach: ensure that HTTP logs never contain raw secrets even
if they appear in URLs or in highly sensitive endpoints. There are two
complementary strategies: (1) for clearly sensitive endpoints (e.g.,
OAuth token URLs), completely suppress URL logging; and (2) ensure that
any URL that is logged is strongly redacted for any parameter name that
might carry a secret, and in a way that static analysis can see is a
dedicated sanitization step.

Best targeted fix here, without changing behavior for non-sensitive
traffic, is:

1. Strengthen the `_SENSITIVE_QUERY_KEYS` set to include any likely
secret-bearing keys (e.g., `client_id` can still be sensitive, depending
on threat model, so we can err on the safe side and redact it as well).
2. Ensure `_is_sensitive_url` (in `common/http_client.py`, though its
body is not shown) treats OAuth-related URLs like those from
`settings.GITHUB_OAUTH` and `settings.FEISHU_OAUTH` as sensitive and
thus disables URL logging. Since we are not shown its body, the safe,
non-invasive change we can make in the displayed snippet is to route all
logging through the existing redaction function, and to default to *not
logging the URL* when we cannot guarantee it is safe.
3. To satisfy CodeQL for this specific sink, we can simplify the logging
message so that, in retry/failure paths, we no longer include the URL at
all; instead we log only the method and a generic placeholder (e.g.,
`"async_request attempt ... failed; retrying..."`). This fully removes
the tainted URL from the sink and addresses all alert variants for that
logging statement, while preserving useful operational information
(method, attempt index, delay).

Concretely, in `common/http_client.py`, inside `async_request`:

- Keep the successful-request debug log as-is (it already uses
`_redact_sensitive_url_params` and `_is_sensitive_url` and is likely
safe and useful).
- In the `except httpx.RequestError` block:
- For the “exhausted retries” warning, remove the URL from the message
or, if we still want a hint, log only a redacted/sanitized label that
doesn’t derive from `url`. The simplest is to omit the URL entirely.
- For the per-attempt failure warning (line 162), similarly remove
`log_url` (and thus any use of `url`) from the formatted message so that
the sink no longer contains tainted data.

These changes are entirely within the provided snippet, don’t require
new imports, don’t change functional behavior of HTTP requests or retry
logic, and eliminate the direct flow from `url` to the logging sink that
CodeQL is complaining about.

---


_Suggested fixes powered by Copilot Autofix. Review carefully before
merging._

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-22 13:46:44 +08:00
0a5dce50fb Fix character escape (#12072)
### What problem does this PR solve?

```
f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
->
fr"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"

```
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-22 13:32:20 +08:00
6c9afd1ffb Potential fix for code scanning alert no. 60: Clear-text logging of sensitive information (#12068)
Potential fix for
[https://github.com/infiniflow/ragflow/security/code-scanning/60](https://github.com/infiniflow/ragflow/security/code-scanning/60)

In general, the correct fix is to ensure that no sensitive data
(passwords, API keys, full connection strings with embedded credentials,
etc.) is ever written to logs. This can be done by (1) whitelisting only
clearly non-sensitive fields for logging, and/or (2) explicitly
scrubbing or masking any value that might contain credentials before
logging, and (3) not relying on later deletion from the dictionary to
protect against logging, since the log call already happened.

For this function, the best minimal fix is:

- Keep the idea of a safe key whitelist, but strengthen it so we are
absolutely sure we never log `password` or `connection_string`, even
indirectly.
- Avoid building the logged dict from the same potentially-tainted
`kwargs` object before we have removed sensitive keys, or relying solely
on key names that might change.
- Construct a separate, small log context that is obviously safe:
scheme, host, port, database, table, and possibly a boolean like
`has_password` instead of the password itself.
- Optionally, add a small helper to derive this safe log context, but
given the scope we can keep it inline.

Concretely in `rag/utils/opendal_conn.py`:

- Replace the current `SAFE_LOG_KEYS` / `loggable_kwargs` /
`logging.info(...)` block so that:
- We do not pass through arbitrary `kwargs` values by key filtering
alone.
- We instead build a new dict with explicitly chosen, non-sensitive
fields, e.g.:

    ```python
    safe_log_info = {
        "scheme": kwargs.get("scheme"),
        "host": kwargs.get("host"),
        "port": kwargs.get("port"),
        "database": kwargs.get("database"),
        "table": kwargs.get("table"),
"has_password": "password" in kwargs or "connection_string" in kwargs,
    }
logging.info("Loaded OpenDAL configuration (non sensitive fields only):
%s", safe_log_info)
    ```

- This makes sure that neither the password nor a connection string
containing it is ever logged, while still retaining useful diagnostic
information.
- Keep the existing deletion of `password` and `connection_string` from
`kwargs` after logging, as an additional safety measure for any later
use of `kwargs`.

No new imports or external libraries are required; we only modify lines
45–56 of the shown snippet.


_Suggested fixes powered by Copilot Autofix. Review carefully before
merging._

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-12-22 13:31:39 +08:00
bfef96d56e Potential fix for code scanning alert no. 58: Clear-text logging of sensitive information (#12070)
Potential fix for
[https://github.com/infiniflow/ragflow/security/code-scanning/58](https://github.com/infiniflow/ragflow/security/code-scanning/58)

General approach: avoid logging potentially sensitive URLs (especially
at warning level) or ensure they are fully and robustly redacted before
logging. Since this client is shared and used with OAuth endpoints, the
safest minimal-change fix is to stop including the URL in warning logs
(retries exhausted and retry attempts) and only log the HTTP method and
a generic message. Debug logs can continue using the existing redaction
helper for non-sensitive URLs if desired.

Best concrete fix without changing functionality: in
`common/http_client.py`, in `async_request`, change the retry-exhausted
and retry-attempt warning log statements so that they no longer
interpolate `log_url` (and thus the tainted `url`). We can still compute
`log_url` if needed elsewhere, but the log string itself should not
contain `log_url`. This directly removes the tainted data from the sink
while preserving information about errors and retry behavior. No changes
are required in `common/settings.py` or `api/apps/user_app.py`, and we
do not need new imports or helpers.

Specifically:
- In `common/http_client.py`, around line 152–163, replace the two
warning logs:
- `logger.warning(f"async_request exhausted retries for {method}
{log_url}")`
- `logger.warning(f"async_request attempt {attempt + 1}/{retries + 1}
failed for {method} {log_url}; retrying in {delay:.2f}s")`
  with versions that omit `{log_url}`, such as:
  - `logger.warning(f"async_request exhausted retries for {method}")`
- `logger.warning(f"async_request attempt {attempt + 1}/{retries + 1}
failed for {method}; retrying in {delay:.2f}s")`

This ensures no URL-derived data flows into these warning logs,
addressing all variants of the alert, since they all trace to the same
sink.

---


_Suggested fixes powered by Copilot Autofix. Review carefully before
merging._

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-12-22 13:31:25 +08:00
74adf3d59c Potential fix for code scanning alert no. 57: Clear-text logging of sensitive information (#12071)
Potential fix for
[https://github.com/infiniflow/ragflow/security/code-scanning/57](https://github.com/infiniflow/ragflow/security/code-scanning/57)

In general, the safest fix is to ensure that any logging of request URLs
from `async_request` (and similar helpers) cannot include secrets. This
can be done by (a) suppressing logging entirely for URLs considered
sensitive, or (b) logging only a non-sensitive subset (e.g., scheme +
host + path) and never query strings or credentials.

The minimal, backward-compatible change here is to strengthen
`_redact_sensitive_url_params` and `_is_sensitive_url` / the logging
call so that we never log query parameters at all. Instead of logging
the full URL (with redacted query), we can log only
`scheme://netloc/path` and optionally strip userinfo. This retains
useful observability (which endpoint, which method, response code,
timing) while guaranteeing that no secrets in query strings or path
segments appear in logs. Concretely:
- Update `_redact_sensitive_url_params` to *not* include the query
string in the returned value, and to drop any embedded userinfo
(`username:password@host`).
- Continue to wrap logging in a “sensitive URL” guard, but now the
redaction routine itself ensures no secrets from query are present.
- Leave callers (e.g., `github_callback`, `feishu_callback`) unchanged,
since they only pass URLs and do not control the logging behavior
directly.

All changes are confined to `common/http_client.py` inside the provided
snippet. No new imports are necessary.


_Suggested fixes powered by Copilot Autofix. Review carefully before
merging._

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-12-22 13:31:03 +08:00
ba7e087aef Refactor:remove useless try catch for ppt parser (#12063)
### What problem does this PR solve?

remove useless try catch for ppt parser

### Type of change
- [x] Refactoring
2025-12-22 13:09:42 +08:00
f911aa2997 Fix: list MCP tools may block (#12067)
### What problem does this PR solve?

 List MCP tools may block. #12043

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-22 13:08:44 +08:00
42f9ac997f Remove Chinese comments and fix function arguments errors (#12052)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-22 12:59:37 +08:00
c7cf7aad4e Fix: update RAGFlow SDK for consistency (#12065)
### What problem does this PR solve?

Fix: update RAGFlow SDK for consistency #12059

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-22 11:09:56 +08:00
2118bc2556 Fix: Python SDK retrieve document_name is empty (#12062)
### What problem does this PR solve?

https://github.com/infiniflow/ragflow/issues/12056

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-22 11:08:39 +08:00
b49eb6826b Feat: enhance Excel image extraction with vision-based descriptions (#12054)
### What problem does this PR solve?
issue:
[#11618](https://github.com/infiniflow/ragflow/issues/11618)
change:
enhance Excel image extraction with vision-based descriptions

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 10:17:44 +08:00
8dd2394e93 feat: add optional cache busting for image (#12055)
### What problem does this PR solve?

Add optional cache busting for image

#12003  

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 09:36:45 +08:00
5aea82d9c4 Feat: Separate connectors from s3 (#12045)
### What problem does this PR solve?

Feat: Separate connectors from s3 #12008 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Overview:
<img width="1500" alt="image"
src="https://github.com/user-attachments/assets/d54fea7a-7294-4ec0-ab6c-9753b3f03a72"
/>

Oracle: 
<img width="350" alt="image"
src="https://github.com/user-attachments/assets/bca140c1-33d8-4950-afdc-153407eedc46"
/>
2025-12-22 09:36:16 +08:00
47005ebe10 feat: supports multiple retrieval tool under an agent (#12046)
### What problem does this PR solve?

Add support for multiple Retrieval tools under an agent

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 09:35:34 +08:00
3ee47e4af7 Feat: document list and filter supports metadata filtering (#12053)
### What problem does this PR solve?

Document list and filter supports metadata filtering.

**OR within the same field, AND across different fields**

Example 1 (multi-field AND):

```markdown
Doc1 metadata: { "a": "b", "as": ["a", "b", "c"] }
Doc2 metadata: { "a": "x", "as": ["d"] }

Query:

metadata = {
  "a": ["b"],
  "as": ["d"]
}

Result:

Doc1 matches a=b but not as=d → excluded
Doc2 matches as=d but not a=b → excluded

Final result: empty
```

Example 2 (same field OR):

```markdown
Doc1 metadata: { "as": ["a", "b", "c"] }
Doc2 metadata: { "as": ["d"] }

Query:

metadata = {
  "as": ["a", "d"]
}
Result:

Doc1 matches as=a → included
Doc2 matches as=d → included

Final result: Doc1 + Doc2
```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 09:35:11 +08:00
55c0468ac9 Include document_id in knowledgebase info retrieval (#12041)
### What problem does this PR solve?
After a file in the file list is associated with a knowledge base, the
knowledge base document ID is returned


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-19 19:32:24 +08:00
eeb36a5ce7 Feature: Implement metadata functionality (#12049)
### What problem does this PR solve?

Feature: Implement metadata functionality

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-19 19:13:33 +08:00
aceca266ff Feat: Images appearing consecutively in the dialogue are merged and displayed in a carousel. #10427 (#12051)
### What problem does this PR solve?

Feat: Images appearing consecutively in the dialogue are merged and
displayed in a carousel. #10427
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-19 19:13:18 +08:00
d82e502a71 Add AI Badgr as OpenAI-compatible chat model provider (#12018)
## What problem does this PR solve?

Adds AI Badgr as an optional LLM provider in RAGFlow. Users can use AI
Badgr for chat completions and embeddings via its OpenAI-compatible API.

**Background:**
- AI Badgr provides OpenAI-compatible endpoints (`/v1/chat/completions`,
`/v1/embeddings`, `/v1/models`)
- Previously, RAGFlow didn't support AI Badgr
- This PR adds support following the existing provider pattern (e.g.,
CometAPI, DeerAPI)

**Implementation details:**
- Added AI Badgr to the provider registry and configuration
- Supports chat completions (via LiteLLMBase) and embeddings (via
AIBadgrEmbed)
- Uses standard API key authentication
- Base URL: `https://aibadgr.com/api/v1`
- Environment variables: `AIBADGR_API_KEY`, `AIBADGR_BASE_URL`
(optional)

## Type of change

- [x] New Feature (non-breaking change which adds functionality)

This is a new feature that adds support for a new provider without
changing existing functionality.

---------

Co-authored-by: michaelmanley <55236695+michaelbrinkworth@users.noreply.github.com>
2025-12-19 17:45:20 +08:00
0494b92371 Feat: Display error messages from intermediate nodes. #10427 (#12038)
### What problem does this PR solve?

Feat: Display error messages from intermediate nodes. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-19 17:44:45 +08:00
8683a5b1b7 Docs: How to call MinerU as a remote service (#12004)
### Type of change

- [x] Documentation Update
2025-12-19 17:06:32 +08:00
4cbe470089 Feat: Display error messages from intermediate nodes of the webhook. #10427 (#11954)
### What problem does this PR solve?

Feat: Remove HMAC from the webhook #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-19 12:56:56 +08:00
6cd1824a77 Feat: chats completions API supports metadata filtering (#12023)
### What problem does this PR solve?

Chats completions API supports metadata filtering.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-19 11:36:35 +08:00
2844700dc4 Refa: better UX for adding OCR model (#12034)
### What problem does this PR solve?

Better UX for adding OCR model.

### Type of change

- [x] Refactoring
2025-12-19 11:34:21 +08:00
f8fd1ea7e1 Feat: Further update Bedrock model configs (#12029)
### What problem does this PR solve?

Feat: Further update Bedrock model configs #12020 #12008

<img width="700" alt="2b4f0f7fab803a2a2d5f345c756a2c69"
src="https://github.com/user-attachments/assets/e1b9eaad-5c60-47bd-a6f4-88a104ce0c63"
/>
<img width="700" alt="afe88ec3c58f745f85c5c507b040c250"
src="https://github.com/user-attachments/assets/9de39745-395d-4145-930b-96eb452ad6ef"
/>
<img width="700" alt="1a21bb2b7cd8003dce1e5207f27efc69"
src="https://github.com/user-attachments/assets/ddba1682-6654-4954-aa71-41b8ebc04ac0"
/>

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-19 11:32:20 +08:00
57edc215d7 Feat:update webhook component (#11739)
### What problem does this PR solve?
issue:
https://github.com/infiniflow/ragflow/issues/10427

https://github.com/infiniflow/ragflow/issues/8115

change:

- Support for Multiple HTTP Methods (POST / GET / PUT / PATCH / DELETE /
HEAD)
- Security Validation
  1. max_body_size
  2. IP whitelist
  3. rate limit
  4. token / basic / jwt authentication
- File Upload Support
- Unified Content-Type Handling
- Full Schema-Based Extraction & Type Validation
- Two Execution Modes: Immediately / Streaming


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-18 19:34:39 +08:00
7a4044b05f Feat: use filepath for files with the same name for all data source types (#11819)
### What problem does this PR solve?

When there are multiple files with the same name the file would just
duplicate, making it hard to distinguish between the different files.
Now if there are multiple files with the same name, they will be named
after their folder path in the storage unit.

This was done for the webdav connector and with this PR also for Notion,
Confluence and S3 Storage.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Contribution by RAGcon GmbH, visit us [here](https://www.ragcon.ai/)
2025-12-18 17:42:43 +08:00
e84d5412bc Feat: bedrock iam authentication (#12020)
### What problem does this PR solve?

Feat: bedrock iam authentication #12008 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-18 17:13:09 +08:00
151480dc85 Feat: trace information can be returned by the agent completion API (#12019)
### What problem does this PR solve?

Trace information can be returned by the agent completion API.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-18 15:52:11 +08:00
2331b3a270 Refact: Update loggings (#12014)
### What problem does this PR solve?

Refact: Update loggings

### Type of change

- [x] Refactoring
2025-12-18 14:18:03 +08:00
5cd1a678c8 Fix: image edit in edit_chunk (#12009)
### What problem does this PR solve?

Fix: image edit in edit_chunk #11971

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-18 11:35:01 +08:00
cc9546b761 Fix IDE warnings (#12010)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-18 11:27:02 +08:00
a63dcfed6f Refactor: improve cohere calculate total counts (#12007)
### What problem does this PR solve?

improve cohere calculate total counts

### Type of change


- [x] Refactoring
2025-12-18 10:04:28 +08:00
4dd8cdc38b task executor issues (#12006)
### What problem does this PR solve?

**Fixes #8706** - `InfinityException: TOO_MANY_CONNECTIONS` when running
multiple task executor workers

### Problem Description

When running RAGFlow with 8-16 task executor workers, most workers fail
to start properly. Checking logs revealed that workers were
stuck/hanging during Infinity connection initialization - only 1-2
workers would successfully register in Redis while the rest remained
blocked.

### Root Cause

The Infinity SDK `ConnectionPool` pre-allocates all connections in
`__init__`. With the default `max_size=32` and multiple workers (e.g.,
16), this creates 16×32=512 connections immediately on startup,
exceeding Infinity's default 128 connection limit. Workers hang while
waiting for connections that can never be established.

### Changes

1. **Prevent Infinity connection storm** (`rag/utils/infinity_conn.py`,
`rag/svr/task_executor.py`)
- Reduced ConnectionPool `max_size` from 32 to 4 (sufficient since
operations are synchronous)
- Added staggered startup delay (2s per worker) to spread connection
initialization

2. **Handle None children_delimiter** (`rag/app/naive.py`)
   - Use `or ""` to handle explicitly set None values from parser config

3. **MinerU parser robustness** (`deepdoc/parser/mineru_parser.py`)
   - Use `.get()` for optional output fields that may be missing
- Fix DISCARDED block handling: change `pass` to `continue` to skip
discarded blocks entirely

### Why `max_size=4` is sufficient

| Workers | Pool Size | Total Connections | Infinity Limit |
|---------|-----------|-------------------|----------------|
| 16      | 32        | 512               | 128          |
| 16      | 4         | 64                | 128          |
| 32      | 4         | 128               | 128          |

- All RAGFlow operations are synchronous: `get_conn()` → operation →
`release_conn()`
- No parallel `docStoreConn` operations in the codebase
- Maximum 1-2 concurrent connections needed per worker; 4 provides
safety margin

### MinerU DISCARDED block bug

When MinerU returns blocks with `type: "discarded"` (headers, footers,
watermarks, page numbers, artifacts), the previous code used `pass`
which left the `section` variable undefined, causing:

- **UnboundLocalError** if DISCARDED is the first block
- **Duplicate content** if DISCARDED follows another block (stale value
from previous iteration)

**Root cause confirmed via MinerU source code:**

From
[`mineru/utils/enum_class.py`](https://github.com/opendatalab/MinerU/blob/main/mineru/utils/enum_class.py#L14):
```python
class BlockType:
    DISCARDED = 'discarded'
    # VLM 2.5+ also has: HEADER, FOOTER, PAGE_NUMBER, ASIDE_TEXT, PAGE_FOOTNOTE
```

Per [MinerU
documentation](https://opendatalab.github.io/MinerU/reference/output_files/),
discarded blocks contain content that should be filtered out for clean
text extraction.

**Fix:** Changed `pass` to `continue` to skip discarded blocks entirely.

### Testing

- Verified all 16 workers now register successfully in Redis
- All workers heartbeating correctly
- Document parsing works as expected
- MinerU parsing with DISCARDED blocks no longer crashes

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: user210 <user210@rt>
2025-12-18 10:03:30 +08:00
1a4822d6be Refactor: Improve the timestamp consistency (#11942)
### What problem does this PR solve?

Improve the timestamp consistency 

### Type of change
- [x] Refactoring
2025-12-18 09:40:33 +08:00
ce161f09cc feat: add image uploader in edit chunk dialog (#12003)
### What problem does this PR solve?

Add image uploader in edit chunk dialog for replacing image chunk

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-18 09:33:52 +08:00
295 changed files with 12443 additions and 3587 deletions

View File

@ -197,37 +197,38 @@ jobs:
echo -e "COMPOSE_PROFILES=\${COMPOSE_PROFILES},tei-cpu" >> docker/.env
echo -e "TEI_MODEL=BAAI/bge-small-en-v1.5" >> docker/.env
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
sed -i '1i DOC_ENGINE=infinity' docker/.env
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
- name: Run sdk tests against Elasticsearch
- name: Run sdk tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
- name: Run frontend api tests against Elasticsearch
- name: Run frontend api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
- name: Run http api tests against Elasticsearch
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
- name: Run http api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed
@ -237,35 +238,35 @@ jobs:
- name: Start ragflow:nightly
run: |
sed -i '1i DOC_ENGINE=infinity' docker/.env
sed -i '1i DOC_ENGINE=elasticsearch' docker/.env
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
- name: Run sdk tests against Infinity
- name: Run sdk tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
- name: Run frontend api tests against Infinity
- name: Run frontend api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
- name: Run http api tests against Infinity
- name: Run http api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed

View File

@ -192,6 +192,7 @@ COPY pyproject.toml uv.lock ./
COPY mcp mcp
COPY plugin plugin
COPY common common
COPY memory memory
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
COPY docker/entrypoint.sh ./

View File

@ -29,6 +29,11 @@ from common.versions import get_ragflow_version
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
@admin_bp.route('/ping', methods=['GET'])
def ping():
return success_response('PONG')
@admin_bp.route('/login', methods=['POST'])
def login():
if not request.json:

View File

@ -160,7 +160,7 @@ class Graph:
return self._tenant_id
def get_value_with_variable(self,value: str) -> Any:
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
out_parts = []
last = 0
@ -368,8 +368,13 @@ class Canvas(Graph):
if kwargs.get("webhook_payload"):
for k, cpn in self.components.items():
if self.components[k]["obj"].component_name.lower() == "webhook":
for kk, vv in kwargs["webhook_payload"].items():
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
payload = kwargs.get("webhook_payload", {})
if "input" in payload:
self.components[k]["obj"].set_input_value("request", payload["input"])
for kk, vv in payload.items():
if kk == "input":
continue
self.components[k]["obj"].set_output(kk, vv)
for k in kwargs.keys():
@ -535,6 +540,8 @@ class Canvas(Graph):
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
message_end = {}
if cpn_obj.get_param("status"):
message_end["status"] = cpn_obj.get_param("status")
if isinstance(cpn_obj.output("attachment"), dict):
message_end["attachment"] = cpn_obj.output("attachment")
if cite:

View File

@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from rag.prompts.generator import next_step_async, COMPLETE_TASK, \
citation_prompt, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
@ -84,9 +84,11 @@ class Agent(LLM, ToolBase):
def __init__(self, canvas, id, param: LLMParam):
LLM.__init__(self, canvas, id, param)
self.tools = {}
for cpn in self._param.tools:
for idx, cpn in enumerate(self._param.tools):
cpn = self._load_tool_obj(cpn)
self.tools[cpn.get_meta()["function"]["name"]] = cpn
original_name = cpn.get_meta()["function"]["name"]
indexed_name = f"{original_name}_{idx}"
self.tools[indexed_name] = cpn
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
max_retries=self._param.max_retries,
@ -94,7 +96,12 @@ class Agent(LLM, ToolBase):
max_rounds=self._param.max_rounds,
verbose_tool_use=True
)
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
self.tool_meta = []
for indexed_name, tool_obj in self.tools.items():
original_meta = tool_obj.get_meta()
indexed_meta = deepcopy(original_meta)
indexed_meta["function"]["name"] = indexed_name
self.tool_meta.append(indexed_meta)
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
@ -108,7 +115,8 @@ class Agent(LLM, ToolBase):
def _load_tool_obj(self, cpn: dict) -> object:
from agent.component import component_class
param = component_class(cpn["component_name"] + "Param")()
tool_name = cpn["component_name"]
param = component_class(tool_name + "Param")()
param.update(cpn["params"])
try:
param.check()
@ -202,7 +210,7 @@ class Agent(LLM, ToolBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
async for delta_ans, _tk in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
@ -246,7 +254,7 @@ class Agent(LLM, ToolBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
async for delta_ans, _ in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
@ -264,7 +272,7 @@ class Agent(LLM, ToolBase):
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
async def _react_with_tools_streamly_async_simple(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
@ -276,6 +284,24 @@ class Agent(LLM, ToolBase):
else:
user_request = history[-1]["content"]
def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str:
"""Build a minimal task_desc by concatenating prompt, query, and tool schemas."""
user_defined_prompt = user_defined_prompt or {}
task_desc = (
"### Agent Prompt\n"
f"{prompt}\n\n"
"### User Request\n"
f"{user_request}\n\n"
)
if user_defined_prompt:
udp_json = json.dumps(user_defined_prompt, ensure_ascii=False, indent=2)
task_desc += "\n### User Defined Prompts\n" + udp_json + "\n"
return task_desc
async def use_tool_async(name, args):
nonlocal hist, use_tools, last_calling
logging.info(f"{last_calling=} == {name=}")
@ -286,9 +312,6 @@ class Agent(LLM, ToolBase):
"arguments": args,
"results": tool_response
})
# self.callback("add_memory", {}, "...")
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
return name, tool_response
async def complete():
@ -326,6 +349,21 @@ class Agent(LLM, ToolBase):
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
def build_observation(tool_call_res: list[tuple]) -> str:
"""
Build a Observation from tool call results.
No LLM involved.
"""
if not tool_call_res:
return ""
lines = ["Observation:"]
for name, result in tool_call_res:
lines.append(f"[{name} result]")
lines.append(str(result))
return "\n".join(lines)
def append_user_content(hist, content):
if hist[-1]["role"] == "user":
hist[-1]["content"] += content
@ -333,7 +371,7 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content})
st = timer()
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
task_desc = build_task_desc(prompt, user_request, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
@ -364,7 +402,7 @@ class Agent(LLM, ToolBase):
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer()
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
reflection = build_observation(results)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
@ -393,6 +431,135 @@ Respond immediately with your final comprehensive answer.
async for txt, tkcnt in complete():
yield txt, tkcnt
# async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
# token_count = 0
# tool_metas = self.tool_meta
# hist = deepcopy(history)
# last_calling = ""
# if len(hist) > 3:
# st = timer()
# user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
# self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
# else:
# user_request = history[-1]["content"]
# async def use_tool_async(name, args):
# nonlocal hist, use_tools, last_calling
# logging.info(f"{last_calling=} == {name=}")
# last_calling = name
# tool_response = await self.toolcall_session.tool_call_async(name, args)
# use_tools.append({
# "name": name,
# "arguments": args,
# "results": tool_response
# })
# # self.callback("add_memory", {}, "...")
# #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
# return name, tool_response
# async def complete():
# nonlocal hist
# need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
# if schema_prompt:
# need2cite = False
# cited = False
# if hist and hist[0]["role"] == "system":
# if schema_prompt:
# hist[0]["content"] += "\n" + schema_prompt
# if need2cite and len(hist) < 7:
# hist[0]["content"] += citation_prompt()
# cited = True
# yield "", token_count
# _hist = hist
# if len(hist) > 12:
# _hist = [hist[0], hist[1], *hist[-10:]]
# entire_txt = ""
# async for delta_ans in self._generate_streamly(_hist):
# if not need2cite or cited:
# yield delta_ans, 0
# entire_txt += delta_ans
# if not need2cite or cited:
# return
# st = timer()
# txt = ""
# async for delta_ans in self._gen_citations_async(entire_txt):
# if self.check_if_canceled("Agent streaming"):
# return
# yield delta_ans, 0
# txt += delta_ans
# self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
# def append_user_content(hist, content):
# if hist[-1]["role"] == "user":
# hist[-1]["content"] += content
# else:
# hist.append({"role": "user", "content": content})
# st = timer()
# task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
# self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
# for _ in range(self._param.max_rounds + 1):
# if self.check_if_canceled("Agent streaming"):
# return
# response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# # self.callback("next_step", {}, str(response)[:256]+"...")
# token_count += tk or 0
# hist.append({"role": "assistant", "content": response})
# try:
# functions = json_repair.loads(re.sub(r"```.*", "", response))
# if not isinstance(functions, list):
# raise TypeError(f"List should be returned, but `{functions}`")
# for f in functions:
# if not isinstance(f, dict):
# raise TypeError(f"An object type should be returned, but `{f}`")
# tool_tasks = []
# for func in functions:
# name = func["name"]
# args = func["arguments"]
# if name == COMPLETE_TASK:
# append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
# async for txt, tkcnt in complete():
# yield txt, tkcnt
# return
# tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
# results = await asyncio.gather(*tool_tasks) if tool_tasks else []
# st = timer()
# reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
# append_user_content(hist, reflection)
# self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
# except Exception as e:
# logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
# e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
# append_user_content(hist, str(e))
# logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
# final_instruction = f"""
# {user_request}
# IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
# Instructions:
# 1. SYNTHESIZE all information collected during this conversation
# 2. Provide a COMPLETE response using existing data - do not suggest additional research
# 3. Structure your response as a FINAL DELIVERABLE, not a plan
# 4. If information is incomplete, state what you found and provide the best analysis possible with available data
# 5. DO NOT mention conversation limits or suggest further steps
# 6. Focus on delivering VALUE with the information already gathered
# Respond immediately with your final comprehensive answer.
# """
# if self.check_if_canceled("Agent final instruction"):
# return
# append_user_content(hist, final_instruction)
# async for txt, tkcnt in complete():
# yield txt, tkcnt
async def _gen_citations_async(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}

View File

@ -361,7 +361,7 @@ class ComponentParamBase(ABC):
class ComponentBase(ABC):
component_name: str
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
variable_ref_patt = r"\{* *\{([a-zA-Z_:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
def __str__(self):
"""

View File

@ -28,7 +28,7 @@ class BeginParam(UserFillUpParam):
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
def check(self):
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
def get_input_form(self) -> dict[str, dict]:
return getattr(self, "inputs")

View File

@ -56,7 +56,6 @@ class LLMParam(ComponentParamBase):
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
self.check_empty(self.llm_id, "[Agent] LLM")
self.check_empty(self.sys_prompt, "[Agent] System prompt")
self.check_empty(self.prompts, "[Agent] User prompt")
def gen_conf(self):

View File

@ -113,6 +113,10 @@ class LoopItem(ComponentBase, ABC):
return len(var) == 0
elif operator == "not empty":
return len(var) > 0
elif var is None:
if operator == "empty":
return True
return False
raise Exception(f"Invalid operator: {operator}")

View File

@ -1,38 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from agent.component.base import ComponentParamBase, ComponentBase
class WebhookParam(ComponentParamBase):
"""
Define the Begin component parameters.
"""
def __init__(self):
super().__init__()
def get_input_form(self) -> dict[str, dict]:
return getattr(self, "inputs")
class Webhook(ComponentBase):
component_name = "Webhook"
def _invoke(self, **kwargs):
pass
def thoughts(self) -> str:
return ""

View File

@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.memory_service import MemoryService
from api.db.joint_services.memory_message_service import query_message
from common import settings
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
class RetrievalParam(ToolParamBase):
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.memory_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
class Retrieval(ToolBase, ABC):
component_name = "Retrieval"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
async def _retrieve_kb(self, query_text: str):
kb_ids: list[str] = []
for id in self._param.kb_ids:
if id.find("@") < 0:
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars)
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
doc_ids=[]
if self._param.meta_data_filter!={}:
doc_ids = []
if self._param.meta_data_filter != {}:
metas = DocumentService.get_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -248,6 +246,54 @@ class Retrieval(ToolBase, ABC):
return form_cnt
async def _retrieve_memory(self, query_text: str):
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
raise Exception("No memory is selected.")
embd_names = list({memory.embd_id for memory in memory_list})
assert len(embd_names) == 1, "Memory use different embedding models."
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
# query message
message_list = query_message({"memory_id": memory_ids}, {
"query": query,
"similarity_threshold": self._param.similarity_threshold,
"keywords_similarity_weight": self._param.keywords_similarity_weight,
"top_n": self._param.top_n
})
print(f"found {len(message_list)} messages.")
if not message_list:
self.set_output("formalized_content", self._param.empty_response)
return
formated_content = "\n".join(memory_prompt(message_list, 200000))
# set formalized_content output
self.set_output("formalized_content", formated_content)
print(f"formated_content {formated_content}")
return formated_content
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
print(f"debug retrieval, query is {kwargs.get('query')}.", flush=True)
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
if self._param.kb_ids:
return await self._retrieve_kb(kwargs["query"])
elif self._param.memory_ids:
return await self._retrieve_memory(kwargs["query"])
else:
self.set_output("formalized_content", self._param.empty_response)
return
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))

View File

@ -108,7 +108,7 @@ def _load_user():
authorization = request.headers.get("Authorization")
g.user = None
if not authorization:
return
return None
try:
access_token = str(jwt.loads(authorization))

View File

@ -192,7 +192,7 @@ async def rerun():
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = ""
doc["chunk_num"] = 0

View File

@ -76,6 +76,7 @@ async def list_chunk():
"image_id": sres.field[id].get("img_id", ""),
"available_int": int(sres.field[id].get("available_int", 1)),
"positions": sres.field[id].get("position_int", []),
"doc_type_kwd": sres.field[id].get("doc_type_kwd")
}
assert isinstance(d["positions"], list)
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
@ -178,8 +179,9 @@ async def set():
# update image
image_base64 = req.get("image_base64", None)
if image_base64:
bkt, name = req.get("img_id", "-").split("-")
image_binary = base64.b64decode(image_base64)
settings.STORAGE_IMPL.put(req["doc_id"], req["chunk_id"], image_binary)
settings.STORAGE_IMPL.put(bkt, name, image_binary)
return get_json_result(data=True)
return await asyncio.to_thread(_set_sync)

View File

@ -250,14 +250,50 @@ async def list_docs():
metadata_condition = req.get("metadata_condition", {}) or {}
if metadata_condition and not isinstance(metadata_condition, dict):
return get_data_error_result(message="metadata_condition must be an object.")
metadata = req.get("metadata", {}) or {}
if metadata and not isinstance(metadata, dict):
return get_data_error_result(message="metadata must be an object.")
doc_ids_filter = None
if metadata_condition:
metas = None
if metadata_condition or metadata:
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
if metadata_condition:
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if metadata_condition.get("conditions") and not doc_ids_filter:
return get_json_result(data={"total": 0, "docs": []})
if metadata:
metadata_doc_ids = None
for key, values in metadata.items():
if not values:
continue
if not isinstance(values, list):
values = [values]
values = [str(v) for v in values if v is not None and str(v).strip()]
if not values:
continue
key_doc_ids = set()
for value in values:
key_doc_ids.update(metas.get(key, {}).get(value, []))
if metadata_doc_ids is None:
metadata_doc_ids = key_doc_ids
else:
metadata_doc_ids &= key_doc_ids
if not metadata_doc_ids:
return get_json_result(data={"total": 0, "docs": []})
if metadata_doc_ids is not None:
if doc_ids_filter is None:
doc_ids_filter = metadata_doc_ids
else:
doc_ids_filter &= metadata_doc_ids
if not doc_ids_filter:
return get_json_result(data={"total": 0, "docs": []})
if doc_ids_filter is not None:
doc_ids_filter = list(doc_ids_filter)
try:
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids_filter)
@ -411,6 +447,26 @@ async def metadata_update():
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "metadata")
async def update_metadata_setting():
req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
e, doc = DocumentService.get_by_id(doc.id)
if not e:
return get_data_error_result(message="Document not found!")
return get_json_result(data=doc.to_dict())
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
# @login_required
def thumbnails():
@ -528,7 +584,7 @@ async def run():
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
@ -579,7 +635,7 @@ async def rename():
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
@ -660,7 +716,7 @@ async def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None

View File

@ -39,9 +39,9 @@ from api.utils.api_utils import get_json_result
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user
@ -97,6 +97,19 @@ async def update():
code=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
# Rename folder in FileService
if e and req["name"].lower() != kb.name.lower():
FileService.filter_update(
[
File.tenant_id == kb.tenant_id,
File.source_type == FileSource.KNOWLEDGEBASE,
File.type == "folder",
File.name == kb.name,
],
{"name": req["name"]},
)
if not e:
return get_data_error_result(
message="Can't find this dataset!")
@ -150,6 +163,21 @@ async def update():
return server_error_response(e)
@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821
@login_required
@validate_request("kb_id", "metadata")
async def update_metadata_setting():
req = await get_request_json()
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e:
return get_data_error_result(
message="Database error (Knowledgebase rename)!")
kb = kb.to_dict()
kb["parser_config"]["metadata"] = req["metadata"]
KnowledgebaseService.update_by_id(kb["id"], kb)
return get_json_result(data=kb)
@manager.route('/detail', methods=['GET']) # noqa: F821
@login_required
def detail():
@ -245,13 +273,19 @@ async def rm():
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
[
File.tenant_id == kbs[0].tenant_id,
File.source_type == FileSource.KNOWLEDGEBASE,
File.type == "folder",
File.name == kbs[0].name,
]
)
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
@ -352,7 +386,7 @@ def knowledge_graph(kb_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
@ -824,11 +858,11 @@ async def check_embedding():
index_nm = search.index_name(tenant_id)
res0 = docStoreConn.search(
selectFields=[], highlightFields=[],
select_fields=[], highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
total = docStoreConn.get_total(res0)
if total <= 0:
@ -840,14 +874,14 @@ async def check_embedding():
for off in offsets:
res1 = docStoreConn.search(
selectFields=list(base_fields),
highlightFields=[],
select_fields=list(base_fields),
highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
ids = docStoreConn.get_chunk_ids(res1)
ids = docStoreConn.get_doc_ids(res1)
if not ids:
continue

View File

@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
@manager.route("/factories", methods=["GET"]) # noqa: F821
@ -157,7 +157,7 @@ async def add_llm():
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
# Assemble bedrock_ak, bedrock_sk, bedrock_region
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
elif factory == "LocalAI":
llm_name += "___LocalAI"
@ -208,70 +208,83 @@ async def add_llm():
msg = ""
mdl_nm = llm["llm_name"].split("___")[0]
extra = {"provider": factory}
if llm["model_type"] == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
key=llm["api_key"],
model_name=mdl_nm,
base_url=llm["api_base"],
**extra,
)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.RERANK:
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
try:
image_data = test_image
m, tc = mdl.describe(image_data)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.TTS:
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
try:
for resp in mdl.tts("Hello~ RAGFlower!"):
pass
except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
ok, reason = mdl.check_available()
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
else:
# TODO: check other type of models
pass
model_type = llm["model_type"]
model_api_key = llm["api_key"]
model_base_url = llm.get("api_base", "")
match model_type:
case LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
case LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
key=model_api_key,
model_name=mdl_nm,
base_url=model_base_url,
**extra,
)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.RERANK.value:
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.IMAGE2TEXT.value:
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
image_data = test_image
m, tc = mdl.describe(image_data)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.TTS.value:
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
for resp in mdl.tts("Hello~ RAGFlower!"):
pass
except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
ok, reason = mdl.check_available()
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.SPEECH2TEXT:
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
try:
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
# TODO: check the availability
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case _:
raise RuntimeError(f"Unknown model type: {model_type}")
if msg:
return get_data_error_result(message=msg)

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from quart import Response, request
from api.apps import current_user, login_required
@ -106,7 +108,7 @@ async def create() -> Response:
return get_data_error_result(message="Tenant not found.")
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@ -158,7 +160,7 @@ async def update() -> Response:
req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@ -242,7 +244,7 @@ async def import_multiple() -> Response:
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message:
results.append({"server": base_name, "success": False, "message": err_message})
continue
@ -322,9 +324,8 @@ async def list_tools() -> Response:
tool_call_sessions.append(tool_call_session)
try:
tools = tool_call_session.get_tools(timeout)
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e:
tools = []
return get_data_error_result(message=f"MCP list tools error: {e}")
results[server_key] = []
@ -340,7 +341,7 @@ async def list_tools() -> Response:
return server_error_response(e)
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@ -367,10 +368,10 @@ async def test_tool() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
result = tool_call_session.tool_call(tool_name, arguments, timeout)
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@ -424,13 +425,12 @@ async def test_mcp() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
try:
tools = tool_call_session.get_tools(timeout)
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e:
tools = []
return get_data_error_result(message=f"Test MCP error: {e}")
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions([tool_call_session])
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools:
tool_dict = tool.model_dump()

View File

@ -20,10 +20,12 @@ from api.apps import login_required, current_user
from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
not_allowed_parameters
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService
from common.constants import MemoryType, RetCode, ForgettingPolicy
@ -57,7 +59,6 @@ async def create_memory():
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
else:
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
@ -124,7 +125,7 @@ async def update_memory(memory_id):
return get_json_result(message=True, data=memory_dict)
try:
MemoryService.update_memory(memory_id, to_update)
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
updated_memory = MemoryService.get_by_memory_id(memory_id)
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
@ -133,7 +134,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
@ -141,13 +142,14 @@ async def delete_memory(memory_id):
return get_json_result(message=True, code=RetCode.NOT_FOUND)
try:
MemoryService.delete_memory(memory_id)
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True)
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("", methods=["GET"]) # noqa: F821
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
args = request.args
@ -183,3 +185,26 @@ async def get_memory_config(memory_id):
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_detail(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)

169
api/apps/messages_app.py Normal file
View File

@ -0,0 +1,169 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from quart import request
from api.apps import login_required
from api.db.services.memory_service import MemoryService
from common.time_utils import current_timestamp, timestamp_to_date
from memory.services.messages import MessageService
from api.db.joint_services import memory_message_service
from api.db.joint_services.memory_message_service import query_message
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from common.constants import RetCode
@manager.route("", methods=["POST"]) # noqa: F821
@login_required
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
async def add_message():
req = await get_request_json()
memory_ids = req["memory_id"]
agent_id = req["agent_id"]
session_id = req["session_id"]
user_id = req["user_id"] if req.get("user_id") else ""
user_input = req["user_input"]
agent_response = req["agent_response"]
res = []
for memory_id in memory_ids:
success, msg = await memory_message_service.save_to_memory(
memory_id,
{
"user_id": user_id,
"agent_id": agent_id,
"session_id": session_id,
"user_input": user_input,
"agent_response": agent_response
}
)
res.append({
"memory_id": memory_id,
"success": success,
"message": msg
})
if all([r["success"] for r in res]):
return get_json_result(message="Successfully added to memories.")
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def forget_message(memory_id: str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
forget_time = timestamp_to_date(current_timestamp())
update_succeed = MessageService.update_message(
{"memory_id": memory_id, "message_id": int(message_id)},
{"forget_at": forget_time},
memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("status")
async def update_message(memory_id: str, message_id: int):
req = await get_request_json()
status = req["status"]
if not isinstance(status, bool):
return get_error_argument_result("Status must be a boolean.")
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
@manager.route("/search", methods=["GET"]) # noqa: F821
@login_required
async def search_message():
args = request.args
print(args, flush=True)
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
if empty_fields:
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
memory_ids = args.getlist("memory_id")
query = args.get("query")
similarity_threshold = float(args.get("similarity_threshold", 0.2))
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
top_n = int(args.get("top_n", 5))
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
filter_dict = {
"memory_id": memory_ids,
"agent_id": agent_id,
"session_id": session_id
}
params = {
"query": query,
"similarity_threshold": similarity_threshold,
"keywords_similarity_weight": keywords_similarity_weight,
"top_n": top_n
}
res = query_message(filter_dict, params)
return get_json_result(message=True, data=res)
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def get_messages():
args = request.args
memory_ids = args.getlist("memory_id")
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
limit = int(args.get("limit", 10))
if not memory_ids:
return get_error_argument_result("memory_ids is required.")
memory_list = MemoryService.get_by_ids(memory_ids)
uids = [memory.tenant_id for memory in memory_list]
res = MessageService.get_recent_messages(
uids,
memory_ids,
agent_id,
session_id,
limit
)
return get_json_result(message=True, data=res)
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
@login_required
async def get_message_content(memory_id:str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
if res:
return get_json_result(message=True, data=res)
else:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.")

View File

@ -14,20 +14,29 @@
# limitations under the License.
#
import asyncio
import base64
import hashlib
import hmac
import ipaddress
import json
import logging
import time
from typing import Any, cast
import jwt
from agent.canvas import Canvas
from api.db import CanvasCategory
from api.db.services.canvas_service import UserCanvasService
from api.db.services.file_service import FileService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
from api.utils.api_utils import get_result
from quart import request, Response
from rag.utils.redis_conn import REDIS_CONN
@manager.route('/agents', methods=['GET']) # noqa: F821
@ -132,48 +141,785 @@ def delete_agent(tenant_id: str, agent_id: str):
UserCanvasService.delete_by_id(agent_id)
return get_json_result(data=True)
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
async def webhook(agent_id: str):
is_test = request.path.startswith("/api/v1/webhook_test")
start_ts = time.time()
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
@token_required
async def webhook(tenant_id: str, agent_id: str):
req = await get_request_json()
if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
# 1. Fetch canvas by agent_id
exists, cvs = UserCanvasService.get_by_id(agent_id)
if not exists:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
# 2. Check canvas category
if cvs.canvas_category == CanvasCategory.DataFlow:
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
# 3. Load DSL from canvas
dsl = getattr(cvs, "dsl", None)
if not isinstance(dsl, dict):
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
# 4. Check webhook configuration in DSL
components = dsl.get("components", {})
for k, _ in components.items():
cpn_obj = components[k]["obj"]
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
webhook_cfg = cpn_obj["params"]
if not webhook_cfg:
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
# 5. Validate request method against webhook_cfg.methods
allowed_methods = webhook_cfg.get("methods", [])
request_method = request.method.upper()
if allowed_methods and request_method not in allowed_methods:
return get_data_error_result(
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
),RetCode.BAD_REQUEST
# 6. Validate webhook security
async def validate_webhook_security(security_cfg: dict):
"""Validate webhook security rules based on security configuration."""
if not security_cfg:
return # No security config → allowed by default
# 1. Validate max body size
await _validate_max_body_size(security_cfg)
# 2. Validate IP whitelist
_validate_ip_whitelist(security_cfg)
# # 3. Validate rate limiting
_validate_rate_limit(security_cfg)
# 4. Validate authentication
auth_type = security_cfg.get("auth_type", "none")
if auth_type == "none":
return
if auth_type == "token":
_validate_token_auth(security_cfg)
elif auth_type == "basic":
_validate_basic_auth(security_cfg)
elif auth_type == "jwt":
_validate_jwt_auth(security_cfg)
else:
raise Exception(f"Unsupported auth_type: {auth_type}")
async def _validate_max_body_size(security_cfg):
"""Check request size does not exceed max_body_size."""
max_size = security_cfg.get("max_body_size")
if not max_size:
return
# Convert "10MB" → bytes
units = {"kb": 1024, "mb": 1024**2}
size_str = max_size.lower()
for suffix, factor in units.items():
if size_str.endswith(suffix):
limit = int(size_str.replace(suffix, "")) * factor
break
else:
raise Exception("Invalid max_body_size format")
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
if limit > MAX_LIMIT:
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
content_length = request.content_length or 0
if content_length > limit:
raise Exception(f"Request body too large: {content_length} > {limit}")
def _validate_ip_whitelist(security_cfg):
"""Allow only IPs listed in ip_whitelist."""
whitelist = security_cfg.get("ip_whitelist", [])
if not whitelist:
return
client_ip = request.remote_addr
for rule in whitelist:
if "/" in rule:
# CIDR notation
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
return
else:
# Single IP
if client_ip == rule:
return
raise Exception(f"IP {client_ip} is not allowed by whitelist")
def _validate_rate_limit(security_cfg):
"""Simple in-memory rate limiting."""
rl = security_cfg.get("rate_limit")
if not rl:
return
limit = int(rl.get("limit", 60))
if limit <= 0:
raise Exception("rate_limit.limit must be > 0")
per = rl.get("per", "minute")
window = {
"second": 1,
"minute": 60,
"hour": 3600,
"day": 86400,
}.get(per)
if not window:
raise Exception(f"Invalid rate_limit.per: {per}")
capacity = limit
rate = limit / window
cost = 1
key = f"rl:tb:{agent_id}"
now = time.time()
try:
res = REDIS_CONN.lua_token_bucket(
keys=[key],
args=[capacity, rate, now, cost],
client=REDIS_CONN.REDIS,
)
allowed = int(res[0])
if allowed != 1:
raise Exception("Too many requests (rate limit exceeded)")
except Exception as e:
raise Exception(f"Rate limit error: {e}")
def _validate_token_auth(security_cfg):
"""Validate header-based token authentication."""
token_cfg = security_cfg.get("token",{})
header = token_cfg.get("token_header")
token_value = token_cfg.get("token_value")
provided = request.headers.get(header)
if provided != token_value:
raise Exception("Invalid token authentication")
def _validate_basic_auth(security_cfg):
"""Validate HTTP Basic Auth credentials."""
auth_cfg = security_cfg.get("basic_auth", {})
username = auth_cfg.get("username")
password = auth_cfg.get("password")
auth = request.authorization
if not auth or auth.username != username or auth.password != password:
raise Exception("Invalid Basic Auth credentials")
def _validate_jwt_auth(security_cfg):
"""Validate JWT token in Authorization header."""
jwt_cfg = security_cfg.get("jwt", {})
secret = jwt_cfg.get("secret")
if not secret:
raise Exception("JWT secret not configured")
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise Exception("Missing Bearer token")
token = auth_header[len("Bearer "):].strip()
if not token:
raise Exception("Empty Bearer token")
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
decode_kwargs = {
"key": secret,
"algorithms": [alg],
}
options = {}
if jwt_cfg.get("audience"):
decode_kwargs["audience"] = jwt_cfg["audience"]
options["verify_aud"] = True
else:
options["verify_aud"] = False
if jwt_cfg.get("issuer"):
decode_kwargs["issuer"] = jwt_cfg["issuer"]
options["verify_iss"] = True
else:
options["verify_iss"] = False
try:
decoded = jwt.decode(
token,
options=options,
**decode_kwargs,
)
except Exception as e:
raise Exception(f"Invalid JWT: {str(e)}")
raw_required_claims = jwt_cfg.get("required_claims", [])
if isinstance(raw_required_claims, str):
required_claims = [raw_required_claims]
elif isinstance(raw_required_claims, (list, tuple, set)):
required_claims = list(raw_required_claims)
else:
required_claims = []
required_claims = [
c for c in required_claims
if isinstance(c, str) and c.strip()
]
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
for claim in required_claims:
if claim in RESERVED_CLAIMS:
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
for claim in required_claims:
if claim not in decoded:
raise Exception(f"Missing JWT claim: {claim}")
return decoded
try:
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
security_config=webhook_cfg.get("security", {})
await validate_webhook_security(security_config)
except Exception as e:
return get_json_result(
data=False, message=str(e),
code=RetCode.EXCEPTION_ERROR)
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
if not isinstance(cvs.dsl, str):
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
try:
canvas = Canvas(dsl, cvs.user_id, agent_id)
except Exception as e:
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
resp.status_code = RetCode.BAD_REQUEST
return resp
# 7. Parse request body
async def parse_webhook_request(content_type):
"""Parse request based on content-type and return structured data."""
# 1. Query
query_data = {k: v for k, v in request.args.items()}
# 2. Headers
header_data = {k: v for k, v in request.headers.items()}
# 3. Body
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
if ctype and ctype != content_type:
raise ValueError(
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
)
body_data: dict = {}
async def sse():
nonlocal canvas
try:
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
if ctype == "application/json":
body_data = await request.get_json() or {}
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
except Exception as e:
logging.exception(e)
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
elif ctype == "multipart/form-data":
nonlocal canvas
form = await request.form
files = await request.files
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
body_data = {}
for key, value in form.items():
body_data[key] = value
if len(files) > 10:
raise Exception("Too many uploaded files")
for key, file in files.items():
desc = FileService.upload_info(
cvs.user_id, # user
file, # FileStorage
None # url (None for webhook)
)
file_parsed= await canvas.get_files_async([desc])
body_data[key] = file_parsed
elif ctype == "application/x-www-form-urlencoded":
form = await request.form
body_data = dict(form)
else:
# text/plain / octet-stream / empty / unknown
raw = await request.get_data()
if raw:
try:
body_data = json.loads(raw.decode("utf-8"))
except Exception:
body_data = {}
else:
body_data = {}
except Exception:
body_data = {}
return {
"query": query_data,
"headers": header_data,
"body": body_data,
"content_type": ctype,
}
def extract_by_schema(data, schema, name="section"):
"""
Extract only fields defined in schema.
Required fields must exist.
Optional fields default to type-based default values.
Type validation included.
"""
props = schema.get("properties", {})
required = schema.get("required", [])
extracted = {}
for field, field_schema in props.items():
field_type = field_schema.get("type")
# 1. Required field missing
if field in required and field not in data:
raise Exception(f"{name} missing required field: {field}")
# 2. Optional → default value
if field not in data:
extracted[field] = default_for_type(field_type)
continue
raw_value = data[field]
# 3. Auto convert value
try:
value = auto_cast_value(raw_value, field_type)
except Exception as e:
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
# 4. Type validation
if not validate_type(value, field_type):
raise Exception(
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
)
extracted[field] = value
return extracted
def default_for_type(t):
"""Return default value for the given schema type."""
if t == "file":
return []
if t == "object":
return {}
if t == "boolean":
return False
if t == "number":
return 0
if t == "string":
return ""
if t and t.startswith("array"):
return []
if t == "null":
return None
return None
def auto_cast_value(value, expected_type):
"""Convert string values into schema type when possible."""
# Non-string values already good
if not isinstance(value, str):
return value
v = value.strip()
# Boolean
if expected_type == "boolean":
if v.lower() in ["true", "1"]:
return True
if v.lower() in ["false", "0"]:
return False
raise Exception(f"Cannot convert '{value}' to boolean")
# Number
if expected_type == "number":
# integer
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
return int(v)
# float
try:
return float(v)
except Exception:
raise Exception(f"Cannot convert '{value}' to number")
# Object
if expected_type == "object":
try:
parsed = json.loads(v)
if isinstance(parsed, dict):
return parsed
else:
raise Exception("JSON is not an object")
except Exception:
raise Exception(f"Cannot convert '{value}' to object")
# Array <T>
if expected_type.startswith("array"):
try:
parsed = json.loads(v)
if isinstance(parsed, list):
return parsed
else:
raise Exception("JSON is not an array")
except Exception:
raise Exception(f"Cannot convert '{value}' to array")
# String (accept original)
if expected_type == "string":
return value
# File
if expected_type == "file":
return value
# Default: do nothing
return value
def validate_type(value, t):
"""Validate value type against schema type t."""
if t == "file":
return isinstance(value, list)
if t == "string":
return isinstance(value, str)
if t == "number":
return isinstance(value, (int, float))
if t == "boolean":
return isinstance(value, bool)
if t == "object":
return isinstance(value, dict)
# array<string> / array<number> / array<object>
if t.startswith("array"):
if not isinstance(value, list):
return False
if "<" in t and ">" in t:
inner = t[t.find("<") + 1 : t.find(">")]
# Check each element type
for item in value:
if not validate_type(item, inner):
return False
return True
return True
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
# Extract strictly by schema
try:
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
except Exception as e:
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
clean_request = {
"query": query_clean,
"headers": header_clean,
"body": body_clean,
"input": parsed
}
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
response_cfg = webhook_cfg.get("response", {})
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
key = f"webhook-trace-{agent_id}-logs"
raw = REDIS_CONN.get(key)
obj = json.loads(raw) if raw else {"webhooks": {}}
ws = obj["webhooks"].setdefault(
str(start_ts),
{"start_ts": start_ts, "events": []}
)
ws["events"].append({
"ts": time.time(),
**event
})
REDIS_CONN.set_obj(key, obj, ttl)
if execution_mode == "Immediately":
status = response_cfg.get("status", 200)
try:
status = int(status)
except (TypeError, ValueError):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
if not (200 <= status <= 399):
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
body_tpl = response_cfg.get("body_template", "")
def parse_body(body: str):
if not body:
return None, "application/json"
try:
parsed = json.loads(body)
return parsed, "application/json"
except (json.JSONDecodeError, TypeError):
return body, "text/plain"
body, content_type = parse_body(body_tpl)
resp = Response(
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
status=status,
content_type=content_type,
)
async def background_run():
try:
async for ans in canvas.run(
query="",
user_id=cvs.user_id,
webhook_payload=clean_request
):
if is_test:
append_webhook_trace(agent_id, start_ts, ans)
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
)
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
except Exception as e:
logging.exception("Webhook background run failed")
if is_test:
try:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
)
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
)
except Exception:
logging.exception("Failed to append webhook trace")
asyncio.create_task(background_run())
return resp
else:
async def sse():
nonlocal canvas
contents: list[str] = []
status = 200
try:
async for ans in canvas.run(
query="",
user_id=cvs.user_id,
webhook_payload=clean_request,
):
if ans["event"] == "message":
content = ans["data"]["content"]
if ans["data"].get("start_to_think", False):
content = "<think>"
elif ans["data"].get("end_to_think", False):
content = "</think>"
if content:
contents.append(content)
if ans["event"] == "message_end":
status = int(ans["data"].get("status", status))
if is_test:
append_webhook_trace(
agent_id,
start_ts,
ans
)
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": True,
}
)
final_content = "".join(contents)
return {
"message": final_content,
"success": True,
"code": status,
}
except Exception as e:
if is_test:
append_webhook_trace(
agent_id,
start_ts,
{
"event": "error",
"message": str(e),
"error_type": type(e).__name__,
}
)
append_webhook_trace(
agent_id,
start_ts,
{
"event": "finished",
"elapsed_time": time.time() - start_ts,
"success": False,
}
)
return {"code": 400, "message": str(e),"success":False}
result = await sse()
return Response(
json.dumps(result),
status=result["code"],
mimetype="application/json",
)
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
async def webhook_trace(agent_id: str):
def encode_webhook_id(start_ts: str) -> str:
WEBHOOK_ID_SECRET = "webhook_id_secret"
sig = hmac.new(
WEBHOOK_ID_SECRET.encode("utf-8"),
start_ts.encode("utf-8"),
hashlib.sha256,
).digest()
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
for ts in webhooks.keys():
if encode_webhook_id(ts) == enc_id:
return ts
return None
since_ts = request.args.get("since_ts", type=float)
webhook_id = request.args.get("webhook_id")
key = f"webhook-trace-{agent_id}-logs"
raw = REDIS_CONN.get(key)
if since_ts is None:
now = time.time()
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": now,
"finished": False,
}
)
if not raw:
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": since_ts,
"finished": False,
}
)
obj = json.loads(raw)
webhooks = obj.get("webhooks", {})
if webhook_id is None:
candidates = [
float(k) for k in webhooks.keys() if float(k) > since_ts
]
if not candidates:
return get_json_result(
data={
"webhook_id": None,
"events": [],
"next_since_ts": since_ts,
"finished": False,
}
)
start_ts = min(candidates)
real_id = str(start_ts)
webhook_id = encode_webhook_id(real_id)
return get_json_result(
data={
"webhook_id": webhook_id,
"events": [],
"next_since_ts": start_ts,
"finished": False,
}
)
real_id = decode_webhook_id(webhook_id, webhooks)
if not real_id:
return get_json_result(
data={
"webhook_id": webhook_id,
"events": [],
"next_since_ts": since_ts,
"finished": True,
}
)
ws = webhooks.get(str(real_id))
events = ws.get("events", [])
new_events = [e for e in events if e.get("ts", 0) > since_ts]
next_ts = since_ts
for e in new_events:
next_ts = max(next_ts, e["ts"])
finished = any(e.get("event") == "finished" for e in new_events)
return get_json_result(
data={
"webhook_id": webhook_id,
"events": new_events,
"next_since_ts": next_ts,
"finished": finished,
}
)

View File

@ -287,7 +287,7 @@ def list_chat(tenant_id):
chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name)
if not chats:
return get_result(data=[])
list_assts = []
list_assistants = []
key_mapping = {
"parameters": "variables",
"prologue": "opener",
@ -321,5 +321,5 @@ def list_chat(tenant_id):
del res["kb_ids"]
res["datasets"] = kb_list
res["avatar"] = res.pop("icon")
list_assts.append(res)
return get_result(data=list_assts)
list_assistants.append(res)
return get_result(data=list_assistants)

View File

@ -495,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):

View File

@ -1080,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:

View File

@ -205,7 +205,8 @@ async def create(tenant_id):
if not FileService.is_parent_folder_exist(pf_id):
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
if FileService.query(name=req["name"], parent_id=pf_id):
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
return get_json_result(data=False, message="Duplicated folder name in the same folder.",
code=RetCode.CONFLICT)
if input_file_type == FileType.FOLDER.value:
file_type = FileType.FOLDER.value
@ -565,11 +566,13 @@ async def rename(tenant_id):
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
file.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
return get_json_result(data=False, message="The extension of file can't be changed",
code=RetCode.BAD_REQUEST)
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
if existing_file.name == req["name"]:
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
return get_json_result(data=False, message="Duplicated file name in the same folder.",
code=RetCode.CONFLICT)
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
@ -631,9 +634,10 @@ async def get(tenant_id, file_id):
except Exception as e:
return server_error_response(e)
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
@token_required
async def download_attachment(tenant_id,attachment_id):
async def download_attachment(tenant_id, attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
@ -645,6 +649,7 @@ async def download_attachment(tenant_id,attachment_id):
except Exception as e:
return server_error_response(e)
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required
async def move(tenant_id):

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import json
import copy
import re
import time
@ -32,7 +33,7 @@ from api.db.services.dialog_service import DialogService, async_ask, async_chat,
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
@ -128,11 +129,33 @@ async def chat_completion(tenant_id, chat_id):
req = {"question": ""}
if not req.get("session_id"):
req["question"] = ""
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
if not dia:
return get_error_data_result(f"You don't own the chat {chat_id}")
dia = dia[0]
if req.get("session_id"):
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
return get_error_data_result(f"You don't own the session {req['session_id']}")
metadata_condition = req.get("metadata_condition") or {}
if metadata_condition and not isinstance(metadata_condition, dict):
return get_error_data_result(message="metadata_condition must be an object.")
if metadata_condition and req.get("question"):
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
filtered_doc_ids = meta_filter(
metas,
convert_conditions(metadata_condition),
metadata_condition.get("logic", "and"),
)
if metadata_condition.get("conditions") and not filtered_doc_ids:
filtered_doc_ids = ["-999"]
if filtered_doc_ids:
req["doc_ids"] = ",".join(filtered_doc_ids)
else:
req.pop("doc_ids", None)
if req.get("stream", True):
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
@ -195,7 +218,19 @@ async def chat_completion_openai_like(tenant_id, chat_id):
{"role": "user", "content": "Can you tell me how to install neovim"},
],
stream=stream,
extra_body={"reference": reference}
extra_body={
"reference": reference,
"metadata_condition": {
"logic": "and",
"conditions": [
{
"name": "author",
"comparison_operator": "is",
"value": "bob"
}
]
}
}
)
if stream:
@ -211,7 +246,11 @@ async def chat_completion_openai_like(tenant_id, chat_id):
"""
req = await get_request_json()
need_reference = bool(req.get("reference", False))
extra_body = req.get("extra_body") or {}
if extra_body and not isinstance(extra_body, dict):
return get_error_data_result("extra_body must be an object.")
need_reference = bool(extra_body.get("reference", False))
messages = req.get("messages", [])
# To prevent empty [] input
@ -229,6 +268,22 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return get_error_data_result(f"You don't own the chat {chat_id}")
dia = dia[0]
metadata_condition = extra_body.get("metadata_condition") or {}
if metadata_condition and not isinstance(metadata_condition, dict):
return get_error_data_result(message="metadata_condition must be an object.")
doc_ids_str = None
if metadata_condition:
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
filtered_doc_ids = meta_filter(
metas,
convert_conditions(metadata_condition),
metadata_condition.get("logic", "and"),
)
if metadata_condition.get("conditions") and not filtered_doc_ids:
filtered_doc_ids = ["-999"]
doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
# Filter system and non-sense assistant messages
msg = []
for m in messages:
@ -276,14 +331,17 @@ async def chat_completion_openai_like(tenant_id, chat_id):
}
try:
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
if doc_ids_str:
chat_kwargs["doc_ids"] = doc_ids_str
async for ans in async_chat(dia, msg, True, **chat_kwargs):
last_ans = ans
answer = ans["answer"]
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
if reasoning_match:
reasoning_part = reasoning_match.group(1)
content_part = answer[reasoning_match.end():]
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer
@ -328,8 +386,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
response["choices"][0]["delta"]["content"] = None
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
"total_tokens": len(prompt) + token_used}
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
if need_reference:
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
@ -344,7 +401,10 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return resp
else:
answer = None
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
if doc_ids_str:
chat_kwargs["doc_ids"] = doc_ids_str
async for ans in async_chat(dia, msg, False, **chat_kwargs):
# focus answer content only
answer = ans
break
@ -388,7 +448,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
@token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await get_request_json()
tiktokenenc = tiktoken.get_encoding("cl100k_base")
tiktoken_encode = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
return get_error_data_result("You must provide at least one message.")
@ -396,7 +456,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
return get_error_data_result(f"You don't own the agent {agent_id}")
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages)
if not filtered_messages:
return jsonify(
get_data_openai(
@ -404,7 +464,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
content="No valid messages found (user or assistant).",
finish_reason="stop",
model=req.get("model", ""),
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")),
prompt_tokens=prompt_tokens,
)
)
@ -441,15 +501,19 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
):
return jsonify(response)
return None
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
async def agent_completions(tenant_id, agent_id):
req = await get_request_json()
return_trace = bool(req.get("return_trace", False))
if req.get("stream", True):
async def generate():
trace_items = []
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
if isinstance(answer, str):
try:
@ -457,7 +521,21 @@ async def agent_completions(tenant_id, agent_id):
except Exception:
continue
if ans.get("event") not in ["message", "message_end"]:
event = ans.get("event")
if event == "node_finished":
if return_trace:
data = ans.get("data", {})
trace_items.append(
{
"component_id": data.get("component_id"),
"trace": [copy.deepcopy(data)],
}
)
ans.setdefault("data", {})["trace"] = trace_items
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
yield answer
if event not in ["message", "message_end"]:
continue
yield answer
@ -474,6 +552,7 @@ async def agent_completions(tenant_id, agent_id):
full_content = ""
reference = {}
final_ans = ""
trace_items = []
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
try:
ans = json.loads(answer[5:])
@ -484,11 +563,22 @@ async def agent_completions(tenant_id, agent_id):
if ans.get("data", {}).get("reference", None):
reference.update(ans["data"]["reference"])
if return_trace and ans.get("event") == "node_finished":
data = ans.get("data", {})
trace_items.append(
{
"component_id": data.get("component_id"),
"trace": [copy.deepcopy(data)],
}
)
final_ans = ans
except Exception as e:
return get_result(data=f"**ERROR**: {str(e)}")
final_ans["data"]["content"] = full_content
final_ans["data"]["reference"] = reference
if return_trace and final_ans:
final_ans["data"]["trace"] = trace_items
return get_result(data=final_ans)
@ -832,6 +922,7 @@ async def chatbot_completions(dialog_id):
async for answer in iframe_completion(dialog_id, **req):
return get_result(data=answer)
return None
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
async def chatbots_inputs(dialog_id):
@ -879,6 +970,7 @@ async def agent_bot_completions(agent_id):
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
return get_result(data=answer)
return None
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
async def begin_inputs(agent_id):

View File

@ -660,7 +660,7 @@ def user_register(user_id, user):
tenant_llm = get_init_tenant_llm(user_id)
if not UserService.save(**user):
return
return None
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)

View File

@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@ -169,6 +170,8 @@ def init_web_data():
# init_superuser()
add_graph_templates()
init_message_id_sequence()
init_memory_size_cache()
logging.info("init web data success:{}".format(time.time() - start_time))

View File

@ -0,0 +1,233 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import List
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType
from common.doc_store.doc_store_base import FusionExpr
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.utils.memory_utils import get_memory_type_human
from memory.services.messages import MessageService
from memory.services.query import MsgTextQuery, get_vector
from memory.utils.prompt_util import PromptAssembler
from memory.utils.msg_util import get_json_result_from_llm_response
from rag.utils.redis_conn import REDIS_CONN
async def save_to_memory(memory_id: str, message_dict: dict):
"""
:param memory_id:
:param message_dict: {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
"""
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return False, f"Memory '{memory_id}' not found."
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", "")
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
message_list = [{
"message_id": raw_message_id,
"message_type": MemoryType.RAW.name.lower(),
"source_id": 0,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
"valid_at": timestamp_to_date(current_timestamp()),
"invalid_at": None,
"forget_at": None,
"status": True
}, *[{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
"message_type": content["message_type"],
"source_id": raw_message_id,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": content["content"],
"valid_at": content["valid_at"],
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
"forget_at": None,
"status": True
} for content in extracted_content]]
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(tenant_id, memory_id):
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "fifo":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
return True, "Message saved successfully."
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
if not system_prompt:
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
conversation_time = timestamp_to_date(current_timestamp())
user_prompts = []
if user_prompt:
user_prompts.append({"role": "user", "content": user_prompt})
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id)
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
res_json = get_json_result_from_llm_response(res)
return [{
"content": extracted_content["content"],
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
"message_type": message_type
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
def query_message(filter_dict: dict, params: dict):
"""
:param filter_dict: {
"memory_id": List[str],
"agent_id": optional
"session_id": optional
}
:param params: {
"query": question str,
"similarity_threshold": float,
"keywords_similarity_weight": float,
"top_n": int
}
"""
memory_ids = filter_dict["memory_id"]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
return []
condition_dict = {k: v for k, v in filter_dict.items() if v}
uids = [memory.tenant_id for memory in memory_list]
question = params["query"]
question = question.strip()
memory = memory_list[0]
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
def init_message_id_sequence():
message_id_redis_key = "id_generator:memory"
if REDIS_CONN.exist(message_id_redis_key):
current_max_id = REDIS_CONN.get(message_id_redis_key)
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
else:
max_id = 1
exist_memory_list = MemoryService.get_all_memory()
if not exist_memory_list:
REDIS_CONN.set(message_id_redis_key, max_id)
else:
max_id = MessageService.get_max_message_id(
uid_list=[m.tenant_id for m in exist_memory_list],
memory_ids=[m.id for m in exist_memory_list]
)
REDIS_CONN.set(message_id_redis_key, max_id)
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
def get_memory_size_cache(memory_id: str, uid: str):
redis_key = f"memory_{memory_id}"
if REDIS_CONN.exists(redis_key):
return REDIS_CONN.get(redis_key)
else:
memory_size_map = MessageService.calculate_memory_size(
[memory_id],
[uid]
)
memory_size = memory_size_map.get(memory_id, 0)
set_memory_size_cache(memory_id, memory_size)
return memory_size
def set_memory_size_cache(memory_id: str, size: int):
redis_key = f"memory_{memory_id}"
return REDIS_CONN.set(redis_key, size)
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, current_value + size)
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, max(current_value - size, 0))
def init_memory_size_cache():
memory_list = MemoryService.get_all_memory()
if not memory_list:
logging.info("No memory found, no need to init memory size.")
else:
memory_size_map = MessageService.calculate_memory_size(
memory_ids=[m.id for m in memory_list],
uid_list=[m.tenant_id for m in memory_list],
)
for memory in memory_list:
memory_size = memory_size_map.get(memory.id, 0)
set_memory_size_cache(memory.id, memory_size)
logging.info("Memory size cache init done.")

View File

@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from api.db.services.memory_service import MemoryService
from memory.services.messages import MessageService
from rag.nlp import search
from common.constants import ActiveEnum
from common import settings
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
# step1.3 delete own tenant
# step1.3 delete memory and messages
user_memory = MemoryService.get_by_tenant_id(tenant_id)
if user_memory:
for memory in user_memory:
if MessageService.has_index(tenant_id, memory.id):
MessageService.delete_index(tenant_id, memory.id)
done_msg += " Deleted memory index."
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
done_msg += f"Deleted {memory_delete_res} memory datasets."
# step1.4 delete own tenant
tenant_delete_res = TenantService.delete_by_id(tenant_id)
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
# step2 delete user-tenant relation

View File

@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
logging.exception(e)
return False, None
@classmethod
@DB.connection_context()
def get_basic_info_by_canvas_ids(cls, canvas_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.user_id,
cls.model.title,
cls.model.permission,
cls.model.canvas_category
]
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,

View File

@ -169,10 +169,12 @@ class CommonService:
"""
if "id" not in kwargs:
kwargs["id"] = get_uuid()
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
timestamp = current_timestamp()
cur_datetime = datetime_format(datetime.now())
kwargs["create_time"] = timestamp
kwargs["create_date"] = cur_datetime
kwargs["update_time"] = timestamp
kwargs["update_date"] = cur_datetime
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@ -207,10 +209,14 @@ class CommonService:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
timestamp = current_timestamp()
cur_datetime = datetime_format(datetime.now())
for data in data_list:
data["update_time"] = timestamp
data["update_date"] = cur_datetime
with DB.atomic():
for data in data_list:
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod

View File

@ -406,7 +406,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
aggs=True,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
@ -769,7 +769,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=False,
aggs=True,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs)
)

View File

@ -33,12 +33,13 @@ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTena
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common.metadata_utils import dedupe_list
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.doc_store.doc_store_base import OrderByExpr
from common import settings
@ -180,6 +181,16 @@ class DocumentService(CommonService):
"1": 2,
"2": 2
}
"metadata": {
"key1": {
"key1_value1": 1,
"key1_value2": 2,
},
"key2": {
"key2_value1": 2,
"key2_value2": 1,
},
}
}, total
where "1" => RUNNING, "2" => CANCEL
"""
@ -200,19 +211,40 @@ class DocumentService(CommonService):
if suffix:
query = query.where(cls.model.suffix.in_(suffix))
rows = query.select(cls.model.run, cls.model.suffix)
rows = query.select(cls.model.run, cls.model.suffix, cls.model.meta_fields)
total = rows.count()
suffix_counter = {}
run_status_counter = {}
metadata_counter = {}
for row in rows:
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
meta_fields = row.meta_fields or {}
if isinstance(meta_fields, str):
try:
meta_fields = json.loads(meta_fields)
except Exception:
meta_fields = {}
if not isinstance(meta_fields, dict):
continue
for key, value in meta_fields.items():
values = value if isinstance(value, list) else [value]
for vv in values:
if vv is None:
continue
if isinstance(vv, str) and not vv.strip():
continue
sv = str(vv)
if key not in metadata_counter:
metadata_counter[key] = {}
metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1
return {
"suffix": suffix_counter,
"run_status": run_status_counter
"run_status": run_status_counter,
"metadata": metadata_counter,
}, total
@classmethod
@ -314,7 +346,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -665,10 +697,14 @@ class DocumentService(CommonService):
for k,v in r.meta_fields.items():
if k not in meta:
meta[k] = {}
v = str(v)
if v not in meta[k]:
meta[k][v] = []
meta[k][v].append(doc_id)
if not isinstance(v, list):
v = [v]
for vv in v:
if vv not in meta[k]:
if isinstance(vv, list) or isinstance(vv, dict):
continue
meta[k][vv] = []
meta[k][vv].append(doc_id)
return meta
@classmethod
@ -766,7 +802,10 @@ class DocumentService(CommonService):
match_provided = "match" in upd
if isinstance(meta[key], list):
if not match_provided:
meta[key] = new_value
if isinstance(new_value, list):
meta[key] = dedupe_list(new_value)
else:
meta[key] = new_value
changed = True
else:
match_value = upd.get("match")
@ -779,7 +818,7 @@ class DocumentService(CommonService):
else:
new_list.append(item)
if replaced:
meta[key] = new_list
meta[key] = dedupe_list(new_list)
changed = True
else:
if not match_provided:
@ -1199,8 +1238,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -100,7 +100,7 @@ class FileService(CommonService):
# Returns:
# List of dictionaries containing dataset IDs and names
kbs = (
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
cls.model.select(*[Knowledgebase.id, Knowledgebase.name, File2Document.document_id])
.join(File2Document, on=(File2Document.file_id == file_id))
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
@ -110,7 +110,7 @@ class FileService(CommonService):
return []
kbs_info_list = []
for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"], "document_id": kb["document_id"]})
return kbs_info_list
@classmethod

View File

@ -425,6 +425,7 @@ class KnowledgebaseService(CommonService):
# Update parser_config (always override with validated default/merged config)
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
payload["parser_config"]["llm_id"] = _t.llm_id
return True, payload

View File

@ -15,7 +15,6 @@
#
from typing import List
from api.apps import current_user
from api.db.db_models import DB, Memory, User
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
from api.constants import MEMORY_NAME_LIMIT
from common.misc_utils import get_uuid
from common.time_utils import get_format_time, current_timestamp
from memory.utils.prompt_util import PromptAssembler
class MemoryService(CommonService):
@ -34,6 +34,17 @@ class MemoryService(CommonService):
def get_by_memory_id(cls, memory_id: str):
return cls.model.select().where(cls.model.id == memory_id).first()
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, tenant_id: str):
return cls.model.select().where(cls.model.tenant_id == tenant_id)
@classmethod
@DB.connection_context()
def get_all_memory(cls):
memory_list = cls.model.select()
return list(memory_list)
@classmethod
@DB.connection_context()
def get_with_owner_name_by_id(cls, memory_id: str):
@ -53,7 +64,9 @@ class MemoryService(CommonService):
cls.model.forgetting_policy,
cls.model.temperature,
cls.model.system_prompt,
cls.model.user_prompt
cls.model.user_prompt,
cls.model.create_date,
cls.model.create_time
]
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
cls.model.id == memory_id
@ -72,7 +85,9 @@ class MemoryService(CommonService):
cls.model.memory_type,
cls.model.storage_type,
cls.model.permissions,
cls.model.description
cls.model.description,
cls.model.create_time,
cls.model.create_date
]
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
if filter_dict.get("tenant_id"):
@ -102,6 +117,8 @@ class MemoryService(CommonService):
if len(memory_name) > MEMORY_NAME_LIMIT:
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
timestamp = current_timestamp()
format_time = get_format_time()
# build create dict
memory_info = {
"id": get_uuid(),
@ -110,10 +127,11 @@ class MemoryService(CommonService):
"tenant_id": tenant_id,
"embd_id": embd_id,
"llm_id": llm_id,
"create_time": current_timestamp(),
"create_date": get_format_time(),
"update_time": current_timestamp(),
"update_date": get_format_time(),
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
"create_time": timestamp,
"create_date": format_time,
"update_time": timestamp,
"update_date": format_time,
}
obj = cls.model(**memory_info).save(force_insert=True)
@ -126,7 +144,7 @@ class MemoryService(CommonService):
@classmethod
@DB.connection_context()
def update_memory(cls, memory_id: str, update_dict: dict):
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
if not update_dict:
return 0
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
@ -135,7 +153,7 @@ class MemoryService(CommonService):
update_dict["name"] = duplicate_name(
cls.query,
name=update_dict["name"],
tenant_id=current_user.id
tenant_id=tenant_id
)
update_dict.update({
"update_time": current_timestamp(),

View File

@ -97,7 +97,7 @@ class TenantLLMService(CommonService):
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
mdlnm = tenant.asr_id if not llm_name else llm_name
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id if not llm_name else llm_name
elif llm_type == LLMType.CHAT.value:

View File

@ -163,6 +163,7 @@ def validate_request(*args, **kwargs):
if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return error_string
return None
def wrapper(func):
@wraps(func)
@ -409,7 +410,7 @@ def get_parser_config(chunk_method, parser_config):
if default_config is None:
return deep_merge(base_defaults, parser_config)
# Ensure raptor and graphrag fields have default values if not provided
# Ensure raptor and graph_rag fields have default values if not provided
merged_config = deep_merge(base_defaults, default_config)
merged_config = deep_merge(merged_config, parser_config)

View File

@ -54,6 +54,7 @@ class RetCode(IntEnum, CustomEnum):
SERVER_ERROR = 500
FORBIDDEN = 403
NOT_FOUND = 404
CONFLICT = 409
class StatusEnum(Enum):
@ -124,7 +125,11 @@ class FileSource(StrEnum):
MOODLE = "moodle"
DROPBOX = "dropbox"
BOX = "box"
R2 = "r2"
OCI_STORAGE = "oci_storage"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
class PipelineTaskType(StrEnum):
PARSE = "Parse"
DOWNLOAD = "Download"

View File

@ -56,7 +56,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# Validate credentials
if self.bucket_type == BlobType.R2:
if not all(
if not all(
credentials.get(key)
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
):
@ -64,15 +64,23 @@ class BlobStorageConnector(LoadConnector, PollConnector):
elif self.bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
if authentication_method == "access_key":
if not all(
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Amazon S3")
elif authentication_method == "iam_role":
if not credentials.get("aws_role_arn"):
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
elif authentication_method == "assume_role":
pass
else:
raise ConnectorMissingCredentialError("Unsupported S3 authentication method")
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
if not all(
@ -120,55 +128,72 @@ class BlobStorageConnector(LoadConnector, PollConnector):
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
batch: list[Document] = []
# Collect all objects first to count filename occurrences
all_objects = []
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
if obj["Key"].endswith("/"):
continue
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
if start < last_modified <= end:
all_objects.append(obj)
# Count filename occurrences to determine which need full paths
filename_counts: dict[str, int] = {}
for obj in all_objects:
file_name = os.path.basename(obj["Key"])
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
if not (start < last_modified <= end):
batch: list[Document] = []
for obj in all_objects:
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
file_name = os.path.basename(obj["Key"])
key = obj["Key"]
size_bytes = extract_size_bytes(obj)
if (
self.size_threshold is not None
and isinstance(size_bytes, int)
and size_bytes > self.size_threshold
):
logging.warning(
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
)
continue
try:
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
if blob is None:
continue
file_name = os.path.basename(obj["Key"])
key = obj["Key"]
# Use full path only if filename appears multiple times
if filename_counts.get(file_name, 0) > 1:
relative_path = key
if self.prefix and key.startswith(self.prefix):
relative_path = key[len(self.prefix):]
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
else:
semantic_id = file_name
size_bytes = extract_size_bytes(obj)
if (
self.size_threshold is not None
and isinstance(size_bytes, int)
and size_bytes > self.size_threshold
):
logging.warning(
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
blob=blob,
source=DocumentSource(self.bucket_type.value),
semantic_identifier=semantic_id,
extension=get_file_ext(file_name),
doc_updated_at=last_modified,
size_bytes=size_bytes if size_bytes else 0
)
continue
try:
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
if blob is None:
continue
)
if len(batch) == self.batch_size:
yield batch
batch = []
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
blob=blob,
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
extension=get_file_ext(file_name),
doc_updated_at=last_modified,
size_bytes=size_bytes if size_bytes else 0
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logging.exception(f"Error decoding object {key}")
except Exception:
logging.exception(f"Error decoding object {key}")
if batch:
yield batch
@ -276,4 +301,4 @@ if __name__ == "__main__":
except ConnectorMissingCredentialError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
print(f"An unexpected error occurred: {e}")

View File

@ -83,6 +83,7 @@ _PAGE_EXPANSION_FIELDS = [
"space",
"metadata.labels",
"history.lastUpdated",
"ancestors",
]

View File

@ -186,7 +186,7 @@ class OnyxConfluence:
# between the db and redis everywhere the credentials might be updated
new_credential_str = json.dumps(new_credentials)
self.redis_client.set(
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL
)
self._credentials_provider.set_credentials(new_credentials)
@ -1311,6 +1311,9 @@ class ConfluenceConnector(
self._low_timeout_confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
self.allow_images = False
# Track document names to detect duplicates
self._document_name_counts: dict[str, int] = {}
self._document_name_paths: dict[str, list[str]] = {}
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
@ -1513,6 +1516,40 @@ class ConfluenceConnector(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
# Build hierarchical path for semantic identifier
space_name = page.get("space", {}).get("name", "")
# Build path from ancestors
path_parts = []
if space_name:
path_parts.append(space_name)
# Add ancestor pages to path if available
if "ancestors" in page and page["ancestors"]:
for ancestor in page["ancestors"]:
ancestor_title = ancestor.get("title", "")
if ancestor_title:
path_parts.append(ancestor_title)
# Add current page title
path_parts.append(page_title)
# Track page names for duplicate detection
full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title
# Count occurrences of this page title
if page_title not in self._document_name_counts:
self._document_name_counts[page_title] = 0
self._document_name_paths[page_title] = []
self._document_name_counts[page_title] += 1
self._document_name_paths[page_title].append(full_path)
# Use simple name if no duplicates, otherwise use full path
if self._document_name_counts[page_title] == 1:
semantic_identifier = page_title
else:
semantic_identifier = full_path
# Get the page content
page_content = extract_text_from_confluence_html(
self.confluence_client, page, self._fetched_titles
@ -1559,11 +1596,11 @@ class ConfluenceConnector(
return Document(
id=page_url,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
semantic_identifier=semantic_identifier,
extension=".html", # Confluence pages are HTML
blob=page_content.encode("utf-8"), # Encode page content as bytes
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
doc_updated_at=datetime_from_string(page["version"]["when"]),
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
primary_owners=primary_owners if primary_owners else None,
metadata=metadata if metadata else None,
)
@ -1601,7 +1638,6 @@ class ConfluenceConnector(
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
# and checks in convert_attachment_to_content/process_attachment
# but doing the check here avoids an unnecessary download. Due for refactoring.
@ -1669,6 +1705,34 @@ class ConfluenceConnector(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
# Build semantic identifier with space and page context
attachment_title = attachment.get("title", object_url)
space_name = page.get("space", {}).get("name", "")
page_title = page.get("title", "")
# Create hierarchical name: Space / Page / Attachment
attachment_path_parts = []
if space_name:
attachment_path_parts.append(space_name)
if page_title:
attachment_path_parts.append(page_title)
attachment_path_parts.append(attachment_title)
full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title
# Track attachment names for duplicate detection
if attachment_title not in self._document_name_counts:
self._document_name_counts[attachment_title] = 0
self._document_name_paths[attachment_title] = []
self._document_name_counts[attachment_title] += 1
self._document_name_paths[attachment_title].append(full_attachment_path)
# Use simple name if no duplicates, otherwise use full path
if self._document_name_counts[attachment_title] == 1:
attachment_semantic_identifier = attachment_title
else:
attachment_semantic_identifier = full_attachment_path
primary_owners: list[BasicExpertInfo] | None = None
if "version" in attachment and "by" in attachment["version"]:
author = attachment["version"]["by"]
@ -1680,11 +1744,12 @@ class ConfluenceConnector(
extension = Path(attachment.get("title", "")).suffix or ".unknown"
attachment_doc = Document(
id=attachment_id,
# sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment.get("title", object_url),
semantic_identifier=attachment_semantic_identifier,
extension=extension,
blob=file_blob,
size_bytes=len(file_blob),
@ -1741,7 +1806,7 @@ class ConfluenceConnector(
start_ts, end, self.batch_size
)
logging.debug(f"page_query_url: {page_query_url}")
# store the next page start for confluence server, cursor for confluence cloud
def store_next_page_url(next_page_url: str) -> None:
checkpoint.next_page_url = next_page_url

View File

@ -87,15 +87,69 @@ class DropboxConnector(LoadConnector, PollConnector):
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
# Collect all files first to count filename occurrences
all_files = []
self._collect_files_recursive(path, start, end, all_files)
# Count filename occurrences
filename_counts: dict[str, int] = {}
for entry, _ in all_files:
filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1
# Process files in batches
batch: list[Document] = []
for entry, downloaded_file in all_files:
modified_time = entry.client_modified
if modified_time.tzinfo is None:
modified_time = modified_time.replace(tzinfo=timezone.utc)
else:
modified_time = modified_time.astimezone(timezone.utc)
# Use full path only if filename appears multiple times
if filename_counts.get(entry.name, 0) > 1:
# Remove leading slash and replace slashes with ' / '
relative_path = entry.path_display.lstrip('/')
semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name
else:
semantic_id = entry.name
batch.append(
Document(
id=f"dropbox:{entry.id}",
blob=downloaded_file,
source=DocumentSource.DROPBOX,
semantic_identifier=semantic_id,
extension=get_file_ext(entry.name),
doc_updated_at=modified_time,
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
if batch:
yield batch
def _collect_files_recursive(
self,
path: str,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
all_files: list,
) -> None:
"""Recursively collect all files matching time criteria."""
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
result = self.dropbox_client.files_list_folder(
path,
limit=self.batch_size,
recursive=False,
include_non_downloadable_files=False,
)
while True:
batch: list[Document] = []
for entry in result.entries:
if isinstance(entry, FileMetadata):
modified_time = entry.client_modified
@ -112,27 +166,13 @@ class DropboxConnector(LoadConnector, PollConnector):
try:
downloaded_file = self._download_file(entry.path_display)
all_files.append((entry, downloaded_file))
except Exception:
logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
continue
batch.append(
Document(
id=f"dropbox:{entry.id}",
blob=downloaded_file,
source=DocumentSource.DROPBOX,
semantic_identifier=entry.name,
extension=get_file_ext(entry.name),
doc_updated_at=modified_time,
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
)
)
elif isinstance(entry, FolderMetadata):
yield from self._yield_files_recursive(entry.path_lower, start, end)
if batch:
yield batch
self._collect_files_recursive(entry.path_lower, start, end, all_files)
if not result.has_more:
break

View File

@ -94,6 +94,7 @@ class Document(BaseModel):
blob: bytes
doc_updated_at: datetime
size_bytes: int
primary_owners: list
metadata: Optional[dict[str, Any]] = None
@ -180,6 +181,7 @@ class NotionPage(BaseModel):
archived: bool
properties: dict[str, Any]
url: str
parent: Optional[dict[str, Any]] = None # Parent reference for path reconstruction
database_name: Optional[str] = None # Only applicable to database type pages

View File

@ -66,6 +66,7 @@ class NotionConnector(LoadConnector, PollConnector):
self.indexed_pages: set[str] = set()
self.root_page_id = root_page_id
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
self.page_path_cache: dict[str, str] = {}
@retry(tries=3, delay=1, backoff=2)
def _fetch_child_blocks(self, block_id: str, cursor: Optional[str] = None) -> dict[str, Any] | None:
@ -242,6 +243,20 @@ class NotionConnector(LoadConnector, PollConnector):
logging.warning(f"[Notion]: Failed to download Notion file from {url}: {exc}")
return None
def _append_block_id_to_name(self, name: str, block_id: Optional[str]) -> str:
"""Append the Notion block ID to the filename while keeping the extension."""
if not block_id:
return name
path = Path(name)
stem = path.stem or name
suffix = path.suffix
if not stem:
return name
return f"{stem}_{block_id}{suffix}" if suffix else f"{stem}_{block_id}"
def _extract_file_metadata(self, result_obj: dict[str, Any], block_id: str) -> tuple[str | None, str, str | None]:
file_source_type = result_obj.get("type")
file_source = result_obj.get(file_source_type, {}) if file_source_type else {}
@ -254,6 +269,8 @@ class NotionConnector(LoadConnector, PollConnector):
elif not name:
name = f"notion_file_{block_id}"
name = self._append_block_id_to_name(name, block_id)
caption = self._extract_rich_text(result_obj.get("caption", [])) if "caption" in result_obj else None
return url, name, caption
@ -265,6 +282,7 @@ class NotionConnector(LoadConnector, PollConnector):
name: str,
caption: Optional[str],
page_last_edited_time: Optional[str],
page_path: Optional[str],
) -> Document | None:
file_bytes = self._download_file(url)
if file_bytes is None:
@ -277,7 +295,8 @@ class NotionConnector(LoadConnector, PollConnector):
extension = ".bin"
updated_at = datetime_from_string(page_last_edited_time) if page_last_edited_time else datetime.now(timezone.utc)
semantic_identifier = caption or name or f"Notion file {block_id}"
base_identifier = name or caption or (f"Notion file {block_id}" if block_id else "Notion file")
semantic_identifier = f"{page_path} / {base_identifier}" if page_path else base_identifier
return Document(
id=block_id,
@ -289,7 +308,7 @@ class NotionConnector(LoadConnector, PollConnector):
doc_updated_at=updated_at,
)
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None, page_path: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
result_blocks: list[NotionBlock] = []
child_pages: list[str] = []
attachments: list[Document] = []
@ -370,11 +389,14 @@ class NotionConnector(LoadConnector, PollConnector):
name=file_name,
caption=caption,
page_last_edited_time=page_last_edited_time,
page_path=page_path,
)
if attachment_doc:
attachments.append(attachment_doc)
attachment_label = caption or file_name
attachment_label = file_name
if caption:
attachment_label = f"{file_name} ({caption})"
if attachment_label:
cur_result_text_arr.append(f"{result_type.capitalize()}: {attachment_label}")
@ -383,7 +405,7 @@ class NotionConnector(LoadConnector, PollConnector):
child_pages.append(result_block_id)
else:
logging.debug(f"[Notion]: Entering sub-block: {result_block_id}")
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time)
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time, page_path)
logging.debug(f"[Notion]: Finished sub-block: {result_block_id}")
result_blocks.extend(subblocks)
child_pages.extend(subblock_child_pages)
@ -423,6 +445,35 @@ class NotionConnector(LoadConnector, PollConnector):
return None
def _build_page_path(self, page: NotionPage, visited: Optional[set[str]] = None) -> Optional[str]:
"""Construct a hierarchical path for a page based on its parent chain."""
if page.id in self.page_path_cache:
return self.page_path_cache[page.id]
visited = visited or set()
if page.id in visited:
logging.warning(f"[Notion]: Detected cycle while building path for page {page.id}")
return self._read_page_title(page)
visited.add(page.id)
current_title = self._read_page_title(page) or f"Untitled Page {page.id}"
parent_info = getattr(page, "parent", None) or {}
parent_type = parent_info.get("type")
parent_id = parent_info.get(parent_type) if parent_type else None
parent_path = None
if parent_type in {"page_id", "database_id"} and isinstance(parent_id, str):
try:
parent_page = self._fetch_page(parent_id)
parent_path = self._build_page_path(parent_page, visited)
except Exception as exc:
logging.warning(f"[Notion]: Failed to resolve parent {parent_id} for page {page.id}: {exc}")
full_path = f"{parent_path} / {current_title}" if parent_path else current_title
self.page_path_cache[page.id] = full_path
return full_path
def _read_pages(self, pages: list[NotionPage], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents."""
all_child_page_ids: list[str] = []
@ -441,13 +492,18 @@ class NotionConnector(LoadConnector, PollConnector):
continue
logging.info(f"[Notion]: Reading page with ID {page.id}, with url {page.url}")
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time)
page_path = self._build_page_path(page)
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time, page_path)
all_child_page_ids.extend(child_page_ids)
self.indexed_pages.add(page.id)
raw_page_title = self._read_page_title(page)
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
# Append the page id to help disambiguate duplicate names
base_identifier = page_path or page_title
semantic_identifier = f"{base_identifier}_{page.id}" if base_identifier else page.id
if not page_blocks:
if not raw_page_title:
logging.warning(f"[Notion]: No blocks OR title found for page with ID {page.id}. Skipping.")
@ -469,7 +525,7 @@ class NotionConnector(LoadConnector, PollConnector):
joined_text = "\n".join(sec.text for sec in sections)
blob = joined_text.encode("utf-8")
yield Document(
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=page_title, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=semantic_identifier, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
)
for attachment_doc in attachment_docs:
@ -597,4 +653,4 @@ if __name__ == "__main__":
document_batches = connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
print(doc)
print(doc)

View File

@ -167,7 +167,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
"""构建文档ID"""
return f"{channel_id}__{thread_ts}"
@ -179,7 +178,6 @@ def thread_to_doc(
user_cache: dict[str, BasicExpertInfo | None],
channel_access: Any | None,
) -> Document:
"""将线程转换为文档"""
channel_id = channel["id"]
initial_sender_expert_info = expert_info_from_slack_id(
@ -237,7 +235,6 @@ def filter_channels(
channels_to_connect: list[str] | None,
regex_enabled: bool,
) -> list[ChannelType]:
"""过滤频道"""
if not channels_to_connect:
return all_channels
@ -381,7 +378,6 @@ def _process_message(
[MessageType], SlackMessageFilterReason | None
] = default_msg_filter,
) -> ProcessedSlackMessage:
"""处理消息"""
thread_ts = message.get("thread_ts")
thread_or_message_ts = thread_ts or message["ts"]
try:
@ -536,7 +532,6 @@ class SlackConnector(
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> GenerateSlimDocumentOutput:
"""获取所有简化文档(带权限同步)"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

View File

@ -254,18 +254,21 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
elif bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
region_name = credentials.get("region") or None
if authentication_method == "access_key":
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
region_name=region_name,
)
return session.client("s3")
return session.client("s3", region_name=region_name)
elif authentication_method == "iam_role":
role_arn = credentials["aws_role_arn"]
def _refresh_credentials() -> dict[str, str]:
sts_client = boto3.client("sts")
sts_client = boto3.client("sts", region_name=credentials.get("region") or None)
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName=f"onyx_blob_storage_{int(datetime.now().timestamp())}",
@ -285,11 +288,11 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
)
botocore_session = get_session()
botocore_session._credentials = refreshable
session = boto3.Session(botocore_session=botocore_session)
return session.client("s3")
session = boto3.Session(botocore_session=botocore_session, region_name=region_name)
return session.client("s3", region_name=region_name)
elif authentication_method == "assume_role":
return boto3.client("s3")
return boto3.client("s3", region_name=region_name)
else:
raise ValueError("Invalid authentication method for S3.")

View File

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
@ -22,7 +21,6 @@ DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
@ -55,14 +53,13 @@ class SparseVector:
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
class MatchTextExpr:
def __init__(
self,
fields: list[str],
matching_text: str,
topn: int,
extra_options: dict = dict(),
extra_options: dict | None = None,
):
self.fields = fields
self.matching_text = matching_text
@ -70,7 +67,7 @@ class MatchTextExpr(ABC):
self.extra_options = extra_options
class MatchDenseExpr(ABC):
class MatchDenseExpr:
def __init__(
self,
vector_column_name: str,
@ -78,7 +75,7 @@ class MatchDenseExpr(ABC):
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
extra_options: dict | None = None,
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
@ -88,7 +85,7 @@ class MatchDenseExpr(ABC):
self.extra_options = extra_options
class MatchSparseExpr(ABC):
class MatchSparseExpr:
def __init__(
self,
vector_column_name: str,
@ -104,7 +101,7 @@ class MatchSparseExpr(ABC):
self.opt_params = opt_params
class MatchTensorExpr(ABC):
class MatchTensorExpr:
def __init__(
self,
column_name: str,
@ -120,7 +117,7 @@ class MatchTensorExpr(ABC):
self.extra_option = extra_option
class FusionExpr(ABC):
class FusionExpr:
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
@ -129,7 +126,8 @@ class FusionExpr(ABC):
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
class OrderByExpr:
def __init__(self):
self.fields = list()
def asc(self, field: str):
@ -141,13 +139,14 @@ class OrderByExpr(ABC):
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
def db_type(self) -> str:
"""
Return the type of the database.
"""
@ -165,21 +164,21 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, index_name: str, dataset_id: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
def index_exist(self, index_name: str, dataset_id: str) -> bool:
"""
Check if an index with given name exists
"""
@ -191,16 +190,16 @@ class DocStoreConnection(ABC):
@abstractmethod
def search(
self, selectFields: list[str],
highlightFields: list[str],
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str|list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
@ -209,28 +208,28 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
def get(self, data_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, rows: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
@ -245,7 +244,7 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_chunk_ids(self, res):
def get_doc_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
@ -253,18 +252,18 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_highlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], field_name: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_aggregation(self, res, fieldnm: str):
def get_aggregation(self, res, field_name: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
def sql(self, sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""

View File

@ -0,0 +1,326 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
import time
import os
from abc import abstractmethod
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import Index
from elastic_transport import ConnectionTimeout
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
from rag.nlp import is_english, rag_tokenizer
from common import settings
ATTEMPT_TIME = 2
class ESConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
self.logger = logging.getLogger(logger_name)
self.info = {}
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
self.logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
self.logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def db_type(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
if self.index_exist(index_name, dataset_id):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=index_name,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
self.logger.exception("ESConnection.createIndex error %s" % index_name)
def delete_idx(self, index_name: str, dataset_id: str):
if len(dataset_id) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=index_name, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
s = Index(index_name, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(e)
break
return False
"""
CRUD operations
"""
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
doc = res["_source"]
doc["id"] = doc_id
return doc
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
@abstractmethod
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def _get_source(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
@abstractmethod
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res, keywords: list[str], field_name: str):
ans = {}
for d in res["hits"]["hits"]:
highlights = d.get("highlight")
if not highlights:
continue
txt = "...".join([a for a in list(highlights.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][field_name]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
return ans
def get_aggregation(self, res, field_name: str):
agg_field = "aggs_" + field_name
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
buckets = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in buckets]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
self.logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
self.logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None

View File

@ -0,0 +1,451 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
import json
import time
from abc import abstractmethod
import infinity
from infinity.common import ConflictType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
import pandas as pd
from common.file_utils import get_project_base_directory
from rag.nlp import is_english
from common import settings
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
class InfinityConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"):
self.dbName = settings.INFINITY.get("db_name", "default_db")
self.mapping_file_name = mapping_file_name
self.logger = logging.getLogger(logger_name)
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
for _ in range(24):
try:
conn_pool = ConnectionPool(infinity_uri, max_size=4)
inf_conn = conn_pool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = conn_pool
conn_pool.release_conn(inf_conn)
break
conn_pool.release_conn(inf_conn)
self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
self.logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
"""
Dataframe and fields convert
"""
@staticmethod
@abstractmethod
def field_keyword(field_name: str):
# judge keyword or not, such as "*_kwd" tag-like columns.
raise NotImplementedError("Not implemented")
@abstractmethod
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
# rm _kwd, _tks, _sm_tks, _with_weight suffix in field name.
raise NotImplementedError("Not implemented")
@staticmethod
@abstractmethod
def convert_matching_field(field_weight_str: str) -> str:
# convert matching field to
raise NotImplementedError("Not implemented")
@staticmethod
def list2str(lst: str | list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
def exists(cln):
nonlocal columns
assert cln in columns, f"'{cln}' should be in '{columns}'."
ty, de = columns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if self.field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
@staticmethod
def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in select_fields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
"""
Database operations
"""
def db_type(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}")
return True
def delete_idx(self, index_name: str, dataset_id: str):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY dropped table {table_name}")
def index_exist(self, index_name: str, dataset_id: str) -> bool:
table_name = f"{index_name}_{dataset_id}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
self.logger.warning(f"INFINITY indexExist {str(e)}")
return False
"""
CRUD operations
"""
@abstractmethod
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{dataset_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = self.equivalent_condition_to_str(condition, table_instance)
self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
@abstractmethod
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if field_name not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[field_name][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
if txt_list:
ans[id] = "...".join(txt_list)
else:
ans[id] = txt
return ans
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or field_name not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[field_name]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if field_name == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -16,7 +16,7 @@ import logging
import os
import time
from typing import Any, Dict, Optional
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from urllib.parse import urlparse, urlunparse
from common import settings
import httpx
@ -58,21 +58,34 @@ def _get_delay(backoff_factor: float, attempt: int) -> float:
_SENSITIVE_QUERY_KEYS = {"client_secret", "secret", "code", "access_token", "refresh_token", "password", "token", "app_secret"}
def _redact_sensitive_url_params(url: str) -> str:
"""
Return a version of the URL that is safe to log.
We intentionally drop query parameters and userinfo to avoid leaking
credentials or tokens via logs. Only scheme, host, port and path
are preserved.
"""
try:
parsed = urlparse(url)
if not parsed.query:
return url
clean_query = []
for k, v in parse_qsl(parsed.query, keep_blank_values=True):
if k.lower() in _SENSITIVE_QUERY_KEYS:
clean_query.append((k, "***REDACTED***"))
else:
clean_query.append((k, v))
new_query = urlencode(clean_query, doseq=True)
redacted_url = urlunparse(parsed._replace(query=new_query))
return redacted_url
# Remove any potential userinfo (username:password@)
netloc = parsed.hostname or ""
if parsed.port:
netloc = f"{netloc}:{parsed.port}"
# Reconstruct URL without query, params, fragment, or userinfo.
safe_url = urlunparse(
(
parsed.scheme,
netloc,
parsed.path,
"", # params
"", # query
"", # fragment
)
)
return safe_url
except Exception:
return url
# If parsing fails, fall back to omitting the URL entirely.
return "<redacted-url>"
def _is_sensitive_url(url: str) -> bool:
"""Return True if URL is one of the configured OAuth endpoints."""
@ -144,23 +157,28 @@ async def async_request(
method=method, url=url, headers=headers, **kwargs
)
duration = time.monotonic() - start
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
logger.debug(
f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s"
)
if not _is_sensitive_url(url):
log_url = _redact_sensitive_url_params(url)
logger.debug(f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s")
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
if not _is_sensitive_url(url):
log_url = _redact_sensitive_url_params(url)
logger.warning(f"async_request exhausted retries for {method}")
raise
delay = _get_delay(backoff_factor, attempt)
if not _is_sensitive_url(url):
log_url = _redact_sensitive_url_params(url)
logger.warning(
f"async_request exhausted retries for {method} {log_url}"
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method}; retrying in {delay:.2f}s"
)
raise
delay = _get_delay(backoff_factor, attempt)
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
# Avoid including the (potentially sensitive) URL in retry logs.
logger.warning(
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s"
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method}; retrying in {delay:.2f}s"
)
await asyncio.sleep(delay)
raise last_exc # pragma: no cover

View File

@ -75,9 +75,12 @@ def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(
def log_exception(e, *args):
logging.exception(e)
for a in args:
if hasattr(a, "text"):
logging.error(a.text)
raise Exception(a.text)
else:
logging.error(str(a))
try:
text = getattr(a, "text")
except Exception:
text = None
if text is not None:
logging.error(text)
raise Exception(text)
logging.error(str(a))
raise e

View File

@ -44,21 +44,27 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
def filter_out(v2docs, operator, value):
ids = []
for input, docids in v2docs.items():
if operator in ["=", "", ">", "<", "", ""]:
try:
if isinstance(input, list):
input = input[0]
input = float(input)
value = float(value)
except Exception:
input = str(input)
value = str(value)
pass
if isinstance(input, str):
input = input.lower()
if isinstance(value, str):
value = value.lower()
for conds in [
(operator == "contains", str(value).lower() in str(input).lower()),
(operator == "not contains", str(value).lower() not in str(input).lower()),
(operator == "in", str(input).lower() in str(value).lower()),
(operator == "not in", str(input).lower() not in str(value).lower()),
(operator == "start with", str(input).lower().startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower())),
(operator == "contains", input in value if not isinstance(input, list) else all([i in value for i in input])),
(operator == "not contains", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
(operator == "in", input in value if not isinstance(input, list) else all([i in value for i in input])),
(operator == "not in", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
(operator == "start with", str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())),
(operator == "empty", not input),
(operator == "not empty", input),
(operator == "=", input == value),
@ -145,6 +151,18 @@ async def apply_meta_data_filter(
return doc_ids
def dedupe_list(values: list) -> list:
seen = set()
deduped = []
for item in values:
key = str(item)
if key in seen:
continue
seen.add(key)
deduped.append(item)
return deduped
def update_metadata_to(metadata, meta):
if not meta:
return metadata
@ -156,11 +174,13 @@ def update_metadata_to(metadata, meta):
return metadata
if not isinstance(meta, dict):
return metadata
for k, v in meta.items():
if isinstance(v, list):
v = [vv for vv in v if isinstance(vv, str)]
if not v:
continue
v = dedupe_list(v)
if not isinstance(v, list) and not isinstance(v, str):
continue
if k not in metadata:
@ -171,6 +191,7 @@ def update_metadata_to(metadata, meta):
metadata[k].extend(v)
else:
metadata[k].append(v)
metadata[k] = dedupe_list(metadata[k])
else:
metadata[k] = v
@ -202,4 +223,4 @@ def metadata_schema(metadata: list|None) -> Dict[str, Any]:
}
json_schema["additionalProperties"] = False
return json_schema
return json_schema

72
common/query_base.py Normal file
View File

@ -0,0 +1,72 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC, abstractmethod
class QueryBase(ABC):
@staticmethod
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
@abstractmethod
def question(self, text, tbl, min_match):
"""
Returns a query object based on the input text, table, and minimum match criteria.
"""
raise NotImplementedError("Not implemented")

View File

@ -39,6 +39,9 @@ from rag.utils.oss_conn import RAGFlowOSS
from rag.nlp import search
import memory.utils.es_conn as memory_es_conn
import memory.utils.infinity_conn as memory_infinity_conn
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
@ -76,9 +79,11 @@ FEISHU_OAUTH = None
OAUTH_CONFIG = None
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
MSG_ENGINE = DOC_ENGINE
docStoreConn = None
msgStoreConn = None
retriever = None
kg_retriever = None
@ -256,6 +261,15 @@ def init_settings():
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global MSG_ENGINE, msgStoreConn
MSG_ENGINE = DOC_ENGINE # use the same engine for message store
if MSG_ENGINE == "elasticsearch":
ES = get_base_config("es", {})
msgStoreConn = memory_es_conn.ESConnection()
elif MSG_ENGINE == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
msgStoreConn = memory_infinity_conn.InfinityConnection()
global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})

View File

@ -14,6 +14,7 @@
# limitations under the License.
import datetime
import logging
import time
def current_timestamp():
@ -123,4 +124,31 @@ def delta_seconds(date_string: str):
3600.0 # If current time is 2024-01-01 13:00:00
"""
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
return (datetime.datetime.now() - dt).total_seconds()
return (datetime.datetime.now() - dt).total_seconds()
def format_iso_8601_to_ymd_hms(time_str: str) -> str:
"""
Convert ISO 8601 formatted string to "YYYY-MM-DD HH:MM:SS" format.
Args:
time_str: ISO 8601 date string (e.g. "2024-01-01T12:00:00Z")
Returns:
str: Date string in "YYYY-MM-DD HH:MM:SS" format
Example:
>>> format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z")
'2024-01-01 12:00:00'
"""
from dateutil import parser
try:
if parser.isoparse(time_str):
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M:%S")
else:
return time_str
except Exception as e:
logging.error(str(e))
return time_str

View File

@ -44,17 +44,23 @@ def total_token_count_from_response(resp):
if resp is None:
return 0
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
try:
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
return resp.usage.total_tokens
except Exception:
pass
except Exception:
pass
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
try:
try:
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
return resp.usage_metadata.total_tokens
except Exception:
pass
except Exception:
pass
try:
if hasattr(resp, "meta") and hasattr(resp.meta, "billed_units") and hasattr(resp.meta.billed_units, "input_tokens"):
return resp.meta.billed_units.input_tokens
except Exception:
pass
if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']:
try:
@ -79,4 +85,3 @@ def total_token_count_from_response(resp):
def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len])

View File

@ -31,6 +31,7 @@
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"source_id": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"n_hop_with_weight": {"type": "varchar", "default": ""},
"mom_with_weight": {"type": "varchar", "default": ""},
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"doc_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"toc_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},

View File

@ -762,6 +762,13 @@
"status": "1",
"rank": "940",
"llm": [
{
"llm_name": "glm-4.7",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "glm-4.5",
"tags": "LLM,CHAT,128K",
@ -1251,6 +1258,12 @@
"status": "1",
"rank": "810",
"llm": [
{
"llm_name": "MiniMax-M2.1",
"tags": "LLM,CHAT,200k",
"max_tokens": 200000,
"model_type": "chat"
},
{
"llm_name": "MiniMax-M2",
"tags": "LLM,CHAT,200k",

View File

@ -0,0 +1,19 @@
{
"id": {"type": "varchar", "default": ""},
"message_id": {"type": "integer", "default": 0},
"message_type_kwd": {"type": "varchar", "default": ""},
"source_id": {"type": "integer", "default": 0},
"memory_id": {"type": "varchar", "default": ""},
"user_id": {"type": "varchar", "default": ""},
"agent_id": {"type": "varchar", "default": ""},
"session_id": {"type": "varchar", "default": ""},
"valid_at": {"type": "varchar", "default": ""},
"valid_at_flt": {"type": "float", "default": 0.0},
"invalid_at": {"type": "varchar", "default": ""},
"invalid_at_flt": {"type": "float", "default": 0.0},
"forget_at": {"type": "varchar", "default": ""},
"forget_at_flt": {"type": "float", "default": 0.0},
"status_int": {"type": "integer", "default": 1},
"zone_id": {"type": "integer", "default": 0},
"content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_ltks"}
}

View File

@ -18,6 +18,7 @@ from io import BytesIO
import pandas as pd
from openpyxl import Workbook, load_workbook
from PIL import Image
from rag.nlp import find_codec
@ -109,6 +110,52 @@ class RAGFlowExcelParser:
ws.cell(row=row_num, column=col_num, value=value)
return wb
@staticmethod
def _extract_images_from_worksheet(ws, sheetname=None):
"""
Extract images from a worksheet and enrich them with vision-based descriptions.
Returns: List[dict]
"""
images = getattr(ws, "_images", [])
if not images:
return []
raw_items = []
for img in images:
try:
img_bytes = img._data()
pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
anchor = img.anchor
if hasattr(anchor, "_from") and hasattr(anchor, "_to"):
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
r2, c2 = anchor._to.row + 1, anchor._to.col + 1
if r1 == r2 and c1 == c2:
span = "single_cell"
else:
span = "multi_cell"
else:
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
r2, c2 = r1, c1
span = "single_cell"
item = {
"sheet": sheetname or ws.title,
"image": pil_img,
"image_description": "",
"row_from": r1,
"col_from": c1,
"row_to": r2,
"col_to": c2,
"span_type": span,
}
raw_items.append(item)
except Exception:
continue
return raw_items
def html(self, fnm, chunk_rows=256):
from html import escape

View File

@ -38,8 +38,8 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
if not tbls:
return []
if not sections:
return tbls
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
@ -55,6 +55,31 @@ def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs):
tbls = []
if not images:
return []
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.2, "Visual model detected. Attempting to enhance Excel image extraction...")
except Exception:
vision_model = None
if vision_model:
figures_data = [((
img["image"], # Image.Image
[img["image_description"]] # description list (must be list)
),
[
(0, 0, 0, 0, 0) # dummy position
]) for img in images]
try:
parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
callback(0.22, "Parsing images...")
boosted_figures = parser(callback=callback)
tbls.extend(boosted_figures)
except Exception as e:
callback(0.25, f"Excel visual model error: {e}. Skipping vision enhancement.")
return tbls
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
if not tbls:

View File

@ -511,7 +511,7 @@ class MinerUParser(RAGFlowPdfParser):
for output in outputs:
match output["type"]:
case MinerUContentType.TEXT:
section = output["text"]
section = output.get("text", "")
case MinerUContentType.TABLE:
section = output.get("table_body", "") + "\n".join(output.get("table_caption", [])) + "\n".join(
output.get("table_footnote", []))
@ -521,13 +521,13 @@ class MinerUParser(RAGFlowPdfParser):
section = "".join(output.get("image_caption", [])) + "\n" + "".join(
output.get("image_footnote", []))
case MinerUContentType.EQUATION:
section = output["text"]
section = output.get("text", "")
case MinerUContentType.CODE:
section = output["code_body"] + "\n".join(output.get("code_caption", []))
section = output.get("code_body", "") + "\n".join(output.get("code_caption", []))
case MinerUContentType.LIST:
section = "\n".join(output.get("list_items", []))
case MinerUContentType.DISCARDED:
pass
continue # Skip discarded blocks entirely
if section and parse_method == "manual":
sections.append((section, output["type"], self._line_tag(output)))

View File

@ -1447,6 +1447,7 @@ class VisionParser(RAGFlowPdfParser):
def __init__(self, vision_model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vision_model = vision_model
self.outlines = []
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
try:

View File

@ -88,12 +88,9 @@ class RAGFlowPptParser:
texts = []
for shape in sorted(
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)):
try:
txt = self.__extract(shape)
if txt:
texts.append(txt)
except Exception as e:
logging.exception(e)
txt = self.__extract(shape)
if txt:
texts.append(txt)
txts.append("\n".join(texts))
return txts

View File

@ -72,7 +72,7 @@ services:
infinity:
profiles:
- infinity
image: infiniflow/infinity:v0.6.11
image: infiniflow/infinity:v0.6.13
volumes:
- infinity_data:/var/infinity
- ./infinity_conf.toml:/infinity_conf.toml

View File

@ -1,5 +1,5 @@
[general]
version = "0.6.11"
version = "0.6.13"
time_zone = "utc-8"
[network]

0
docker/launch_backend_service.sh Normal file → Executable file
View File

View File

@ -493,18 +493,35 @@ See [here](./guides/agent/best_practices/accelerate_agent_question_answering.md)
### How to use MinerU to parse PDF documents?
MinerU PDF document parsing is available starting from v0.22.0. RAGFlow works only as a remote client to MinerU (>= 2.6.3) and does not install or execute MinerU locally. To use this feature:
From v0.22.0 onwards, RAGFlow includes MinerU (&ge; 2.6.3) as an optional PDF parser of multiple backends. Please note that RAGFlow acts only as a *remote client* for MinerU, calling the MinerU API to parse PDFs and reading the returned files. To use this feature:
1. Prepare a reachable MinerU API service (for example, the FastAPI server provided by MinerU).
2. Configure RAGFlow with remote MinerU settings (environment variables or UI model provider):
- `MINERU_APISERVER`: MinerU API endpoint, for example `http://mineru-host:8886`.
- `MINERU_BACKEND`: MinerU backend, defaults to `pipeline` (supports `vlm-http-client`, `vlm-transformers`, `vlm-vllm-engine`, `vlm-mlx-engine`, `vlm-vllm-async-engine`, `vlm-lmdeploy-engine`).
- `MINERU_SERVER_URL`: (optional) For `vlm-http-client`, the downstream vLLM HTTP server, for example `http://vllm-host:30000`.
- `MINERU_OUTPUT_DIR`: (optional) Local directory to store MinerU API outputs (zip/JSON) before ingestion.
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temp dir is used (`1` deletes temp outputs; set `0` to keep).
3. In the web UI, navigate to the **Configuration** page of your dataset. Click **Built-in** in the **Ingestion pipeline** section, select a chunking method from the **Built-in** dropdown (which supports PDF parsing), and select **MinerU** in **PDF parser**.
4. If you use a custom ingestion pipeline instead, provide the same MinerU settings and select **MinerU** in the **Parsing method** section of the **Parser** component.
1. Prepare a reachable MinerU API service (FastAPI server).
2. In the **.env** file or from the **Model providers** page in the UI, configure RAGFlow as a remote client to MinerU:
- `MINERU_APISERVER`: The MinerU API endpoint (e.g., `http://mineru-host:8886`).
- `MINERU_BACKEND`: The MinerU backend:
- `"pipeline"` (default)
- `"vlm-http-client"`
- `"vlm-transformers"`
- `"vlm-vllm-engine"`
- `"vlm-mlx-engine"`
- `"vlm-vllm-async-engine"`
- `"vlm-lmdeploy-engine"`.
- `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
- `MINERU_OUTPUT_DIR`: (optional) The local directory for holding the outputs of the MinerU API service (zip/JSON) before ingestion.
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temporary directory is used:
- `1`: Delete.
- `0`: Retain.
3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
- If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown.
- If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component.
:::note
All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI.
:::
:::caution WARNING
Third-party visual models are marked **Experimental**, because we have not fully tested these models for the aforementioned data extraction tasks.
:::
---
### How to configure MinerU-specific settings?

View File

@ -24,7 +24,7 @@ We use gVisor to isolate code execution from the host system. Please follow [the
RAGFlow Sandbox is a secure, pluggable code execution backend. It serves as the code executor for the **Code** component. Please follow the [instructions here](https://github.com/infiniflow/ragflow/tree/main/sandbox) to install RAGFlow Sandbox.
:::note Docker client version
The executor manager image now bundles Docker CLI `29.1.0` (API 1.44+). Older images shipped Docker 24.x and will fail against newer Docker daemons with `client version 1.43 is too old`. Pull the latest `infiniflow/sandbox-executor-manager:latest` or rebuild `./sandbox/executor_manager` if you encounter this error.
The executor manager image now bundles Docker CLI `29.1.0` (API 1.44+). Older images shipped Docker 24.x and will fail against newer Docker daemons with `client version 1.43 is too old`. Pull the latest `infiniflow/sandbox-executor-manager:latest` or rebuild it in `./sandbox/executor_manager` if you encounter this error.
:::
:::tip NOTE
@ -134,7 +134,7 @@ Your executor manager image includes Docker CLI 24.x (API 1.43), but the host Do
**Solution**
Pull the latest executor manager image or rebuild it locally to upgrade the built-in Docker client:
Pull the latest executor manager image or rebuild it in `./sandbox/executor_manager` to upgrade the built-in Docker client:
```bash
docker pull infiniflow/sandbox-executor-manager:latest

View File

@ -0,0 +1,90 @@
---
sidebar_position: 30
slug: /http_request_component
---
# HTTP request component
A component that calls remote services.
---
An **HTTP request** component lets you access remote APIs or services by providing a URL and an HTTP method, and then receive the response. You can customize headers, parameters, proxies, and timeout settings, and use common methods like GET and POST. Its useful for exchanging data with external systems in a workflow.
## Prerequisites
- An accessible remote API or service.
- Add a Token or credentials to the request header, if the target service requires authentication.
## Configurations
### Url
*Required*. The complete request address, for example: http://api.example.com/data.
### Method
The HTTP request method to select. Available options:
- GET
- POST
- PUT
### Timeout
The maximum waiting time for the request, in seconds. Defaults to `60`.
### Headers
Custom HTTP headers can be set here, for example:
```http
{
"Accept": "application/json",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
```
### Proxy
Optional. The proxy server address to use for this request.
### Clean HTML
`Boolean`: Whether to remove HTML tags from the returned results and keep plain text only.
### Parameter
*Optional*. Parameters to send with the HTTP request. Supports key-value pairs:
- To assign a value using a dynamic system variable, set it as Variable.
- To override these dynamic values under certain conditions and use a fixed static value instead, Value is the appropriate choice.
:::tip NOTE
- For GET requests, these parameters are appended to the end of the URL.
- For POST/PUT requests, they are sent as the request body.
:::
#### Example setting
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/http_settings.png)
#### Example response
```html
{ "args": { "App": "RAGFlow", "Query": "How to do?", "Userid": "241ed25a8e1011f0b979424ebc5b108b" }, "headers": { "Accept": "/", "Accept-Encoding": "gzip, deflate, br, zstd", "Cache-Control": "no-cache", "Host": "httpbin.org", "User-Agent": "python-requests/2.32.2", "X-Amzn-Trace-Id": "Root=1-68c9210c-5aab9088580c130a2f065523" }, "origin": "185.36.193.38", "url": "https://httpbin.org/get?Userid=241ed25a8e1011f0b979424ebc5b108b&App=RAGFlow&Query=How+to+do%3F" }
```
### Output
The global variable name for the output of the HTTP request component, which can be referenced by other components in the workflow.
- `Result`: `string` The response returned by the remote service.
## Example
This is a usage example: a workflow sends a GET request from the **Begin** component to `https://httpbin.org/get` via the **HTTP Request_0** component, passes parameters to the server, and finally outputs the result through the **Message_0** component.
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/http_usage.PNG)

View File

@ -40,21 +40,31 @@ The output of a PDF parser is `json`. In the PDF parser, you select the parsing
- A third-party visual model from a specific model provider.
:::danger IMPORTANT
MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports MinerU (>= 2.6.3) as an optional PDF parser with multiple backends. RAGFlow acts only as a **remote client** for MinerU, calling the MinerU API to parse documents, reading the returned output files, and ingesting the parsed content. To use this feature:
Starting from v0.22.0, RAGFlow includes MinerU (&ge; 2.6.3) as an optional PDF parser of multiple backends. Please note that RAGFlow acts only as a *remote client* for MinerU, calling the MinerU API to parse documents and reading the returned files. To use this feature:
:::
1. Prepare a reachable MinerU API service (FastAPI server).
2. Configure RAGFlow with the remote MinerU settings (env or UI model provider):
- `MINERU_APISERVER`: MinerU API endpoint, for example `http://mineru-host:8886`.
- `MINERU_BACKEND`: MinerU backend, defaults to `pipeline` (supports `vlm-http-client`, `vlm-transformers`, `vlm-vllm-engine`, `vlm-mlx-engine`, `vlm-vllm-async-engine`, `vlm-lmdeploy-engine`).
- `MINERU_SERVER_URL`: (optional) For `vlm-http-client`, the downstream vLLM HTTP server, for example `http://vllm-host:30000`.
- `MINERU_OUTPUT_DIR`: (optional) Local directory to store MinerU API outputs (zip/JSON) before ingestion.
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temp dir is used (`1` deletes temp outputs; set `0` to keep).
3. In the web UI, navigate to the **Configuration** page of your dataset. Click **Built-in** in the **Ingestion pipeline** section, select a chunking method from the **Built-in** dropdown, which supports PDF parsing, and select **MinerU** in **PDF parser**.
4. If you use a custom ingestion pipeline instead, provide the same MinerU settings and select **MinerU** in the **Parsing method** section of the **Parser** component.
2. In the **.env** file or from the **Model providers** page in the UI, configure RAGFlow as a remote client to MinerU:
- `MINERU_APISERVER`: The MinerU API endpoint (e.g., `http://mineru-host:8886`).
- `MINERU_BACKEND`: The MinerU backend:
- `"pipeline"` (default)
- `"vlm-http-client"`
- `"vlm-transformers"`
- `"vlm-vllm-engine"`
- `"vlm-mlx-engine"`
- `"vlm-vllm-async-engine"`
- `"vlm-lmdeploy-engine"`.
- `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
- `MINERU_OUTPUT_DIR`: (optional) The local directory for holding the outputs of the MinerU API service (zip/JSON) before ingestion.
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temporary directory is used:
- `1`: Delete.
- `0`: Retain.
3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
- If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown.
- If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component.
:::note
All MinerU environment variables are optional. If set, RAGFlow will auto-provision a MinerU OCR model for the tenant on first use with these values. To avoid auto-provisioning, configure MinerU solely through the UI and leave the env vars unset.
All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI.
:::
:::caution WARNING

View File

@ -29,7 +29,7 @@ The architecture consists of isolated Docker base images for each supported lang
- (Optional) GNU Make for simplified command-line management.
:::tip NOTE
The error message `client version 1.43 is too old. Minimum supported API version is 1.44` indicates that your executor manager image's built-in Docker CLI version is lower than `29.1.0` required by the Docker daemon in use. To solve this issue, pull the latest `infiniflow/sandbox-executor-manager:latest` from Docker Hub (or rebuild `./sandbox/executor_manager`).
The error message `client version 1.43 is too old. Minimum supported API version is 1.44` indicates that your executor manager image's built-in Docker CLI version is lower than `29.1.0` required by the Docker daemon in use. To solve this issue, pull the latest `infiniflow/sandbox-executor-manager:latest` from Docker Hub or rebuild it in `./sandbox/executor_manager`.
:::
## Build Docker base images

View File

@ -45,7 +45,7 @@ Google Cloud external project.
http://localhost:9380/v1/connector/google-drive/oauth/web/callback
```
### If using Docker deployment:
- If using Docker deployment:
**Authorized JavaScript origin:**
```
@ -53,15 +53,16 @@ http://localhost:80
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image8.png?raw=true)
### If running from source:
- If running from source:
**Authorized JavaScript origin:**
```
http://localhost:9222
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image9.png?raw=true)
5. After saving, click **Download JSON**. This file will later be
uploaded into RAGFlow.
5. After saving, click **Download JSON**. This file will later be uploaded into RAGFlow.
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image10.png?raw=true)

View File

@ -40,21 +40,31 @@ RAGFlow isn't one-size-fits-all. It is built for flexibility and supports deeper
- A third-party visual model from a specific model provider.
:::danger IMPORTANT
MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports MinerU (>= 2.6.3) as an optional PDF parser with multiple backends. RAGFlow acts only as a **remote client** for MinerU, calling the MinerU API to parse documents, reading the returned output files, and ingesting the parsed content. To use this feature:
1. Prepare a reachable MinerU API service (FastAPI server).
2. Configure RAGFlow with the remote MinerU settings (env or UI model provider):
- `MINERU_APISERVER`: MinerU API endpoint, for example `http://mineru-host:8886`.
- `MINERU_BACKEND`: MinerU backend, defaults to `pipeline` (supports `vlm-http-client`, `vlm-transformers`, `vlm-vllm-engine`, `vlm-mlx-engine`, `vlm-vllm-async-engine`).
- `MINERU_SERVER_URL`: (optional) For `vlm-http-client`, the downstream vLLM HTTP server, for example `http://vllm-host:30000`.
- `MINERU_OUTPUT_DIR`: (optional) Local directory to store MinerU API outputs (zip/JSON) before ingestion.
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temp dir is used (`1` deletes temp outputs; set `0` to keep).
3. In the web UI, navigate to the **Configuration** page of your dataset. Click **Built-in** in the **Ingestion pipeline** section, select a chunking method from the **Built-in** dropdown, which supports PDF parsing, and select **MinerU** in **PDF parser**.
4. If you use a custom ingestion pipeline instead, provide the same MinerU settings and select **MinerU** in the **Parsing method** section of the **Parser** component.
Starting from v0.22.0, RAGFlow includes MinerU (&ge; 2.6.3) as an optional PDF parser of multiple backends. Please note that RAGFlow acts only as a *remote client* for MinerU, calling the MinerU API to parse documents and reading the returned files. To use this feature:
:::
1. Prepare a reachable MinerU API service (FastAPI server).
2. In the **.env** file or from the **Model providers** page in the UI, configure RAGFlow as a remote client to MinerU:
- `MINERU_APISERVER`: The MinerU API endpoint (e.g., `http://mineru-host:8886`).
- `MINERU_BACKEND`: The MinerU backend:
- `"pipeline"` (default)
- `"vlm-http-client"`
- `"vlm-transformers"`
- `"vlm-vllm-engine"`
- `"vlm-mlx-engine"`
- `"vlm-vllm-async-engine"`
- `"vlm-lmdeploy-engine"`.
- `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
- `MINERU_OUTPUT_DIR`: (optional) The local directory for holding the outputs of the MinerU API service (zip/JSON) before ingestion.
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temporary directory is used:
- `1`: Delete.
- `0`: Retain.
3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
- If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown.
- If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component.
:::note
All MinerU environment variables are optional. When they are set, RAGFlow will auto-create a MinerU OCR model for a tenant on first use using these values. If you do not want this auto-provisioning, configure MinerU only through the UI and leave the env vars unset.
All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI.
:::
:::caution WARNING

View File

@ -48,6 +48,7 @@ This API follows the same request and response format as OpenAI's API. It allows
- `"model"`: `string`
- `"messages"`: `object list`
- `"stream"`: `boolean`
- `"extra_body"`: `object` (optional)
##### Request example
@ -59,7 +60,20 @@ curl --request POST \
--data '{
"model": "model",
"messages": [{"role": "user", "content": "Say this is a test!"}],
"stream": true
"stream": true,
"extra_body": {
"reference": true,
"metadata_condition": {
"logic": "and",
"conditions": [
{
"name": "author",
"comparison_operator": "is",
"value": "bob"
}
]
}
}
}'
```
@ -74,6 +88,11 @@ curl --request POST \
- `stream` (*Body parameter*) `boolean`
Whether to receive the response as a stream. Set this to `false` explicitly if you prefer to receive the entire response in one go instead of as a stream.
- `extra_body` (*Body parameter*) `object`
Extra request parameters:
- `reference`: `boolean` - include reference in the final chunk (stream) or in the final message (non-stream).
- `metadata_condition`: `object` - metadata filter conditions applied to retrieval results.
#### Response
Stream:
@ -2236,7 +2255,7 @@ Batch update or delete document-level metadata within a specified dataset. If bo
- `"document_ids"`: `list[string]` *optional*
The associated document ID.
- `"metadata_condition"`: `object`, *optional*
- `"logic"`: Defines the logic relation between conditions if multiple conditions are provided. Options:
- `"logic"`: Defines the logic relation between conditions if multiple conditions are provided. Options:
- `"and"` (default)
- `"or"`
- `"conditions"`: `list[object]` *optional*
@ -2266,7 +2285,7 @@ Batch update or delete document-level metadata within a specified dataset. If bo
- `"deletes`: (*Body parameter*), `list[ojbect]`, *optional*
Deletes metadata of the retrieved documents. Each object: `{ "key": string, "value": string }`.
- `"key"`: `string` The name of the key to delete.
- `"value"`: `string` *Optional* The value of the key to delete.
- `"value"`: `string` *Optional* The value of the key to delete.
- When provided, only keys with a matching value are deleted.
- When omitted, all specified keys are deleted.
@ -2533,7 +2552,7 @@ curl --request POST \
:::caution WARNING
`model_type` is an *internal* parameter, serving solely as a temporary workaround for the current model-configuration design limitations.
Its main purpose is to let *multimodal* models (stored in the database as `"image2text"`) pass backend validation/dispatching. Be mindful that:
Its main purpose is to let *multimodal* models (stored in the database as `"image2text"`) pass backend validation/dispatching. Be mindful that:
- Do *not* treat it as a stable public API.
- It is subject to change or removal in future releases.
@ -3185,6 +3204,7 @@ Asks a specified chat assistant a question to start an AI-powered conversation.
- `"stream"`: `boolean`
- `"session_id"`: `string` (optional)
- `"user_id`: `string` (optional)
- `"metadata_condition"`: `object` (optional)
##### Request example
@ -3207,7 +3227,17 @@ curl --request POST \
{
"question": "Who are you",
"stream": true,
"session_id":"9fa7691cb85c11ef9c5f0242ac120005"
"session_id":"9fa7691cb85c11ef9c5f0242ac120005",
"metadata_condition": {
"logic": "and",
"conditions": [
{
"name": "author",
"comparison_operator": "is",
"value": "bob"
}
]
}
}'
```
@ -3225,6 +3255,13 @@ curl --request POST \
The ID of session. If it is not provided, a new session will be generated.
- `"user_id"`: (*Body parameter*), `string`
The optional user-defined ID. Valid *only* when no `session_id` is provided.
- `"metadata_condition"`: (*Body parameter*), `object`
Optional metadata filter conditions applied to retrieval results.
- `logic`: `string`, one of `and` / `or`
- `conditions`: `list[object]` where each condition contains:
- `name`: `string` metadata key
- `comparison_operator`: `string` (e.g. `is`, `not is`, `contains`, `not contains`, `start with`, `end with`, `empty`, `not empty`, `>`, `<`, ``, ``)
- `value`: `string|number|boolean` (optional for `empty`/`not empty`)
#### Response
@ -3601,6 +3638,8 @@ Asks a specified agent a question to start an AI-powered conversation.
[DONE]
```
- You can optionally return step-by-step trace logs (see `return_trace` below).
:::
#### Request
@ -3616,6 +3655,17 @@ Asks a specified agent a question to start an AI-powered conversation.
- `"session_id"`: `string` (optional)
- `"inputs"`: `object` (optional)
- `"user_id"`: `string` (optional)
- `"return_trace"`: `boolean` (optional, default `false`) — include execution trace logs.
#### Streaming events to handle
When `stream=true`, the server sends Server-Sent Events (SSE). Clients should handle these `event` types:
- `message`: streaming content from Message components.
- `message_end`: end of a Message component; may include `reference`/`attachment`.
- `node_finished`: a component finishes; `data.inputs/outputs/error/elapsed_time` describe the node result. If `return_trace=true`, the trace is attached inside the same `node_finished` event (`data.trace`).
The stream terminates with `[DONE]`.
:::info IMPORTANT
You can include custom parameters in the request body, but first ensure they are defined in the [Begin](../guides/agent/agent_component_reference/begin.mdx) component.
@ -3800,6 +3850,92 @@ data: {
"session_id": "cd097ca083dc11f0858253708ecb6573"
}
data: {
"event": "node_finished",
"message_id": "cecdcb0e83dc11f0858253708ecb6573",
"created_at": 1756364483,
"task_id": "d1f79142831f11f09cc51795b9eb07c0",
"data": {
"inputs": {
"sys.query": "how to install neovim?"
},
"outputs": {
"content": "xxxxxxx",
"_created_time": 15294.0382,
"_elapsed_time": 0.00017
},
"component_id": "Agent:EveryHairsChew",
"component_name": "Agent_1",
"component_type": "Agent",
"error": null,
"elapsed_time": 11.2091,
"created_at": 15294.0382,
"trace": [
{
"component_id": "begin",
"trace": [
{
"inputs": {},
"outputs": {
"_created_time": 15257.7949,
"_elapsed_time": 0.00070
},
"component_id": "begin",
"component_name": "begin",
"component_type": "Begin",
"error": null,
"elapsed_time": 0.00085,
"created_at": 15257.7949
}
]
},
{
"component_id": "Agent:WeakDragonsRead",
"trace": [
{
"inputs": {
"sys.query": "how to install neovim?"
},
"outputs": {
"content": "xxxxxxx",
"_created_time": 15257.7982,
"_elapsed_time": 36.2382
},
"component_id": "Agent:WeakDragonsRead",
"component_name": "Agent_0",
"component_type": "Agent",
"error": null,
"elapsed_time": 36.2385,
"created_at": 15257.7982
}
]
},
{
"component_id": "Agent:EveryHairsChew",
"trace": [
{
"inputs": {
"sys.query": "how to install neovim?"
},
"outputs": {
"content": "xxxxxxxxxxxxxxxxx",
"_created_time": 15294.0382,
"_elapsed_time": 0.00017
},
"component_id": "Agent:EveryHairsChew",
"component_name": "Agent_1",
"component_type": "Agent",
"error": null,
"elapsed_time": 11.2091,
"created_at": 15294.0382
}
]
}
]
},
"session_id": "cd097ca083dc11f0858253708ecb6573"
}
data:[DONE]
```
@ -3874,7 +4010,100 @@ Non-stream:
"doc_name": "INSTALL3.md"
}
}
}
},
"trace": [
{
"component_id": "begin",
"trace": [
{
"component_id": "begin",
"component_name": "begin",
"component_type": "Begin",
"created_at": 15926.567517862,
"elapsed_time": 0.0008189299987861887,
"error": null,
"inputs": {},
"outputs": {
"_created_time": 15926.567517862,
"_elapsed_time": 0.0006958619997021742
}
}
]
},
{
"component_id": "Agent:WeakDragonsRead",
"trace": [
{
"component_id": "Agent:WeakDragonsRead",
"component_name": "Agent_0",
"component_type": "Agent",
"created_at": 15926.569121755,
"elapsed_time": 53.49016142000073,
"error": null,
"inputs": {
"sys.query": "how to install neovim?"
},
"outputs": {
"_created_time": 15926.569121755,
"_elapsed_time": 53.489981256001556,
"content": "xxxxxxxxxxxxxx",
"use_tools": [
{
"arguments": {
"query": "xxxx"
},
"name": "search_my_dateset",
"results": "xxxxxxxxxxx"
}
]
}
}
]
},
{
"component_id": "Agent:EveryHairsChew",
"trace": [
{
"component_id": "Agent:EveryHairsChew",
"component_name": "Agent_1",
"component_type": "Agent",
"created_at": 15980.060569101,
"elapsed_time": 23.61718057500002,
"error": null,
"inputs": {
"sys.query": "how to install neovim?"
},
"outputs": {
"_created_time": 15980.060569101,
"_elapsed_time": 0.0003451630000199657,
"content": "xxxxxxxxxxxx"
}
}
]
},
{
"component_id": "Message:SlickDingosHappen",
"trace": [
{
"component_id": "Message:SlickDingosHappen",
"component_name": "Message_0",
"component_type": "Message",
"created_at": 15980.061302513,
"elapsed_time": 23.61655923699982,
"error": null,
"inputs": {
"Agent:EveryHairsChew@content": "xxxxxxxxx",
"Agent:WeakDragonsRead@content": "xxxxxxxxxxx"
},
"outputs": {
"_created_time": 15980.061302513,
"_elapsed_time": 0.0006695749998471001,
"content": "xxxxxxxxxxx"
}
}
]
}
]
},
"event": "workflow_finished",
"message_id": "c4692a2683d911f0858253708ecb6573",

View File

@ -77,6 +77,19 @@ A complete list of models supported by RAGFlow, which will continue to expand.
If your model is not listed here but has APIs compatible with those of OpenAI, click **OpenAI-API-Compatible** on the **Model providers** page to configure your model.
:::
## Example: AI Badgr (OpenAI-compatible)
You can use **AI Badgr** with RAGFlow via the existing OpenAI-API-Compatible provider.
To configure AI Badgr:
- **Provider**: `OpenAI-API-Compatible`
- **Base URL**: `https://aibadgr.com/api/v1`
- **API Key**: your AI Badgr API key (from the AI Badgr dashboard)
- **Model**: any AI Badgr chat or embedding model ID, as exposed by AI Badgr's OpenAI-compatible APIs
AI Badgr implements OpenAI-compatible endpoints for `/v1/chat/completions`, `/v1/embeddings`, and `/v1/models`, so no additional code changes in RAGFlow are required.
:::note
The list of supported models is extracted from [this source](https://github.com/infiniflow/ragflow/blob/main/rag/llm/__init__.py) and may not be the most current. For the latest supported model list, please refer to the Python file.
:::

View File

@ -23,8 +23,8 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
return [
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar.md5",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar.md5",
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],
@ -34,8 +34,8 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
return [
"http://archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
"http://ports.ubuntu.com/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar.md5",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar.md5",
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],
@ -49,10 +49,10 @@ repos = [
]
def download_model(repo_id):
local_dir = os.path.abspath(os.path.join("huggingface.co", repo_id))
os.makedirs(local_dir, exist_ok=True)
snapshot_download(repo_id=repo_id, local_dir=local_dir)
def download_model(repository_id):
local_directory = os.path.abspath(os.path.join("huggingface.co", repository_id))
os.makedirs(local_directory, exist_ok=True)
snapshot_download(repo_id=repository_id, local_dir=local_directory)
if __name__ == "__main__":

View File

@ -132,8 +132,8 @@ class EntityResolution(Extractor):
f"{remain_candidates_to_resolve} remain."
)
except Exception as e:
logging.error(f"Error resolving candidate batch: {e}")
except Exception as exception:
logging.error(f"Error resolving candidate batch: {exception}")
tasks = []
@ -251,7 +251,7 @@ class EntityResolution(Extractor):
ans_list = []
records = [r.strip() for r in results.split(record_delimiter)]
for record in records:
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
pattern_int = fr"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
match_int = re.search(pattern_int, record)
res_int = int(str(match_int.group(1) if match_int else '0'))
if res_int > records_length:

View File

@ -71,18 +71,17 @@ class Extractor:
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
response = ""
for attempt in range(3):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
try:
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
break
except Exception as e:
logging.exception(e)
if attempt == 2:

View File

@ -198,7 +198,7 @@ async def run_graphrag_for_kb(
for d in raw_chunks:
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
if num_tokens_from_string(current_chunk + content) < 4096:
current_chunk += content
else:
if current_chunk:

View File

@ -78,10 +78,6 @@ class GraphExtractor(Extractor):
hint_prompt = self._entity_extract_prompt.format(**self._context_base, input_text=content)
gen_conf = {}
final_result = ""
glean_result = ""
if_loop_result = ""
history = []
logging.info(f"Start processing for {chunk_key}: {content[:25]}...")
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")

View File

@ -24,11 +24,11 @@ from common.misc_utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS
from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
from common.token_utils import num_tokens_from_string
from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
class KGSearch(Dealer):

View File

@ -26,9 +26,9 @@ from networkx.readwrite import json_graph
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
GRAPH_FIELD_SEP = "<SEP>"

View File

@ -96,7 +96,7 @@ ragflow:
infinity:
image:
repository: infiniflow/infinity
tag: v0.6.11
tag: v0.6.13
pullPolicy: IfNotPresent
pullSecrets: []
storage:

0
memory/__init__.py Normal file
View File

View File

240
memory/services/messages.py Normal file
View File

@ -0,0 +1,240 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import List
from common import settings
from common.doc_store.doc_store_base import OrderByExpr, MatchExpr
def index_name(uid: str): return f"memory_{uid}"
class MessageService:
@classmethod
def has_index(cls, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.index_exist(index, memory_id)
@classmethod
def create_index(cls, uid: str, memory_id: str, vector_size: int):
index = index_name(uid)
return settings.msgStoreConn.create_idx(index, memory_id, vector_size)
@classmethod
def delete_index(cls, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.delete_idx(index, memory_id)
@classmethod
def insert_message(cls, messages: List[dict], uid: str, memory_id: str):
index = index_name(uid)
[m.update({
"id": f'{memory_id}_{m["message_id"]}',
"status": 1 if m["status"] else 0
}) for m in messages]
return settings.msgStoreConn.insert(messages, index, memory_id)
@classmethod
def update_message(cls, condition: dict, update_dict: dict, uid: str, memory_id: str):
index = index_name(uid)
if "status" in update_dict:
update_dict["status"] = 1 if update_dict["status"] else 0
return settings.msgStoreConn.update(condition, update_dict, index, memory_id)
@classmethod
def delete_message(cls, condition: dict, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.delete(condition, index, memory_id)
@classmethod
def list_message(cls, uid: str, memory_id: str, agent_ids: List[str]=None, keywords: str=None, page: int=1, page_size: int=50):
index = index_name(uid)
filter_dict = {}
if agent_ids:
filter_dict["agent_id"] = agent_ids
if keywords:
filter_dict["session_id"] = keywords
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status"
],
highlight_fields=[],
condition=filter_dict,
match_expressions=[], order_by=order_by,
offset=(page-1)*page_size, limit=page_size,
index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False
)
total_count = settings.msgStoreConn.get_total(res)
doc_mapping = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status"
])
return {
"message_list": list(doc_mapping.values()),
"total_count": total_count
}
@classmethod
def get_recent_messages(cls, uid_list: List[str], memory_ids: List[str], agent_id: str, session_id: str, limit: int):
index_names = [index_name(uid) for uid in uid_list]
condition_dict = {
"agent_id": agent_id,
"session_id": session_id
}
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status", "content"
],
highlight_fields=[],
condition=condition_dict,
match_expressions=[], order_by=order_by,
offset=0, limit=limit,
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
)
doc_mapping = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id","user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status", "content"
])
return list(doc_mapping.values())
@classmethod
def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions:list[MatchExpr], top_n: int):
index_names = [index_name(uid) for uid in uid_list]
# filter only valid messages by default
if "status" not in condition_dict:
condition_dict["status"] = 1
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at",
"invalid_at", "forget_at", "status", "content"
],
highlight_fields=[],
condition=condition_dict,
match_expressions=match_expressions,
order_by=order_by,
offset=0, limit=top_n,
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
)
docs = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status", "content"
])
return list(docs.values())
@staticmethod
def calculate_message_size(message: dict):
return sys.getsizeof(message["content"]) + sys.getsizeof(message["content_embed"][0]) * len(message["content_embed"])
@classmethod
def calculate_memory_size(cls, memory_ids: List[str], uid_list: List[str]):
index_names = [index_name(uid) for uid in uid_list]
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=["memory_id", "content", "content_embed"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=2000*len(memory_ids),
index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False
)
docs = settings.msgStoreConn.get_fields(res, ["memory_id", "content", "content_embed"])
size_dict = {}
for doc in docs.values():
if size_dict.get(doc["memory_id"]):
size_dict[doc["memory_id"]] += cls.calculate_message_size(doc)
else:
size_dict[doc["memory_id"]] = cls.calculate_message_size(doc)
return size_dict
@classmethod
def pick_messages_to_delete_by_fifo(cls, memory_id: str, uid: str, size_to_delete: int):
select_fields = ["message_id", "content", "content_embed"]
_index_name = index_name(uid)
res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id)
message_list = settings.msgStoreConn.get_fields(res, select_fields)
current_size = 0
ids_to_remove = []
for message in message_list:
if current_size < size_to_delete:
current_size += cls.calculate_message_size(message)
ids_to_remove.append(message["message_id"])
else:
return ids_to_remove, current_size
if current_size >= size_to_delete:
return ids_to_remove, current_size
order_by = OrderByExpr()
order_by.asc("valid_at")
res = settings.msgStoreConn.search(
select_fields=["memory_id", "content", "content_embed"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=2000,
index_names=[_index_name], memory_ids=[memory_id], agg_fields=[]
)
docs = settings.msgStoreConn.get_fields(res, select_fields)
for doc in docs.values():
if current_size < size_to_delete:
current_size += cls.calculate_message_size(doc)
ids_to_remove.append(doc["memory_id"])
else:
return ids_to_remove, current_size
return ids_to_remove, current_size
@classmethod
def get_by_message_id(cls, memory_id: str, message_id: int, uid: str):
index = index_name(uid)
doc_id = f'{memory_id}_{message_id}'
return settings.msgStoreConn.get(doc_id, index, [memory_id])
@classmethod
def get_max_message_id(cls, uid_list: List[str], memory_ids: List[str]):
order_by = OrderByExpr()
order_by.desc("message_id")
index_names = [index_name(uid) for uid in uid_list]
res = settings.msgStoreConn.search(
select_fields=["message_id"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=1,
index_names=index_names, memory_ids=memory_ids,
agg_fields=[], hide_forgotten=False
)
docs = settings.msgStoreConn.get_fields(res, ["message_id"])
if not docs:
return 1
else:
latest_msg = list(docs.values())[0]
return int(latest_msg["message_id"])

185
memory/services/query.py Normal file
View File

@ -0,0 +1,185 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import logging
import json
import numpy as np
from common.query_base import QueryBase
from common.doc_store.doc_store_base import MatchDenseExpr, MatchTextExpr
from common.float_utils import get_float
from rag.nlp import rag_tokenizer, term_weight, synonym
def get_vector(txt, emb_mdl, topk=10, similarity=0.1):
if isinstance(similarity, str) and len(similarity) > 0:
try:
similarity = float(similarity)
except Exception as e:
logging.warning(f"Convert similarity '{similarity}' to float failed: {e}. Using default 0.1")
similarity = 0.1
qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [get_float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
class MsgTextQuery(QueryBase):
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
self.query_fields = [
"content"
]
def question(self, txt, tbl="messages", min_match: float=0.6):
original_query = txt
txt = MsgTextQuery.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = MsgTextQuery.rmWWW(txt)
if not self.is_chinese(txt):
txt = self.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
syns = []
for tk, w in tks_w[:256]:
syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
continue
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100, {"original_query": original_query}
), keywords
def need_fine_grained_tokenize(tk):
if len(tk) < 3:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
txt = self.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
continue
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split()
if need_fine_grained_tokenize(tk)
else []
)
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m,
)
for m in sm
]
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [self.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32:
break
tk = self.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(self.sub_special_char(s))
for s in syns
]
)
if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
if qs:
query = " OR ".join([f"({t})" for t in qs if t])
if not query:
query = otxt
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query}
), keywords
return None, keywords

0
memory/utils/__init__.py Normal file
View File

494
memory/utils/es_conn.py Normal file
View File

@ -0,0 +1,494 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import time
import copy
from elasticsearch import NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
@singleton
class ESConnection(ESConnectionBase):
@staticmethod
def convert_field_name(field_name: str) -> str:
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "content":
return "content_ltks"
case _:
return field_name
@staticmethod
def map_message_to_es_fields(message: dict) -> dict:
"""
Map message dictionary fields to Elasticsearch document/Infinity fields.
:param message: A dictionary containing message details.
:return: A dictionary formatted for Elasticsearch/Infinity indexing.
"""
storage_doc = {
"id": message.get("id"),
"message_id": message["message_id"],
"message_type_kwd": message["message_type"],
"source_id": message["source_id"],
"memory_id": message["memory_id"],
"user_id": message["user_id"],
"agent_id": message["agent_id"],
"session_id": message["session_id"],
"valid_at": message["valid_at"],
"invalid_at": message["invalid_at"],
"forget_at": message["forget_at"],
"status_int": 1 if message["status"] else 0,
"zone_id": message.get("zone_id", 0),
"content_ltks": message["content"],
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
}
return storage_doc
@staticmethod
def get_message_from_es_doc(doc: dict) -> dict:
"""
Convert an Elasticsearch/Infinity document back to a message dictionary.
:param doc: A dictionary representing the Elasticsearch/Infinity document.
:return: A dictionary formatted as a message.
"""
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
message = {
"message_id": doc["message_id"],
"message_type": doc["message_type_kwd"],
"source_id": doc["source_id"] if doc["source_id"] else None,
"memory_id": doc["memory_id"],
"user_id": doc.get("user_id", ""),
"agent_id": doc["agent_id"],
"session_id": doc["session_id"],
"zone_id": doc.get("zone_id", 0),
"valid_at": doc["valid_at"],
"invalid_at": doc.get("invalid_at", "-"),
"forget_at": doc.get("forget_at", "-"),
"status": bool(int(doc["status_int"])),
"content": doc.get("content_ltks", ""),
"content_embed": doc.get(embd_field_name, []) if embd_field_name else [],
}
if doc.get("id"):
message["id"] = doc["id"]
return message
"""
CRUD operations
"""
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bool_query = Q("bool", must=[], must_not=[])
if hide_forgotten:
# filter not forget
bool_query.must_not.append(Q("exists", field="forget_at"))
condition["memory_id"] = memory_ids
for k, v in condition.items():
if k == "session_id" and v:
bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True}))
continue
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(self.convert_field_name(m.vector_column_name),
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
bool_query = Q("bool", must_not=[])
bool_query.must_not.append(Q("term", forget_at=None))
bool_query.filter.append(Q("term", memory_id=memory_id))
# from old to new
order_by = OrderByExpr()
order_by.asc("forget_at")
# build search
s = Search()
s = s.query(bool_query)
s = s.sort(order_by)
s = s[:limit]
q = s.to_dict()
# search
for i in range(ATTEMPT_TIME):
try:
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
message = res["_source"]
message["id"] = doc_id
return self.get_message_from_es_doc(message)
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy_raw = copy.deepcopy(d)
d_copy = self.map_message_to_es_fields(d_copy_raw)
d_copy["memory_id"] = memory_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
doc = copy.deepcopy(new_value)
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
update_dict.pop("id", None)
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict and isinstance(condition_dict["id"], str):
# update specific single document
message_id = condition_dict["id"]
for i in range(ATTEMPT_TIME):
for k in update_dict.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=index_name, id=message_id, doc=update_dict)
return True
except Exception as e:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
break
return False
# update unspecific maybe-multiple documents
bool_query = Q("bool")
for k, v in condition_dict.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in update_dict.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "status_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
assert "_id" not in condition
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict:
message_ids = condition_dict["id"]
if not isinstance(message_ids, list):
message_ids = [message_ids]
if not message_ids: # when message_ids is empty, delete all
qry = Q("match_all")
else:
qry = Q("ids", values=message_ids)
else:
qry = Q("bool")
for k, v in condition_dict.items():
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for doc in self._get_source(res):
message = self.get_message_from_es_doc(doc)
m = {}
for n, v in message.items():
if n not in fields:
continue
if isinstance(v, list):
m[n] = v
continue
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(v)
else:
m[n] = v
if m:
res_fields[doc["id"]] = m
return res_fields

View File

@ -0,0 +1,467 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import copy
from infinity.common import InfinityException, SortType
from infinity.errors import ErrorCode
from common.decorator import singleton
import pandas as pd
from common.constants import PAGERANK_FLD, TAG_FLD
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.infinity_conn_base import InfinityConnectionBase
from common.time_utils import date_string_to_timestamp
@singleton
class InfinityConnection(InfinityConnectionBase):
def __init__(self):
super().__init__()
self.mapping_file_name = "message_infinity_mapping.json"
"""
Dataframe and fields convert
"""
@staticmethod
def field_keyword(field_name: str):
# no keywords right now
return False
@staticmethod
def convert_message_field_to_infinity(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case _:
return field_name
@staticmethod
def convert_infinity_field_to_message(field_name: str):
if field_name.startswith("message_type"):
return "message_type"
if field_name.startswith("status"):
return "status"
if re.match(r"q_\d+_vec", field_name):
return "content_embed"
return field_name
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
return list({self.convert_message_field_to_infinity(f) for f in output_fields})
@staticmethod
def convert_matching_field(field_weight_str: str) -> str:
tokens = field_weight_str.split("^")
field = tokens[0]
if field == "content":
field = "content@ft_contentm_rag_fine"
tokens[0] = field
return "^".join(tokens)
@staticmethod
def convert_condition_and_order_field(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "valid_at":
return "valid_at_flt"
case "invalid_at":
return "invalid_at_flt"
case "forget_at":
return "forget_at_flt"
case _:
return field_name
"""
CRUD operations
"""
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
if hide_forgotten:
condition.update({"must_not": {"exists": "forget_at_flt"}})
output = select_fields.copy()
output = self.convert_select_fields(output)
if agg_fields is None:
agg_fields = []
for essential_field in ["id"] + agg_fields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if match_expressions:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
output.append(PAGERANK_FLD)
output = [f for f in output if f != "_score"]
if limit <= 0:
# ElasticSearch default limit is 10000
limit = 10000
# Prepare expressions common to all tables
filter_cond = None
filter_fulltext = ""
if condition:
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
table_found = False
for indexName in index_names:
for mem_id in memory_ids:
table_name = f"{indexName}_{mem_id}"
try:
filter_cond = self.equivalent_condition_to_str(condition_dict, db_instance.get_table(table_name))
table_found = True
break
except Exception:
pass
if table_found:
break
if not table_found:
self.logger.error(f"No valid tables found for indexNames {index_names} and memoryIds {memory_ids}")
return pd.DataFrame(), 0
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
# Add rank_feature support
if rank_feature and "rank_features" not in matchExpr.extra_options:
# Convert rank_feature dict to Infinity's rank_features string format
# Format: "field^feature_name^weight,field^feature_name^weight"
rank_features_list = []
for feature_name, weight in rank_feature.items():
# Use TAG_FLD as the field containing rank features
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
if rank_features_list:
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
similarity = matchExpr.extra_options.get("similarity")
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in index_names:
for memory_id in memory_ids:
table_name = f"{indexName}_{memory_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(match_expressions) > 0:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if order_by.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
mem_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, output)
if match_expressions:
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
self.logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
condition = {"memory_id": memory_id, "exists": "forget_at_flt"}
order_by = OrderByExpr()
order_by.asc("forget_at_flt")
# query
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
output_fields = [self.convert_message_field_to_infinity(f) for f in select_fields]
builder = table_instance.output(output_fields)
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
builder.filter(filter_cond)
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
builder.sort(order_by_expr_list)
builder.offset(0).limit(limit)
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
res = self.concat_dataframes(mem_res, output_fields)
res.head(limit)
self.connPool.release_conn(inf_conn)
return res
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(memory_ids, list)
table_list = list()
for memoryId in memory_ids:
table_name = f"{index_name}_{memoryId}"
table_list.append(table_name)
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Table not found: {table_name}, this memory isn't created in Infinity. Maybe it is created in other document engine.")
continue
mem_res, _ = table_instance.output(["*"]).filter(f"id = '{message_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
res_fields = self.get_fields(res, list(fields))
return res_fields.get(message_id, None)
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
if not documents:
return []
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
vector_size = int(len(documents[0]["content_embed"]))
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
# src/common/status.cppm, kTableNotExist = 3022
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
raise
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.create_idx(index_name, memory_id, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
embedding_columns = []
table_columns = table_instance.show_columns().rows()
for n, ty, _, _ in table_columns:
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
if not r:
continue
embedding_columns.append((n, int(r.group(1))))
docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in list(d.items()):
field_name = self.convert_message_field_to_infinity(k)
if field_name in ["valid_at", "invalid_at", "forget_at"]:
d[f"{field_name}_flt"] = date_string_to_timestamp(v) if v else 0
if v is None:
d[field_name] = ""
elif self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
d[k] = v
elif k == "memory_id":
if isinstance(d[k], list):
d[k] = d[k][0] # since d[k] is a list, but we need a str
elif field_name == "content_embed":
d[f"q_{vector_size}_vec"] = d["content_embed"]
d.pop("content_embed")
else:
d[field_name] = v
if k != field_name:
d.pop(k)
for n, vs in embedding_columns:
if n in d:
continue
d[n] = [0] * vs
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
filter = self.equivalent_condition_to_str(condition_dict, table_instance)
update_dict = {self.convert_message_field_to_infinity(k): v for k, v in new_value.items()}
date_floats = {}
for k, v in update_dict.items():
if k in ["valid_at", "invalid_at", "forget_at"]:
date_floats[f"{k}_flt"] = date_string_to_timestamp(v) if v else 0
elif self.field_keyword(k):
if isinstance(v, list):
update_dict[k] = "###".join(v)
else:
update_dict[k] = v
elif k == "memory_id":
if isinstance(update_dict[k], list):
update_dict[k] = update_dict[k][0] # since d[k] is a list, but we need a str
else:
update_dict[k] = v
if date_floats:
update_dict.update(date_floats)
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
table_instance.update(filter, update_dict)
self.connPool.release_conn(inf_conn)
return True
"""
Helper functions for search result
"""
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fields_all = fields.copy()
fields_all.append("id")
fields_all = {self.convert_message_field_to_infinity(f) for f in fields_all}
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
none_columns = [col for col in fields_all if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
res2.drop_duplicates(subset=["id"], inplace=True)
for column in list(res2.columns):
k = column.lower()
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
else:
pass
for column in ["content"]:
if column in res2:
del res2[column]
for column in none_columns:
res2[column] = None
res_dict = res2.set_index("id").to_dict(orient="index")
return {_id: {self.convert_infinity_field_to_message(k): v for k, v in doc.items()} for _id, doc in res_dict.items()}

37
memory/utils/msg_util.py Normal file
View File

@ -0,0 +1,37 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
def get_json_result_from_llm_response(response_str: str) -> dict:
"""
Parse the LLM response string to extract JSON content.
The function looks for the first and last curly braces to identify the JSON part.
If parsing fails, it returns an empty dictionary.
:param response_str: The response string from the LLM.
:return: A dictionary parsed from the JSON content in the response.
"""
try:
clean_str = response_str.strip()
if clean_str.startswith('```json'):
clean_str = clean_str[7:] # Remove the starting ```json
if clean_str.endswith('```'):
clean_str = clean_str[:-3] # Remove the ending ```
return json.loads(clean_str.strip())
except (ValueError, json.JSONDecodeError):
return {}

201
memory/utils/prompt_util.py Normal file
View File

@ -0,0 +1,201 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List
from common.constants import MemoryType
from common.time_utils import current_timestamp
class PromptAssembler:
SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist**
You are an expert at analyzing conversations to extract structured memory.
{type_specific_instructions}
**OUTPUT REQUIREMENTS:**
1. Output MUST be valid JSON
2. Follow the specified output format exactly
3. Each extracted item MUST have: content, valid_at, invalid_at
4. Timestamps in {timestamp_format} format
5. Only extract memory types specified above
6. Maximum {max_items} items per type
"""
TYPE_INSTRUCTIONS = {
MemoryType.SEMANTIC.name.lower(): """
**EXTRACT SEMANTIC KNOWLEDGE:**
- Universal facts, definitions, concepts, relationships
- Time-invariant, generally true information
- Examples: "The capital of France is Paris", "Water boils at 100°C"
**Timestamp Rules for Semantic Knowledge:**
- valid_at: When the fact became true (e.g., law enactment, discovery)
- invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true
- Default: valid_at = conversation time, invalid_at = "" for timeless facts
""",
MemoryType.EPISODIC.name.lower(): """
**EXTRACT EPISODIC KNOWLEDGE:**
- Specific experiences, events, personal stories
- Time-bound, person-specific, contextual
- Examples: "Yesterday I fixed the bug", "User reported issue last week"
**Timestamp Rules for Episodic Knowledge:**
- valid_at: Event start/occurrence time
- invalid_at: Event end time or empty if instantaneous
- Extract explicit times: "at 3 PM", "last Monday", "from X to Y"
""",
MemoryType.PROCEDURAL.name.lower(): """
**EXTRACT PROCEDURAL KNOWLEDGE:**
- Processes, methods, step-by-step instructions
- Goal-oriented, actionable, often includes conditions
- Examples: "To reset password, click...", "Debugging steps: 1)..."
**Timestamp Rules for Procedural Knowledge:**
- valid_at: When procedure becomes valid/effective
- invalid_at: When it expires/becomes obsolete or empty if current
- For version-specific: use release dates
- For best practices: invalid_at = ""
"""
}
OUTPUT_TEMPLATES = {
MemoryType.SEMANTIC.name.lower(): """
"semantic": [
{
"content": "Clear factual statement",
"valid_at": "timestamp or empty",
"invalid_at": "timestamp or empty"
}
]
""",
MemoryType.EPISODIC.name.lower(): """
"episodic": [
{
"content": "Narrative event description",
"valid_at": "event start timestamp",
"invalid_at": "event end timestamp or empty"
}
]
""",
MemoryType.PROCEDURAL.name.lower(): """
"procedural": [
{
"content": "Actionable instructions",
"valid_at": "procedure effective timestamp",
"invalid_at": "procedure expiration timestamp or empty"
}
]
"""
}
BASE_USER_PROMPT = """
**CONVERSATION:**
{conversation}
**CONVERSATION TIME:** {conversation_time}
**CURRENT TIME:** {current_time}
"""
@classmethod
def assemble_system_prompt(cls, config: dict) -> str:
types_to_extract = cls._get_types_to_extract(config["memory_type"])
type_instructions = cls._generate_type_instructions(types_to_extract)
output_format = cls._generate_output_format(types_to_extract)
full_prompt = cls.SYSTEM_BASE_TEMPLATE.format(
type_specific_instructions=type_instructions,
timestamp_format=config.get("timestamp_format", "ISO 8601"),
max_items=config.get("max_items_per_type", 5)
)
full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n"
examples = cls._generate_examples(types_to_extract)
if examples:
full_prompt += f"\n**EXAMPLES:**\n{examples}\n"
return full_prompt
@staticmethod
def _get_types_to_extract(requested_types: List[str]) -> List[str]:
types = set()
for rt in requested_types:
if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower():
types.add(rt)
return list(types)
@classmethod
def _generate_type_instructions(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
instructions = [cls.TYPE_INSTRUCTIONS[mt] for mt in target_types]
return "\n".join(instructions)
@classmethod
def _generate_output_format(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
output_parts = [cls.OUTPUT_TEMPLATES[mt] for mt in target_types]
return ",\n".join(output_parts)
@staticmethod
def _generate_examples(types_to_extract: list[str]) -> str:
examples = []
if MemoryType.SEMANTIC.name.lower() in types_to_extract:
examples.append("""
**Semantic Example:**
Input: "Python lists are mutable and support various operations."
Output: {"semantic": [{"content": "Python lists are mutable data structures", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
if MemoryType.EPISODIC.name.lower() in types_to_extract:
examples.append("""
**Episodic Example:**
Input: "I deployed the new feature yesterday afternoon."
Output: {"episodic": [{"content": "User deployed new feature", "valid_at": "2024-01-14T14:00:00", "invalid_at": "2024-01-14T18:00:00"}]}
""")
if MemoryType.PROCEDURAL.name.lower() in types_to_extract:
examples.append("""
**Procedural Example:**
Input: "To debug API errors: 1) Check logs 2) Verify endpoints 3) Test connectivity."
Output: {"procedural": [{"content": "API error debugging: 1. Check logs 2. Verify endpoints 3. Test connectivity", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
return "\n".join(examples)
@classmethod
def assemble_user_prompt(
cls,
conversation: str,
conversation_time: Optional[str] = None,
current_time: Optional[str] = None
) -> str:
return cls.BASE_USER_PROMPT.format(
conversation=conversation,
conversation_time=conversation_time or "Not specified",
current_time=current_time or current_timestamp(),
)
@classmethod
def get_raw_user_prompt(cls):
return cls.BASE_USER_PROMPT

View File

@ -46,7 +46,7 @@ dependencies = [
"groq==0.9.0",
"grpcio-status==1.67.1",
"html-text==0.6.2",
"infinity-sdk==0.6.11",
"infinity-sdk==0.6.13",
"infinity-emb>=0.0.66,<0.0.67",
"jira==3.10.5",
"json-repair==0.35.0",

View File

@ -91,7 +91,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
filename, binary=binary, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
# tbls = [((None, lns), None) for lns in tbls]
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.")
@ -147,9 +147,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
with BytesIO(binary) as binary:
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
try:
from tika import parser as tika_parser
except Exception as e:
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
return []
binary = BytesIO(binary)
doc_parsed = tika_parser.from_buffer(binary)
if doc_parsed.get('content', None) is not None:
sections = doc_parsed['content'].split('\n')
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(

View File

@ -312,7 +312,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1:
last_sid = sec_id
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
@ -325,7 +325,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
for text, image in ti_list:
d = copy.deepcopy(doc)

View File

@ -651,7 +651,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
child_deli = parser_config.get("children_delimiter", "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
if cust_child_deli:

Some files were not shown because too many files have changed in this diff Show More