Compare commits

...

57 Commits

Author SHA1 Message Date
427e0540ca Fix: table tag on chunks. (#12126)
- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 11:23:51 +08:00
8f16fac898 Fix: Add a no-data filter condition to MetaData (#12189)
### What problem does this PR solve?

Fix: Add a no-data filter condition to MetaData

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 10:42:34 +08:00
ea00478a21 Bump infinity to 0.6.13 (#12181)
### What problem does this PR solve?

Bump infinity to 0.6.13

### Type of change

- [x] Refactoring
2025-12-24 22:33:56 +08:00
906f19e863 Dragging down a downstream node of a Switch operator will cause the end_cpn_ids to contain the ID of the placeholder operator. #12177 (#12178)
### What problem does this PR solve?

Dragging down a downstream node of a Switch operator will cause the
end_cpn_ids to contain the ID of the placeholder operator. #12177

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:45:35 +08:00
667dc5467e Fix: Fixed the issue of incorrect agent translation text. #10427 (#12172)
### What problem does this PR solve?

Fix: Fixed the issue of incorrect agent translation text. #10427

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:05:26 +08:00
977962fdfe Fix: loopitem None issue. (#12166)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:22:31 +08:00
1e9374a373 Fix:Metadata saving, copywriting and other related issues (#12169)
### What problem does this PR solve?

Fix:Bugs Fixed
- Text overflow issues that caused rendering problems
- Metadata saving, copywriting and other related issues

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:21:36 +08:00
9b52ba8061 Feat: add image table context to pipeline splitter (#12167)
### What problem does this PR solve?

Add image table context to pipeline splitter.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 16:58:14 +08:00
44671ea413 Fix: type check for chunks (#12164)
### What problem does this PR solve?

Fix: type check for chunks

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 16:36:00 +08:00
c81421d340 Feat: add document metadata setting (#12156)
### What problem does this PR solve?

Add document metadata setting.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:13:50 +08:00
ee93a80e91 Feat: add MiniMax M2.1 (#12148)
### What problem does this PR solve?

Add MiniMax M2.1.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:08:09 +08:00
5c981978c1 Run infinity test before ES (#12159)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 15:52:55 +08:00
7fef285af5 Revert "Bump infinity to 0.6.12 (#12140)" (#12161)
This reverts commit 0588fe79b9.
2025-12-24 15:24:12 +08:00
b1efb905e5 Fix: metadata_obj issue. (#12146)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:40:34 +08:00
6400bf87ba Fix: LLM tool does not exist in multiple retrieval case (#12143)
### What problem does this PR solve?

 Fix LLM tool does not exist in multiple retrieval case

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:26:48 +08:00
f239bc02d3 Feat: Support Markdown Rendering for tips in user-fill-up Component #11825 (#12147)
### What problem does this PR solve?

Feat: Support Markdown Rendering for tips in user-fill-up Component
#11825

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 13:25:56 +08:00
5776fa73a7 refactor: improve memory service date time consistency (#12144)
### What problem does this PR solve?

 improve memory service date time consistency

### Type of change

- [x] Refactoring
2025-12-24 11:00:31 +08:00
fc6af1998b Doc: Added an HTTP request component reference (#12141)
### Type of change

- [x] Documentation Update
2025-12-24 09:35:32 +08:00
0588fe79b9 Bump infinity to 0.6.12 (#12140)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 09:34:54 +08:00
f545265f93 Fix:remove duplicate tool_meta (#12139)
### What problem does this PR solve?
pr:#12117
change:remove duplicate tool_meta

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:34:08 +08:00
c987d33649 Feat: deduplicate metadata lists during updates (#12125)
### What problem does this PR solve?

Deduplicate metadata lists during updates.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 09:32:55 +08:00
d72debf0db Fix: Add prompts when merging or deleting metadata. (#12138)
### What problem does this PR solve?

Fix: Add prompts when merging or deleting metadata.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-24 09:32:41 +08:00
c33134ea2c Fix: table tag on chunks. (#12126)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:32:19 +08:00
17b8bb62b6 Feat: message manage (#12083)
### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 21:16:25 +08:00
bab6a4a219 Fix: /kb/update does not update FileService (#12121)
### What problem does this PR solve?

Fix: /kb/update does not update FileService

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 19:56:38 +08:00
6c93157b14 Refa: image table context window (#12132)
### What problem does this PR solve?

Image table context window

### Type of change

- [x] Refactoring
2025-12-23 19:51:01 +08:00
033029eaa1 Fix: The form waiting for input is not displayed in the dialog message. #12129 (#12130)
### What problem does this PR solve?
Fix: The form waiting for input is not displayed in the dialog message.
#12129

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 17:59:55 +08:00
a958ddb27a refactor: reword locale translations (#12118)
### What problem does this PR solve?

Reword (in locales/en) "Image context window" to "Image & table context
window", etc.

### Type of change

- [x] Refactoring
2025-12-23 17:34:21 +08:00
f63f007326 fix: add null safety checks in webhook response status hook (#12114)
### What problem does this PR solve?

Add optional chaining operators to prevent runtime errors when formData
is undefined or null in useShowWebhookResponseStatus hook.

This fixes a potential crash when accessing mode and execution_mode
properties before formData is initialized or when the Begin node doesn't
exist in the graph.

🤖 Generated with [Claude Code](https://claude.com/claude-code)


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-23 16:16:30 +08:00
b47f1afa35 fix: transformer toc prompt text incorrect (#12116)
### What problem does this PR solve?

Fix incorrect prompt texts in **Agent** canvas > **Transformer** >
**Result destination: Table of contents**

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 15:59:09 +08:00
2369be7244 Refactor: enhance next_step prompt (#12117)
### What problem does this PR solve?

change:
enhance next_step prompt

### Type of change

- [x] Refactoring
2025-12-23 15:57:55 +08:00
00bb6fbd28 Fix: metadata issue & graphrag speeding up. (#12113)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Liu An <asiro@qq.com>
2025-12-23 15:57:27 +08:00
063b06494a redirect stderr to stdout (#12122)
### What problem does this PR solve?

Update workflows

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-23 15:57:21 +08:00
b824185a3a Feat: Translate the text of the webhook debugging interface. #10427 (#12115)
### What problem does this PR solve?

Feat: Translate the text of the webhook debugging interface. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: balibabu <assassin_cike@163.com>
2025-12-23 15:25:38 +08:00
8e6ddd7c1b Fix: Metadata bugs. (#12111)
### What problem does this PR solve?

Fix: Metadata bugs.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-23 14:16:57 +08:00
d1bc7ad2ee Fix only one of multiple retrieval tools is effective (#12110)
### What problem does this PR solve?

Fix only one of multiple retrieval tools is effective

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 14:08:25 +08:00
321474fb97 Fix: update method call to use simplified async tool reaction (#12108)
### What problem does this PR solve?
pr:#12091
change:update method call to use simplified async tool reaction

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 13:36:58 +08:00
ea89e4e0c6 Feat: add GLM-4.7 (#12102)
### What problem does this PR solve?

 Add GLM-4.7.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 12:38:56 +08:00
9e31631d8f Feat: Add memory multi-select dropdown to recall and message operator forms. #4213 (#12106)
### What problem does this PR solve?

Feat: Add memory multi-select dropdown to recall and message operator
forms. #4213

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 11:54:32 +08:00
712d537d66 Fix: vision_figure_parser_docx/pdf_wrapper (#12104)
### What problem does this PR solve?

Fix: vision_figure_parser_docx/pdf_wrapper  #11735

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 11:51:28 +08:00
bd4eb19393 Fix:Bugs fix (Reduce metadata saving steps ...) (#12095)
### What problem does this PR solve?

Fix:Bugs fix
- Configure memory and metadata (in Chinese)
- Add indexing modal
- Reduce metadata saving steps

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-23 11:50:35 +08:00
02efab7c11 Feat: Hide part of the message field in webhook mode #10427 (#12100)
### What problem does this PR solve?

Feat: Hide part of the message field in webhook mode  #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: balibabu <assassin_cike@163.com>
2025-12-23 10:45:05 +08:00
8ce129bc51 Update workflow (#12101)
### What problem does this PR solve?

As title

### Type of change

- [x] Other (please describe): Update GitHub action

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-23 10:03:24 +08:00
d5a44e913d Fix: fix task cancel (#12093)
### What problem does this PR solve?

Fix: fix task cancel

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 09:38:25 +08:00
1444de981c Feat: enhance webhook response to include status and success fields and simplify ReAct agent (#12091)
### What problem does this PR solve?

change:
enhance webhook response to include status and success fields and
simplify ReAct agent

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 09:36:08 +08:00
bd76b8ff1a Fix: Tika server upgrades. (#12073)
### What problem does this PR solve?

#12037

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-23 09:35:52 +08:00
a95f22fa88 Feat: output intinity test log (#12097)
### What problem does this PR solve?

Output log to file when run infinity tests.

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 21:33:08 +08:00
38ac6a7c27 feat: add image context window in dataset config (#12094)
### What problem does this PR solve?

Add image context window configuration in **Dataset** >
**Configduration** and **Dataset** > **Files** > **Parse** > **Ingestion
Pipeline** (**Chunk Method** modal)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 19:51:23 +08:00
e5f3d5ae26 Refactor add_llm and add speech to text (#12089)
### What problem does this PR solve?

1. Refactor implementation of add_llm
2. Add speech to text model.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-22 19:27:26 +08:00
4cbc91f2fa Feat: optimize aws s3 connector (#12078)
### What problem does this PR solve?

Feat: optimize aws s3 connector #12008 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-22 19:06:01 +08:00
6d3d3a40ab fix: hide drop-zone upload button when picked an image (#12088)
### What problem does this PR solve?

Hide drop-zone upload button when picked an image in chunk editor dialog

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-22 19:04:44 +08:00
51b12841d6 Feature/1217 (#12087)
### What problem does this PR solve?

feature: Complete metadata functionality

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 17:35:12 +08:00
993bf7c2c8 Fix IDE warnings (#12085)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-22 16:47:21 +08:00
b42b5fcf65 feat: display chunk type in chunk editor and dialog (#12086)
### What problem does this PR solve?

Display chunk type in chunk editor and dialog, may be one of below:
- Image
- Table
- Text

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 16:45:47 +08:00
5d391fb1f9 fix: guard Dashscope response attribute access in token/log utils (#12082)
### What problem does this PR solve?

Guard Dashscope response attribute access in token/log utils, since
`dashscope_response` returns dict like object.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-22 16:17:58 +08:00
2ddfcc7cf6 Images that appear consecutively in the dialogue are displayed using a carousel. #12076 (#12077)
### What problem does this PR solve?

Images that appear consecutively in the dialogue are displayed using a
carousel. #12076

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-22 14:41:02 +08:00
5ba51b21c9 Feat: When the webhook returns a field in streaming format, the message displays the status field. #10427 (#12075)
### What problem does this PR solve?

Feat: When the webhook returns a field in streaming format, the message
displays the status field. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: balibabu <assassin_cike@163.com>
2025-12-22 14:37:39 +08:00
180 changed files with 6845 additions and 2423 deletions

View File

@ -197,37 +197,38 @@ jobs:
echo -e "COMPOSE_PROFILES=\${COMPOSE_PROFILES},tei-cpu" >> docker/.env
echo -e "TEI_MODEL=BAAI/bge-small-en-v1.5" >> docker/.env
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
sed -i '1i DOC_ENGINE=infinity' docker/.env
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
- name: Run sdk tests against Elasticsearch
- name: Run sdk tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
- name: Run frontend api tests against Elasticsearch
- name: Run frontend api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
- name: Run http api tests against Elasticsearch
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
- name: Run http api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed
@ -237,35 +238,35 @@ jobs:
- name: Start ragflow:nightly
run: |
sed -i '1i DOC_ENGINE=infinity' docker/.env
sed -i '1i DOC_ENGINE=elasticsearch' docker/.env
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
- name: Run sdk tests against Infinity
- name: Run sdk tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
- name: Run frontend api tests against Infinity
- name: Run frontend api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
- name: Run http api tests against Infinity
- name: Run http api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed

View File

@ -192,6 +192,7 @@ COPY pyproject.toml uv.lock ./
COPY mcp mcp
COPY plugin plugin
COPY common common
COPY memory memory
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
COPY docker/entrypoint.sh ./

View File

@ -29,6 +29,11 @@ from common.versions import get_ragflow_version
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
@admin_bp.route('/ping', methods=['GET'])
def ping():
return success_response('PONG')
@admin_bp.route('/login', methods=['POST'])
def login():
if not request.json:

View File

@ -540,6 +540,8 @@ class Canvas(Graph):
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
message_end = {}
if cpn_obj.get_param("status"):
message_end["status"] = cpn_obj.get_param("status")
if isinstance(cpn_obj.output("attachment"), dict):
message_end["attachment"] = cpn_obj.output("attachment")
if cite:

View File

@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from rag.prompts.generator import next_step_async, COMPLETE_TASK, \
citation_prompt, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
@ -84,9 +84,11 @@ class Agent(LLM, ToolBase):
def __init__(self, canvas, id, param: LLMParam):
LLM.__init__(self, canvas, id, param)
self.tools = {}
for cpn in self._param.tools:
for idx, cpn in enumerate(self._param.tools):
cpn = self._load_tool_obj(cpn)
self.tools[cpn.get_meta()["function"]["name"]] = cpn
original_name = cpn.get_meta()["function"]["name"]
indexed_name = f"{original_name}_{idx}"
self.tools[indexed_name] = cpn
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
max_retries=self._param.max_retries,
@ -94,7 +96,12 @@ class Agent(LLM, ToolBase):
max_rounds=self._param.max_rounds,
verbose_tool_use=True
)
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
self.tool_meta = []
for indexed_name, tool_obj in self.tools.items():
original_meta = tool_obj.get_meta()
indexed_meta = deepcopy(original_meta)
indexed_meta["function"]["name"] = indexed_name
self.tool_meta.append(indexed_meta)
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
@ -108,7 +115,8 @@ class Agent(LLM, ToolBase):
def _load_tool_obj(self, cpn: dict) -> object:
from agent.component import component_class
param = component_class(cpn["component_name"] + "Param")()
tool_name = cpn["component_name"]
param = component_class(tool_name + "Param")()
param.update(cpn["params"])
try:
param.check()
@ -202,7 +210,7 @@ class Agent(LLM, ToolBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
async for delta_ans, _tk in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
@ -246,7 +254,7 @@ class Agent(LLM, ToolBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
async for delta_ans, _ in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
@ -264,7 +272,7 @@ class Agent(LLM, ToolBase):
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
async def _react_with_tools_streamly_async_simple(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
@ -276,6 +284,24 @@ class Agent(LLM, ToolBase):
else:
user_request = history[-1]["content"]
def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str:
"""Build a minimal task_desc by concatenating prompt, query, and tool schemas."""
user_defined_prompt = user_defined_prompt or {}
task_desc = (
"### Agent Prompt\n"
f"{prompt}\n\n"
"### User Request\n"
f"{user_request}\n\n"
)
if user_defined_prompt:
udp_json = json.dumps(user_defined_prompt, ensure_ascii=False, indent=2)
task_desc += "\n### User Defined Prompts\n" + udp_json + "\n"
return task_desc
async def use_tool_async(name, args):
nonlocal hist, use_tools, last_calling
logging.info(f"{last_calling=} == {name=}")
@ -286,9 +312,6 @@ class Agent(LLM, ToolBase):
"arguments": args,
"results": tool_response
})
# self.callback("add_memory", {}, "...")
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
return name, tool_response
async def complete():
@ -326,6 +349,21 @@ class Agent(LLM, ToolBase):
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
def build_observation(tool_call_res: list[tuple]) -> str:
"""
Build a Observation from tool call results.
No LLM involved.
"""
if not tool_call_res:
return ""
lines = ["Observation:"]
for name, result in tool_call_res:
lines.append(f"[{name} result]")
lines.append(str(result))
return "\n".join(lines)
def append_user_content(hist, content):
if hist[-1]["role"] == "user":
hist[-1]["content"] += content
@ -333,7 +371,7 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content})
st = timer()
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
task_desc = build_task_desc(prompt, user_request, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
@ -364,7 +402,7 @@ class Agent(LLM, ToolBase):
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer()
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
reflection = build_observation(results)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
@ -393,6 +431,135 @@ Respond immediately with your final comprehensive answer.
async for txt, tkcnt in complete():
yield txt, tkcnt
# async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
# token_count = 0
# tool_metas = self.tool_meta
# hist = deepcopy(history)
# last_calling = ""
# if len(hist) > 3:
# st = timer()
# user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
# self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
# else:
# user_request = history[-1]["content"]
# async def use_tool_async(name, args):
# nonlocal hist, use_tools, last_calling
# logging.info(f"{last_calling=} == {name=}")
# last_calling = name
# tool_response = await self.toolcall_session.tool_call_async(name, args)
# use_tools.append({
# "name": name,
# "arguments": args,
# "results": tool_response
# })
# # self.callback("add_memory", {}, "...")
# #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
# return name, tool_response
# async def complete():
# nonlocal hist
# need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
# if schema_prompt:
# need2cite = False
# cited = False
# if hist and hist[0]["role"] == "system":
# if schema_prompt:
# hist[0]["content"] += "\n" + schema_prompt
# if need2cite and len(hist) < 7:
# hist[0]["content"] += citation_prompt()
# cited = True
# yield "", token_count
# _hist = hist
# if len(hist) > 12:
# _hist = [hist[0], hist[1], *hist[-10:]]
# entire_txt = ""
# async for delta_ans in self._generate_streamly(_hist):
# if not need2cite or cited:
# yield delta_ans, 0
# entire_txt += delta_ans
# if not need2cite or cited:
# return
# st = timer()
# txt = ""
# async for delta_ans in self._gen_citations_async(entire_txt):
# if self.check_if_canceled("Agent streaming"):
# return
# yield delta_ans, 0
# txt += delta_ans
# self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
# def append_user_content(hist, content):
# if hist[-1]["role"] == "user":
# hist[-1]["content"] += content
# else:
# hist.append({"role": "user", "content": content})
# st = timer()
# task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
# self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
# for _ in range(self._param.max_rounds + 1):
# if self.check_if_canceled("Agent streaming"):
# return
# response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# # self.callback("next_step", {}, str(response)[:256]+"...")
# token_count += tk or 0
# hist.append({"role": "assistant", "content": response})
# try:
# functions = json_repair.loads(re.sub(r"```.*", "", response))
# if not isinstance(functions, list):
# raise TypeError(f"List should be returned, but `{functions}`")
# for f in functions:
# if not isinstance(f, dict):
# raise TypeError(f"An object type should be returned, but `{f}`")
# tool_tasks = []
# for func in functions:
# name = func["name"]
# args = func["arguments"]
# if name == COMPLETE_TASK:
# append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
# async for txt, tkcnt in complete():
# yield txt, tkcnt
# return
# tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
# results = await asyncio.gather(*tool_tasks) if tool_tasks else []
# st = timer()
# reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
# append_user_content(hist, reflection)
# self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
# except Exception as e:
# logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
# e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
# append_user_content(hist, str(e))
# logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
# final_instruction = f"""
# {user_request}
# IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
# Instructions:
# 1. SYNTHESIZE all information collected during this conversation
# 2. Provide a COMPLETE response using existing data - do not suggest additional research
# 3. Structure your response as a FINAL DELIVERABLE, not a plan
# 4. If information is incomplete, state what you found and provide the best analysis possible with available data
# 5. DO NOT mention conversation limits or suggest further steps
# 6. Focus on delivering VALUE with the information already gathered
# Respond immediately with your final comprehensive answer.
# """
# if self.check_if_canceled("Agent final instruction"):
# return
# append_user_content(hist, final_instruction)
# async for txt, tkcnt in complete():
# yield txt, tkcnt
async def _gen_citations_async(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}

View File

@ -56,7 +56,6 @@ class LLMParam(ComponentParamBase):
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
self.check_empty(self.llm_id, "[Agent] LLM")
self.check_empty(self.sys_prompt, "[Agent] System prompt")
self.check_empty(self.prompts, "[Agent] User prompt")
def gen_conf(self):

View File

@ -113,6 +113,10 @@ class LoopItem(ComponentBase, ABC):
return len(var) == 0
elif operator == "not empty":
return len(var) > 0
elif var is None:
if operator == "empty":
return True
return False
raise Exception(f"Invalid operator: {operator}")

View File

@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.memory_service import MemoryService
from api.db.joint_services.memory_message_service import query_message
from common import settings
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
class RetrievalParam(ToolParamBase):
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.memory_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
class Retrieval(ToolBase, ABC):
component_name = "Retrieval"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
async def _retrieve_kb(self, query_text: str):
kb_ids: list[str] = []
for id in self._param.kb_ids:
if id.find("@") < 0:
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars)
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
doc_ids=[]
if self._param.meta_data_filter!={}:
doc_ids = []
if self._param.meta_data_filter != {}:
metas = DocumentService.get_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -248,6 +246,54 @@ class Retrieval(ToolBase, ABC):
return form_cnt
async def _retrieve_memory(self, query_text: str):
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
raise Exception("No memory is selected.")
embd_names = list({memory.embd_id for memory in memory_list})
assert len(embd_names) == 1, "Memory use different embedding models."
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
# query message
message_list = query_message({"memory_id": memory_ids}, {
"query": query,
"similarity_threshold": self._param.similarity_threshold,
"keywords_similarity_weight": self._param.keywords_similarity_weight,
"top_n": self._param.top_n
})
print(f"found {len(message_list)} messages.")
if not message_list:
self.set_output("formalized_content", self._param.empty_response)
return
formated_content = "\n".join(memory_prompt(message_list, 200000))
# set formalized_content output
self.set_output("formalized_content", formated_content)
print(f"formated_content {formated_content}")
return formated_content
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
print(f"debug retrieval, query is {kwargs.get('query')}.", flush=True)
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
if self._param.kb_ids:
return await self._retrieve_kb(kwargs["query"])
elif self._param.memory_ids:
return await self._retrieve_memory(kwargs["query"])
else:
self.set_output("formalized_content", self._param.empty_response)
return
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))

View File

@ -108,7 +108,7 @@ def _load_user():
authorization = request.headers.get("Authorization")
g.user = None
if not authorization:
return
return None
try:
access_token = str(jwt.loads(authorization))

View File

@ -192,7 +192,7 @@ async def rerun():
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = ""
doc["chunk_num"] = 0

View File

@ -76,6 +76,7 @@ async def list_chunk():
"image_id": sres.field[id].get("img_id", ""),
"available_int": int(sres.field[id].get("available_int", 1)),
"positions": sres.field[id].get("position_int", []),
"doc_type_kwd": sres.field[id].get("doc_type_kwd")
}
assert isinstance(d["positions"], list)
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
@ -176,10 +177,9 @@ async def set():
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
# update image
image_id = req.get("img_id")
bkt, name = image_id.split("-")
image_base64 = req.get("image_base64", None)
if image_base64:
bkt, name = req.get("img_id", "-").split("-")
image_binary = base64.b64decode(image_base64)
settings.STORAGE_IMPL.put(bkt, name, image_binary)
return get_json_result(data=True)

View File

@ -447,6 +447,26 @@ async def metadata_update():
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "metadata")
async def update_metadata_setting():
req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
e, doc = DocumentService.get_by_id(doc.id)
if not e:
return get_data_error_result(message="Document not found!")
return get_json_result(data=doc.to_dict())
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
# @login_required
def thumbnails():
@ -564,7 +584,7 @@ async def run():
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
@ -615,7 +635,7 @@ async def rename():
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
@ -696,7 +716,7 @@ async def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None

View File

@ -39,9 +39,9 @@ from api.utils.api_utils import get_json_result
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user
@ -97,6 +97,19 @@ async def update():
code=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
# Rename folder in FileService
if e and req["name"].lower() != kb.name.lower():
FileService.filter_update(
[
File.tenant_id == kb.tenant_id,
File.source_type == FileSource.KNOWLEDGEBASE,
File.type == "folder",
File.name == kb.name,
],
{"name": req["name"]},
)
if not e:
return get_data_error_result(
message="Can't find this dataset!")
@ -150,6 +163,21 @@ async def update():
return server_error_response(e)
@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821
@login_required
@validate_request("kb_id", "metadata")
async def update_metadata_setting():
req = await get_request_json()
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e:
return get_data_error_result(
message="Database error (Knowledgebase rename)!")
kb = kb.to_dict()
kb["parser_config"]["metadata"] = req["metadata"]
KnowledgebaseService.update_by_id(kb["id"], kb)
return get_json_result(data=kb)
@manager.route('/detail', methods=['GET']) # noqa: F821
@login_required
def detail():
@ -245,13 +273,19 @@ async def rm():
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
[
File.tenant_id == kbs[0].tenant_id,
File.source_type == FileSource.KNOWLEDGEBASE,
File.type == "folder",
File.name == kbs[0].name,
]
)
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
@ -352,7 +386,7 @@ def knowledge_graph(kb_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
@ -824,11 +858,11 @@ async def check_embedding():
index_nm = search.index_name(tenant_id)
res0 = docStoreConn.search(
selectFields=[], highlightFields=[],
select_fields=[], highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
total = docStoreConn.get_total(res0)
if total <= 0:
@ -840,14 +874,14 @@ async def check_embedding():
for off in offsets:
res1 = docStoreConn.search(
selectFields=list(base_fields),
highlightFields=[],
select_fields=list(base_fields),
highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
ids = docStoreConn.get_chunk_ids(res1)
ids = docStoreConn.get_doc_ids(res1)
if not ids:
continue

View File

@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
@manager.route("/factories", methods=["GET"]) # noqa: F821
@ -208,70 +208,83 @@ async def add_llm():
msg = ""
mdl_nm = llm["llm_name"].split("___")[0]
extra = {"provider": factory}
if llm["model_type"] == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
key=llm["api_key"],
model_name=mdl_nm,
base_url=llm["api_base"],
**extra,
)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.RERANK:
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
try:
image_data = test_image
m, tc = mdl.describe(image_data)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.TTS:
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
try:
for resp in mdl.tts("Hello~ RAGFlower!"):
pass
except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
ok, reason = mdl.check_available()
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
else:
# TODO: check other type of models
pass
model_type = llm["model_type"]
model_api_key = llm["api_key"]
model_base_url = llm.get("api_base", "")
match model_type:
case LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
case LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
key=model_api_key,
model_name=mdl_nm,
base_url=model_base_url,
**extra,
)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.RERANK.value:
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.IMAGE2TEXT.value:
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
image_data = test_image
m, tc = mdl.describe(image_data)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.TTS.value:
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
for resp in mdl.tts("Hello~ RAGFlower!"):
pass
except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
ok, reason = mdl.check_available()
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.SPEECH2TEXT:
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
try:
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
# TODO: check the availability
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case _:
raise RuntimeError(f"Unknown model type: {model_type}")
if msg:
return get_data_error_result(message=msg)

View File

@ -326,7 +326,6 @@ async def list_tools() -> Response:
try:
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e:
tools = []
return get_data_error_result(message=f"MCP list tools error: {e}")
results[server_key] = []
@ -428,7 +427,6 @@ async def test_mcp() -> Response:
try:
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e:
tools = []
return get_data_error_result(message=f"Test MCP error: {e}")
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue

View File

@ -20,10 +20,12 @@ from api.apps import login_required, current_user
from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
not_allowed_parameters
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService
from common.constants import MemoryType, RetCode, ForgettingPolicy
@ -57,7 +59,6 @@ async def create_memory():
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
else:
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
@ -124,7 +125,7 @@ async def update_memory(memory_id):
return get_json_result(message=True, data=memory_dict)
try:
MemoryService.update_memory(memory_id, to_update)
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
updated_memory = MemoryService.get_by_memory_id(memory_id)
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
@ -133,7 +134,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
@ -141,13 +142,14 @@ async def delete_memory(memory_id):
return get_json_result(message=True, code=RetCode.NOT_FOUND)
try:
MemoryService.delete_memory(memory_id)
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True)
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("", methods=["GET"]) # noqa: F821
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
args = request.args
@ -183,3 +185,26 @@ async def get_memory_config(memory_id):
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_detail(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)

169
api/apps/messages_app.py Normal file
View File

@ -0,0 +1,169 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from quart import request
from api.apps import login_required
from api.db.services.memory_service import MemoryService
from common.time_utils import current_timestamp, timestamp_to_date
from memory.services.messages import MessageService
from api.db.joint_services import memory_message_service
from api.db.joint_services.memory_message_service import query_message
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from common.constants import RetCode
@manager.route("", methods=["POST"]) # noqa: F821
@login_required
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
async def add_message():
req = await get_request_json()
memory_ids = req["memory_id"]
agent_id = req["agent_id"]
session_id = req["session_id"]
user_id = req["user_id"] if req.get("user_id") else ""
user_input = req["user_input"]
agent_response = req["agent_response"]
res = []
for memory_id in memory_ids:
success, msg = await memory_message_service.save_to_memory(
memory_id,
{
"user_id": user_id,
"agent_id": agent_id,
"session_id": session_id,
"user_input": user_input,
"agent_response": agent_response
}
)
res.append({
"memory_id": memory_id,
"success": success,
"message": msg
})
if all([r["success"] for r in res]):
return get_json_result(message="Successfully added to memories.")
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def forget_message(memory_id: str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
forget_time = timestamp_to_date(current_timestamp())
update_succeed = MessageService.update_message(
{"memory_id": memory_id, "message_id": int(message_id)},
{"forget_at": forget_time},
memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("status")
async def update_message(memory_id: str, message_id: int):
req = await get_request_json()
status = req["status"]
if not isinstance(status, bool):
return get_error_argument_result("Status must be a boolean.")
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
@manager.route("/search", methods=["GET"]) # noqa: F821
@login_required
async def search_message():
args = request.args
print(args, flush=True)
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
if empty_fields:
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
memory_ids = args.getlist("memory_id")
query = args.get("query")
similarity_threshold = float(args.get("similarity_threshold", 0.2))
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
top_n = int(args.get("top_n", 5))
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
filter_dict = {
"memory_id": memory_ids,
"agent_id": agent_id,
"session_id": session_id
}
params = {
"query": query,
"similarity_threshold": similarity_threshold,
"keywords_similarity_weight": keywords_similarity_weight,
"top_n": top_n
}
res = query_message(filter_dict, params)
return get_json_result(message=True, data=res)
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def get_messages():
args = request.args
memory_ids = args.getlist("memory_id")
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
limit = int(args.get("limit", 10))
if not memory_ids:
return get_error_argument_result("memory_ids is required.")
memory_list = MemoryService.get_by_ids(memory_ids)
uids = [memory.tenant_id for memory in memory_list]
res = MessageService.get_recent_messages(
uids,
memory_ids,
agent_id,
session_id,
limit
)
return get_json_result(message=True, data=res)
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
@login_required
async def get_message_content(memory_id:str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
if res:
return get_json_result(message=True, data=res)
else:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.")

View File

@ -326,7 +326,6 @@ async def webhook(agent_id: str):
secret = jwt_cfg.get("secret")
if not secret:
raise Exception("JWT secret not configured")
required_claims = jwt_cfg.get("required_claims", [])
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
@ -750,7 +749,7 @@ async def webhook(agent_id: str):
async def sse():
nonlocal canvas
contents: list[str] = []
status = 200
try:
async for ans in canvas.run(
query="",
@ -765,6 +764,8 @@ async def webhook(agent_id: str):
content = "</think>"
if content:
contents.append(content)
if ans["event"] == "message_end":
status = int(ans["data"].get("status", status))
if is_test:
append_webhook_trace(
agent_id,
@ -782,7 +783,11 @@ async def webhook(agent_id: str):
}
)
final_content = "".join(contents)
yield json.dumps(final_content, ensure_ascii=False)
return {
"message": final_content,
"success": True,
"code": status,
}
except Exception as e:
if is_test:
@ -804,10 +809,14 @@ async def webhook(agent_id: str):
"success": False,
}
)
yield json.dumps({"code": 500, "message": str(e)}, ensure_ascii=False)
return {"code": 400, "message": str(e),"success":False}
resp = Response(sse(), mimetype="application/json")
return resp
result = await sse()
return Response(
json.dumps(result),
status=result["code"],
mimetype="application/json",
)
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821

View File

@ -287,7 +287,7 @@ def list_chat(tenant_id):
chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name)
if not chats:
return get_result(data=[])
list_assts = []
list_assistants = []
key_mapping = {
"parameters": "variables",
"prologue": "opener",
@ -321,5 +321,5 @@ def list_chat(tenant_id):
del res["kb_ids"]
res["datasets"] = kb_list
res["avatar"] = res.pop("icon")
list_assts.append(res)
return get_result(data=list_assts)
list_assistants.append(res)
return get_result(data=list_assistants)

View File

@ -495,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):

View File

@ -1080,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:

View File

@ -205,7 +205,8 @@ async def create(tenant_id):
if not FileService.is_parent_folder_exist(pf_id):
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
if FileService.query(name=req["name"], parent_id=pf_id):
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
return get_json_result(data=False, message="Duplicated folder name in the same folder.",
code=RetCode.CONFLICT)
if input_file_type == FileType.FOLDER.value:
file_type = FileType.FOLDER.value
@ -565,11 +566,13 @@ async def rename(tenant_id):
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
file.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
return get_json_result(data=False, message="The extension of file can't be changed",
code=RetCode.BAD_REQUEST)
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
if existing_file.name == req["name"]:
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
return get_json_result(data=False, message="Duplicated file name in the same folder.",
code=RetCode.CONFLICT)
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
@ -631,9 +634,10 @@ async def get(tenant_id, file_id):
except Exception as e:
return server_error_response(e)
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
@token_required
async def download_attachment(tenant_id,attachment_id):
async def download_attachment(tenant_id, attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
@ -645,6 +649,7 @@ async def download_attachment(tenant_id,attachment_id):
except Exception as e:
return server_error_response(e)
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required
async def move(tenant_id):

View File

@ -448,7 +448,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
@token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await get_request_json()
tiktokenenc = tiktoken.get_encoding("cl100k_base")
tiktoken_encode = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
return get_error_data_result("You must provide at least one message.")
@ -456,7 +456,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
return get_error_data_result(f"You don't own the agent {agent_id}")
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages)
if not filtered_messages:
return jsonify(
get_data_openai(
@ -464,7 +464,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
content="No valid messages found (user or assistant).",
finish_reason="stop",
model=req.get("model", ""),
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")),
prompt_tokens=prompt_tokens,
)
)
@ -501,6 +501,8 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
):
return jsonify(response)
return None
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
@ -920,6 +922,7 @@ async def chatbot_completions(dialog_id):
async for answer in iframe_completion(dialog_id, **req):
return get_result(data=answer)
return None
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
async def chatbots_inputs(dialog_id):
@ -967,6 +970,7 @@ async def agent_bot_completions(agent_id):
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
return get_result(data=answer)
return None
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
async def begin_inputs(agent_id):

View File

@ -660,7 +660,7 @@ def user_register(user_id, user):
tenant_llm = get_init_tenant_llm(user_id)
if not UserService.save(**user):
return
return None
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)

View File

@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@ -169,6 +170,8 @@ def init_web_data():
# init_superuser()
add_graph_templates()
init_message_id_sequence()
init_memory_size_cache()
logging.info("init web data success:{}".format(time.time() - start_time))

View File

@ -0,0 +1,233 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import List
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType
from common.doc_store.doc_store_base import FusionExpr
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.utils.memory_utils import get_memory_type_human
from memory.services.messages import MessageService
from memory.services.query import MsgTextQuery, get_vector
from memory.utils.prompt_util import PromptAssembler
from memory.utils.msg_util import get_json_result_from_llm_response
from rag.utils.redis_conn import REDIS_CONN
async def save_to_memory(memory_id: str, message_dict: dict):
"""
:param memory_id:
:param message_dict: {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
"""
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return False, f"Memory '{memory_id}' not found."
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", "")
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
message_list = [{
"message_id": raw_message_id,
"message_type": MemoryType.RAW.name.lower(),
"source_id": 0,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
"valid_at": timestamp_to_date(current_timestamp()),
"invalid_at": None,
"forget_at": None,
"status": True
}, *[{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
"message_type": content["message_type"],
"source_id": raw_message_id,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": content["content"],
"valid_at": content["valid_at"],
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
"forget_at": None,
"status": True
} for content in extracted_content]]
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(tenant_id, memory_id):
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "fifo":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
return True, "Message saved successfully."
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
if not system_prompt:
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
conversation_time = timestamp_to_date(current_timestamp())
user_prompts = []
if user_prompt:
user_prompts.append({"role": "user", "content": user_prompt})
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id)
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
res_json = get_json_result_from_llm_response(res)
return [{
"content": extracted_content["content"],
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
"message_type": message_type
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
def query_message(filter_dict: dict, params: dict):
"""
:param filter_dict: {
"memory_id": List[str],
"agent_id": optional
"session_id": optional
}
:param params: {
"query": question str,
"similarity_threshold": float,
"keywords_similarity_weight": float,
"top_n": int
}
"""
memory_ids = filter_dict["memory_id"]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
return []
condition_dict = {k: v for k, v in filter_dict.items() if v}
uids = [memory.tenant_id for memory in memory_list]
question = params["query"]
question = question.strip()
memory = memory_list[0]
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
def init_message_id_sequence():
message_id_redis_key = "id_generator:memory"
if REDIS_CONN.exist(message_id_redis_key):
current_max_id = REDIS_CONN.get(message_id_redis_key)
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
else:
max_id = 1
exist_memory_list = MemoryService.get_all_memory()
if not exist_memory_list:
REDIS_CONN.set(message_id_redis_key, max_id)
else:
max_id = MessageService.get_max_message_id(
uid_list=[m.tenant_id for m in exist_memory_list],
memory_ids=[m.id for m in exist_memory_list]
)
REDIS_CONN.set(message_id_redis_key, max_id)
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
def get_memory_size_cache(memory_id: str, uid: str):
redis_key = f"memory_{memory_id}"
if REDIS_CONN.exists(redis_key):
return REDIS_CONN.get(redis_key)
else:
memory_size_map = MessageService.calculate_memory_size(
[memory_id],
[uid]
)
memory_size = memory_size_map.get(memory_id, 0)
set_memory_size_cache(memory_id, memory_size)
return memory_size
def set_memory_size_cache(memory_id: str, size: int):
redis_key = f"memory_{memory_id}"
return REDIS_CONN.set(redis_key, size)
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, current_value + size)
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, max(current_value - size, 0))
def init_memory_size_cache():
memory_list = MemoryService.get_all_memory()
if not memory_list:
logging.info("No memory found, no need to init memory size.")
else:
memory_size_map = MessageService.calculate_memory_size(
memory_ids=[m.id for m in memory_list],
uid_list=[m.tenant_id for m in memory_list],
)
for memory in memory_list:
memory_size = memory_size_map.get(memory.id, 0)
set_memory_size_cache(memory.id, memory_size)
logging.info("Memory size cache init done.")

View File

@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from api.db.services.memory_service import MemoryService
from memory.services.messages import MessageService
from rag.nlp import search
from common.constants import ActiveEnum
from common import settings
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
# step1.3 delete own tenant
# step1.3 delete memory and messages
user_memory = MemoryService.get_by_tenant_id(tenant_id)
if user_memory:
for memory in user_memory:
if MessageService.has_index(tenant_id, memory.id):
MessageService.delete_index(tenant_id, memory.id)
done_msg += " Deleted memory index."
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
done_msg += f"Deleted {memory_delete_res} memory datasets."
# step1.4 delete own tenant
tenant_delete_res = TenantService.delete_by_id(tenant_id)
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
# step2 delete user-tenant relation

View File

@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
logging.exception(e)
return False, None
@classmethod
@DB.connection_context()
def get_basic_info_by_canvas_ids(cls, canvas_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.user_id,
cls.model.title,
cls.model.permission,
cls.model.canvas_category
]
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,

View File

@ -406,7 +406,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
aggs=True,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
@ -769,7 +769,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=False,
aggs=True,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs)
)

View File

@ -33,12 +33,13 @@ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTena
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common.metadata_utils import dedupe_list
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.doc_store.doc_store_base import OrderByExpr
from common import settings
@ -345,7 +346,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -696,10 +697,14 @@ class DocumentService(CommonService):
for k,v in r.meta_fields.items():
if k not in meta:
meta[k] = {}
v = str(v)
if v not in meta[k]:
meta[k][v] = []
meta[k][v].append(doc_id)
if not isinstance(v, list):
v = [v]
for vv in v:
if vv not in meta[k]:
if isinstance(vv, list) or isinstance(vv, dict):
continue
meta[k][vv] = []
meta[k][vv].append(doc_id)
return meta
@classmethod
@ -797,7 +802,10 @@ class DocumentService(CommonService):
match_provided = "match" in upd
if isinstance(meta[key], list):
if not match_provided:
meta[key] = new_value
if isinstance(new_value, list):
meta[key] = dedupe_list(new_value)
else:
meta[key] = new_value
changed = True
else:
match_value = upd.get("match")
@ -810,7 +818,7 @@ class DocumentService(CommonService):
else:
new_list.append(item)
if replaced:
meta[key] = new_list
meta[key] = dedupe_list(new_list)
changed = True
else:
if not match_provided:
@ -1230,8 +1238,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -425,6 +425,7 @@ class KnowledgebaseService(CommonService):
# Update parser_config (always override with validated default/merged config)
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
payload["parser_config"]["llm_id"] = _t.llm_id
return True, payload

View File

@ -15,7 +15,6 @@
#
from typing import List
from api.apps import current_user
from api.db.db_models import DB, Memory, User
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
from api.constants import MEMORY_NAME_LIMIT
from common.misc_utils import get_uuid
from common.time_utils import get_format_time, current_timestamp
from memory.utils.prompt_util import PromptAssembler
class MemoryService(CommonService):
@ -34,6 +34,17 @@ class MemoryService(CommonService):
def get_by_memory_id(cls, memory_id: str):
return cls.model.select().where(cls.model.id == memory_id).first()
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, tenant_id: str):
return cls.model.select().where(cls.model.tenant_id == tenant_id)
@classmethod
@DB.connection_context()
def get_all_memory(cls):
memory_list = cls.model.select()
return list(memory_list)
@classmethod
@DB.connection_context()
def get_with_owner_name_by_id(cls, memory_id: str):
@ -53,7 +64,9 @@ class MemoryService(CommonService):
cls.model.forgetting_policy,
cls.model.temperature,
cls.model.system_prompt,
cls.model.user_prompt
cls.model.user_prompt,
cls.model.create_date,
cls.model.create_time
]
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
cls.model.id == memory_id
@ -72,7 +85,9 @@ class MemoryService(CommonService):
cls.model.memory_type,
cls.model.storage_type,
cls.model.permissions,
cls.model.description
cls.model.description,
cls.model.create_time,
cls.model.create_date
]
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
if filter_dict.get("tenant_id"):
@ -102,6 +117,8 @@ class MemoryService(CommonService):
if len(memory_name) > MEMORY_NAME_LIMIT:
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
timestamp = current_timestamp()
format_time = get_format_time()
# build create dict
memory_info = {
"id": get_uuid(),
@ -110,10 +127,11 @@ class MemoryService(CommonService):
"tenant_id": tenant_id,
"embd_id": embd_id,
"llm_id": llm_id,
"create_time": current_timestamp(),
"create_date": get_format_time(),
"update_time": current_timestamp(),
"update_date": get_format_time(),
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
"create_time": timestamp,
"create_date": format_time,
"update_time": timestamp,
"update_date": format_time,
}
obj = cls.model(**memory_info).save(force_insert=True)
@ -126,7 +144,7 @@ class MemoryService(CommonService):
@classmethod
@DB.connection_context()
def update_memory(cls, memory_id: str, update_dict: dict):
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
if not update_dict:
return 0
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
@ -135,7 +153,7 @@ class MemoryService(CommonService):
update_dict["name"] = duplicate_name(
cls.query,
name=update_dict["name"],
tenant_id=current_user.id
tenant_id=tenant_id
)
update_dict.update({
"update_time": current_timestamp(),

View File

@ -97,7 +97,7 @@ class TenantLLMService(CommonService):
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
mdlnm = tenant.asr_id if not llm_name else llm_name
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id if not llm_name else llm_name
elif llm_type == LLMType.CHAT.value:

View File

@ -54,6 +54,7 @@ class RetCode(IntEnum, CustomEnum):
SERVER_ERROR = 500
FORBIDDEN = 403
NOT_FOUND = 404
CONFLICT = 409
class StatusEnum(Enum):

View File

@ -64,15 +64,23 @@ class BlobStorageConnector(LoadConnector, PollConnector):
elif self.bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
if authentication_method == "access_key":
if not all(
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Amazon S3")
elif authentication_method == "iam_role":
if not credentials.get("aws_role_arn"):
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
elif authentication_method == "assume_role":
pass
else:
raise ConnectorMissingCredentialError("Unsupported S3 authentication method")
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
if not all(
@ -293,4 +301,4 @@ if __name__ == "__main__":
except ConnectorMissingCredentialError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
print(f"An unexpected error occurred: {e}")

View File

@ -254,18 +254,21 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
elif bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
region_name = credentials.get("region") or None
if authentication_method == "access_key":
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
region_name=region_name,
)
return session.client("s3")
return session.client("s3", region_name=region_name)
elif authentication_method == "iam_role":
role_arn = credentials["aws_role_arn"]
def _refresh_credentials() -> dict[str, str]:
sts_client = boto3.client("sts")
sts_client = boto3.client("sts", region_name=credentials.get("region") or None)
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName=f"onyx_blob_storage_{int(datetime.now().timestamp())}",
@ -285,11 +288,11 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
)
botocore_session = get_session()
botocore_session._credentials = refreshable
session = boto3.Session(botocore_session=botocore_session)
return session.client("s3")
session = boto3.Session(botocore_session=botocore_session, region_name=region_name)
return session.client("s3", region_name=region_name)
elif authentication_method == "assume_role":
return boto3.client("s3")
return boto3.client("s3", region_name=region_name)
else:
raise ValueError("Invalid authentication method for S3.")

View File

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
@ -22,7 +21,6 @@ DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
@ -55,14 +53,13 @@ class SparseVector:
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
class MatchTextExpr:
def __init__(
self,
fields: list[str],
matching_text: str,
topn: int,
extra_options: dict = dict(),
extra_options: dict | None = None,
):
self.fields = fields
self.matching_text = matching_text
@ -70,7 +67,7 @@ class MatchTextExpr(ABC):
self.extra_options = extra_options
class MatchDenseExpr(ABC):
class MatchDenseExpr:
def __init__(
self,
vector_column_name: str,
@ -78,7 +75,7 @@ class MatchDenseExpr(ABC):
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
extra_options: dict | None = None,
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
@ -88,7 +85,7 @@ class MatchDenseExpr(ABC):
self.extra_options = extra_options
class MatchSparseExpr(ABC):
class MatchSparseExpr:
def __init__(
self,
vector_column_name: str,
@ -104,7 +101,7 @@ class MatchSparseExpr(ABC):
self.opt_params = opt_params
class MatchTensorExpr(ABC):
class MatchTensorExpr:
def __init__(
self,
column_name: str,
@ -120,7 +117,7 @@ class MatchTensorExpr(ABC):
self.extra_option = extra_option
class FusionExpr(ABC):
class FusionExpr:
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
@ -129,7 +126,8 @@ class FusionExpr(ABC):
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
class OrderByExpr:
def __init__(self):
self.fields = list()
def asc(self, field: str):
@ -141,13 +139,14 @@ class OrderByExpr(ABC):
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
def db_type(self) -> str:
"""
Return the type of the database.
"""
@ -165,21 +164,21 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, index_name: str, dataset_id: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
def index_exist(self, index_name: str, dataset_id: str) -> bool:
"""
Check if an index with given name exists
"""
@ -191,16 +190,16 @@ class DocStoreConnection(ABC):
@abstractmethod
def search(
self, selectFields: list[str],
highlightFields: list[str],
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str|list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
@ -209,28 +208,28 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
def get(self, data_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, rows: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
@ -245,7 +244,7 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_chunk_ids(self, res):
def get_doc_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
@ -253,18 +252,18 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_highlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], field_name: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_aggregation(self, res, fieldnm: str):
def get_aggregation(self, res, field_name: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
def sql(self, sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""

View File

@ -0,0 +1,326 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
import time
import os
from abc import abstractmethod
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import Index
from elastic_transport import ConnectionTimeout
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
from rag.nlp import is_english, rag_tokenizer
from common import settings
ATTEMPT_TIME = 2
class ESConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
self.logger = logging.getLogger(logger_name)
self.info = {}
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
self.logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
self.logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def db_type(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
if self.index_exist(index_name, dataset_id):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=index_name,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
self.logger.exception("ESConnection.createIndex error %s" % index_name)
def delete_idx(self, index_name: str, dataset_id: str):
if len(dataset_id) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=index_name, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
s = Index(index_name, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(e)
break
return False
"""
CRUD operations
"""
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
doc = res["_source"]
doc["id"] = doc_id
return doc
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
@abstractmethod
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def _get_source(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
@abstractmethod
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res, keywords: list[str], field_name: str):
ans = {}
for d in res["hits"]["hits"]:
highlights = d.get("highlight")
if not highlights:
continue
txt = "...".join([a for a in list(highlights.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][field_name]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
return ans
def get_aggregation(self, res, field_name: str):
agg_field = "aggs_" + field_name
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
buckets = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in buckets]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
self.logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
self.logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None

View File

@ -0,0 +1,451 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
import json
import time
from abc import abstractmethod
import infinity
from infinity.common import ConflictType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
import pandas as pd
from common.file_utils import get_project_base_directory
from rag.nlp import is_english
from common import settings
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
class InfinityConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"):
self.dbName = settings.INFINITY.get("db_name", "default_db")
self.mapping_file_name = mapping_file_name
self.logger = logging.getLogger(logger_name)
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
for _ in range(24):
try:
conn_pool = ConnectionPool(infinity_uri, max_size=4)
inf_conn = conn_pool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = conn_pool
conn_pool.release_conn(inf_conn)
break
conn_pool.release_conn(inf_conn)
self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
self.logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
"""
Dataframe and fields convert
"""
@staticmethod
@abstractmethod
def field_keyword(field_name: str):
# judge keyword or not, such as "*_kwd" tag-like columns.
raise NotImplementedError("Not implemented")
@abstractmethod
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
# rm _kwd, _tks, _sm_tks, _with_weight suffix in field name.
raise NotImplementedError("Not implemented")
@staticmethod
@abstractmethod
def convert_matching_field(field_weight_str: str) -> str:
# convert matching field to
raise NotImplementedError("Not implemented")
@staticmethod
def list2str(lst: str | list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
def exists(cln):
nonlocal columns
assert cln in columns, f"'{cln}' should be in '{columns}'."
ty, de = columns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if self.field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
@staticmethod
def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in select_fields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
"""
Database operations
"""
def db_type(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}")
return True
def delete_idx(self, index_name: str, dataset_id: str):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY dropped table {table_name}")
def index_exist(self, index_name: str, dataset_id: str) -> bool:
table_name = f"{index_name}_{dataset_id}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
self.logger.warning(f"INFINITY indexExist {str(e)}")
return False
"""
CRUD operations
"""
@abstractmethod
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{dataset_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = self.equivalent_condition_to_str(condition, table_instance)
self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
@abstractmethod
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if field_name not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[field_name][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
if txt_list:
ans[id] = "...".join(txt_list)
else:
ans[id] = txt
return ans
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or field_name not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[field_name]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if field_name == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -75,9 +75,12 @@ def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(
def log_exception(e, *args):
logging.exception(e)
for a in args:
if hasattr(a, "text"):
logging.error(a.text)
raise Exception(a.text)
else:
logging.error(str(a))
try:
text = getattr(a, "text")
except Exception:
text = None
if text is not None:
logging.error(text)
raise Exception(text)
logging.error(str(a))
raise e

View File

@ -44,21 +44,27 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
def filter_out(v2docs, operator, value):
ids = []
for input, docids in v2docs.items():
if operator in ["=", "", ">", "<", "", ""]:
try:
if isinstance(input, list):
input = input[0]
input = float(input)
value = float(value)
except Exception:
input = str(input)
value = str(value)
pass
if isinstance(input, str):
input = input.lower()
if isinstance(value, str):
value = value.lower()
for conds in [
(operator == "contains", str(value).lower() in str(input).lower()),
(operator == "not contains", str(value).lower() not in str(input).lower()),
(operator == "in", str(input).lower() in str(value).lower()),
(operator == "not in", str(input).lower() not in str(value).lower()),
(operator == "start with", str(input).lower().startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower())),
(operator == "contains", input in value if not isinstance(input, list) else all([i in value for i in input])),
(operator == "not contains", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
(operator == "in", input in value if not isinstance(input, list) else all([i in value for i in input])),
(operator == "not in", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
(operator == "start with", str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())),
(operator == "empty", not input),
(operator == "not empty", input),
(operator == "=", input == value),
@ -145,6 +151,18 @@ async def apply_meta_data_filter(
return doc_ids
def dedupe_list(values: list) -> list:
seen = set()
deduped = []
for item in values:
key = str(item)
if key in seen:
continue
seen.add(key)
deduped.append(item)
return deduped
def update_metadata_to(metadata, meta):
if not meta:
return metadata
@ -156,11 +174,13 @@ def update_metadata_to(metadata, meta):
return metadata
if not isinstance(meta, dict):
return metadata
for k, v in meta.items():
if isinstance(v, list):
v = [vv for vv in v if isinstance(vv, str)]
if not v:
continue
v = dedupe_list(v)
if not isinstance(v, list) and not isinstance(v, str):
continue
if k not in metadata:
@ -171,6 +191,7 @@ def update_metadata_to(metadata, meta):
metadata[k].extend(v)
else:
metadata[k].append(v)
metadata[k] = dedupe_list(metadata[k])
else:
metadata[k] = v
@ -202,4 +223,4 @@ def metadata_schema(metadata: list|None) -> Dict[str, Any]:
}
json_schema["additionalProperties"] = False
return json_schema
return json_schema

72
common/query_base.py Normal file
View File

@ -0,0 +1,72 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC, abstractmethod
class QueryBase(ABC):
@staticmethod
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
@abstractmethod
def question(self, text, tbl, min_match):
"""
Returns a query object based on the input text, table, and minimum match criteria.
"""
raise NotImplementedError("Not implemented")

View File

@ -39,6 +39,9 @@ from rag.utils.oss_conn import RAGFlowOSS
from rag.nlp import search
import memory.utils.es_conn as memory_es_conn
import memory.utils.infinity_conn as memory_infinity_conn
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
@ -76,9 +79,11 @@ FEISHU_OAUTH = None
OAUTH_CONFIG = None
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
MSG_ENGINE = DOC_ENGINE
docStoreConn = None
msgStoreConn = None
retriever = None
kg_retriever = None
@ -256,6 +261,15 @@ def init_settings():
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global MSG_ENGINE, msgStoreConn
MSG_ENGINE = DOC_ENGINE # use the same engine for message store
if MSG_ENGINE == "elasticsearch":
ES = get_base_config("es", {})
msgStoreConn = memory_es_conn.ESConnection()
elif MSG_ENGINE == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
msgStoreConn = memory_infinity_conn.InfinityConnection()
global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})

View File

@ -14,6 +14,7 @@
# limitations under the License.
import datetime
import logging
import time
def current_timestamp():
@ -123,4 +124,31 @@ def delta_seconds(date_string: str):
3600.0 # If current time is 2024-01-01 13:00:00
"""
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
return (datetime.datetime.now() - dt).total_seconds()
return (datetime.datetime.now() - dt).total_seconds()
def format_iso_8601_to_ymd_hms(time_str: str) -> str:
"""
Convert ISO 8601 formatted string to "YYYY-MM-DD HH:MM:SS" format.
Args:
time_str: ISO 8601 date string (e.g. "2024-01-01T12:00:00Z")
Returns:
str: Date string in "YYYY-MM-DD HH:MM:SS" format
Example:
>>> format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z")
'2024-01-01 12:00:00'
"""
from dateutil import parser
try:
if parser.isoparse(time_str):
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M:%S")
else:
return time_str
except Exception as e:
logging.error(str(e))
return time_str

View File

@ -44,23 +44,23 @@ def total_token_count_from_response(resp):
if resp is None:
return 0
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
try:
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
return resp.usage.total_tokens
except Exception:
pass
except Exception:
pass
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
try:
try:
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
return resp.usage_metadata.total_tokens
except Exception:
pass
except Exception:
pass
if hasattr(resp, "meta") and hasattr(resp.meta, "billed_units") and hasattr(resp.meta.billed_units, "input_tokens"):
try:
return resp.meta.billed_units.input_tokens
except Exception:
pass
try:
if hasattr(resp, "meta") and hasattr(resp.meta, "billed_units") and hasattr(resp.meta.billed_units, "input_tokens"):
return resp.meta.billed_units.input_tokens
except Exception:
pass
if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']:
try:
@ -85,4 +85,3 @@ def total_token_count_from_response(resp):
def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len])

View File

@ -31,6 +31,7 @@
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"source_id": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"n_hop_with_weight": {"type": "varchar", "default": ""},
"mom_with_weight": {"type": "varchar", "default": ""},
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"doc_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"toc_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},

View File

@ -762,6 +762,13 @@
"status": "1",
"rank": "940",
"llm": [
{
"llm_name": "glm-4.7",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "glm-4.5",
"tags": "LLM,CHAT,128K",
@ -1251,6 +1258,12 @@
"status": "1",
"rank": "810",
"llm": [
{
"llm_name": "MiniMax-M2.1",
"tags": "LLM,CHAT,200k",
"max_tokens": 200000,
"model_type": "chat"
},
{
"llm_name": "MiniMax-M2",
"tags": "LLM,CHAT,200k",

View File

@ -0,0 +1,19 @@
{
"id": {"type": "varchar", "default": ""},
"message_id": {"type": "integer", "default": 0},
"message_type_kwd": {"type": "varchar", "default": ""},
"source_id": {"type": "integer", "default": 0},
"memory_id": {"type": "varchar", "default": ""},
"user_id": {"type": "varchar", "default": ""},
"agent_id": {"type": "varchar", "default": ""},
"session_id": {"type": "varchar", "default": ""},
"valid_at": {"type": "varchar", "default": ""},
"valid_at_flt": {"type": "float", "default": 0.0},
"invalid_at": {"type": "varchar", "default": ""},
"invalid_at_flt": {"type": "float", "default": 0.0},
"forget_at": {"type": "varchar", "default": ""},
"forget_at_flt": {"type": "float", "default": 0.0},
"status_int": {"type": "integer", "default": 1},
"zone_id": {"type": "integer", "default": 0},
"content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_ltks"}
}

View File

@ -38,8 +38,8 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
if not tbls:
return []
if not sections:
return tbls
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")

View File

@ -1447,6 +1447,7 @@ class VisionParser(RAGFlowPdfParser):
def __init__(self, vision_model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vision_model = vision_model
self.outlines = []
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
try:

View File

@ -72,7 +72,7 @@ services:
infinity:
profiles:
- infinity
image: infiniflow/infinity:v0.6.11
image: infiniflow/infinity:v0.6.13
volumes:
- infinity_data:/var/infinity
- ./infinity_conf.toml:/infinity_conf.toml

View File

@ -1,5 +1,5 @@
[general]
version = "0.6.11"
version = "0.6.13"
time_zone = "utc-8"
[network]

0
docker/launch_backend_service.sh Normal file → Executable file
View File

View File

@ -0,0 +1,90 @@
---
sidebar_position: 30
slug: /http_request_component
---
# HTTP request component
A component that calls remote services.
---
An **HTTP request** component lets you access remote APIs or services by providing a URL and an HTTP method, and then receive the response. You can customize headers, parameters, proxies, and timeout settings, and use common methods like GET and POST. Its useful for exchanging data with external systems in a workflow.
## Prerequisites
- An accessible remote API or service.
- Add a Token or credentials to the request header, if the target service requires authentication.
## Configurations
### Url
*Required*. The complete request address, for example: http://api.example.com/data.
### Method
The HTTP request method to select. Available options:
- GET
- POST
- PUT
### Timeout
The maximum waiting time for the request, in seconds. Defaults to `60`.
### Headers
Custom HTTP headers can be set here, for example:
```http
{
"Accept": "application/json",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
```
### Proxy
Optional. The proxy server address to use for this request.
### Clean HTML
`Boolean`: Whether to remove HTML tags from the returned results and keep plain text only.
### Parameter
*Optional*. Parameters to send with the HTTP request. Supports key-value pairs:
- To assign a value using a dynamic system variable, set it as Variable.
- To override these dynamic values under certain conditions and use a fixed static value instead, Value is the appropriate choice.
:::tip NOTE
- For GET requests, these parameters are appended to the end of the URL.
- For POST/PUT requests, they are sent as the request body.
:::
#### Example setting
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/http_settings.png)
#### Example response
```html
{ "args": { "App": "RAGFlow", "Query": "How to do?", "Userid": "241ed25a8e1011f0b979424ebc5b108b" }, "headers": { "Accept": "/", "Accept-Encoding": "gzip, deflate, br, zstd", "Cache-Control": "no-cache", "Host": "httpbin.org", "User-Agent": "python-requests/2.32.2", "X-Amzn-Trace-Id": "Root=1-68c9210c-5aab9088580c130a2f065523" }, "origin": "185.36.193.38", "url": "https://httpbin.org/get?Userid=241ed25a8e1011f0b979424ebc5b108b&App=RAGFlow&Query=How+to+do%3F" }
```
### Output
The global variable name for the output of the HTTP request component, which can be referenced by other components in the workflow.
- `Result`: `string` The response returned by the remote service.
## Example
This is a usage example: a workflow sends a GET request from the **Begin** component to `https://httpbin.org/get` via the **HTTP Request_0** component, passes parameters to the server, and finally outputs the result through the **Message_0** component.
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/http_usage.PNG)

View File

@ -23,8 +23,8 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
return [
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar.md5",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar",
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar.md5",
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],
@ -34,8 +34,8 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
return [
"http://archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
"http://ports.ubuntu.com/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar.md5",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar",
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar.md5",
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],

View File

@ -71,18 +71,17 @@ class Extractor:
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
response = ""
for attempt in range(3):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
try:
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
break
except Exception as e:
logging.exception(e)
if attempt == 2:

View File

@ -198,7 +198,7 @@ async def run_graphrag_for_kb(
for d in raw_chunks:
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
if num_tokens_from_string(current_chunk + content) < 4096:
current_chunk += content
else:
if current_chunk:

View File

@ -78,10 +78,6 @@ class GraphExtractor(Extractor):
hint_prompt = self._entity_extract_prompt.format(**self._context_base, input_text=content)
gen_conf = {}
final_result = ""
glean_result = ""
if_loop_result = ""
history = []
logging.info(f"Start processing for {chunk_key}: {content[:25]}...")
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")

View File

@ -24,11 +24,11 @@ from common.misc_utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS
from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
from common.token_utils import num_tokens_from_string
from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
class KGSearch(Dealer):

View File

@ -26,9 +26,9 @@ from networkx.readwrite import json_graph
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
GRAPH_FIELD_SEP = "<SEP>"

View File

@ -96,7 +96,7 @@ ragflow:
infinity:
image:
repository: infiniflow/infinity
tag: v0.6.11
tag: v0.6.13
pullPolicy: IfNotPresent
pullSecrets: []
storage:

0
memory/__init__.py Normal file
View File

View File

240
memory/services/messages.py Normal file
View File

@ -0,0 +1,240 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import List
from common import settings
from common.doc_store.doc_store_base import OrderByExpr, MatchExpr
def index_name(uid: str): return f"memory_{uid}"
class MessageService:
@classmethod
def has_index(cls, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.index_exist(index, memory_id)
@classmethod
def create_index(cls, uid: str, memory_id: str, vector_size: int):
index = index_name(uid)
return settings.msgStoreConn.create_idx(index, memory_id, vector_size)
@classmethod
def delete_index(cls, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.delete_idx(index, memory_id)
@classmethod
def insert_message(cls, messages: List[dict], uid: str, memory_id: str):
index = index_name(uid)
[m.update({
"id": f'{memory_id}_{m["message_id"]}',
"status": 1 if m["status"] else 0
}) for m in messages]
return settings.msgStoreConn.insert(messages, index, memory_id)
@classmethod
def update_message(cls, condition: dict, update_dict: dict, uid: str, memory_id: str):
index = index_name(uid)
if "status" in update_dict:
update_dict["status"] = 1 if update_dict["status"] else 0
return settings.msgStoreConn.update(condition, update_dict, index, memory_id)
@classmethod
def delete_message(cls, condition: dict, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.delete(condition, index, memory_id)
@classmethod
def list_message(cls, uid: str, memory_id: str, agent_ids: List[str]=None, keywords: str=None, page: int=1, page_size: int=50):
index = index_name(uid)
filter_dict = {}
if agent_ids:
filter_dict["agent_id"] = agent_ids
if keywords:
filter_dict["session_id"] = keywords
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status"
],
highlight_fields=[],
condition=filter_dict,
match_expressions=[], order_by=order_by,
offset=(page-1)*page_size, limit=page_size,
index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False
)
total_count = settings.msgStoreConn.get_total(res)
doc_mapping = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status"
])
return {
"message_list": list(doc_mapping.values()),
"total_count": total_count
}
@classmethod
def get_recent_messages(cls, uid_list: List[str], memory_ids: List[str], agent_id: str, session_id: str, limit: int):
index_names = [index_name(uid) for uid in uid_list]
condition_dict = {
"agent_id": agent_id,
"session_id": session_id
}
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status", "content"
],
highlight_fields=[],
condition=condition_dict,
match_expressions=[], order_by=order_by,
offset=0, limit=limit,
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
)
doc_mapping = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id","user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status", "content"
])
return list(doc_mapping.values())
@classmethod
def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions:list[MatchExpr], top_n: int):
index_names = [index_name(uid) for uid in uid_list]
# filter only valid messages by default
if "status" not in condition_dict:
condition_dict["status"] = 1
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at",
"invalid_at", "forget_at", "status", "content"
],
highlight_fields=[],
condition=condition_dict,
match_expressions=match_expressions,
order_by=order_by,
offset=0, limit=top_n,
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
)
docs = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status", "content"
])
return list(docs.values())
@staticmethod
def calculate_message_size(message: dict):
return sys.getsizeof(message["content"]) + sys.getsizeof(message["content_embed"][0]) * len(message["content_embed"])
@classmethod
def calculate_memory_size(cls, memory_ids: List[str], uid_list: List[str]):
index_names = [index_name(uid) for uid in uid_list]
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=["memory_id", "content", "content_embed"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=2000*len(memory_ids),
index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False
)
docs = settings.msgStoreConn.get_fields(res, ["memory_id", "content", "content_embed"])
size_dict = {}
for doc in docs.values():
if size_dict.get(doc["memory_id"]):
size_dict[doc["memory_id"]] += cls.calculate_message_size(doc)
else:
size_dict[doc["memory_id"]] = cls.calculate_message_size(doc)
return size_dict
@classmethod
def pick_messages_to_delete_by_fifo(cls, memory_id: str, uid: str, size_to_delete: int):
select_fields = ["message_id", "content", "content_embed"]
_index_name = index_name(uid)
res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id)
message_list = settings.msgStoreConn.get_fields(res, select_fields)
current_size = 0
ids_to_remove = []
for message in message_list:
if current_size < size_to_delete:
current_size += cls.calculate_message_size(message)
ids_to_remove.append(message["message_id"])
else:
return ids_to_remove, current_size
if current_size >= size_to_delete:
return ids_to_remove, current_size
order_by = OrderByExpr()
order_by.asc("valid_at")
res = settings.msgStoreConn.search(
select_fields=["memory_id", "content", "content_embed"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=2000,
index_names=[_index_name], memory_ids=[memory_id], agg_fields=[]
)
docs = settings.msgStoreConn.get_fields(res, select_fields)
for doc in docs.values():
if current_size < size_to_delete:
current_size += cls.calculate_message_size(doc)
ids_to_remove.append(doc["memory_id"])
else:
return ids_to_remove, current_size
return ids_to_remove, current_size
@classmethod
def get_by_message_id(cls, memory_id: str, message_id: int, uid: str):
index = index_name(uid)
doc_id = f'{memory_id}_{message_id}'
return settings.msgStoreConn.get(doc_id, index, [memory_id])
@classmethod
def get_max_message_id(cls, uid_list: List[str], memory_ids: List[str]):
order_by = OrderByExpr()
order_by.desc("message_id")
index_names = [index_name(uid) for uid in uid_list]
res = settings.msgStoreConn.search(
select_fields=["message_id"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=1,
index_names=index_names, memory_ids=memory_ids,
agg_fields=[], hide_forgotten=False
)
docs = settings.msgStoreConn.get_fields(res, ["message_id"])
if not docs:
return 1
else:
latest_msg = list(docs.values())[0]
return int(latest_msg["message_id"])

185
memory/services/query.py Normal file
View File

@ -0,0 +1,185 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import logging
import json
import numpy as np
from common.query_base import QueryBase
from common.doc_store.doc_store_base import MatchDenseExpr, MatchTextExpr
from common.float_utils import get_float
from rag.nlp import rag_tokenizer, term_weight, synonym
def get_vector(txt, emb_mdl, topk=10, similarity=0.1):
if isinstance(similarity, str) and len(similarity) > 0:
try:
similarity = float(similarity)
except Exception as e:
logging.warning(f"Convert similarity '{similarity}' to float failed: {e}. Using default 0.1")
similarity = 0.1
qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [get_float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
class MsgTextQuery(QueryBase):
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
self.query_fields = [
"content"
]
def question(self, txt, tbl="messages", min_match: float=0.6):
original_query = txt
txt = MsgTextQuery.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = MsgTextQuery.rmWWW(txt)
if not self.is_chinese(txt):
txt = self.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
syns = []
for tk, w in tks_w[:256]:
syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
continue
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100, {"original_query": original_query}
), keywords
def need_fine_grained_tokenize(tk):
if len(tk) < 3:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
txt = self.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
continue
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split()
if need_fine_grained_tokenize(tk)
else []
)
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m,
)
for m in sm
]
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [self.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32:
break
tk = self.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(self.sub_special_char(s))
for s in syns
]
)
if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
if qs:
query = " OR ".join([f"({t})" for t in qs if t])
if not query:
query = otxt
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query}
), keywords
return None, keywords

0
memory/utils/__init__.py Normal file
View File

494
memory/utils/es_conn.py Normal file
View File

@ -0,0 +1,494 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import time
import copy
from elasticsearch import NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
@singleton
class ESConnection(ESConnectionBase):
@staticmethod
def convert_field_name(field_name: str) -> str:
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "content":
return "content_ltks"
case _:
return field_name
@staticmethod
def map_message_to_es_fields(message: dict) -> dict:
"""
Map message dictionary fields to Elasticsearch document/Infinity fields.
:param message: A dictionary containing message details.
:return: A dictionary formatted for Elasticsearch/Infinity indexing.
"""
storage_doc = {
"id": message.get("id"),
"message_id": message["message_id"],
"message_type_kwd": message["message_type"],
"source_id": message["source_id"],
"memory_id": message["memory_id"],
"user_id": message["user_id"],
"agent_id": message["agent_id"],
"session_id": message["session_id"],
"valid_at": message["valid_at"],
"invalid_at": message["invalid_at"],
"forget_at": message["forget_at"],
"status_int": 1 if message["status"] else 0,
"zone_id": message.get("zone_id", 0),
"content_ltks": message["content"],
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
}
return storage_doc
@staticmethod
def get_message_from_es_doc(doc: dict) -> dict:
"""
Convert an Elasticsearch/Infinity document back to a message dictionary.
:param doc: A dictionary representing the Elasticsearch/Infinity document.
:return: A dictionary formatted as a message.
"""
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
message = {
"message_id": doc["message_id"],
"message_type": doc["message_type_kwd"],
"source_id": doc["source_id"] if doc["source_id"] else None,
"memory_id": doc["memory_id"],
"user_id": doc.get("user_id", ""),
"agent_id": doc["agent_id"],
"session_id": doc["session_id"],
"zone_id": doc.get("zone_id", 0),
"valid_at": doc["valid_at"],
"invalid_at": doc.get("invalid_at", "-"),
"forget_at": doc.get("forget_at", "-"),
"status": bool(int(doc["status_int"])),
"content": doc.get("content_ltks", ""),
"content_embed": doc.get(embd_field_name, []) if embd_field_name else [],
}
if doc.get("id"):
message["id"] = doc["id"]
return message
"""
CRUD operations
"""
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bool_query = Q("bool", must=[], must_not=[])
if hide_forgotten:
# filter not forget
bool_query.must_not.append(Q("exists", field="forget_at"))
condition["memory_id"] = memory_ids
for k, v in condition.items():
if k == "session_id" and v:
bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True}))
continue
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(self.convert_field_name(m.vector_column_name),
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
bool_query = Q("bool", must_not=[])
bool_query.must_not.append(Q("term", forget_at=None))
bool_query.filter.append(Q("term", memory_id=memory_id))
# from old to new
order_by = OrderByExpr()
order_by.asc("forget_at")
# build search
s = Search()
s = s.query(bool_query)
s = s.sort(order_by)
s = s[:limit]
q = s.to_dict()
# search
for i in range(ATTEMPT_TIME):
try:
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
message = res["_source"]
message["id"] = doc_id
return self.get_message_from_es_doc(message)
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy_raw = copy.deepcopy(d)
d_copy = self.map_message_to_es_fields(d_copy_raw)
d_copy["memory_id"] = memory_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
doc = copy.deepcopy(new_value)
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
update_dict.pop("id", None)
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict and isinstance(condition_dict["id"], str):
# update specific single document
message_id = condition_dict["id"]
for i in range(ATTEMPT_TIME):
for k in update_dict.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=index_name, id=message_id, doc=update_dict)
return True
except Exception as e:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
break
return False
# update unspecific maybe-multiple documents
bool_query = Q("bool")
for k, v in condition_dict.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in update_dict.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "status_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
assert "_id" not in condition
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict:
message_ids = condition_dict["id"]
if not isinstance(message_ids, list):
message_ids = [message_ids]
if not message_ids: # when message_ids is empty, delete all
qry = Q("match_all")
else:
qry = Q("ids", values=message_ids)
else:
qry = Q("bool")
for k, v in condition_dict.items():
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for doc in self._get_source(res):
message = self.get_message_from_es_doc(doc)
m = {}
for n, v in message.items():
if n not in fields:
continue
if isinstance(v, list):
m[n] = v
continue
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(v)
else:
m[n] = v
if m:
res_fields[doc["id"]] = m
return res_fields

View File

@ -0,0 +1,467 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import copy
from infinity.common import InfinityException, SortType
from infinity.errors import ErrorCode
from common.decorator import singleton
import pandas as pd
from common.constants import PAGERANK_FLD, TAG_FLD
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.infinity_conn_base import InfinityConnectionBase
from common.time_utils import date_string_to_timestamp
@singleton
class InfinityConnection(InfinityConnectionBase):
def __init__(self):
super().__init__()
self.mapping_file_name = "message_infinity_mapping.json"
"""
Dataframe and fields convert
"""
@staticmethod
def field_keyword(field_name: str):
# no keywords right now
return False
@staticmethod
def convert_message_field_to_infinity(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case _:
return field_name
@staticmethod
def convert_infinity_field_to_message(field_name: str):
if field_name.startswith("message_type"):
return "message_type"
if field_name.startswith("status"):
return "status"
if re.match(r"q_\d+_vec", field_name):
return "content_embed"
return field_name
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
return list({self.convert_message_field_to_infinity(f) for f in output_fields})
@staticmethod
def convert_matching_field(field_weight_str: str) -> str:
tokens = field_weight_str.split("^")
field = tokens[0]
if field == "content":
field = "content@ft_contentm_rag_fine"
tokens[0] = field
return "^".join(tokens)
@staticmethod
def convert_condition_and_order_field(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "valid_at":
return "valid_at_flt"
case "invalid_at":
return "invalid_at_flt"
case "forget_at":
return "forget_at_flt"
case _:
return field_name
"""
CRUD operations
"""
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
if hide_forgotten:
condition.update({"must_not": {"exists": "forget_at_flt"}})
output = select_fields.copy()
output = self.convert_select_fields(output)
if agg_fields is None:
agg_fields = []
for essential_field in ["id"] + agg_fields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if match_expressions:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
output.append(PAGERANK_FLD)
output = [f for f in output if f != "_score"]
if limit <= 0:
# ElasticSearch default limit is 10000
limit = 10000
# Prepare expressions common to all tables
filter_cond = None
filter_fulltext = ""
if condition:
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
table_found = False
for indexName in index_names:
for mem_id in memory_ids:
table_name = f"{indexName}_{mem_id}"
try:
filter_cond = self.equivalent_condition_to_str(condition_dict, db_instance.get_table(table_name))
table_found = True
break
except Exception:
pass
if table_found:
break
if not table_found:
self.logger.error(f"No valid tables found for indexNames {index_names} and memoryIds {memory_ids}")
return pd.DataFrame(), 0
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
# Add rank_feature support
if rank_feature and "rank_features" not in matchExpr.extra_options:
# Convert rank_feature dict to Infinity's rank_features string format
# Format: "field^feature_name^weight,field^feature_name^weight"
rank_features_list = []
for feature_name, weight in rank_feature.items():
# Use TAG_FLD as the field containing rank features
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
if rank_features_list:
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
similarity = matchExpr.extra_options.get("similarity")
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in index_names:
for memory_id in memory_ids:
table_name = f"{indexName}_{memory_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(match_expressions) > 0:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if order_by.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
mem_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, output)
if match_expressions:
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
self.logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
condition = {"memory_id": memory_id, "exists": "forget_at_flt"}
order_by = OrderByExpr()
order_by.asc("forget_at_flt")
# query
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
output_fields = [self.convert_message_field_to_infinity(f) for f in select_fields]
builder = table_instance.output(output_fields)
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
builder.filter(filter_cond)
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
builder.sort(order_by_expr_list)
builder.offset(0).limit(limit)
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
res = self.concat_dataframes(mem_res, output_fields)
res.head(limit)
self.connPool.release_conn(inf_conn)
return res
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(memory_ids, list)
table_list = list()
for memoryId in memory_ids:
table_name = f"{index_name}_{memoryId}"
table_list.append(table_name)
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Table not found: {table_name}, this memory isn't created in Infinity. Maybe it is created in other document engine.")
continue
mem_res, _ = table_instance.output(["*"]).filter(f"id = '{message_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
res_fields = self.get_fields(res, list(fields))
return res_fields.get(message_id, None)
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
if not documents:
return []
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
vector_size = int(len(documents[0]["content_embed"]))
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
# src/common/status.cppm, kTableNotExist = 3022
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
raise
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.create_idx(index_name, memory_id, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
embedding_columns = []
table_columns = table_instance.show_columns().rows()
for n, ty, _, _ in table_columns:
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
if not r:
continue
embedding_columns.append((n, int(r.group(1))))
docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in list(d.items()):
field_name = self.convert_message_field_to_infinity(k)
if field_name in ["valid_at", "invalid_at", "forget_at"]:
d[f"{field_name}_flt"] = date_string_to_timestamp(v) if v else 0
if v is None:
d[field_name] = ""
elif self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
d[k] = v
elif k == "memory_id":
if isinstance(d[k], list):
d[k] = d[k][0] # since d[k] is a list, but we need a str
elif field_name == "content_embed":
d[f"q_{vector_size}_vec"] = d["content_embed"]
d.pop("content_embed")
else:
d[field_name] = v
if k != field_name:
d.pop(k)
for n, vs in embedding_columns:
if n in d:
continue
d[n] = [0] * vs
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
filter = self.equivalent_condition_to_str(condition_dict, table_instance)
update_dict = {self.convert_message_field_to_infinity(k): v for k, v in new_value.items()}
date_floats = {}
for k, v in update_dict.items():
if k in ["valid_at", "invalid_at", "forget_at"]:
date_floats[f"{k}_flt"] = date_string_to_timestamp(v) if v else 0
elif self.field_keyword(k):
if isinstance(v, list):
update_dict[k] = "###".join(v)
else:
update_dict[k] = v
elif k == "memory_id":
if isinstance(update_dict[k], list):
update_dict[k] = update_dict[k][0] # since d[k] is a list, but we need a str
else:
update_dict[k] = v
if date_floats:
update_dict.update(date_floats)
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
table_instance.update(filter, update_dict)
self.connPool.release_conn(inf_conn)
return True
"""
Helper functions for search result
"""
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fields_all = fields.copy()
fields_all.append("id")
fields_all = {self.convert_message_field_to_infinity(f) for f in fields_all}
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
none_columns = [col for col in fields_all if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
res2.drop_duplicates(subset=["id"], inplace=True)
for column in list(res2.columns):
k = column.lower()
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
else:
pass
for column in ["content"]:
if column in res2:
del res2[column]
for column in none_columns:
res2[column] = None
res_dict = res2.set_index("id").to_dict(orient="index")
return {_id: {self.convert_infinity_field_to_message(k): v for k, v in doc.items()} for _id, doc in res_dict.items()}

37
memory/utils/msg_util.py Normal file
View File

@ -0,0 +1,37 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
def get_json_result_from_llm_response(response_str: str) -> dict:
"""
Parse the LLM response string to extract JSON content.
The function looks for the first and last curly braces to identify the JSON part.
If parsing fails, it returns an empty dictionary.
:param response_str: The response string from the LLM.
:return: A dictionary parsed from the JSON content in the response.
"""
try:
clean_str = response_str.strip()
if clean_str.startswith('```json'):
clean_str = clean_str[7:] # Remove the starting ```json
if clean_str.endswith('```'):
clean_str = clean_str[:-3] # Remove the ending ```
return json.loads(clean_str.strip())
except (ValueError, json.JSONDecodeError):
return {}

201
memory/utils/prompt_util.py Normal file
View File

@ -0,0 +1,201 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List
from common.constants import MemoryType
from common.time_utils import current_timestamp
class PromptAssembler:
SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist**
You are an expert at analyzing conversations to extract structured memory.
{type_specific_instructions}
**OUTPUT REQUIREMENTS:**
1. Output MUST be valid JSON
2. Follow the specified output format exactly
3. Each extracted item MUST have: content, valid_at, invalid_at
4. Timestamps in {timestamp_format} format
5. Only extract memory types specified above
6. Maximum {max_items} items per type
"""
TYPE_INSTRUCTIONS = {
MemoryType.SEMANTIC.name.lower(): """
**EXTRACT SEMANTIC KNOWLEDGE:**
- Universal facts, definitions, concepts, relationships
- Time-invariant, generally true information
- Examples: "The capital of France is Paris", "Water boils at 100°C"
**Timestamp Rules for Semantic Knowledge:**
- valid_at: When the fact became true (e.g., law enactment, discovery)
- invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true
- Default: valid_at = conversation time, invalid_at = "" for timeless facts
""",
MemoryType.EPISODIC.name.lower(): """
**EXTRACT EPISODIC KNOWLEDGE:**
- Specific experiences, events, personal stories
- Time-bound, person-specific, contextual
- Examples: "Yesterday I fixed the bug", "User reported issue last week"
**Timestamp Rules for Episodic Knowledge:**
- valid_at: Event start/occurrence time
- invalid_at: Event end time or empty if instantaneous
- Extract explicit times: "at 3 PM", "last Monday", "from X to Y"
""",
MemoryType.PROCEDURAL.name.lower(): """
**EXTRACT PROCEDURAL KNOWLEDGE:**
- Processes, methods, step-by-step instructions
- Goal-oriented, actionable, often includes conditions
- Examples: "To reset password, click...", "Debugging steps: 1)..."
**Timestamp Rules for Procedural Knowledge:**
- valid_at: When procedure becomes valid/effective
- invalid_at: When it expires/becomes obsolete or empty if current
- For version-specific: use release dates
- For best practices: invalid_at = ""
"""
}
OUTPUT_TEMPLATES = {
MemoryType.SEMANTIC.name.lower(): """
"semantic": [
{
"content": "Clear factual statement",
"valid_at": "timestamp or empty",
"invalid_at": "timestamp or empty"
}
]
""",
MemoryType.EPISODIC.name.lower(): """
"episodic": [
{
"content": "Narrative event description",
"valid_at": "event start timestamp",
"invalid_at": "event end timestamp or empty"
}
]
""",
MemoryType.PROCEDURAL.name.lower(): """
"procedural": [
{
"content": "Actionable instructions",
"valid_at": "procedure effective timestamp",
"invalid_at": "procedure expiration timestamp or empty"
}
]
"""
}
BASE_USER_PROMPT = """
**CONVERSATION:**
{conversation}
**CONVERSATION TIME:** {conversation_time}
**CURRENT TIME:** {current_time}
"""
@classmethod
def assemble_system_prompt(cls, config: dict) -> str:
types_to_extract = cls._get_types_to_extract(config["memory_type"])
type_instructions = cls._generate_type_instructions(types_to_extract)
output_format = cls._generate_output_format(types_to_extract)
full_prompt = cls.SYSTEM_BASE_TEMPLATE.format(
type_specific_instructions=type_instructions,
timestamp_format=config.get("timestamp_format", "ISO 8601"),
max_items=config.get("max_items_per_type", 5)
)
full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n"
examples = cls._generate_examples(types_to_extract)
if examples:
full_prompt += f"\n**EXAMPLES:**\n{examples}\n"
return full_prompt
@staticmethod
def _get_types_to_extract(requested_types: List[str]) -> List[str]:
types = set()
for rt in requested_types:
if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower():
types.add(rt)
return list(types)
@classmethod
def _generate_type_instructions(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
instructions = [cls.TYPE_INSTRUCTIONS[mt] for mt in target_types]
return "\n".join(instructions)
@classmethod
def _generate_output_format(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
output_parts = [cls.OUTPUT_TEMPLATES[mt] for mt in target_types]
return ",\n".join(output_parts)
@staticmethod
def _generate_examples(types_to_extract: list[str]) -> str:
examples = []
if MemoryType.SEMANTIC.name.lower() in types_to_extract:
examples.append("""
**Semantic Example:**
Input: "Python lists are mutable and support various operations."
Output: {"semantic": [{"content": "Python lists are mutable data structures", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
if MemoryType.EPISODIC.name.lower() in types_to_extract:
examples.append("""
**Episodic Example:**
Input: "I deployed the new feature yesterday afternoon."
Output: {"episodic": [{"content": "User deployed new feature", "valid_at": "2024-01-14T14:00:00", "invalid_at": "2024-01-14T18:00:00"}]}
""")
if MemoryType.PROCEDURAL.name.lower() in types_to_extract:
examples.append("""
**Procedural Example:**
Input: "To debug API errors: 1) Check logs 2) Verify endpoints 3) Test connectivity."
Output: {"procedural": [{"content": "API error debugging: 1. Check logs 2. Verify endpoints 3. Test connectivity", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
return "\n".join(examples)
@classmethod
def assemble_user_prompt(
cls,
conversation: str,
conversation_time: Optional[str] = None,
current_time: Optional[str] = None
) -> str:
return cls.BASE_USER_PROMPT.format(
conversation=conversation,
conversation_time=conversation_time or "Not specified",
current_time=current_time or current_timestamp(),
)
@classmethod
def get_raw_user_prompt(cls):
return cls.BASE_USER_PROMPT

View File

@ -46,7 +46,7 @@ dependencies = [
"groq==0.9.0",
"grpcio-status==1.67.1",
"html-text==0.6.2",
"infinity-sdk==0.6.11",
"infinity-sdk==0.6.13",
"infinity-emb>=0.0.66,<0.0.67",
"jira==3.10.5",
"json-repair==0.35.0",

View File

@ -91,7 +91,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
filename, binary=binary, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
# tbls = [((None, lns), None) for lns in tbls]
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.")
@ -147,9 +147,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
with BytesIO(binary) as binary:
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
try:
from tika import parser as tika_parser
except Exception as e:
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
return []
binary = BytesIO(binary)
doc_parsed = tika_parser.from_buffer(binary)
if doc_parsed.get('content', None) is not None:
sections = doc_parsed['content'].split('\n')
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(

View File

@ -312,7 +312,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1:
last_sid = sec_id
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
@ -325,7 +325,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
for text, image in ti_list:
d = copy.deepcopy(doc)

View File

@ -76,7 +76,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections, tbls = naive.Docx()(filename, binary)
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs)
sections = [s for s, _ in sections if s]
for (_, html), _ in tbls:
sections.append(html)
@ -142,10 +142,18 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
try:
from tika import parser as tika_parser
except Exception as e:
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
return []
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
doc_parsed = tika_parser.from_buffer(binary)
if doc_parsed.get('content', None) is not None:
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:

View File

@ -77,9 +77,9 @@ class Benchmark:
def init_index(self, vector_size: int):
if self.initialized_index:
return
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
if settings.docStoreConn.index_exist(self.index_name, self.kb_id):
settings.docStoreConn.delete_idx(self.index_name, self.kb_id)
settings.docStoreConn.create_idx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):

View File

@ -37,7 +37,6 @@ from rag.app.naive import Docx
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream
from rag.llm.cv_model import Base as VLM
from rag.nlp import attach_media_context
from rag.utils.base64_image import image2id
@ -86,8 +85,6 @@ class ParserParam(ProcessParamBase):
"pdf",
],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"spreadsheet": {
"parse_method": "deepdoc", # deepdoc/tcadp_parser
@ -97,8 +94,6 @@ class ParserParam(ProcessParamBase):
"xlsx",
"csv",
],
"table_context_size": 0,
"image_context_size": 0,
},
"word": {
"suffix": [
@ -106,14 +101,10 @@ class ParserParam(ProcessParamBase):
"docx",
],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"text&markdown": {
"suffix": ["md", "markdown", "mdx", "txt"],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"slides": {
"parse_method": "deepdoc", # deepdoc/tcadp_parser
@ -122,8 +113,6 @@ class ParserParam(ProcessParamBase):
"ppt",
],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"image": {
"parse_method": "ocr",
@ -357,11 +346,6 @@ class Parser(ProcessBase):
elif layout == "table":
b["doc_type_kwd"] = "table"
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
bboxes = attach_media_context(bboxes, table_ctx, image_ctx)
if conf.get("output_format") == "json":
self.set_output("json", bboxes)
if conf.get("output_format") == "markdown":
@ -436,11 +420,6 @@ class Parser(ProcessBase):
if table:
result.append({"text": table, "doc_type_kwd": "table"})
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
result = attach_media_context(result, table_ctx, image_ctx)
self.set_output("json", result)
elif output_format == "markdown":
@ -476,11 +455,6 @@ class Parser(ProcessBase):
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
sections.extend([{"text": tb, "image": None, "doc_type_kwd": "table"} for ((_, tb), _) in tbls])
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
sections = attach_media_context(sections, table_ctx, image_ctx)
self.set_output("json", sections)
elif conf.get("output_format") == "markdown":
markdown_text = docx_parser.to_markdown(name, binary=blob)
@ -536,11 +510,6 @@ class Parser(ProcessBase):
if table:
result.append({"text": table, "doc_type_kwd": "table"})
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
result = attach_media_context(result, table_ctx, image_ctx)
self.set_output("json", result)
else:
# Default DeepDOC parser (supports .pptx format)
@ -554,10 +523,6 @@ class Parser(ProcessBase):
# json
assert conf.get("output_format") == "json", "have to be json for ppt"
if conf.get("output_format") == "json":
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
sections = attach_media_context(sections, table_ctx, image_ctx)
self.set_output("json", sections)
def _markdown(self, name, blob):
@ -597,11 +562,6 @@ class Parser(ProcessBase):
json_results.append(json_result)
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
json_results = attach_media_context(json_results, table_ctx, image_ctx)
self.set_output("json", json_results)
else:
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
@ -650,7 +610,7 @@ class Parser(ProcessBase):
tmpf.flush()
tmp_path = os.path.abspath(tmpf.name)
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT)
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, llm_name=conf["llm_id"])
txt = seq2txt_mdl.transcription(tmp_path)
self.set_output("text", txt)

View File

@ -23,7 +23,7 @@ from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
from rag.nlp import attach_media_context, naive_merge, naive_merge_with_images
from common import settings
@ -34,11 +34,15 @@ class SplitterParam(ProcessParamBase):
self.delimiters = ["\n"]
self.overlapped_percent = 0
self.children_delimiters = []
self.table_context_size = 0
self.image_context_size = 0
def check(self):
self.check_empty(self.delimiters, "Delimiters.")
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
self.check_nonnegative_number(self.table_context_size, "Table context size.")
self.check_nonnegative_number(self.image_context_size, "Image context size.")
def get_input_form(self) -> dict[str, dict]:
return {}
@ -103,8 +107,18 @@ class Splitter(ProcessBase):
return
# json
json_result = from_upstream.json_result or []
if self._param.table_context_size or self._param.image_context_size:
for ck in json_result:
if "image" not in ck and ck.get("img_id") and not (isinstance(ck.get("text"), str) and ck.get("text").strip()):
ck["image"] = True
attach_media_context(json_result, self._param.table_context_size, self._param.image_context_size)
for ck in json_result:
if ck.get("image") is True:
del ck["image"]
sections, section_images = [], []
for o in from_upstream.json_result or []:
for o in json_result:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(id2image(o.get("img_id"), partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))

View File

@ -105,6 +105,9 @@ class Tokenizer(ProcessBase):
async def _invoke(self, **kwargs):
try:
chunks = kwargs.get("chunks")
kwargs["chunks"] = [c for c in chunks if c is not None]
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")

View File

@ -348,7 +348,8 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d["doc_type_kwd"] = "table"
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
if d["content_with_weight"].find("<tr>") < 0:
d["doc_type_kwd"] = "image"
if poss:
add_positions(d, poss)
res.append(d)
@ -361,7 +362,8 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d["doc_type_kwd"] = "table"
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
if d["content_with_weight"].find("<tr>") < 0:
d["doc_type_kwd"] = "image"
add_positions(d, poss)
res.append(d)
return res

View File

@ -19,11 +19,12 @@ import json
import re
from collections import defaultdict
from rag.utils.doc_store_conn import MatchTextExpr
from common.query_base import QueryBase
from common.doc_store.doc_store_base import MatchTextExpr
from rag.nlp import rag_tokenizer, term_weight, synonym
class FulltextQueryer:
class FulltextQueryer(QueryBase):
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
@ -37,64 +38,19 @@ class FulltextQueryer:
"content_sm_ltks",
]
@staticmethod
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
def question(self, txt, tbl="qa", min_match: float = 0.6):
original_query = txt
txt = FulltextQueryer.add_space_between_eng_zh(txt)
txt = self.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
txt = self.rmWWW(txt)
if not self.is_chinese(txt):
txt = FulltextQueryer.rmWWW(txt)
txt = self.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
@ -138,7 +94,7 @@ class FulltextQueryer:
return False
return True
txt = FulltextQueryer.rmWWW(txt)
txt = self.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
@ -164,7 +120,7 @@ class FulltextQueryer:
)
for m in sm
]
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
@ -172,7 +128,7 @@ class FulltextQueryer:
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [self.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
@ -181,7 +137,7 @@ class FulltextQueryer:
if len(keywords) >= 32:
break
tk = FulltextQueryer.sub_special_char(tk)
tk = self.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
@ -199,7 +155,7 @@ class FulltextQueryer:
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
% rag_tokenizer.tokenize(self.sub_special_char(s))
for s in syns
]
)
@ -264,10 +220,10 @@ class FulltextQueryer:
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [self.sub_special_char(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.sub_special_char(tk)
tk = self.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:

View File

@ -24,7 +24,7 @@ from dataclasses import dataclass
from rag.prompts.generator import relevant_chunks_with_toc
from rag.nlp import rag_tokenizer, query
import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByExpr, DocStoreConnection
from common.string_utils import remove_redundant_spaces
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
@ -155,7 +155,7 @@ class Dealer:
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.get_chunk_ids(res)
ids = self.dataStore.get_doc_ids(res)
keywords = list(kwds)
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
@ -545,7 +545,7 @@ class Dealer:
return res
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
if not self.dataStore.index_exist(index_name(tenant_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.get_aggregation(res, "tag_kwd")

View File

@ -136,6 +136,19 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
return knowledges
def memory_prompt(message_list, max_tokens):
used_token_count = 0
content_list = []
for message in message_list:
current_content_tokens = num_tokens_from_string(message["content"])
if used_token_count + current_content_tokens > max_tokens * 0.97:
logging.warning(f"Not all the retrieval into prompt: {len(content_list)}/{len(message_list)}")
break
content_list.append(message["content"])
used_token_count += current_content_tokens
return content_list
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
@ -324,8 +337,9 @@ def tool_schema(tools_description: list[dict], complete_task=False):
}
}
}
for tool in tools_description:
desc[tool["function"]["name"]] = tool
for idx, tool in enumerate(tools_description):
name = tool["function"]["name"]
desc[name] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])

View File

@ -8,9 +8,9 @@ Your job is:
{{ task_analysis }}
# ========== TOOLS (JSON-Schema) ==========
You may invoke only the tools listed below.
Return a JSON array of objects in which item is with exactly two top-level keys:
• "name": the tool to call
You may invoke only the tools listed below.
Return a JSON array of objects in which item is with exactly two top-level keys:
• "name": the tool to call
• "arguments": an object whose keys/values satisfy the schema
{{ desc }}
@ -82,11 +82,57 @@ If you encounter issues:
⚠️ Any output that is not valid JSON or that contains extra fields will be rejected.
# ========== REASONING & REFLECTION ==========
You may think privately (not shown to the user) before producing each JSON object.
Internal guideline:
1. **Reason**: Analyse the user question; decide which tools (if any) are needed.
2. **Act**: Emit the JSON object to call the tool.
# ========== PRIVATE REASONING & REFLECTION ==========
You may think privately inside `<think>` tags.
This content will NOT be shown to the user.
## Step 1: Core Reasoning
- Analyze the task requirements
- Decide whether tools are required
- Decide if parallel execution is appropriate
## Step 2: Structured Reflection (MANDATORY before `complete_task`)
### Context
- Goal: Reflect on the current task based on the full conversation context
- Executed tool calls so far (if any): reflect from conversation history
### Task Complexity Assessment
Evaluate the task along these dimensions:
- Scope Breadth: Single-step (1) | Multi-step (2) | Multi-domain (3)
- Data Dependency: Self-contained (1) | External inputs (2) | Multiple sources (3)
- Decision Points: Linear (1) | Few branches (2) | Complex logic (3)
- Risk Level: Low (1) | Medium (2) | High (3)
Compute the **Complexity Score (412)**.
### Reflection Depth Control
- 45: Brief sanity check
- 68: Check completeness + risks
- 912: Full reflection with alternatives
### Reflection Checklist
- Goal alignment: Is the objective truly satisfied?
- Step completion: Any planned step missing?
- Information adequacy: Is evidence sufficient?
- Errors or uncertainty: Any low-confidence result?
- Tool misuse risk: Wrong tool / missing tool?
### Decision Gate
Ask yourself explicitly:
> “If I stop now and call `complete_task`, would a downstream agent or user reasonably say something is missing or wrong?”
If YES → continue with tools
If NO → safe to call `complete_task`
---
# ========== FINAL ACTION ==========
After reflection, emit ONLY ONE of the following:
- A JSON array of tool calls
- OR a single `complete_task` call
Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct.

View File

@ -374,13 +374,13 @@ async def build_chunks(task, progress_callback):
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
async def gen_metadata_task(chat_mdl, d):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata")
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", {})
if not cached:
async with chat_limiter:
cached = await gen_metadata(chat_mdl,
metadata_schema(task["parser_config"]["metadata"]),
d["content_with_weight"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata")
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", {})
if cached:
d["metadata_obj"] = cached
tasks = []
@ -395,9 +395,9 @@ async def build_chunks(task, progress_callback):
await asyncio.gather(*tasks, return_exceptions=True)
raise
metadata = {}
for ck in cks:
metadata = update_metadata_to(metadata, ck["metadata_obj"])
del ck["metadata_obj"]
for doc in docs:
metadata = update_metadata_to(metadata, doc["metadata_obj"])
del doc["metadata_obj"]
if metadata:
e, doc = DocumentService.get_by_id(task["doc_id"])
if e:
@ -506,7 +506,7 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size)
async def embedding(docs, mdl, parser_config=None, callback=None):
@ -852,7 +852,7 @@ async def do_handle_task(task):
task_tenant_id = task["tenant_id"]
task_embedding_id = task["embd_id"]
task_language = task["language"]
task_llm_id = task["llm_id"]
task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"]
task_dataset_id = task["kb_id"]
task_doc_id = task["doc_id"]
task_document_name = task["name"]
@ -1024,33 +1024,65 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
if not e:
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
timer() - start_ts))
async def _maybe_insert_es(_chunks):
if has_canceled(task_id):
return True
e = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(e)
try:
if not await _maybe_insert_es(chunks):
return
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
logging.info(
"Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts
)
)
time_cost = timer() - start_ts
progress_callback(msg="Indexing done ({:.2f}s).".format(time_cost))
if toc_thread:
d = toc_thread.result()
if d:
e = await insert_es(task_id, task_tenant_id, task_dataset_id, [d], progress_callback)
if not e:
return
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
token_count, task_time_cost))
progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
if toc_thread:
d = toc_thread.result()
if d:
if not await _maybe_insert_es([d]):
return
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)
if has_canceled(task_id):
progress_callback(-1, msg="Task has been canceled.")
return
task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost
)
)
finally:
if has_canceled(task_id):
try:
exists = await asyncio.to_thread(
settings.docStoreConn.indexExist,
search.index_name(task_tenant_id),
task_dataset_id,
)
if exists:
await asyncio.to_thread(
settings.docStoreConn.delete,
{"doc_id": task_doc_id},
search.index_name(task_tenant_id),
task_dataset_id,
)
except Exception:
logging.exception(
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled."
)
async def handle_task():

View File

@ -82,7 +82,7 @@ def id2image(image_id:str|None, storage_get_func: partial):
return
bkt, nm = image_id.split("-")
try:
blob = storage_get_func(bucket=bkt, filename=nm)
blob = storage_get_func(bucket=bkt, fnm=nm)
if not blob:
return
return Image.open(BytesIO(blob))

View File

@ -14,194 +14,92 @@
# limitations under the License.
#
import logging
import re
import json
import time
import os
import copy
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
logger = logging.getLogger('ragflow.es_conn')
@singleton
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def dbType(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
logger.exception("ESConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=indexName, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(e)
break
return False
class ESConnection(ESConnectionBase):
"""
CRUD operations
"""
def search(
self, selectFields: list[str],
highlightFields: list[str],
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebaseIds
bool_query = Q("bool", must=[])
condition["kb_id"] = knowledgebase_ids
for k, v in condition.items():
if k == "available_int":
if v == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
bool_query.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
bool_query.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in matchExprs:
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
MatchDenseExpr) and isinstance(
matchExprs[2], FusionExpr)
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in matchExprs:
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bqry.must.append(Q("query_string", fields=m.fields,
bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bqry is not None)
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
@ -209,24 +107,24 @@ class ESConnection(DocStoreConnection):
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bqry and rank_feature:
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
s = s.query(bqry)
for field in highlightFields:
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if orderBy:
if order_by:
orders = list()
for field, order in orderBy.fields:
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
@ -237,19 +135,19 @@ class ESConnection(DocStoreConnection):
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=indexNames,
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
@ -257,55 +155,37 @@ class ESConnection(DocStoreConnection):
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=(indexName),
id=chunkId, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
chunk = res["_source"]
chunk["id"] = chunkId
return chunk
except NotFoundError:
return None
except Exception as e:
logger.exception(f"ESConnection.get({chunkId}) got exception")
raise e
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebaseId
d_copy["kb_id"] = knowledgebase_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": indexName, "_id": meta_id}})
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=(indexName), operations=operations,
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
@ -316,58 +196,58 @@ class ESConnection(DocStoreConnection):
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
logger.warning("ESConnection.insert got exception: " + str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
doc = copy.deepcopy(newValue)
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
doc = copy.deepcopy(new_value)
doc.pop("id", None)
condition["kb_id"] = knowledgebaseId
condition["kb_id"] = knowledgebase_id
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunkId = condition["id"]
chunk_id = condition["id"]
for i in range(ATTEMPT_TIME):
for k in doc.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=indexName, id=chunkId, doc=doc)
self.es.update(index=index_name, id=chunk_id, doc=doc)
return True
except Exception as e:
logger.exception(
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
break
return False
# update unspecific maybe-multiple documents
bqry = Q("bool")
bool_query = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bqry.filter.append(Q("exists", field=v))
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in newValue.items():
for k, v in new_value.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
@ -397,8 +277,8 @@ class ESConnection(DocStoreConnection):
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
@ -409,19 +289,18 @@ class ESConnection(DocStoreConnection):
_ = ubq.execute()
return True
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
qry = None
def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
assert "_id" not in condition
condition["kb_id"] = knowledgebaseId
condition["kb_id"] = knowledgebase_id
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
@ -448,21 +327,21 @@ class ESConnection(DocStoreConnection):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=indexName,
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.warning("ESConnection.delete got exception: " + str(e))
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
@ -471,27 +350,11 @@ class ESConnection(DocStoreConnection):
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for d in self.__getSource(res):
for d in self._get_source(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
@ -508,124 +371,3 @@ class ESConnection(DocStoreConnection):
if m:
res_fields[d["id"]] = m
return res_fields
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None

View File

@ -14,365 +14,125 @@
# limitations under the License.
#
import logging
import os
import re
import json
import time
import copy
import infinity
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.common import InfinityException, SortType
from infinity.errors import ErrorCode
from common.decorator import singleton
import pandas as pd
from common.file_utils import get_project_base_directory
from rag.nlp import is_english
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
from rag.utils.doc_store_conn import (
DocStoreConnection,
MatchExpr,
MatchTextExpr,
MatchDenseExpr,
FusionExpr,
OrderByExpr,
)
logger = logging.getLogger("ragflow.infinity_conn")
def field_keyword(field_name: str):
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]):
return True
return False
def convert_select_fields(output_fields: list[str]) -> list[str]:
for i, field in enumerate(output_fields):
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
output_fields[i] = "docnm"
elif field in ["important_kwd", "important_tks"]:
output_fields[i] = "important_keywords"
elif field in ["question_kwd", "question_tks"]:
output_fields[i] = "questions"
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
output_fields[i] = "content"
elif field in ["authors_tks", "authors_sm_tks"]:
output_fields[i] = "authors"
return list(set(output_fields))
def convert_matching_field(field_weightstr: str) -> str:
tokens = field_weightstr.split("^")
field = tokens[0]
if field == "docnm_kwd" or field == "title_tks":
field = "docnm@ft_docnm_rag_coarse"
elif field == "title_sm_tks":
field = "docnm@ft_docnm_rag_fine"
elif field == "important_kwd":
field = "important_keywords@ft_important_keywords_rag_coarse"
elif field == "important_tks":
field = "important_keywords@ft_important_keywords_rag_fine"
elif field == "question_kwd":
field = "questions@ft_questions_rag_coarse"
elif field == "question_tks":
field = "questions@ft_questions_rag_fine"
elif field == "content_with_weight" or field == "content_ltks":
field = "content@ft_content_rag_coarse"
elif field == "content_sm_ltks":
field = "content@ft_content_rag_fine"
elif field == "authors_tks":
field = "authors@ft_authors_rag_coarse"
elif field == "authors_sm_tks":
field = "authors@ft_authors_rag_fine"
tokens[0] = field
return "^".join(tokens)
def list2str(lst: str|list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
clmns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
def exists(cln):
nonlocal clmns
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
ty, de = clmns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in selectFields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.infinity_conn_base import InfinityConnectionBase
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
for _ in range(24):
try:
connPool = ConnectionPool(infinity_uri, max_size=4)
inf_conn = connPool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = connPool
connPool.release_conn(inf_conn)
break
connPool.release_conn(inf_conn)
logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
class InfinityConnection(InfinityConnectionBase):
"""
Database operations
Dataframe and fields convert
"""
def dbType(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vectorSize}_vec"
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
def deleteIdx(self, indexName: str, knowledgebaseId: str):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY dropped table {table_name}")
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
table_name = f"{indexName}_{knowledgebaseId}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
@staticmethod
def field_keyword(field_name: str):
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
if field_name == "source_id" or (
field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd",
"question_kwd"]):
return True
except Exception as e:
logger.warning(f"INFINITY indexExist {str(e)}")
return False
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
for i, field in enumerate(output_fields):
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
output_fields[i] = "docnm"
elif field in ["important_kwd", "important_tks"]:
output_fields[i] = "important_keywords"
elif field in ["question_kwd", "question_tks"]:
output_fields[i] = "questions"
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
output_fields[i] = "content"
elif field in ["authors_tks", "authors_sm_tks"]:
output_fields[i] = "authors"
return list(set(output_fields))
@staticmethod
def convert_matching_field(field_weight_str: str) -> str:
tokens = field_weight_str.split("^")
field = tokens[0]
if field == "docnm_kwd" or field == "title_tks":
field = "docnm@ft_docnm_rag_coarse"
elif field == "title_sm_tks":
field = "docnm@ft_docnm_rag_fine"
elif field == "important_kwd":
field = "important_keywords@ft_important_keywords_rag_coarse"
elif field == "important_tks":
field = "important_keywords@ft_important_keywords_rag_fine"
elif field == "question_kwd":
field = "questions@ft_questions_rag_coarse"
elif field == "question_tks":
field = "questions@ft_questions_rag_fine"
elif field == "content_with_weight" or field == "content_ltks":
field = "content@ft_content_rag_coarse"
elif field == "content_sm_ltks":
field = "content@ft_content_rag_fine"
elif field == "authors_tks":
field = "authors@ft_authors_rag_coarse"
elif field == "authors_sm_tks":
field = "authors@ft_authors_rag_fine"
tokens[0] = field
return "^".join(tokens)
"""
CRUD operations
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
output = selectFields.copy()
output = convert_select_fields(output)
for essential_field in ["id"] + aggFields:
output = select_fields.copy()
output = self.convert_select_fields(output)
if agg_fields is None:
agg_fields = []
for essential_field in ["id"] + agg_fields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if matchExprs:
if match_expressions:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
@ -387,11 +147,11 @@ class InfinityConnection(DocStoreConnection):
filter_fulltext = ""
if condition:
table_found = False
for indexName in indexNames:
for kb_id in knowledgebaseIds:
for indexName in index_names:
for kb_id in knowledgebase_ids:
table_name = f"{indexName}_{kb_id}"
try:
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
table_found = True
break
except Exception:
@ -399,14 +159,14 @@ class InfinityConnection(DocStoreConnection):
if table_found:
break
if not table_found:
logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
return pd.DataFrame(), 0
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [convert_matching_field(field) for field in matchExpr.fields]
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
@ -430,7 +190,7 @@ class InfinityConnection(DocStoreConnection):
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
@ -441,16 +201,16 @@ class InfinityConnection(DocStoreConnection):
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
if matchExpr.method == "weighted_sum":
# The default is "minmax" which gives a zero score for the last doc.
matchExpr.fusion_params["normalize"] = "atan"
logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if orderBy.fields:
for order_field in orderBy.fields:
if order_by.fields:
for order_field in order_by.fields:
if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
@ -458,8 +218,8 @@ class InfinityConnection(DocStoreConnection):
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in indexNames:
for knowledgebaseId in knowledgebaseIds:
for indexName in index_names:
for knowledgebaseId in knowledgebase_ids:
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
@ -467,8 +227,8 @@ class InfinityConnection(DocStoreConnection):
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if len(match_expressions) > 0:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
@ -491,53 +251,52 @@ class InfinityConnection(DocStoreConnection):
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields:
if order_by.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, output)
if matchExprs:
res = self.concat_dataframes(df_list, output)
if match_expressions:
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
self.logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(knowledgebaseIds, list)
assert isinstance(knowledgebase_ids, list)
table_list = list()
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
for knowledgebaseId in knowledgebase_ids:
table_name = f"{index_name}_{knowledgebaseId}"
table_list.append(table_name)
table_instance = None
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
fields.add(field)
res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunkId, None)
return res_fields.get(chunk_id, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_name = f"{index_name}_{knowledgebase_id}"
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
@ -553,7 +312,7 @@ class InfinityConnection(DocStoreConnection):
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.createIdx(indexName, knowledgebaseId, vector_size)
self.create_idx(index_name, knowledgebase_id, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
@ -574,12 +333,12 @@ class InfinityConnection(DocStoreConnection):
d["docnm"] = v
elif k == "title_kwd":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
d["docnm"] = self.list2str(v)
elif k == "title_sm_tks":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
d["docnm"] = self.list2str(v)
elif k == "important_kwd":
d["important_keywords"] = list2str(v)
d["important_keywords"] = self.list2str(v)
elif k == "important_tks":
if not d.get("important_kwd"):
d["important_keywords"] = v
@ -597,11 +356,11 @@ class InfinityConnection(DocStoreConnection):
if not d.get("authors_tks"):
d["authors"] = v
elif k == "question_kwd":
d["questions"] = list2str(v, "\n")
d["questions"] = self.list2str(v, "\n")
elif k == "question_tks":
if not d.get("question_kwd"):
d["questions"] = list2str(v)
elif field_keyword(k):
d["questions"] = self.list2str(v)
elif self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
@ -637,15 +396,15 @@ class InfinityConnection(DocStoreConnection):
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
# if 'position_int' in newValue:
# logger.info(f"update position_int: {newValue['position_int']}")
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_name = f"{index_name}_{knowledgebase_id}"
table_instance = db_instance.get_table(table_name)
# if "exists" in condition:
# del condition["exists"]
@ -654,57 +413,57 @@ class InfinityConnection(DocStoreConnection):
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
filter = equivalent_condition_to_str(condition, table_instance)
filter = self.equivalent_condition_to_str(condition, table_instance)
removeValue = {}
for k, v in list(newValue.items()):
for k, v in list(new_value.items()):
if k == "docnm_kwd":
newValue["docnm"] = list2str(v)
new_value["docnm"] = self.list2str(v)
elif k == "title_kwd":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = list2str(v)
if not new_value.get("docnm_kwd"):
new_value["docnm"] = self.list2str(v)
elif k == "title_sm_tks":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = v
if not new_value.get("docnm_kwd"):
new_value["docnm"] = v
elif k == "important_kwd":
newValue["important_keywords"] = list2str(v)
new_value["important_keywords"] = self.list2str(v)
elif k == "important_tks":
if not newValue.get("important_kwd"):
newValue["important_keywords"] = v
if not new_value.get("important_kwd"):
new_value["important_keywords"] = v
elif k == "content_with_weight":
newValue["content"] = v
new_value["content"] = v
elif k == "content_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
if not new_value.get("content_with_weight"):
new_value["content"] = v
elif k == "content_sm_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
if not new_value.get("content_with_weight"):
new_value["content"] = v
elif k == "authors_tks":
newValue["authors"] = v
new_value["authors"] = v
elif k == "authors_sm_tks":
if not newValue.get("authors_tks"):
newValue["authors"] = v
if not new_value.get("authors_tks"):
new_value["authors"] = v
elif k == "question_kwd":
newValue["questions"] = "\n".join(v)
new_value["questions"] = "\n".join(v)
elif k == "question_tks":
if not newValue.get("question_kwd"):
newValue["questions"] = list2str(v)
elif field_keyword(k):
if not new_value.get("question_kwd"):
new_value["questions"] = self.list2str(v)
elif self.field_keyword(k):
if isinstance(v, list):
newValue[k] = "###".join(v)
new_value[k] = "###".join(v)
else:
newValue[k] = v
new_value[k] = v
elif re.search(r"_feas$", k):
newValue[k] = json.dumps(v)
new_value[k] = json.dumps(v)
elif k == "kb_id":
if isinstance(newValue[k], list):
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
if isinstance(new_value[k], list):
new_value[k] = new_value[k][0] # since d[k] is a list, but we need a str
elif k == "position_int":
assert isinstance(v, list)
arr = [num for row in v for num in row]
newValue[k] = "_".join(f"{num:08x}" for num in arr)
new_value[k] = "_".join(f"{num:08x}" for num in arr)
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v)
new_value[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove":
if isinstance(v, str):
assert v in clmns, f"'{v}' should be in '{clmns}'."
@ -712,22 +471,22 @@ class InfinityConnection(DocStoreConnection):
if ty.lower().find("cha"):
if not de:
de = ""
newValue[v] = de
new_value[v] = de
else:
for kk, vv in v.items():
removeValue[kk] = vv
del newValue[k]
del new_value[k]
else:
newValue[k] = v
new_value[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in newValue:
del newValue[k]
if k in new_value:
del new_value[k]
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
if removeValue:
col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
self.logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items():
@ -740,78 +499,53 @@ class InfinityConnection(DocStoreConnection):
else:
remove_opt[kv_key].append(id)
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv)
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
table_instance.update(filter, newValue)
table_instance.update(filter, new_value)
self.connPool.release_conn(inf_conn)
return True
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = equivalent_condition_to_str(condition, table_instance)
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fieldsAll = fields.copy()
fieldsAll.append("id")
fieldsAll = set(fieldsAll)
fields_all = fields.copy()
fields_all.append("id")
fields_all = set(fields_all)
if "docnm" in res.columns:
for field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["docnm"]
if "important_keywords" in res.columns:
if "important_kwd" in fieldsAll:
if "important_kwd" in fields_all:
res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split())
if "important_tks" in fieldsAll:
if "important_tks" in fields_all:
res["important_tks"] = res["important_keywords"]
if "questions" in res.columns:
if "question_kwd" in fieldsAll:
if "question_kwd" in fields_all:
res["question_kwd"] = res["questions"].apply(lambda v: v.splitlines())
if "question_tks" in fieldsAll:
if "question_tks" in fields_all:
res["question_tks"] = res["questions"]
if "content" in res.columns:
for field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["content"]
if "authors" in res.columns:
for field in ["authors_tks", "authors_sm_tks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["authors"]
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in fieldsAll if col.lower() in column_map}
none_columns = [col for col in fieldsAll if col.lower() not in column_map]
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
none_columns = [col for col in fields_all if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
@ -819,7 +553,7 @@ class InfinityConnection(DocStoreConnection):
for column in list(res2.columns):
k = column.lower()
if field_keyword(k):
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
elif re.search(r"_feas$", k):
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
@ -844,95 +578,3 @@ class InfinityConnection(DocStoreConnection):
res2[column] = None
return res2.set_index("id").to_dict(orient="index")
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if fieldnm not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[fieldnm][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
if txts:
ans[id] = "...".join(txts)
else:
ans[id] = txt
return ans
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or fieldnm not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[fieldnm]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if fieldnm == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -37,9 +37,8 @@ from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.float_utils import get_float
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
from rag.nlp import rag_tokenizer
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
MatchDenseExpr
ATTEMPT_TIME = 2
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
@ -497,7 +496,7 @@ class OBConnection(DocStoreConnection):
Database operations
"""
def dbType(self) -> str:
def db_type(self) -> str:
return "oceanbase"
def health(self) -> dict:
@ -553,7 +552,7 @@ class OBConnection(DocStoreConnection):
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
vector_field_name = f"q_{vectorSize}_vec"
vector_index_name = f"{vector_field_name}_idx"
@ -604,7 +603,7 @@ class OBConnection(DocStoreConnection):
# always refresh metadata to make sure it contains the latest table structure
self.client.refresh_metadata([indexName])
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
@ -615,7 +614,7 @@ class OBConnection(DocStoreConnection):
except Exception as e:
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
return self._check_table_exists_cached(indexName)
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
@ -1500,7 +1499,7 @@ class OBConnection(DocStoreConnection):
def get_total(self, res) -> int:
return res.total
def get_chunk_ids(self, res) -> list[str]:
def get_doc_ids(self, res) -> list[str]:
return [row["id"] for row in res.chunks]
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:

View File

@ -26,8 +26,7 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
@ -79,7 +78,7 @@ class OSConnection(DocStoreConnection):
Database operations
"""
def dbType(self) -> str:
def db_type(self) -> str:
return "opensearch"
def health(self) -> dict:
@ -91,8 +90,8 @@ class OSConnection(DocStoreConnection):
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.index_exist(indexName, knowledgebaseId):
return True
try:
from opensearchpy.client import IndicesClient
@ -101,7 +100,7 @@ class OSConnection(DocStoreConnection):
except Exception:
logger.exception("OSConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
@ -112,7 +111,7 @@ class OSConnection(DocStoreConnection):
except Exception:
logger.exception("OSConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.os)
for i in range(ATTEMPT_TIME):
try:
@ -460,7 +459,7 @@ class OSConnection(DocStoreConnection):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_chunk_ids(self, res):
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):

View File

@ -272,6 +272,49 @@ class RedisDB:
self.__open__()
return None
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
redis_key = f"{key_prefix}:{namespace}"
try:
# Use pipeline for atomicity
pipe = self.REDIS.pipeline()
# Check if key exists
pipe.exists(redis_key)
# Get/Increment
if ensure_minimum is not None:
# Ensure minimum value
pipe.get(redis_key)
results = pipe.execute()
if results[0] == 0: # Key doesn't exist
start_id = max(1, ensure_minimum)
pipe.set(redis_key, start_id)
pipe.execute()
return start_id
else:
current = int(results[1])
if current < ensure_minimum:
pipe.set(redis_key, ensure_minimum)
pipe.execute()
return ensure_minimum
# Increment operation
next_id = self.REDIS.incrby(redis_key, increment)
# If it's the first time, set a reasonable initial value
if next_id == increment:
self.REDIS.set(redis_key, 1 + increment)
return 1 + increment
return next_id
except Exception as e:
logging.warning("RedisDB.generate_auto_increment_id got exception: " + str(e))
self.__open__()
return -1
def transaction(self, key, value, exp=3600):
try:
pipeline = self.REDIS.pipeline(transaction=True)

View File

@ -32,8 +32,8 @@ def add_memory_func(request, WebApiAuth):
payload = {
"name": f"test_memory_{i}",
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
memory_ids.append(res["data"]["id"])

View File

@ -49,8 +49,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
assert res["code"] == 0, res
@ -72,8 +72,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
assert res["message"] == expected_message, res
@ -84,8 +84,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["something"],
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
assert res["message"] == f"Memory type '{ {'something'} }' is not supported.", res
@ -96,8 +96,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res1 = create_memory(WebApiAuth, payload)
assert res1["code"] == 0, res1

View File

@ -101,7 +101,7 @@ class TestMemoryUpdate:
@pytest.mark.p1
def test_llm(self, WebApiAuth, add_memory_func):
memory_ids = add_memory_func
llm_id = "ZHIPU-AI@glm-4"
llm_id = "glm-4@ZHIPU-AI"
payload = {"llm_id": llm_id}
res = update_memory(WebApiAuth, memory_ids[0], payload)
assert res["code"] == 0, res

8
uv.lock generated
View File

@ -3051,7 +3051,7 @@ wheels = [
[[package]]
name = "infinity-sdk"
version = "0.6.11"
version = "0.6.13"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "datrie" },
@ -3068,9 +3068,9 @@ dependencies = [
{ name = "sqlglot", extra = ["rs"] },
{ name = "thrift" },
]
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/6d/294b4c8fb36f874c92576107fab22da6d64f567fd3e24a312d7bcba5f17a/infinity_sdk-0.6.11.tar.gz", hash = "sha256:f78acd5439c3837715ab308c49be04b416bdaa42a5f4fb840682639ee39d435f", size = 29518792, upload-time = "2025-12-08T05:58:35.167Z" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/de/56fdc0fa962d5a8e0aa68d16f5321b2d88d79fceb7d0d6cfdde338b65d05/infinity_sdk-0.6.13.tar.gz", hash = "sha256:faf7bc23de7fa549a3842753eddad54ae551ada9df4fff25421658a7fa6fa8c2", size = 29518902, upload-time = "2025-12-24T10:00:01.483Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/cf/d7d1bc584c8f7d1cd8f75c39067179b98ca4bcbe5a86c61e7dbc2b8e692d/infinity_sdk-0.6.11-py3-none-any.whl", hash = "sha256:ec5ac3e710f29db4b875d3e24e20e391ad64270a5e5d189295cc91c362af74d1", size = 29737403, upload-time = "2025-12-08T05:54:58.798Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f4/a0/8f1e134fdf4ca8bebac7b62caace1816953bb5ffc720d9f0004246c8c38d/infinity_sdk-0.6.13-py3-none-any.whl", hash = "sha256:c08a523d2c27e9a7e6e88be640970530b4661a67c3e9dc3e1aa89533a822fd78", size = 29737403, upload-time = "2025-12-24T09:56:16.93Z" },
]
[[package]]
@ -6229,7 +6229,7 @@ requires-dist = [
{ name = "grpcio-status", specifier = "==1.67.1" },
{ name = "html-text", specifier = "==0.6.2" },
{ name = "infinity-emb", specifier = ">=0.0.66,<0.0.67" },
{ name = "infinity-sdk", specifier = "==0.6.11" },
{ name = "infinity-sdk", specifier = "==0.6.13" },
{ name = "jira", specifier = "==3.10.5" },
{ name = "json-repair", specifier = "==0.35.0" },
{ name = "langfuse", specifier = ">=2.60.0" },

554
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,10 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M22 12H2M22 12V18C22 18.5304 21.7893 19.0391 21.4142 19.4142C21.0391 19.7893 20.5304 20 20 20H4C3.46957 20 2.96086 19.7893 2.58579 19.4142C2.21071 19.0391 2 18.5304 2 18V12M22 12L18.55 5.11C18.3844 4.77679 18.1292 4.49637 17.813 4.30028C17.4967 4.10419 17.1321 4.0002 16.76 4H7.24C6.86792 4.0002 6.50326 4.10419 6.18704 4.30028C5.87083 4.49637 5.61558 4.77679 5.45 5.11L2 12" stroke="url(#paint0_linear_1415_84974)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M6 16H6.01M10 16H10.01" stroke="#00BEB4" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<defs>
<linearGradient id="paint0_linear_1415_84974" x1="12.5556" y1="4" x2="12.5556" y2="20" gradientUnits="userSpaceOnUse">
<stop stop-color="#161618"/>
<stop offset="1" stop-color="#666666"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 935 B

View File

@ -18,9 +18,12 @@ import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-reques
import { IModalProps } from '@/interfaces/common';
import { IParserConfig } from '@/interfaces/database/document';
import { IChangeParserConfigRequestBody } from '@/interfaces/request/document';
import { MetadataType } from '@/pages/dataset/components/metedata/hooks/use-manage-modal';
import {
AutoMetadata,
ChunkMethodItem,
EnableTocToggle,
ImageContextWindow,
ParseTypeItem,
} from '@/pages/dataset/dataset-setting/configuration/common-item';
import { zodResolver } from '@hookform/resolvers/zod';
@ -85,6 +88,7 @@ export function ChunkMethodDialog({
visible,
parserConfig,
loading,
documentId,
}: IProps) {
const { t } = useTranslation();
@ -119,6 +123,7 @@ export function ChunkMethodDialog({
auto_questions: z.coerce.number().optional(),
html4excel: z.boolean().optional(),
toc_extraction: z.boolean().optional(),
image_table_context_window: z.coerce.number().optional(),
mineru_parse_method: z.enum(['auto', 'txt', 'ocr']).optional(),
mineru_formula_enable: z.boolean().optional(),
mineru_table_enable: z.boolean().optional(),
@ -140,6 +145,18 @@ export function ChunkMethodDialog({
pages: z
.array(z.object({ from: z.coerce.number(), to: z.coerce.number() }))
.optional(),
metadata: z
.array(
z
.object({
key: z.string().optional(),
description: z.string().optional(),
enum: z.array(z.string().optional()).optional(),
})
.optional(),
)
.optional(),
enable_metadata: z.boolean().optional(),
}),
})
.superRefine((data, ctx) => {
@ -364,10 +381,17 @@ export function ChunkMethodDialog({
className="space-y-3"
>
{selectedTag === DocumentParserType.Naive && (
<EnableTocToggle />
<>
<EnableTocToggle />
<ImageContextWindow />
</>
)}
{showAutoKeywords(selectedTag) && (
<>
<AutoMetadata
type={MetadataType.SingleFileSetting}
otherData={{ documentId }}
/>
<AutoKeywordsFormField></AutoKeywordsFormField>
<AutoQuestionsFormField></AutoQuestionsFormField>
</>

View File

@ -18,6 +18,7 @@ export function useDefaultParserValues() {
auto_questions: 0,
html4excel: false,
toc_extraction: false,
image_table_context_window: 0,
mineru_parse_method: 'auto',
mineru_formula_enable: true,
mineru_table_enable: true,
@ -35,9 +36,11 @@ export function useDefaultParserValues() {
// },
entity_types: [],
pages: [],
metadata: [],
enable_metadata: false,
};
return defaultParserValues;
return defaultParserValues as IParserConfig;
}, [t]);
return defaultParserValues;

View File

@ -35,6 +35,7 @@ import { cn } from '@/lib/utils';
import { t } from 'i18next';
import { Loader } from 'lucide-react';
import { MultiSelect, MultiSelectOptionType } from './ui/multi-select';
import { Switch } from './ui/switch';
// Field type enumeration
export enum FormFieldType {
@ -46,6 +47,7 @@ export enum FormFieldType {
Select = 'select',
MultiSelect = 'multi-select',
Checkbox = 'checkbox',
Switch = 'switch',
Tag = 'tag',
Custom = 'custom',
}
@ -154,6 +156,7 @@ export const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
}
break;
case FormFieldType.Checkbox:
case FormFieldType.Switch:
fieldSchema = z.boolean();
break;
case FormFieldType.Tag:
@ -193,6 +196,8 @@ export const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
if (
field.type !== FormFieldType.Number &&
field.type !== FormFieldType.Checkbox &&
field.type !== FormFieldType.Switch &&
field.type !== FormFieldType.Custom &&
field.type !== FormFieldType.Tag &&
field.required
) {
@ -289,7 +294,10 @@ const generateDefaultValues = <T extends FieldValues>(
const lastKey = keys[keys.length - 1];
if (field.defaultValue !== undefined) {
current[lastKey] = field.defaultValue;
} else if (field.type === FormFieldType.Checkbox) {
} else if (
field.type === FormFieldType.Checkbox ||
field.type === FormFieldType.Switch
) {
current[lastKey] = false;
} else if (field.type === FormFieldType.Tag) {
current[lastKey] = [];
@ -299,7 +307,10 @@ const generateDefaultValues = <T extends FieldValues>(
} else {
if (field.defaultValue !== undefined) {
defaultValues[field.name] = field.defaultValue;
} else if (field.type === FormFieldType.Checkbox) {
} else if (
field.type === FormFieldType.Checkbox ||
field.type === FormFieldType.Switch
) {
defaultValues[field.name] = false;
} else if (
field.type === FormFieldType.Tag ||
@ -502,6 +513,32 @@ export const RenderField = ({
)}
/>
);
case FormFieldType.Switch:
return (
<RAGFlowFormItem
{...field}
labelClassName={labelClassName || field.labelClassName}
>
{(fieldProps) => {
const finalFieldProps = field.onChange
? {
...fieldProps,
onChange: (checked: boolean) => {
fieldProps.onChange(checked);
field.onChange?.(checked);
},
}
: fieldProps;
return (
<Switch
checked={finalFieldProps.value as boolean}
onCheckedChange={(checked) => finalFieldProps.onChange(checked)}
disabled={field.disabled}
/>
);
}}
</RAGFlowFormItem>
);
case FormFieldType.Tag:
return (

Some files were not shown because too many files have changed in this diff Show More