Compare commits

...

24 Commits

Author SHA1 Message Date
427e0540ca Fix: table tag on chunks. (#12126)
- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 11:23:51 +08:00
8f16fac898 Fix: Add a no-data filter condition to MetaData (#12189)
### What problem does this PR solve?

Fix: Add a no-data filter condition to MetaData

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 10:42:34 +08:00
ea00478a21 Bump infinity to 0.6.13 (#12181)
### What problem does this PR solve?

Bump infinity to 0.6.13

### Type of change

- [x] Refactoring
2025-12-24 22:33:56 +08:00
906f19e863 Dragging down a downstream node of a Switch operator will cause the end_cpn_ids to contain the ID of the placeholder operator. #12177 (#12178)
### What problem does this PR solve?

Dragging down a downstream node of a Switch operator will cause the
end_cpn_ids to contain the ID of the placeholder operator. #12177

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:45:35 +08:00
667dc5467e Fix: Fixed the issue of incorrect agent translation text. #10427 (#12172)
### What problem does this PR solve?

Fix: Fixed the issue of incorrect agent translation text. #10427

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:05:26 +08:00
977962fdfe Fix: loopitem None issue. (#12166)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:22:31 +08:00
1e9374a373 Fix:Metadata saving, copywriting and other related issues (#12169)
### What problem does this PR solve?

Fix:Bugs Fixed
- Text overflow issues that caused rendering problems
- Metadata saving, copywriting and other related issues

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:21:36 +08:00
9b52ba8061 Feat: add image table context to pipeline splitter (#12167)
### What problem does this PR solve?

Add image table context to pipeline splitter.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 16:58:14 +08:00
44671ea413 Fix: type check for chunks (#12164)
### What problem does this PR solve?

Fix: type check for chunks

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 16:36:00 +08:00
c81421d340 Feat: add document metadata setting (#12156)
### What problem does this PR solve?

Add document metadata setting.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:13:50 +08:00
ee93a80e91 Feat: add MiniMax M2.1 (#12148)
### What problem does this PR solve?

Add MiniMax M2.1.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:08:09 +08:00
5c981978c1 Run infinity test before ES (#12159)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 15:52:55 +08:00
7fef285af5 Revert "Bump infinity to 0.6.12 (#12140)" (#12161)
This reverts commit 0588fe79b9.
2025-12-24 15:24:12 +08:00
b1efb905e5 Fix: metadata_obj issue. (#12146)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:40:34 +08:00
6400bf87ba Fix: LLM tool does not exist in multiple retrieval case (#12143)
### What problem does this PR solve?

 Fix LLM tool does not exist in multiple retrieval case

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:26:48 +08:00
f239bc02d3 Feat: Support Markdown Rendering for tips in user-fill-up Component #11825 (#12147)
### What problem does this PR solve?

Feat: Support Markdown Rendering for tips in user-fill-up Component
#11825

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 13:25:56 +08:00
5776fa73a7 refactor: improve memory service date time consistency (#12144)
### What problem does this PR solve?

 improve memory service date time consistency

### Type of change

- [x] Refactoring
2025-12-24 11:00:31 +08:00
fc6af1998b Doc: Added an HTTP request component reference (#12141)
### Type of change

- [x] Documentation Update
2025-12-24 09:35:32 +08:00
0588fe79b9 Bump infinity to 0.6.12 (#12140)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 09:34:54 +08:00
f545265f93 Fix:remove duplicate tool_meta (#12139)
### What problem does this PR solve?
pr:#12117
change:remove duplicate tool_meta

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:34:08 +08:00
c987d33649 Feat: deduplicate metadata lists during updates (#12125)
### What problem does this PR solve?

Deduplicate metadata lists during updates.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 09:32:55 +08:00
d72debf0db Fix: Add prompts when merging or deleting metadata. (#12138)
### What problem does this PR solve?

Fix: Add prompts when merging or deleting metadata.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-24 09:32:41 +08:00
c33134ea2c Fix: table tag on chunks. (#12126)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:32:19 +08:00
17b8bb62b6 Feat: message manage (#12083)
### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 21:16:25 +08:00
99 changed files with 4577 additions and 1417 deletions

View File

@ -197,37 +197,38 @@ jobs:
echo -e "COMPOSE_PROFILES=\${COMPOSE_PROFILES},tei-cpu" >> docker/.env
echo -e "TEI_MODEL=BAAI/bge-small-en-v1.5" >> docker/.env
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
sed -i '1i DOC_ENGINE=infinity' docker/.env
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
- name: Run sdk tests against Elasticsearch
- name: Run sdk tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
- name: Run frontend api tests against Elasticsearch
- name: Run frontend api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
- name: Run http api tests against Elasticsearch
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
- name: Run http api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed
@ -237,35 +238,35 @@ jobs:
- name: Start ragflow:nightly
run: |
sed -i '1i DOC_ENGINE=infinity' docker/.env
sed -i '1i DOC_ENGINE=elasticsearch' docker/.env
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
- name: Run sdk tests against Infinity
- name: Run sdk tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
- name: Run frontend api tests against Infinity
- name: Run frontend api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
- name: Run http api tests against Infinity
- name: Run http api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed

View File

@ -192,6 +192,7 @@ COPY pyproject.toml uv.lock ./
COPY mcp mcp
COPY plugin plugin
COPY common common
COPY memory memory
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
COPY docker/entrypoint.sh ./

View File

@ -29,6 +29,11 @@ from common.versions import get_ragflow_version
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
@admin_bp.route('/ping', methods=['GET'])
def ping():
return success_response('PONG')
@admin_bp.route('/login', methods=['POST'])
def login():
if not request.json:

View File

@ -86,8 +86,9 @@ class Agent(LLM, ToolBase):
self.tools = {}
for idx, cpn in enumerate(self._param.tools):
cpn = self._load_tool_obj(cpn)
name = cpn.get_meta()["function"]["name"]
self.tools[f"{name}_{idx}"] = cpn
original_name = cpn.get_meta()["function"]["name"]
indexed_name = f"{original_name}_{idx}"
self.tools[indexed_name] = cpn
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
max_retries=self._param.max_retries,
@ -95,7 +96,12 @@ class Agent(LLM, ToolBase):
max_rounds=self._param.max_rounds,
verbose_tool_use=True
)
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
self.tool_meta = []
for indexed_name, tool_obj in self.tools.items():
original_meta = tool_obj.get_meta()
indexed_meta = deepcopy(original_meta)
indexed_meta["function"]["name"] = indexed_name
self.tool_meta.append(indexed_meta)
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
@ -109,7 +115,8 @@ class Agent(LLM, ToolBase):
def _load_tool_obj(self, cpn: dict) -> object:
from agent.component import component_class
param = component_class(cpn["component_name"] + "Param")()
tool_name = cpn["component_name"]
param = component_class(tool_name + "Param")()
param.update(cpn["params"])
try:
param.check()
@ -277,19 +284,15 @@ class Agent(LLM, ToolBase):
else:
user_request = history[-1]["content"]
def build_task_desc(prompt: str, user_request: str, tool_metas: list[dict], user_defined_prompt: dict | None = None) -> str:
def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str:
"""Build a minimal task_desc by concatenating prompt, query, and tool schemas."""
user_defined_prompt = user_defined_prompt or {}
tools_json = json.dumps(tool_metas, ensure_ascii=False, indent=2)
task_desc = (
"### Agent Prompt\n"
f"{prompt}\n\n"
"### User Request\n"
f"{user_request}\n\n"
"### Tools (schemas)\n"
f"{tools_json}\n"
)
if user_defined_prompt:
@ -368,7 +371,7 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content})
st = timer()
task_desc = build_task_desc(prompt, user_request, tool_metas, user_defined_prompt)
task_desc = build_task_desc(prompt, user_request, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):

View File

@ -56,7 +56,6 @@ class LLMParam(ComponentParamBase):
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
self.check_empty(self.llm_id, "[Agent] LLM")
self.check_empty(self.sys_prompt, "[Agent] System prompt")
self.check_empty(self.prompts, "[Agent] User prompt")
def gen_conf(self):

View File

@ -113,6 +113,10 @@ class LoopItem(ComponentBase, ABC):
return len(var) == 0
elif operator == "not empty":
return len(var) > 0
elif var is None:
if operator == "empty":
return True
return False
raise Exception(f"Invalid operator: {operator}")

View File

@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.memory_service import MemoryService
from api.db.joint_services.memory_message_service import query_message
from common import settings
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
class RetrievalParam(ToolParamBase):
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.memory_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
class Retrieval(ToolBase, ABC):
component_name = "Retrieval"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
async def _retrieve_kb(self, query_text: str):
kb_ids: list[str] = []
for id in self._param.kb_ids:
if id.find("@") < 0:
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars)
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
doc_ids=[]
if self._param.meta_data_filter!={}:
doc_ids = []
if self._param.meta_data_filter != {}:
metas = DocumentService.get_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -248,6 +246,54 @@ class Retrieval(ToolBase, ABC):
return form_cnt
async def _retrieve_memory(self, query_text: str):
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
raise Exception("No memory is selected.")
embd_names = list({memory.embd_id for memory in memory_list})
assert len(embd_names) == 1, "Memory use different embedding models."
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
# query message
message_list = query_message({"memory_id": memory_ids}, {
"query": query,
"similarity_threshold": self._param.similarity_threshold,
"keywords_similarity_weight": self._param.keywords_similarity_weight,
"top_n": self._param.top_n
})
print(f"found {len(message_list)} messages.")
if not message_list:
self.set_output("formalized_content", self._param.empty_response)
return
formated_content = "\n".join(memory_prompt(message_list, 200000))
# set formalized_content output
self.set_output("formalized_content", formated_content)
print(f"formated_content {formated_content}")
return formated_content
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
print(f"debug retrieval, query is {kwargs.get('query')}.", flush=True)
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
if self._param.kb_ids:
return await self._retrieve_kb(kwargs["query"])
elif self._param.memory_ids:
return await self._retrieve_memory(kwargs["query"])
else:
self.set_output("formalized_content", self._param.empty_response)
return
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))

View File

@ -192,7 +192,7 @@ async def rerun():
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = ""
doc["chunk_num"] = 0

View File

@ -447,6 +447,26 @@ async def metadata_update():
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "metadata")
async def update_metadata_setting():
req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
e, doc = DocumentService.get_by_id(doc.id)
if not e:
return get_data_error_result(message="Document not found!")
return get_json_result(data=doc.to_dict())
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
# @login_required
def thumbnails():
@ -564,7 +584,7 @@ async def run():
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
@ -615,7 +635,7 @@ async def rename():
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
@ -696,7 +716,7 @@ async def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None

View File

@ -39,9 +39,9 @@ from api.utils.api_utils import get_json_result
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user
@ -285,7 +285,7 @@ async def rm():
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
@ -386,7 +386,7 @@ def knowledge_graph(kb_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
@ -858,11 +858,11 @@ async def check_embedding():
index_nm = search.index_name(tenant_id)
res0 = docStoreConn.search(
selectFields=[], highlightFields=[],
select_fields=[], highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
total = docStoreConn.get_total(res0)
if total <= 0:
@ -874,14 +874,14 @@ async def check_embedding():
for off in offsets:
res1 = docStoreConn.search(
selectFields=list(base_fields),
highlightFields=[],
select_fields=list(base_fields),
highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
ids = docStoreConn.get_chunk_ids(res1)
ids = docStoreConn.get_doc_ids(res1)
if not ids:
continue

View File

@ -20,10 +20,12 @@ from api.apps import login_required, current_user
from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
not_allowed_parameters
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService
from common.constants import MemoryType, RetCode, ForgettingPolicy
@ -57,7 +59,6 @@ async def create_memory():
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
else:
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
@ -124,7 +125,7 @@ async def update_memory(memory_id):
return get_json_result(message=True, data=memory_dict)
try:
MemoryService.update_memory(memory_id, to_update)
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
updated_memory = MemoryService.get_by_memory_id(memory_id)
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
@ -133,7 +134,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
@ -141,13 +142,14 @@ async def delete_memory(memory_id):
return get_json_result(message=True, code=RetCode.NOT_FOUND)
try:
MemoryService.delete_memory(memory_id)
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True)
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("", methods=["GET"]) # noqa: F821
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
args = request.args
@ -183,3 +185,26 @@ async def get_memory_config(memory_id):
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_detail(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)

169
api/apps/messages_app.py Normal file
View File

@ -0,0 +1,169 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from quart import request
from api.apps import login_required
from api.db.services.memory_service import MemoryService
from common.time_utils import current_timestamp, timestamp_to_date
from memory.services.messages import MessageService
from api.db.joint_services import memory_message_service
from api.db.joint_services.memory_message_service import query_message
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from common.constants import RetCode
@manager.route("", methods=["POST"]) # noqa: F821
@login_required
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
async def add_message():
req = await get_request_json()
memory_ids = req["memory_id"]
agent_id = req["agent_id"]
session_id = req["session_id"]
user_id = req["user_id"] if req.get("user_id") else ""
user_input = req["user_input"]
agent_response = req["agent_response"]
res = []
for memory_id in memory_ids:
success, msg = await memory_message_service.save_to_memory(
memory_id,
{
"user_id": user_id,
"agent_id": agent_id,
"session_id": session_id,
"user_input": user_input,
"agent_response": agent_response
}
)
res.append({
"memory_id": memory_id,
"success": success,
"message": msg
})
if all([r["success"] for r in res]):
return get_json_result(message="Successfully added to memories.")
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def forget_message(memory_id: str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
forget_time = timestamp_to_date(current_timestamp())
update_succeed = MessageService.update_message(
{"memory_id": memory_id, "message_id": int(message_id)},
{"forget_at": forget_time},
memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("status")
async def update_message(memory_id: str, message_id: int):
req = await get_request_json()
status = req["status"]
if not isinstance(status, bool):
return get_error_argument_result("Status must be a boolean.")
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
@manager.route("/search", methods=["GET"]) # noqa: F821
@login_required
async def search_message():
args = request.args
print(args, flush=True)
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
if empty_fields:
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
memory_ids = args.getlist("memory_id")
query = args.get("query")
similarity_threshold = float(args.get("similarity_threshold", 0.2))
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
top_n = int(args.get("top_n", 5))
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
filter_dict = {
"memory_id": memory_ids,
"agent_id": agent_id,
"session_id": session_id
}
params = {
"query": query,
"similarity_threshold": similarity_threshold,
"keywords_similarity_weight": keywords_similarity_weight,
"top_n": top_n
}
res = query_message(filter_dict, params)
return get_json_result(message=True, data=res)
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def get_messages():
args = request.args
memory_ids = args.getlist("memory_id")
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
limit = int(args.get("limit", 10))
if not memory_ids:
return get_error_argument_result("memory_ids is required.")
memory_list = MemoryService.get_by_ids(memory_ids)
uids = [memory.tenant_id for memory in memory_list]
res = MessageService.get_recent_messages(
uids,
memory_ids,
agent_id,
session_id,
limit
)
return get_json_result(message=True, data=res)
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
@login_required
async def get_message_content(memory_id:str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
if res:
return get_json_result(message=True, data=res)
else:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.")

View File

@ -495,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):

View File

@ -1080,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:

View File

@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@ -169,6 +170,8 @@ def init_web_data():
# init_superuser()
add_graph_templates()
init_message_id_sequence()
init_memory_size_cache()
logging.info("init web data success:{}".format(time.time() - start_time))

View File

@ -0,0 +1,233 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import List
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType
from common.doc_store.doc_store_base import FusionExpr
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.utils.memory_utils import get_memory_type_human
from memory.services.messages import MessageService
from memory.services.query import MsgTextQuery, get_vector
from memory.utils.prompt_util import PromptAssembler
from memory.utils.msg_util import get_json_result_from_llm_response
from rag.utils.redis_conn import REDIS_CONN
async def save_to_memory(memory_id: str, message_dict: dict):
"""
:param memory_id:
:param message_dict: {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
"""
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return False, f"Memory '{memory_id}' not found."
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", "")
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
message_list = [{
"message_id": raw_message_id,
"message_type": MemoryType.RAW.name.lower(),
"source_id": 0,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
"valid_at": timestamp_to_date(current_timestamp()),
"invalid_at": None,
"forget_at": None,
"status": True
}, *[{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
"message_type": content["message_type"],
"source_id": raw_message_id,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": content["content"],
"valid_at": content["valid_at"],
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
"forget_at": None,
"status": True
} for content in extracted_content]]
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(tenant_id, memory_id):
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "fifo":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
return True, "Message saved successfully."
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
if not system_prompt:
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
conversation_time = timestamp_to_date(current_timestamp())
user_prompts = []
if user_prompt:
user_prompts.append({"role": "user", "content": user_prompt})
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id)
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
res_json = get_json_result_from_llm_response(res)
return [{
"content": extracted_content["content"],
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
"message_type": message_type
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
def query_message(filter_dict: dict, params: dict):
"""
:param filter_dict: {
"memory_id": List[str],
"agent_id": optional
"session_id": optional
}
:param params: {
"query": question str,
"similarity_threshold": float,
"keywords_similarity_weight": float,
"top_n": int
}
"""
memory_ids = filter_dict["memory_id"]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
return []
condition_dict = {k: v for k, v in filter_dict.items() if v}
uids = [memory.tenant_id for memory in memory_list]
question = params["query"]
question = question.strip()
memory = memory_list[0]
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
def init_message_id_sequence():
message_id_redis_key = "id_generator:memory"
if REDIS_CONN.exist(message_id_redis_key):
current_max_id = REDIS_CONN.get(message_id_redis_key)
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
else:
max_id = 1
exist_memory_list = MemoryService.get_all_memory()
if not exist_memory_list:
REDIS_CONN.set(message_id_redis_key, max_id)
else:
max_id = MessageService.get_max_message_id(
uid_list=[m.tenant_id for m in exist_memory_list],
memory_ids=[m.id for m in exist_memory_list]
)
REDIS_CONN.set(message_id_redis_key, max_id)
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
def get_memory_size_cache(memory_id: str, uid: str):
redis_key = f"memory_{memory_id}"
if REDIS_CONN.exists(redis_key):
return REDIS_CONN.get(redis_key)
else:
memory_size_map = MessageService.calculate_memory_size(
[memory_id],
[uid]
)
memory_size = memory_size_map.get(memory_id, 0)
set_memory_size_cache(memory_id, memory_size)
return memory_size
def set_memory_size_cache(memory_id: str, size: int):
redis_key = f"memory_{memory_id}"
return REDIS_CONN.set(redis_key, size)
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, current_value + size)
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, max(current_value - size, 0))
def init_memory_size_cache():
memory_list = MemoryService.get_all_memory()
if not memory_list:
logging.info("No memory found, no need to init memory size.")
else:
memory_size_map = MessageService.calculate_memory_size(
memory_ids=[m.id for m in memory_list],
uid_list=[m.tenant_id for m in memory_list],
)
for memory in memory_list:
memory_size = memory_size_map.get(memory.id, 0)
set_memory_size_cache(memory.id, memory_size)
logging.info("Memory size cache init done.")

View File

@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from api.db.services.memory_service import MemoryService
from memory.services.messages import MessageService
from rag.nlp import search
from common.constants import ActiveEnum
from common import settings
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
# step1.3 delete own tenant
# step1.3 delete memory and messages
user_memory = MemoryService.get_by_tenant_id(tenant_id)
if user_memory:
for memory in user_memory:
if MessageService.has_index(tenant_id, memory.id):
MessageService.delete_index(tenant_id, memory.id)
done_msg += " Deleted memory index."
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
done_msg += f"Deleted {memory_delete_res} memory datasets."
# step1.4 delete own tenant
tenant_delete_res = TenantService.delete_by_id(tenant_id)
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
# step2 delete user-tenant relation

View File

@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
logging.exception(e)
return False, None
@classmethod
@DB.connection_context()
def get_basic_info_by_canvas_ids(cls, canvas_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.user_id,
cls.model.title,
cls.model.permission,
cls.model.canvas_category
]
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,

View File

@ -33,12 +33,13 @@ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTena
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common.metadata_utils import dedupe_list
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.doc_store.doc_store_base import OrderByExpr
from common import settings
@ -345,7 +346,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -696,10 +697,14 @@ class DocumentService(CommonService):
for k,v in r.meta_fields.items():
if k not in meta:
meta[k] = {}
v = str(v)
if v not in meta[k]:
meta[k][v] = []
meta[k][v].append(doc_id)
if not isinstance(v, list):
v = [v]
for vv in v:
if vv not in meta[k]:
if isinstance(vv, list) or isinstance(vv, dict):
continue
meta[k][vv] = []
meta[k][vv].append(doc_id)
return meta
@classmethod
@ -797,7 +802,10 @@ class DocumentService(CommonService):
match_provided = "match" in upd
if isinstance(meta[key], list):
if not match_provided:
meta[key] = new_value
if isinstance(new_value, list):
meta[key] = dedupe_list(new_value)
else:
meta[key] = new_value
changed = True
else:
match_value = upd.get("match")
@ -810,7 +818,7 @@ class DocumentService(CommonService):
else:
new_list.append(item)
if replaced:
meta[key] = new_list
meta[key] = dedupe_list(new_list)
changed = True
else:
if not match_provided:
@ -1230,8 +1238,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -411,8 +411,6 @@ class KnowledgebaseService(CommonService):
ok, _t = TenantService.get_by_id(tenant_id)
if not ok:
return False, get_data_error_result(message="Tenant not found.")
if kwargs.get("parser_config") and isinstance(kwargs["parser_config"], dict) and not kwargs["parser_config"].get("llm_id"):
kwargs["parser_config"]["llm_id"] = _t.llm_id
# Build payload
kb_id = get_uuid()
@ -427,6 +425,7 @@ class KnowledgebaseService(CommonService):
# Update parser_config (always override with validated default/merged config)
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
payload["parser_config"]["llm_id"] = _t.llm_id
return True, payload

View File

@ -15,7 +15,6 @@
#
from typing import List
from api.apps import current_user
from api.db.db_models import DB, Memory, User
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
from api.constants import MEMORY_NAME_LIMIT
from common.misc_utils import get_uuid
from common.time_utils import get_format_time, current_timestamp
from memory.utils.prompt_util import PromptAssembler
class MemoryService(CommonService):
@ -34,6 +34,17 @@ class MemoryService(CommonService):
def get_by_memory_id(cls, memory_id: str):
return cls.model.select().where(cls.model.id == memory_id).first()
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, tenant_id: str):
return cls.model.select().where(cls.model.tenant_id == tenant_id)
@classmethod
@DB.connection_context()
def get_all_memory(cls):
memory_list = cls.model.select()
return list(memory_list)
@classmethod
@DB.connection_context()
def get_with_owner_name_by_id(cls, memory_id: str):
@ -53,7 +64,9 @@ class MemoryService(CommonService):
cls.model.forgetting_policy,
cls.model.temperature,
cls.model.system_prompt,
cls.model.user_prompt
cls.model.user_prompt,
cls.model.create_date,
cls.model.create_time
]
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
cls.model.id == memory_id
@ -72,7 +85,9 @@ class MemoryService(CommonService):
cls.model.memory_type,
cls.model.storage_type,
cls.model.permissions,
cls.model.description
cls.model.description,
cls.model.create_time,
cls.model.create_date
]
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
if filter_dict.get("tenant_id"):
@ -102,6 +117,8 @@ class MemoryService(CommonService):
if len(memory_name) > MEMORY_NAME_LIMIT:
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
timestamp = current_timestamp()
format_time = get_format_time()
# build create dict
memory_info = {
"id": get_uuid(),
@ -110,10 +127,11 @@ class MemoryService(CommonService):
"tenant_id": tenant_id,
"embd_id": embd_id,
"llm_id": llm_id,
"create_time": current_timestamp(),
"create_date": get_format_time(),
"update_time": current_timestamp(),
"update_date": get_format_time(),
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
"create_time": timestamp,
"create_date": format_time,
"update_time": timestamp,
"update_date": format_time,
}
obj = cls.model(**memory_info).save(force_insert=True)
@ -126,7 +144,7 @@ class MemoryService(CommonService):
@classmethod
@DB.connection_context()
def update_memory(cls, memory_id: str, update_dict: dict):
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
if not update_dict:
return 0
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
@ -135,7 +153,7 @@ class MemoryService(CommonService):
update_dict["name"] = duplicate_name(
cls.query,
name=update_dict["name"],
tenant_id=current_user.id
tenant_id=tenant_id
)
update_dict.update({
"update_time": current_timestamp(),

View File

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
@ -22,7 +21,6 @@ DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
@ -55,14 +53,13 @@ class SparseVector:
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
class MatchTextExpr:
def __init__(
self,
fields: list[str],
matching_text: str,
topn: int,
extra_options: dict = dict(),
extra_options: dict | None = None,
):
self.fields = fields
self.matching_text = matching_text
@ -70,7 +67,7 @@ class MatchTextExpr(ABC):
self.extra_options = extra_options
class MatchDenseExpr(ABC):
class MatchDenseExpr:
def __init__(
self,
vector_column_name: str,
@ -78,7 +75,7 @@ class MatchDenseExpr(ABC):
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
extra_options: dict | None = None,
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
@ -88,7 +85,7 @@ class MatchDenseExpr(ABC):
self.extra_options = extra_options
class MatchSparseExpr(ABC):
class MatchSparseExpr:
def __init__(
self,
vector_column_name: str,
@ -104,7 +101,7 @@ class MatchSparseExpr(ABC):
self.opt_params = opt_params
class MatchTensorExpr(ABC):
class MatchTensorExpr:
def __init__(
self,
column_name: str,
@ -120,7 +117,7 @@ class MatchTensorExpr(ABC):
self.extra_option = extra_option
class FusionExpr(ABC):
class FusionExpr:
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
@ -129,7 +126,8 @@ class FusionExpr(ABC):
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
class OrderByExpr:
def __init__(self):
self.fields = list()
def asc(self, field: str):
@ -141,13 +139,14 @@ class OrderByExpr(ABC):
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
def db_type(self) -> str:
"""
Return the type of the database.
"""
@ -165,21 +164,21 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, index_name: str, dataset_id: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
def index_exist(self, index_name: str, dataset_id: str) -> bool:
"""
Check if an index with given name exists
"""
@ -191,16 +190,16 @@ class DocStoreConnection(ABC):
@abstractmethod
def search(
self, selectFields: list[str],
highlightFields: list[str],
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str|list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
@ -209,28 +208,28 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
def get(self, data_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, rows: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
@ -245,7 +244,7 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_chunk_ids(self, res):
def get_doc_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
@ -253,18 +252,18 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_highlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], field_name: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_aggregation(self, res, fieldnm: str):
def get_aggregation(self, res, field_name: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
def sql(self, sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""

View File

@ -0,0 +1,326 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
import time
import os
from abc import abstractmethod
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import Index
from elastic_transport import ConnectionTimeout
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
from rag.nlp import is_english, rag_tokenizer
from common import settings
ATTEMPT_TIME = 2
class ESConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
self.logger = logging.getLogger(logger_name)
self.info = {}
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
self.logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
self.logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def db_type(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
if self.index_exist(index_name, dataset_id):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=index_name,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
self.logger.exception("ESConnection.createIndex error %s" % index_name)
def delete_idx(self, index_name: str, dataset_id: str):
if len(dataset_id) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=index_name, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
s = Index(index_name, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(e)
break
return False
"""
CRUD operations
"""
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
doc = res["_source"]
doc["id"] = doc_id
return doc
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
@abstractmethod
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def _get_source(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
@abstractmethod
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res, keywords: list[str], field_name: str):
ans = {}
for d in res["hits"]["hits"]:
highlights = d.get("highlight")
if not highlights:
continue
txt = "...".join([a for a in list(highlights.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][field_name]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
return ans
def get_aggregation(self, res, field_name: str):
agg_field = "aggs_" + field_name
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
buckets = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in buckets]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
self.logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
self.logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None

View File

@ -0,0 +1,451 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
import json
import time
from abc import abstractmethod
import infinity
from infinity.common import ConflictType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
import pandas as pd
from common.file_utils import get_project_base_directory
from rag.nlp import is_english
from common import settings
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
class InfinityConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"):
self.dbName = settings.INFINITY.get("db_name", "default_db")
self.mapping_file_name = mapping_file_name
self.logger = logging.getLogger(logger_name)
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
for _ in range(24):
try:
conn_pool = ConnectionPool(infinity_uri, max_size=4)
inf_conn = conn_pool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = conn_pool
conn_pool.release_conn(inf_conn)
break
conn_pool.release_conn(inf_conn)
self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
self.logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
"""
Dataframe and fields convert
"""
@staticmethod
@abstractmethod
def field_keyword(field_name: str):
# judge keyword or not, such as "*_kwd" tag-like columns.
raise NotImplementedError("Not implemented")
@abstractmethod
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
# rm _kwd, _tks, _sm_tks, _with_weight suffix in field name.
raise NotImplementedError("Not implemented")
@staticmethod
@abstractmethod
def convert_matching_field(field_weight_str: str) -> str:
# convert matching field to
raise NotImplementedError("Not implemented")
@staticmethod
def list2str(lst: str | list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
def exists(cln):
nonlocal columns
assert cln in columns, f"'{cln}' should be in '{columns}'."
ty, de = columns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if self.field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
@staticmethod
def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in select_fields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
"""
Database operations
"""
def db_type(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}")
return True
def delete_idx(self, index_name: str, dataset_id: str):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY dropped table {table_name}")
def index_exist(self, index_name: str, dataset_id: str) -> bool:
table_name = f"{index_name}_{dataset_id}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
self.logger.warning(f"INFINITY indexExist {str(e)}")
return False
"""
CRUD operations
"""
@abstractmethod
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{dataset_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = self.equivalent_condition_to_str(condition, table_instance)
self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
@abstractmethod
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if field_name not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[field_name][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
if txt_list:
ans[id] = "...".join(txt_list)
else:
ans[id] = txt
return ans
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or field_name not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[field_name]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if field_name == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -44,21 +44,27 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
def filter_out(v2docs, operator, value):
ids = []
for input, docids in v2docs.items():
if operator in ["=", "", ">", "<", "", ""]:
try:
if isinstance(input, list):
input = input[0]
input = float(input)
value = float(value)
except Exception:
input = str(input)
value = str(value)
pass
if isinstance(input, str):
input = input.lower()
if isinstance(value, str):
value = value.lower()
for conds in [
(operator == "contains", str(value).lower() in str(input).lower()),
(operator == "not contains", str(value).lower() not in str(input).lower()),
(operator == "in", str(input).lower() in str(value).lower()),
(operator == "not in", str(input).lower() not in str(value).lower()),
(operator == "start with", str(input).lower().startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower())),
(operator == "contains", input in value if not isinstance(input, list) else all([i in value for i in input])),
(operator == "not contains", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
(operator == "in", input in value if not isinstance(input, list) else all([i in value for i in input])),
(operator == "not in", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
(operator == "start with", str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())),
(operator == "empty", not input),
(operator == "not empty", input),
(operator == "=", input == value),
@ -145,6 +151,18 @@ async def apply_meta_data_filter(
return doc_ids
def dedupe_list(values: list) -> list:
seen = set()
deduped = []
for item in values:
key = str(item)
if key in seen:
continue
seen.add(key)
deduped.append(item)
return deduped
def update_metadata_to(metadata, meta):
if not meta:
return metadata
@ -156,11 +174,13 @@ def update_metadata_to(metadata, meta):
return metadata
if not isinstance(meta, dict):
return metadata
for k, v in meta.items():
if isinstance(v, list):
v = [vv for vv in v if isinstance(vv, str)]
if not v:
continue
v = dedupe_list(v)
if not isinstance(v, list) and not isinstance(v, str):
continue
if k not in metadata:
@ -171,6 +191,7 @@ def update_metadata_to(metadata, meta):
metadata[k].extend(v)
else:
metadata[k].append(v)
metadata[k] = dedupe_list(metadata[k])
else:
metadata[k] = v
@ -202,4 +223,4 @@ def metadata_schema(metadata: list|None) -> Dict[str, Any]:
}
json_schema["additionalProperties"] = False
return json_schema
return json_schema

72
common/query_base.py Normal file
View File

@ -0,0 +1,72 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC, abstractmethod
class QueryBase(ABC):
@staticmethod
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
@abstractmethod
def question(self, text, tbl, min_match):
"""
Returns a query object based on the input text, table, and minimum match criteria.
"""
raise NotImplementedError("Not implemented")

View File

@ -39,6 +39,9 @@ from rag.utils.oss_conn import RAGFlowOSS
from rag.nlp import search
import memory.utils.es_conn as memory_es_conn
import memory.utils.infinity_conn as memory_infinity_conn
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
@ -76,9 +79,11 @@ FEISHU_OAUTH = None
OAUTH_CONFIG = None
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
MSG_ENGINE = DOC_ENGINE
docStoreConn = None
msgStoreConn = None
retriever = None
kg_retriever = None
@ -256,6 +261,15 @@ def init_settings():
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global MSG_ENGINE, msgStoreConn
MSG_ENGINE = DOC_ENGINE # use the same engine for message store
if MSG_ENGINE == "elasticsearch":
ES = get_base_config("es", {})
msgStoreConn = memory_es_conn.ESConnection()
elif MSG_ENGINE == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
msgStoreConn = memory_infinity_conn.InfinityConnection()
global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})

View File

@ -14,6 +14,7 @@
# limitations under the License.
import datetime
import logging
import time
def current_timestamp():
@ -123,4 +124,31 @@ def delta_seconds(date_string: str):
3600.0 # If current time is 2024-01-01 13:00:00
"""
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
return (datetime.datetime.now() - dt).total_seconds()
return (datetime.datetime.now() - dt).total_seconds()
def format_iso_8601_to_ymd_hms(time_str: str) -> str:
"""
Convert ISO 8601 formatted string to "YYYY-MM-DD HH:MM:SS" format.
Args:
time_str: ISO 8601 date string (e.g. "2024-01-01T12:00:00Z")
Returns:
str: Date string in "YYYY-MM-DD HH:MM:SS" format
Example:
>>> format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z")
'2024-01-01 12:00:00'
"""
from dateutil import parser
try:
if parser.isoparse(time_str):
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M:%S")
else:
return time_str
except Exception as e:
logging.error(str(e))
return time_str

View File

@ -1258,6 +1258,12 @@
"status": "1",
"rank": "810",
"llm": [
{
"llm_name": "MiniMax-M2.1",
"tags": "LLM,CHAT,200k",
"max_tokens": 200000,
"model_type": "chat"
},
{
"llm_name": "MiniMax-M2",
"tags": "LLM,CHAT,200k",

View File

@ -0,0 +1,19 @@
{
"id": {"type": "varchar", "default": ""},
"message_id": {"type": "integer", "default": 0},
"message_type_kwd": {"type": "varchar", "default": ""},
"source_id": {"type": "integer", "default": 0},
"memory_id": {"type": "varchar", "default": ""},
"user_id": {"type": "varchar", "default": ""},
"agent_id": {"type": "varchar", "default": ""},
"session_id": {"type": "varchar", "default": ""},
"valid_at": {"type": "varchar", "default": ""},
"valid_at_flt": {"type": "float", "default": 0.0},
"invalid_at": {"type": "varchar", "default": ""},
"invalid_at_flt": {"type": "float", "default": 0.0},
"forget_at": {"type": "varchar", "default": ""},
"forget_at_flt": {"type": "float", "default": 0.0},
"status_int": {"type": "integer", "default": 1},
"zone_id": {"type": "integer", "default": 0},
"content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_ltks"}
}

View File

@ -1447,6 +1447,7 @@ class VisionParser(RAGFlowPdfParser):
def __init__(self, vision_model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vision_model = vision_model
self.outlines = []
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
try:

View File

@ -72,7 +72,7 @@ services:
infinity:
profiles:
- infinity
image: infiniflow/infinity:v0.6.11
image: infiniflow/infinity:v0.6.13
volumes:
- infinity_data:/var/infinity
- ./infinity_conf.toml:/infinity_conf.toml

View File

@ -1,5 +1,5 @@
[general]
version = "0.6.11"
version = "0.6.13"
time_zone = "utc-8"
[network]

View File

@ -0,0 +1,90 @@
---
sidebar_position: 30
slug: /http_request_component
---
# HTTP request component
A component that calls remote services.
---
An **HTTP request** component lets you access remote APIs or services by providing a URL and an HTTP method, and then receive the response. You can customize headers, parameters, proxies, and timeout settings, and use common methods like GET and POST. Its useful for exchanging data with external systems in a workflow.
## Prerequisites
- An accessible remote API or service.
- Add a Token or credentials to the request header, if the target service requires authentication.
## Configurations
### Url
*Required*. The complete request address, for example: http://api.example.com/data.
### Method
The HTTP request method to select. Available options:
- GET
- POST
- PUT
### Timeout
The maximum waiting time for the request, in seconds. Defaults to `60`.
### Headers
Custom HTTP headers can be set here, for example:
```http
{
"Accept": "application/json",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
```
### Proxy
Optional. The proxy server address to use for this request.
### Clean HTML
`Boolean`: Whether to remove HTML tags from the returned results and keep plain text only.
### Parameter
*Optional*. Parameters to send with the HTTP request. Supports key-value pairs:
- To assign a value using a dynamic system variable, set it as Variable.
- To override these dynamic values under certain conditions and use a fixed static value instead, Value is the appropriate choice.
:::tip NOTE
- For GET requests, these parameters are appended to the end of the URL.
- For POST/PUT requests, they are sent as the request body.
:::
#### Example setting
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/http_settings.png)
#### Example response
```html
{ "args": { "App": "RAGFlow", "Query": "How to do?", "Userid": "241ed25a8e1011f0b979424ebc5b108b" }, "headers": { "Accept": "/", "Accept-Encoding": "gzip, deflate, br, zstd", "Cache-Control": "no-cache", "Host": "httpbin.org", "User-Agent": "python-requests/2.32.2", "X-Amzn-Trace-Id": "Root=1-68c9210c-5aab9088580c130a2f065523" }, "origin": "185.36.193.38", "url": "https://httpbin.org/get?Userid=241ed25a8e1011f0b979424ebc5b108b&App=RAGFlow&Query=How+to+do%3F" }
```
### Output
The global variable name for the output of the HTTP request component, which can be referenced by other components in the workflow.
- `Result`: `string` The response returned by the remote service.
## Example
This is a usage example: a workflow sends a GET request from the **Begin** component to `https://httpbin.org/get` via the **HTTP Request_0** component, passes parameters to the server, and finally outputs the result through the **Message_0** component.
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/http_usage.PNG)

View File

@ -24,11 +24,11 @@ from common.misc_utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS
from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
from common.token_utils import num_tokens_from_string
from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
class KGSearch(Dealer):

View File

@ -26,9 +26,9 @@ from networkx.readwrite import json_graph
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
GRAPH_FIELD_SEP = "<SEP>"

View File

@ -96,7 +96,7 @@ ragflow:
infinity:
image:
repository: infiniflow/infinity
tag: v0.6.11
tag: v0.6.13
pullPolicy: IfNotPresent
pullSecrets: []
storage:

0
memory/__init__.py Normal file
View File

View File

240
memory/services/messages.py Normal file
View File

@ -0,0 +1,240 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import List
from common import settings
from common.doc_store.doc_store_base import OrderByExpr, MatchExpr
def index_name(uid: str): return f"memory_{uid}"
class MessageService:
@classmethod
def has_index(cls, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.index_exist(index, memory_id)
@classmethod
def create_index(cls, uid: str, memory_id: str, vector_size: int):
index = index_name(uid)
return settings.msgStoreConn.create_idx(index, memory_id, vector_size)
@classmethod
def delete_index(cls, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.delete_idx(index, memory_id)
@classmethod
def insert_message(cls, messages: List[dict], uid: str, memory_id: str):
index = index_name(uid)
[m.update({
"id": f'{memory_id}_{m["message_id"]}',
"status": 1 if m["status"] else 0
}) for m in messages]
return settings.msgStoreConn.insert(messages, index, memory_id)
@classmethod
def update_message(cls, condition: dict, update_dict: dict, uid: str, memory_id: str):
index = index_name(uid)
if "status" in update_dict:
update_dict["status"] = 1 if update_dict["status"] else 0
return settings.msgStoreConn.update(condition, update_dict, index, memory_id)
@classmethod
def delete_message(cls, condition: dict, uid: str, memory_id: str):
index = index_name(uid)
return settings.msgStoreConn.delete(condition, index, memory_id)
@classmethod
def list_message(cls, uid: str, memory_id: str, agent_ids: List[str]=None, keywords: str=None, page: int=1, page_size: int=50):
index = index_name(uid)
filter_dict = {}
if agent_ids:
filter_dict["agent_id"] = agent_ids
if keywords:
filter_dict["session_id"] = keywords
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status"
],
highlight_fields=[],
condition=filter_dict,
match_expressions=[], order_by=order_by,
offset=(page-1)*page_size, limit=page_size,
index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False
)
total_count = settings.msgStoreConn.get_total(res)
doc_mapping = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status"
])
return {
"message_list": list(doc_mapping.values()),
"total_count": total_count
}
@classmethod
def get_recent_messages(cls, uid_list: List[str], memory_ids: List[str], agent_id: str, session_id: str, limit: int):
index_names = [index_name(uid) for uid in uid_list]
condition_dict = {
"agent_id": agent_id,
"session_id": session_id
}
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status", "content"
],
highlight_fields=[],
condition=condition_dict,
match_expressions=[], order_by=order_by,
offset=0, limit=limit,
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
)
doc_mapping = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id","user_id", "agent_id", "session_id",
"valid_at", "invalid_at", "forget_at", "status", "content"
])
return list(doc_mapping.values())
@classmethod
def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions:list[MatchExpr], top_n: int):
index_names = [index_name(uid) for uid in uid_list]
# filter only valid messages by default
if "status" not in condition_dict:
condition_dict["status"] = 1
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=[
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
"valid_at",
"invalid_at", "forget_at", "status", "content"
],
highlight_fields=[],
condition=condition_dict,
match_expressions=match_expressions,
order_by=order_by,
offset=0, limit=top_n,
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
)
docs = settings.msgStoreConn.get_fields(res, [
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
"invalid_at", "forget_at", "status", "content"
])
return list(docs.values())
@staticmethod
def calculate_message_size(message: dict):
return sys.getsizeof(message["content"]) + sys.getsizeof(message["content_embed"][0]) * len(message["content_embed"])
@classmethod
def calculate_memory_size(cls, memory_ids: List[str], uid_list: List[str]):
index_names = [index_name(uid) for uid in uid_list]
order_by = OrderByExpr()
order_by.desc("valid_at")
res = settings.msgStoreConn.search(
select_fields=["memory_id", "content", "content_embed"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=2000*len(memory_ids),
index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False
)
docs = settings.msgStoreConn.get_fields(res, ["memory_id", "content", "content_embed"])
size_dict = {}
for doc in docs.values():
if size_dict.get(doc["memory_id"]):
size_dict[doc["memory_id"]] += cls.calculate_message_size(doc)
else:
size_dict[doc["memory_id"]] = cls.calculate_message_size(doc)
return size_dict
@classmethod
def pick_messages_to_delete_by_fifo(cls, memory_id: str, uid: str, size_to_delete: int):
select_fields = ["message_id", "content", "content_embed"]
_index_name = index_name(uid)
res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id)
message_list = settings.msgStoreConn.get_fields(res, select_fields)
current_size = 0
ids_to_remove = []
for message in message_list:
if current_size < size_to_delete:
current_size += cls.calculate_message_size(message)
ids_to_remove.append(message["message_id"])
else:
return ids_to_remove, current_size
if current_size >= size_to_delete:
return ids_to_remove, current_size
order_by = OrderByExpr()
order_by.asc("valid_at")
res = settings.msgStoreConn.search(
select_fields=["memory_id", "content", "content_embed"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=2000,
index_names=[_index_name], memory_ids=[memory_id], agg_fields=[]
)
docs = settings.msgStoreConn.get_fields(res, select_fields)
for doc in docs.values():
if current_size < size_to_delete:
current_size += cls.calculate_message_size(doc)
ids_to_remove.append(doc["memory_id"])
else:
return ids_to_remove, current_size
return ids_to_remove, current_size
@classmethod
def get_by_message_id(cls, memory_id: str, message_id: int, uid: str):
index = index_name(uid)
doc_id = f'{memory_id}_{message_id}'
return settings.msgStoreConn.get(doc_id, index, [memory_id])
@classmethod
def get_max_message_id(cls, uid_list: List[str], memory_ids: List[str]):
order_by = OrderByExpr()
order_by.desc("message_id")
index_names = [index_name(uid) for uid in uid_list]
res = settings.msgStoreConn.search(
select_fields=["message_id"],
highlight_fields=[],
condition={},
match_expressions=[],
order_by=order_by,
offset=0, limit=1,
index_names=index_names, memory_ids=memory_ids,
agg_fields=[], hide_forgotten=False
)
docs = settings.msgStoreConn.get_fields(res, ["message_id"])
if not docs:
return 1
else:
latest_msg = list(docs.values())[0]
return int(latest_msg["message_id"])

185
memory/services/query.py Normal file
View File

@ -0,0 +1,185 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import logging
import json
import numpy as np
from common.query_base import QueryBase
from common.doc_store.doc_store_base import MatchDenseExpr, MatchTextExpr
from common.float_utils import get_float
from rag.nlp import rag_tokenizer, term_weight, synonym
def get_vector(txt, emb_mdl, topk=10, similarity=0.1):
if isinstance(similarity, str) and len(similarity) > 0:
try:
similarity = float(similarity)
except Exception as e:
logging.warning(f"Convert similarity '{similarity}' to float failed: {e}. Using default 0.1")
similarity = 0.1
qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [get_float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
class MsgTextQuery(QueryBase):
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
self.query_fields = [
"content"
]
def question(self, txt, tbl="messages", min_match: float=0.6):
original_query = txt
txt = MsgTextQuery.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = MsgTextQuery.rmWWW(txt)
if not self.is_chinese(txt):
txt = self.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
syns = []
for tk, w in tks_w[:256]:
syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
continue
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100, {"original_query": original_query}
), keywords
def need_fine_grained_tokenize(tk):
if len(tk) < 3:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
txt = self.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
continue
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split()
if need_fine_grained_tokenize(tk)
else []
)
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m,
)
for m in sm
]
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [self.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32:
break
tk = self.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(self.sub_special_char(s))
for s in syns
]
)
if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
if qs:
query = " OR ".join([f"({t})" for t in qs if t])
if not query:
query = otxt
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query}
), keywords
return None, keywords

0
memory/utils/__init__.py Normal file
View File

494
memory/utils/es_conn.py Normal file
View File

@ -0,0 +1,494 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import time
import copy
from elasticsearch import NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
@singleton
class ESConnection(ESConnectionBase):
@staticmethod
def convert_field_name(field_name: str) -> str:
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "content":
return "content_ltks"
case _:
return field_name
@staticmethod
def map_message_to_es_fields(message: dict) -> dict:
"""
Map message dictionary fields to Elasticsearch document/Infinity fields.
:param message: A dictionary containing message details.
:return: A dictionary formatted for Elasticsearch/Infinity indexing.
"""
storage_doc = {
"id": message.get("id"),
"message_id": message["message_id"],
"message_type_kwd": message["message_type"],
"source_id": message["source_id"],
"memory_id": message["memory_id"],
"user_id": message["user_id"],
"agent_id": message["agent_id"],
"session_id": message["session_id"],
"valid_at": message["valid_at"],
"invalid_at": message["invalid_at"],
"forget_at": message["forget_at"],
"status_int": 1 if message["status"] else 0,
"zone_id": message.get("zone_id", 0),
"content_ltks": message["content"],
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
}
return storage_doc
@staticmethod
def get_message_from_es_doc(doc: dict) -> dict:
"""
Convert an Elasticsearch/Infinity document back to a message dictionary.
:param doc: A dictionary representing the Elasticsearch/Infinity document.
:return: A dictionary formatted as a message.
"""
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
message = {
"message_id": doc["message_id"],
"message_type": doc["message_type_kwd"],
"source_id": doc["source_id"] if doc["source_id"] else None,
"memory_id": doc["memory_id"],
"user_id": doc.get("user_id", ""),
"agent_id": doc["agent_id"],
"session_id": doc["session_id"],
"zone_id": doc.get("zone_id", 0),
"valid_at": doc["valid_at"],
"invalid_at": doc.get("invalid_at", "-"),
"forget_at": doc.get("forget_at", "-"),
"status": bool(int(doc["status_int"])),
"content": doc.get("content_ltks", ""),
"content_embed": doc.get(embd_field_name, []) if embd_field_name else [],
}
if doc.get("id"):
message["id"] = doc["id"]
return message
"""
CRUD operations
"""
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bool_query = Q("bool", must=[], must_not=[])
if hide_forgotten:
# filter not forget
bool_query.must_not.append(Q("exists", field="forget_at"))
condition["memory_id"] = memory_ids
for k, v in condition.items():
if k == "session_id" and v:
bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True}))
continue
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(self.convert_field_name(m.vector_column_name),
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
bool_query = Q("bool", must_not=[])
bool_query.must_not.append(Q("term", forget_at=None))
bool_query.filter.append(Q("term", memory_id=memory_id))
# from old to new
order_by = OrderByExpr()
order_by.asc("forget_at")
# build search
s = Search()
s = s.query(bool_query)
s = s.sort(order_by)
s = s[:limit]
q = s.to_dict()
# search
for i in range(ATTEMPT_TIME):
try:
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
message = res["_source"]
message["id"] = doc_id
return self.get_message_from_es_doc(message)
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy_raw = copy.deepcopy(d)
d_copy = self.map_message_to_es_fields(d_copy_raw)
d_copy["memory_id"] = memory_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
doc = copy.deepcopy(new_value)
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
update_dict.pop("id", None)
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict and isinstance(condition_dict["id"], str):
# update specific single document
message_id = condition_dict["id"]
for i in range(ATTEMPT_TIME):
for k in update_dict.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=index_name, id=message_id, doc=update_dict)
return True
except Exception as e:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
break
return False
# update unspecific maybe-multiple documents
bool_query = Q("bool")
for k, v in condition_dict.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in update_dict.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "status_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
assert "_id" not in condition
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict:
message_ids = condition_dict["id"]
if not isinstance(message_ids, list):
message_ids = [message_ids]
if not message_ids: # when message_ids is empty, delete all
qry = Q("match_all")
else:
qry = Q("ids", values=message_ids)
else:
qry = Q("bool")
for k, v in condition_dict.items():
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for doc in self._get_source(res):
message = self.get_message_from_es_doc(doc)
m = {}
for n, v in message.items():
if n not in fields:
continue
if isinstance(v, list):
m[n] = v
continue
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(v)
else:
m[n] = v
if m:
res_fields[doc["id"]] = m
return res_fields

View File

@ -0,0 +1,467 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import copy
from infinity.common import InfinityException, SortType
from infinity.errors import ErrorCode
from common.decorator import singleton
import pandas as pd
from common.constants import PAGERANK_FLD, TAG_FLD
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.infinity_conn_base import InfinityConnectionBase
from common.time_utils import date_string_to_timestamp
@singleton
class InfinityConnection(InfinityConnectionBase):
def __init__(self):
super().__init__()
self.mapping_file_name = "message_infinity_mapping.json"
"""
Dataframe and fields convert
"""
@staticmethod
def field_keyword(field_name: str):
# no keywords right now
return False
@staticmethod
def convert_message_field_to_infinity(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case _:
return field_name
@staticmethod
def convert_infinity_field_to_message(field_name: str):
if field_name.startswith("message_type"):
return "message_type"
if field_name.startswith("status"):
return "status"
if re.match(r"q_\d+_vec", field_name):
return "content_embed"
return field_name
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
return list({self.convert_message_field_to_infinity(f) for f in output_fields})
@staticmethod
def convert_matching_field(field_weight_str: str) -> str:
tokens = field_weight_str.split("^")
field = tokens[0]
if field == "content":
field = "content@ft_contentm_rag_fine"
tokens[0] = field
return "^".join(tokens)
@staticmethod
def convert_condition_and_order_field(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "valid_at":
return "valid_at_flt"
case "invalid_at":
return "invalid_at_flt"
case "forget_at":
return "forget_at_flt"
case _:
return field_name
"""
CRUD operations
"""
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
if hide_forgotten:
condition.update({"must_not": {"exists": "forget_at_flt"}})
output = select_fields.copy()
output = self.convert_select_fields(output)
if agg_fields is None:
agg_fields = []
for essential_field in ["id"] + agg_fields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if match_expressions:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
output.append(PAGERANK_FLD)
output = [f for f in output if f != "_score"]
if limit <= 0:
# ElasticSearch default limit is 10000
limit = 10000
# Prepare expressions common to all tables
filter_cond = None
filter_fulltext = ""
if condition:
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
table_found = False
for indexName in index_names:
for mem_id in memory_ids:
table_name = f"{indexName}_{mem_id}"
try:
filter_cond = self.equivalent_condition_to_str(condition_dict, db_instance.get_table(table_name))
table_found = True
break
except Exception:
pass
if table_found:
break
if not table_found:
self.logger.error(f"No valid tables found for indexNames {index_names} and memoryIds {memory_ids}")
return pd.DataFrame(), 0
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
# Add rank_feature support
if rank_feature and "rank_features" not in matchExpr.extra_options:
# Convert rank_feature dict to Infinity's rank_features string format
# Format: "field^feature_name^weight,field^feature_name^weight"
rank_features_list = []
for feature_name, weight in rank_feature.items():
# Use TAG_FLD as the field containing rank features
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
if rank_features_list:
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
similarity = matchExpr.extra_options.get("similarity")
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in index_names:
for memory_id in memory_ids:
table_name = f"{indexName}_{memory_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(match_expressions) > 0:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if order_by.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
mem_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, output)
if match_expressions:
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
self.logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
condition = {"memory_id": memory_id, "exists": "forget_at_flt"}
order_by = OrderByExpr()
order_by.asc("forget_at_flt")
# query
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
output_fields = [self.convert_message_field_to_infinity(f) for f in select_fields]
builder = table_instance.output(output_fields)
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
builder.filter(filter_cond)
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
builder.sort(order_by_expr_list)
builder.offset(0).limit(limit)
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
res = self.concat_dataframes(mem_res, output_fields)
res.head(limit)
self.connPool.release_conn(inf_conn)
return res
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(memory_ids, list)
table_list = list()
for memoryId in memory_ids:
table_name = f"{index_name}_{memoryId}"
table_list.append(table_name)
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Table not found: {table_name}, this memory isn't created in Infinity. Maybe it is created in other document engine.")
continue
mem_res, _ = table_instance.output(["*"]).filter(f"id = '{message_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
res_fields = self.get_fields(res, list(fields))
return res_fields.get(message_id, None)
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
if not documents:
return []
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
vector_size = int(len(documents[0]["content_embed"]))
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
# src/common/status.cppm, kTableNotExist = 3022
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
raise
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.create_idx(index_name, memory_id, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
embedding_columns = []
table_columns = table_instance.show_columns().rows()
for n, ty, _, _ in table_columns:
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
if not r:
continue
embedding_columns.append((n, int(r.group(1))))
docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in list(d.items()):
field_name = self.convert_message_field_to_infinity(k)
if field_name in ["valid_at", "invalid_at", "forget_at"]:
d[f"{field_name}_flt"] = date_string_to_timestamp(v) if v else 0
if v is None:
d[field_name] = ""
elif self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
d[k] = v
elif k == "memory_id":
if isinstance(d[k], list):
d[k] = d[k][0] # since d[k] is a list, but we need a str
elif field_name == "content_embed":
d[f"q_{vector_size}_vec"] = d["content_embed"]
d.pop("content_embed")
else:
d[field_name] = v
if k != field_name:
d.pop(k)
for n, vs in embedding_columns:
if n in d:
continue
d[n] = [0] * vs
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
filter = self.equivalent_condition_to_str(condition_dict, table_instance)
update_dict = {self.convert_message_field_to_infinity(k): v for k, v in new_value.items()}
date_floats = {}
for k, v in update_dict.items():
if k in ["valid_at", "invalid_at", "forget_at"]:
date_floats[f"{k}_flt"] = date_string_to_timestamp(v) if v else 0
elif self.field_keyword(k):
if isinstance(v, list):
update_dict[k] = "###".join(v)
else:
update_dict[k] = v
elif k == "memory_id":
if isinstance(update_dict[k], list):
update_dict[k] = update_dict[k][0] # since d[k] is a list, but we need a str
else:
update_dict[k] = v
if date_floats:
update_dict.update(date_floats)
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
table_instance.update(filter, update_dict)
self.connPool.release_conn(inf_conn)
return True
"""
Helper functions for search result
"""
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fields_all = fields.copy()
fields_all.append("id")
fields_all = {self.convert_message_field_to_infinity(f) for f in fields_all}
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
none_columns = [col for col in fields_all if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
res2.drop_duplicates(subset=["id"], inplace=True)
for column in list(res2.columns):
k = column.lower()
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
else:
pass
for column in ["content"]:
if column in res2:
del res2[column]
for column in none_columns:
res2[column] = None
res_dict = res2.set_index("id").to_dict(orient="index")
return {_id: {self.convert_infinity_field_to_message(k): v for k, v in doc.items()} for _id, doc in res_dict.items()}

37
memory/utils/msg_util.py Normal file
View File

@ -0,0 +1,37 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
def get_json_result_from_llm_response(response_str: str) -> dict:
"""
Parse the LLM response string to extract JSON content.
The function looks for the first and last curly braces to identify the JSON part.
If parsing fails, it returns an empty dictionary.
:param response_str: The response string from the LLM.
:return: A dictionary parsed from the JSON content in the response.
"""
try:
clean_str = response_str.strip()
if clean_str.startswith('```json'):
clean_str = clean_str[7:] # Remove the starting ```json
if clean_str.endswith('```'):
clean_str = clean_str[:-3] # Remove the ending ```
return json.loads(clean_str.strip())
except (ValueError, json.JSONDecodeError):
return {}

201
memory/utils/prompt_util.py Normal file
View File

@ -0,0 +1,201 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List
from common.constants import MemoryType
from common.time_utils import current_timestamp
class PromptAssembler:
SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist**
You are an expert at analyzing conversations to extract structured memory.
{type_specific_instructions}
**OUTPUT REQUIREMENTS:**
1. Output MUST be valid JSON
2. Follow the specified output format exactly
3. Each extracted item MUST have: content, valid_at, invalid_at
4. Timestamps in {timestamp_format} format
5. Only extract memory types specified above
6. Maximum {max_items} items per type
"""
TYPE_INSTRUCTIONS = {
MemoryType.SEMANTIC.name.lower(): """
**EXTRACT SEMANTIC KNOWLEDGE:**
- Universal facts, definitions, concepts, relationships
- Time-invariant, generally true information
- Examples: "The capital of France is Paris", "Water boils at 100°C"
**Timestamp Rules for Semantic Knowledge:**
- valid_at: When the fact became true (e.g., law enactment, discovery)
- invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true
- Default: valid_at = conversation time, invalid_at = "" for timeless facts
""",
MemoryType.EPISODIC.name.lower(): """
**EXTRACT EPISODIC KNOWLEDGE:**
- Specific experiences, events, personal stories
- Time-bound, person-specific, contextual
- Examples: "Yesterday I fixed the bug", "User reported issue last week"
**Timestamp Rules for Episodic Knowledge:**
- valid_at: Event start/occurrence time
- invalid_at: Event end time or empty if instantaneous
- Extract explicit times: "at 3 PM", "last Monday", "from X to Y"
""",
MemoryType.PROCEDURAL.name.lower(): """
**EXTRACT PROCEDURAL KNOWLEDGE:**
- Processes, methods, step-by-step instructions
- Goal-oriented, actionable, often includes conditions
- Examples: "To reset password, click...", "Debugging steps: 1)..."
**Timestamp Rules for Procedural Knowledge:**
- valid_at: When procedure becomes valid/effective
- invalid_at: When it expires/becomes obsolete or empty if current
- For version-specific: use release dates
- For best practices: invalid_at = ""
"""
}
OUTPUT_TEMPLATES = {
MemoryType.SEMANTIC.name.lower(): """
"semantic": [
{
"content": "Clear factual statement",
"valid_at": "timestamp or empty",
"invalid_at": "timestamp or empty"
}
]
""",
MemoryType.EPISODIC.name.lower(): """
"episodic": [
{
"content": "Narrative event description",
"valid_at": "event start timestamp",
"invalid_at": "event end timestamp or empty"
}
]
""",
MemoryType.PROCEDURAL.name.lower(): """
"procedural": [
{
"content": "Actionable instructions",
"valid_at": "procedure effective timestamp",
"invalid_at": "procedure expiration timestamp or empty"
}
]
"""
}
BASE_USER_PROMPT = """
**CONVERSATION:**
{conversation}
**CONVERSATION TIME:** {conversation_time}
**CURRENT TIME:** {current_time}
"""
@classmethod
def assemble_system_prompt(cls, config: dict) -> str:
types_to_extract = cls._get_types_to_extract(config["memory_type"])
type_instructions = cls._generate_type_instructions(types_to_extract)
output_format = cls._generate_output_format(types_to_extract)
full_prompt = cls.SYSTEM_BASE_TEMPLATE.format(
type_specific_instructions=type_instructions,
timestamp_format=config.get("timestamp_format", "ISO 8601"),
max_items=config.get("max_items_per_type", 5)
)
full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n"
examples = cls._generate_examples(types_to_extract)
if examples:
full_prompt += f"\n**EXAMPLES:**\n{examples}\n"
return full_prompt
@staticmethod
def _get_types_to_extract(requested_types: List[str]) -> List[str]:
types = set()
for rt in requested_types:
if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower():
types.add(rt)
return list(types)
@classmethod
def _generate_type_instructions(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
instructions = [cls.TYPE_INSTRUCTIONS[mt] for mt in target_types]
return "\n".join(instructions)
@classmethod
def _generate_output_format(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
output_parts = [cls.OUTPUT_TEMPLATES[mt] for mt in target_types]
return ",\n".join(output_parts)
@staticmethod
def _generate_examples(types_to_extract: list[str]) -> str:
examples = []
if MemoryType.SEMANTIC.name.lower() in types_to_extract:
examples.append("""
**Semantic Example:**
Input: "Python lists are mutable and support various operations."
Output: {"semantic": [{"content": "Python lists are mutable data structures", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
if MemoryType.EPISODIC.name.lower() in types_to_extract:
examples.append("""
**Episodic Example:**
Input: "I deployed the new feature yesterday afternoon."
Output: {"episodic": [{"content": "User deployed new feature", "valid_at": "2024-01-14T14:00:00", "invalid_at": "2024-01-14T18:00:00"}]}
""")
if MemoryType.PROCEDURAL.name.lower() in types_to_extract:
examples.append("""
**Procedural Example:**
Input: "To debug API errors: 1) Check logs 2) Verify endpoints 3) Test connectivity."
Output: {"procedural": [{"content": "API error debugging: 1. Check logs 2. Verify endpoints 3. Test connectivity", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
return "\n".join(examples)
@classmethod
def assemble_user_prompt(
cls,
conversation: str,
conversation_time: Optional[str] = None,
current_time: Optional[str] = None
) -> str:
return cls.BASE_USER_PROMPT.format(
conversation=conversation,
conversation_time=conversation_time or "Not specified",
current_time=current_time or current_timestamp(),
)
@classmethod
def get_raw_user_prompt(cls):
return cls.BASE_USER_PROMPT

View File

@ -46,7 +46,7 @@ dependencies = [
"groq==0.9.0",
"grpcio-status==1.67.1",
"html-text==0.6.2",
"infinity-sdk==0.6.11",
"infinity-sdk==0.6.13",
"infinity-emb>=0.0.66,<0.0.67",
"jira==3.10.5",
"json-repair==0.35.0",

View File

@ -77,9 +77,9 @@ class Benchmark:
def init_index(self, vector_size: int):
if self.initialized_index:
return
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
if settings.docStoreConn.index_exist(self.index_name, self.kb_id):
settings.docStoreConn.delete_idx(self.index_name, self.kb_id)
settings.docStoreConn.create_idx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):

View File

@ -37,7 +37,6 @@ from rag.app.naive import Docx
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream
from rag.llm.cv_model import Base as VLM
from rag.nlp import attach_media_context
from rag.utils.base64_image import image2id
@ -86,8 +85,6 @@ class ParserParam(ProcessParamBase):
"pdf",
],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"spreadsheet": {
"parse_method": "deepdoc", # deepdoc/tcadp_parser
@ -97,8 +94,6 @@ class ParserParam(ProcessParamBase):
"xlsx",
"csv",
],
"table_context_size": 0,
"image_context_size": 0,
},
"word": {
"suffix": [
@ -106,14 +101,10 @@ class ParserParam(ProcessParamBase):
"docx",
],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"text&markdown": {
"suffix": ["md", "markdown", "mdx", "txt"],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"slides": {
"parse_method": "deepdoc", # deepdoc/tcadp_parser
@ -122,8 +113,6 @@ class ParserParam(ProcessParamBase):
"ppt",
],
"output_format": "json",
"table_context_size": 0,
"image_context_size": 0,
},
"image": {
"parse_method": "ocr",
@ -357,11 +346,6 @@ class Parser(ProcessBase):
elif layout == "table":
b["doc_type_kwd"] = "table"
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
bboxes = attach_media_context(bboxes, table_ctx, image_ctx)
if conf.get("output_format") == "json":
self.set_output("json", bboxes)
if conf.get("output_format") == "markdown":
@ -436,11 +420,6 @@ class Parser(ProcessBase):
if table:
result.append({"text": table, "doc_type_kwd": "table"})
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
result = attach_media_context(result, table_ctx, image_ctx)
self.set_output("json", result)
elif output_format == "markdown":
@ -476,11 +455,6 @@ class Parser(ProcessBase):
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
sections.extend([{"text": tb, "image": None, "doc_type_kwd": "table"} for ((_, tb), _) in tbls])
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
sections = attach_media_context(sections, table_ctx, image_ctx)
self.set_output("json", sections)
elif conf.get("output_format") == "markdown":
markdown_text = docx_parser.to_markdown(name, binary=blob)
@ -536,11 +510,6 @@ class Parser(ProcessBase):
if table:
result.append({"text": table, "doc_type_kwd": "table"})
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
result = attach_media_context(result, table_ctx, image_ctx)
self.set_output("json", result)
else:
# Default DeepDOC parser (supports .pptx format)
@ -554,10 +523,6 @@ class Parser(ProcessBase):
# json
assert conf.get("output_format") == "json", "have to be json for ppt"
if conf.get("output_format") == "json":
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
sections = attach_media_context(sections, table_ctx, image_ctx)
self.set_output("json", sections)
def _markdown(self, name, blob):
@ -597,11 +562,6 @@ class Parser(ProcessBase):
json_results.append(json_result)
table_ctx = conf.get("table_context_size", 0) or 0
image_ctx = conf.get("image_context_size", 0) or 0
if table_ctx or image_ctx:
json_results = attach_media_context(json_results, table_ctx, image_ctx)
self.set_output("json", json_results)
else:
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))

View File

@ -23,7 +23,7 @@ from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
from rag.nlp import attach_media_context, naive_merge, naive_merge_with_images
from common import settings
@ -34,11 +34,15 @@ class SplitterParam(ProcessParamBase):
self.delimiters = ["\n"]
self.overlapped_percent = 0
self.children_delimiters = []
self.table_context_size = 0
self.image_context_size = 0
def check(self):
self.check_empty(self.delimiters, "Delimiters.")
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
self.check_nonnegative_number(self.table_context_size, "Table context size.")
self.check_nonnegative_number(self.image_context_size, "Image context size.")
def get_input_form(self) -> dict[str, dict]:
return {}
@ -103,8 +107,18 @@ class Splitter(ProcessBase):
return
# json
json_result = from_upstream.json_result or []
if self._param.table_context_size or self._param.image_context_size:
for ck in json_result:
if "image" not in ck and ck.get("img_id") and not (isinstance(ck.get("text"), str) and ck.get("text").strip()):
ck["image"] = True
attach_media_context(json_result, self._param.table_context_size, self._param.image_context_size)
for ck in json_result:
if ck.get("image") is True:
del ck["image"]
sections, section_images = [], []
for o in from_upstream.json_result or []:
for o in json_result:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(id2image(o.get("img_id"), partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))

View File

@ -105,6 +105,9 @@ class Tokenizer(ProcessBase):
async def _invoke(self, **kwargs):
try:
chunks = kwargs.get("chunks")
kwargs["chunks"] = [c for c in chunks if c is not None]
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")

View File

@ -348,7 +348,8 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d["doc_type_kwd"] = "table"
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
if d["content_with_weight"].find("<tr>") < 0:
d["doc_type_kwd"] = "image"
if poss:
add_positions(d, poss)
res.append(d)
@ -361,7 +362,8 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d["doc_type_kwd"] = "table"
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
if d["content_with_weight"].find("<tr>") < 0:
d["doc_type_kwd"] = "image"
add_positions(d, poss)
res.append(d)
return res

View File

@ -19,11 +19,12 @@ import json
import re
from collections import defaultdict
from rag.utils.doc_store_conn import MatchTextExpr
from common.query_base import QueryBase
from common.doc_store.doc_store_base import MatchTextExpr
from rag.nlp import rag_tokenizer, term_weight, synonym
class FulltextQueryer:
class FulltextQueryer(QueryBase):
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
@ -37,64 +38,19 @@ class FulltextQueryer:
"content_sm_ltks",
]
@staticmethod
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
def question(self, txt, tbl="qa", min_match: float = 0.6):
original_query = txt
txt = FulltextQueryer.add_space_between_eng_zh(txt)
txt = self.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
txt = self.rmWWW(txt)
if not self.is_chinese(txt):
txt = FulltextQueryer.rmWWW(txt)
txt = self.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
@ -138,7 +94,7 @@ class FulltextQueryer:
return False
return True
txt = FulltextQueryer.rmWWW(txt)
txt = self.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
@ -164,7 +120,7 @@ class FulltextQueryer:
)
for m in sm
]
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
@ -172,7 +128,7 @@ class FulltextQueryer:
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [self.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
@ -181,7 +137,7 @@ class FulltextQueryer:
if len(keywords) >= 32:
break
tk = FulltextQueryer.sub_special_char(tk)
tk = self.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
@ -199,7 +155,7 @@ class FulltextQueryer:
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
% rag_tokenizer.tokenize(self.sub_special_char(s))
for s in syns
]
)
@ -264,10 +220,10 @@ class FulltextQueryer:
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [self.sub_special_char(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.sub_special_char(tk)
tk = self.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:

View File

@ -24,7 +24,7 @@ from dataclasses import dataclass
from rag.prompts.generator import relevant_chunks_with_toc
from rag.nlp import rag_tokenizer, query
import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByExpr, DocStoreConnection
from common.string_utils import remove_redundant_spaces
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
@ -155,7 +155,7 @@ class Dealer:
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.get_chunk_ids(res)
ids = self.dataStore.get_doc_ids(res)
keywords = list(kwds)
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
@ -545,7 +545,7 @@ class Dealer:
return res
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
if not self.dataStore.index_exist(index_name(tenant_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.get_aggregation(res, "tag_kwd")

View File

@ -136,6 +136,19 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
return knowledges
def memory_prompt(message_list, max_tokens):
used_token_count = 0
content_list = []
for message in message_list:
current_content_tokens = num_tokens_from_string(message["content"])
if used_token_count + current_content_tokens > max_tokens * 0.97:
logging.warning(f"Not all the retrieval into prompt: {len(content_list)}/{len(message_list)}")
break
content_list.append(message["content"])
used_token_count += current_content_tokens
return content_list
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
@ -326,7 +339,7 @@ def tool_schema(tools_description: list[dict], complete_task=False):
}
for idx, tool in enumerate(tools_description):
name = tool["function"]["name"]
desc[f"{name}_{idx}"] = tool
desc[name] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])

View File

@ -94,7 +94,7 @@ This content will NOT be shown to the user.
## Step 2: Structured Reflection (MANDATORY before `complete_task`)
### Context
- Goal: {{ task_analysis }}
- Goal: Reflect on the current task based on the full conversation context
- Executed tool calls so far (if any): reflect from conversation history
### Task Complexity Assessment

View File

@ -395,9 +395,9 @@ async def build_chunks(task, progress_callback):
await asyncio.gather(*tasks, return_exceptions=True)
raise
metadata = {}
for ck in cks:
metadata = update_metadata_to(metadata, ck["metadata_obj"])
del ck["metadata_obj"]
for doc in docs:
metadata = update_metadata_to(metadata, doc["metadata_obj"])
del doc["metadata_obj"]
if metadata:
e, doc = DocumentService.get_by_id(task["doc_id"])
if e:
@ -506,7 +506,7 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size)
async def embedding(docs, mdl, parser_config=None, callback=None):

View File

@ -82,7 +82,7 @@ def id2image(image_id:str|None, storage_get_func: partial):
return
bkt, nm = image_id.split("-")
try:
blob = storage_get_func(bucket=bkt, filename=nm)
blob = storage_get_func(bucket=bkt, fnm=nm)
if not blob:
return
return Image.open(BytesIO(blob))

View File

@ -14,194 +14,92 @@
# limitations under the License.
#
import logging
import re
import json
import time
import os
import copy
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
logger = logging.getLogger('ragflow.es_conn')
@singleton
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def dbType(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
logger.exception("ESConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=indexName, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(e)
break
return False
class ESConnection(ESConnectionBase):
"""
CRUD operations
"""
def search(
self, selectFields: list[str],
highlightFields: list[str],
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebaseIds
bool_query = Q("bool", must=[])
condition["kb_id"] = knowledgebase_ids
for k, v in condition.items():
if k == "available_int":
if v == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
bool_query.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
bool_query.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in matchExprs:
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
MatchDenseExpr) and isinstance(
matchExprs[2], FusionExpr)
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in matchExprs:
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bqry.must.append(Q("query_string", fields=m.fields,
bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bqry is not None)
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
@ -209,24 +107,24 @@ class ESConnection(DocStoreConnection):
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bqry and rank_feature:
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
s = s.query(bqry)
for field in highlightFields:
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if orderBy:
if order_by:
orders = list()
for field, order in orderBy.fields:
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
@ -237,19 +135,19 @@ class ESConnection(DocStoreConnection):
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=indexNames,
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
@ -257,55 +155,37 @@ class ESConnection(DocStoreConnection):
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=(indexName),
id=chunkId, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
chunk = res["_source"]
chunk["id"] = chunkId
return chunk
except NotFoundError:
return None
except Exception as e:
logger.exception(f"ESConnection.get({chunkId}) got exception")
raise e
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebaseId
d_copy["kb_id"] = knowledgebase_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": indexName, "_id": meta_id}})
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=(indexName), operations=operations,
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
@ -316,58 +196,58 @@ class ESConnection(DocStoreConnection):
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
logger.warning("ESConnection.insert got exception: " + str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
doc = copy.deepcopy(newValue)
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
doc = copy.deepcopy(new_value)
doc.pop("id", None)
condition["kb_id"] = knowledgebaseId
condition["kb_id"] = knowledgebase_id
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunkId = condition["id"]
chunk_id = condition["id"]
for i in range(ATTEMPT_TIME):
for k in doc.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=indexName, id=chunkId, doc=doc)
self.es.update(index=index_name, id=chunk_id, doc=doc)
return True
except Exception as e:
logger.exception(
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
break
return False
# update unspecific maybe-multiple documents
bqry = Q("bool")
bool_query = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bqry.filter.append(Q("exists", field=v))
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in newValue.items():
for k, v in new_value.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
@ -397,8 +277,8 @@ class ESConnection(DocStoreConnection):
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
@ -409,19 +289,18 @@ class ESConnection(DocStoreConnection):
_ = ubq.execute()
return True
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
qry = None
def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
assert "_id" not in condition
condition["kb_id"] = knowledgebaseId
condition["kb_id"] = knowledgebase_id
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
@ -448,21 +327,21 @@ class ESConnection(DocStoreConnection):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=indexName,
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.warning("ESConnection.delete got exception: " + str(e))
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
@ -471,27 +350,11 @@ class ESConnection(DocStoreConnection):
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for d in self.__getSource(res):
for d in self._get_source(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
@ -508,124 +371,3 @@ class ESConnection(DocStoreConnection):
if m:
res_fields[d["id"]] = m
return res_fields
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None

View File

@ -14,365 +14,125 @@
# limitations under the License.
#
import logging
import os
import re
import json
import time
import copy
import infinity
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.common import InfinityException, SortType
from infinity.errors import ErrorCode
from common.decorator import singleton
import pandas as pd
from common.file_utils import get_project_base_directory
from rag.nlp import is_english
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
from rag.utils.doc_store_conn import (
DocStoreConnection,
MatchExpr,
MatchTextExpr,
MatchDenseExpr,
FusionExpr,
OrderByExpr,
)
logger = logging.getLogger("ragflow.infinity_conn")
def field_keyword(field_name: str):
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]):
return True
return False
def convert_select_fields(output_fields: list[str]) -> list[str]:
for i, field in enumerate(output_fields):
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
output_fields[i] = "docnm"
elif field in ["important_kwd", "important_tks"]:
output_fields[i] = "important_keywords"
elif field in ["question_kwd", "question_tks"]:
output_fields[i] = "questions"
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
output_fields[i] = "content"
elif field in ["authors_tks", "authors_sm_tks"]:
output_fields[i] = "authors"
return list(set(output_fields))
def convert_matching_field(field_weightstr: str) -> str:
tokens = field_weightstr.split("^")
field = tokens[0]
if field == "docnm_kwd" or field == "title_tks":
field = "docnm@ft_docnm_rag_coarse"
elif field == "title_sm_tks":
field = "docnm@ft_docnm_rag_fine"
elif field == "important_kwd":
field = "important_keywords@ft_important_keywords_rag_coarse"
elif field == "important_tks":
field = "important_keywords@ft_important_keywords_rag_fine"
elif field == "question_kwd":
field = "questions@ft_questions_rag_coarse"
elif field == "question_tks":
field = "questions@ft_questions_rag_fine"
elif field == "content_with_weight" or field == "content_ltks":
field = "content@ft_content_rag_coarse"
elif field == "content_sm_ltks":
field = "content@ft_content_rag_fine"
elif field == "authors_tks":
field = "authors@ft_authors_rag_coarse"
elif field == "authors_sm_tks":
field = "authors@ft_authors_rag_fine"
tokens[0] = field
return "^".join(tokens)
def list2str(lst: str|list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
clmns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
def exists(cln):
nonlocal clmns
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
ty, de = clmns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in selectFields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.infinity_conn_base import InfinityConnectionBase
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
for _ in range(24):
try:
connPool = ConnectionPool(infinity_uri, max_size=4)
inf_conn = connPool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = connPool
connPool.release_conn(inf_conn)
break
connPool.release_conn(inf_conn)
logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
class InfinityConnection(InfinityConnectionBase):
"""
Database operations
Dataframe and fields convert
"""
def dbType(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vectorSize}_vec"
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
def deleteIdx(self, indexName: str, knowledgebaseId: str):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY dropped table {table_name}")
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
table_name = f"{indexName}_{knowledgebaseId}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
@staticmethod
def field_keyword(field_name: str):
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
if field_name == "source_id" or (
field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd",
"question_kwd"]):
return True
except Exception as e:
logger.warning(f"INFINITY indexExist {str(e)}")
return False
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
for i, field in enumerate(output_fields):
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
output_fields[i] = "docnm"
elif field in ["important_kwd", "important_tks"]:
output_fields[i] = "important_keywords"
elif field in ["question_kwd", "question_tks"]:
output_fields[i] = "questions"
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
output_fields[i] = "content"
elif field in ["authors_tks", "authors_sm_tks"]:
output_fields[i] = "authors"
return list(set(output_fields))
@staticmethod
def convert_matching_field(field_weight_str: str) -> str:
tokens = field_weight_str.split("^")
field = tokens[0]
if field == "docnm_kwd" or field == "title_tks":
field = "docnm@ft_docnm_rag_coarse"
elif field == "title_sm_tks":
field = "docnm@ft_docnm_rag_fine"
elif field == "important_kwd":
field = "important_keywords@ft_important_keywords_rag_coarse"
elif field == "important_tks":
field = "important_keywords@ft_important_keywords_rag_fine"
elif field == "question_kwd":
field = "questions@ft_questions_rag_coarse"
elif field == "question_tks":
field = "questions@ft_questions_rag_fine"
elif field == "content_with_weight" or field == "content_ltks":
field = "content@ft_content_rag_coarse"
elif field == "content_sm_ltks":
field = "content@ft_content_rag_fine"
elif field == "authors_tks":
field = "authors@ft_authors_rag_coarse"
elif field == "authors_sm_tks":
field = "authors@ft_authors_rag_fine"
tokens[0] = field
return "^".join(tokens)
"""
CRUD operations
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
output = selectFields.copy()
output = convert_select_fields(output)
for essential_field in ["id"] + aggFields:
output = select_fields.copy()
output = self.convert_select_fields(output)
if agg_fields is None:
agg_fields = []
for essential_field in ["id"] + agg_fields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if matchExprs:
if match_expressions:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
@ -387,11 +147,11 @@ class InfinityConnection(DocStoreConnection):
filter_fulltext = ""
if condition:
table_found = False
for indexName in indexNames:
for kb_id in knowledgebaseIds:
for indexName in index_names:
for kb_id in knowledgebase_ids:
table_name = f"{indexName}_{kb_id}"
try:
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
table_found = True
break
except Exception:
@ -399,14 +159,14 @@ class InfinityConnection(DocStoreConnection):
if table_found:
break
if not table_found:
logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
return pd.DataFrame(), 0
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [convert_matching_field(field) for field in matchExpr.fields]
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
@ -430,7 +190,7 @@ class InfinityConnection(DocStoreConnection):
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
@ -441,16 +201,16 @@ class InfinityConnection(DocStoreConnection):
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
if matchExpr.method == "weighted_sum":
# The default is "minmax" which gives a zero score for the last doc.
matchExpr.fusion_params["normalize"] = "atan"
logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if orderBy.fields:
for order_field in orderBy.fields:
if order_by.fields:
for order_field in order_by.fields:
if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
@ -458,8 +218,8 @@ class InfinityConnection(DocStoreConnection):
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in indexNames:
for knowledgebaseId in knowledgebaseIds:
for indexName in index_names:
for knowledgebaseId in knowledgebase_ids:
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
@ -467,8 +227,8 @@ class InfinityConnection(DocStoreConnection):
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if len(match_expressions) > 0:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
@ -491,53 +251,52 @@ class InfinityConnection(DocStoreConnection):
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields:
if order_by.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, output)
if matchExprs:
res = self.concat_dataframes(df_list, output)
if match_expressions:
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
self.logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(knowledgebaseIds, list)
assert isinstance(knowledgebase_ids, list)
table_list = list()
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
for knowledgebaseId in knowledgebase_ids:
table_name = f"{index_name}_{knowledgebaseId}"
table_list.append(table_name)
table_instance = None
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
fields.add(field)
res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunkId, None)
return res_fields.get(chunk_id, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_name = f"{index_name}_{knowledgebase_id}"
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
@ -553,7 +312,7 @@ class InfinityConnection(DocStoreConnection):
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.createIdx(indexName, knowledgebaseId, vector_size)
self.create_idx(index_name, knowledgebase_id, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
@ -574,12 +333,12 @@ class InfinityConnection(DocStoreConnection):
d["docnm"] = v
elif k == "title_kwd":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
d["docnm"] = self.list2str(v)
elif k == "title_sm_tks":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
d["docnm"] = self.list2str(v)
elif k == "important_kwd":
d["important_keywords"] = list2str(v)
d["important_keywords"] = self.list2str(v)
elif k == "important_tks":
if not d.get("important_kwd"):
d["important_keywords"] = v
@ -597,11 +356,11 @@ class InfinityConnection(DocStoreConnection):
if not d.get("authors_tks"):
d["authors"] = v
elif k == "question_kwd":
d["questions"] = list2str(v, "\n")
d["questions"] = self.list2str(v, "\n")
elif k == "question_tks":
if not d.get("question_kwd"):
d["questions"] = list2str(v)
elif field_keyword(k):
d["questions"] = self.list2str(v)
elif self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
@ -637,15 +396,15 @@ class InfinityConnection(DocStoreConnection):
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
# if 'position_int' in newValue:
# logger.info(f"update position_int: {newValue['position_int']}")
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_name = f"{index_name}_{knowledgebase_id}"
table_instance = db_instance.get_table(table_name)
# if "exists" in condition:
# del condition["exists"]
@ -654,57 +413,57 @@ class InfinityConnection(DocStoreConnection):
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
filter = equivalent_condition_to_str(condition, table_instance)
filter = self.equivalent_condition_to_str(condition, table_instance)
removeValue = {}
for k, v in list(newValue.items()):
for k, v in list(new_value.items()):
if k == "docnm_kwd":
newValue["docnm"] = list2str(v)
new_value["docnm"] = self.list2str(v)
elif k == "title_kwd":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = list2str(v)
if not new_value.get("docnm_kwd"):
new_value["docnm"] = self.list2str(v)
elif k == "title_sm_tks":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = v
if not new_value.get("docnm_kwd"):
new_value["docnm"] = v
elif k == "important_kwd":
newValue["important_keywords"] = list2str(v)
new_value["important_keywords"] = self.list2str(v)
elif k == "important_tks":
if not newValue.get("important_kwd"):
newValue["important_keywords"] = v
if not new_value.get("important_kwd"):
new_value["important_keywords"] = v
elif k == "content_with_weight":
newValue["content"] = v
new_value["content"] = v
elif k == "content_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
if not new_value.get("content_with_weight"):
new_value["content"] = v
elif k == "content_sm_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
if not new_value.get("content_with_weight"):
new_value["content"] = v
elif k == "authors_tks":
newValue["authors"] = v
new_value["authors"] = v
elif k == "authors_sm_tks":
if not newValue.get("authors_tks"):
newValue["authors"] = v
if not new_value.get("authors_tks"):
new_value["authors"] = v
elif k == "question_kwd":
newValue["questions"] = "\n".join(v)
new_value["questions"] = "\n".join(v)
elif k == "question_tks":
if not newValue.get("question_kwd"):
newValue["questions"] = list2str(v)
elif field_keyword(k):
if not new_value.get("question_kwd"):
new_value["questions"] = self.list2str(v)
elif self.field_keyword(k):
if isinstance(v, list):
newValue[k] = "###".join(v)
new_value[k] = "###".join(v)
else:
newValue[k] = v
new_value[k] = v
elif re.search(r"_feas$", k):
newValue[k] = json.dumps(v)
new_value[k] = json.dumps(v)
elif k == "kb_id":
if isinstance(newValue[k], list):
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
if isinstance(new_value[k], list):
new_value[k] = new_value[k][0] # since d[k] is a list, but we need a str
elif k == "position_int":
assert isinstance(v, list)
arr = [num for row in v for num in row]
newValue[k] = "_".join(f"{num:08x}" for num in arr)
new_value[k] = "_".join(f"{num:08x}" for num in arr)
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v)
new_value[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove":
if isinstance(v, str):
assert v in clmns, f"'{v}' should be in '{clmns}'."
@ -712,22 +471,22 @@ class InfinityConnection(DocStoreConnection):
if ty.lower().find("cha"):
if not de:
de = ""
newValue[v] = de
new_value[v] = de
else:
for kk, vv in v.items():
removeValue[kk] = vv
del newValue[k]
del new_value[k]
else:
newValue[k] = v
new_value[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in newValue:
del newValue[k]
if k in new_value:
del new_value[k]
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
if removeValue:
col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
self.logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items():
@ -740,78 +499,53 @@ class InfinityConnection(DocStoreConnection):
else:
remove_opt[kv_key].append(id)
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv)
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
table_instance.update(filter, newValue)
table_instance.update(filter, new_value)
self.connPool.release_conn(inf_conn)
return True
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = equivalent_condition_to_str(condition, table_instance)
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fieldsAll = fields.copy()
fieldsAll.append("id")
fieldsAll = set(fieldsAll)
fields_all = fields.copy()
fields_all.append("id")
fields_all = set(fields_all)
if "docnm" in res.columns:
for field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["docnm"]
if "important_keywords" in res.columns:
if "important_kwd" in fieldsAll:
if "important_kwd" in fields_all:
res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split())
if "important_tks" in fieldsAll:
if "important_tks" in fields_all:
res["important_tks"] = res["important_keywords"]
if "questions" in res.columns:
if "question_kwd" in fieldsAll:
if "question_kwd" in fields_all:
res["question_kwd"] = res["questions"].apply(lambda v: v.splitlines())
if "question_tks" in fieldsAll:
if "question_tks" in fields_all:
res["question_tks"] = res["questions"]
if "content" in res.columns:
for field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["content"]
if "authors" in res.columns:
for field in ["authors_tks", "authors_sm_tks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["authors"]
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in fieldsAll if col.lower() in column_map}
none_columns = [col for col in fieldsAll if col.lower() not in column_map]
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
none_columns = [col for col in fields_all if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
@ -819,7 +553,7 @@ class InfinityConnection(DocStoreConnection):
for column in list(res2.columns):
k = column.lower()
if field_keyword(k):
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
elif re.search(r"_feas$", k):
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
@ -844,95 +578,3 @@ class InfinityConnection(DocStoreConnection):
res2[column] = None
return res2.set_index("id").to_dict(orient="index")
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if fieldnm not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[fieldnm][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
if txts:
ans[id] = "...".join(txts)
else:
ans[id] = txt
return ans
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or fieldnm not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[fieldnm]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if fieldnm == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -37,9 +37,8 @@ from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.float_utils import get_float
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
from rag.nlp import rag_tokenizer
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
MatchDenseExpr
ATTEMPT_TIME = 2
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
@ -497,7 +496,7 @@ class OBConnection(DocStoreConnection):
Database operations
"""
def dbType(self) -> str:
def db_type(self) -> str:
return "oceanbase"
def health(self) -> dict:
@ -553,7 +552,7 @@ class OBConnection(DocStoreConnection):
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
vector_field_name = f"q_{vectorSize}_vec"
vector_index_name = f"{vector_field_name}_idx"
@ -604,7 +603,7 @@ class OBConnection(DocStoreConnection):
# always refresh metadata to make sure it contains the latest table structure
self.client.refresh_metadata([indexName])
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
@ -615,7 +614,7 @@ class OBConnection(DocStoreConnection):
except Exception as e:
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
return self._check_table_exists_cached(indexName)
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
@ -1500,7 +1499,7 @@ class OBConnection(DocStoreConnection):
def get_total(self, res) -> int:
return res.total
def get_chunk_ids(self, res) -> list[str]:
def get_doc_ids(self, res) -> list[str]:
return [row["id"] for row in res.chunks]
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:

View File

@ -26,8 +26,7 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
@ -79,7 +78,7 @@ class OSConnection(DocStoreConnection):
Database operations
"""
def dbType(self) -> str:
def db_type(self) -> str:
return "opensearch"
def health(self) -> dict:
@ -91,8 +90,8 @@ class OSConnection(DocStoreConnection):
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.index_exist(indexName, knowledgebaseId):
return True
try:
from opensearchpy.client import IndicesClient
@ -101,7 +100,7 @@ class OSConnection(DocStoreConnection):
except Exception:
logger.exception("OSConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
@ -112,7 +111,7 @@ class OSConnection(DocStoreConnection):
except Exception:
logger.exception("OSConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.os)
for i in range(ATTEMPT_TIME):
try:
@ -460,7 +459,7 @@ class OSConnection(DocStoreConnection):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_chunk_ids(self, res):
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):

View File

@ -272,6 +272,49 @@ class RedisDB:
self.__open__()
return None
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
redis_key = f"{key_prefix}:{namespace}"
try:
# Use pipeline for atomicity
pipe = self.REDIS.pipeline()
# Check if key exists
pipe.exists(redis_key)
# Get/Increment
if ensure_minimum is not None:
# Ensure minimum value
pipe.get(redis_key)
results = pipe.execute()
if results[0] == 0: # Key doesn't exist
start_id = max(1, ensure_minimum)
pipe.set(redis_key, start_id)
pipe.execute()
return start_id
else:
current = int(results[1])
if current < ensure_minimum:
pipe.set(redis_key, ensure_minimum)
pipe.execute()
return ensure_minimum
# Increment operation
next_id = self.REDIS.incrby(redis_key, increment)
# If it's the first time, set a reasonable initial value
if next_id == increment:
self.REDIS.set(redis_key, 1 + increment)
return 1 + increment
return next_id
except Exception as e:
logging.warning("RedisDB.generate_auto_increment_id got exception: " + str(e))
self.__open__()
return -1
def transaction(self, key, value, exp=3600):
try:
pipeline = self.REDIS.pipeline(transaction=True)

View File

@ -32,8 +32,8 @@ def add_memory_func(request, WebApiAuth):
payload = {
"name": f"test_memory_{i}",
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
memory_ids.append(res["data"]["id"])

View File

@ -49,8 +49,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
assert res["code"] == 0, res
@ -72,8 +72,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
assert res["message"] == expected_message, res
@ -84,8 +84,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["something"],
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
assert res["message"] == f"Memory type '{ {'something'} }' is not supported.", res
@ -96,8 +96,8 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
"llm_id": "ZHIPU-AI@glm-4-flash"
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res1 = create_memory(WebApiAuth, payload)
assert res1["code"] == 0, res1

View File

@ -101,7 +101,7 @@ class TestMemoryUpdate:
@pytest.mark.p1
def test_llm(self, WebApiAuth, add_memory_func):
memory_ids = add_memory_func
llm_id = "ZHIPU-AI@glm-4"
llm_id = "glm-4@ZHIPU-AI"
payload = {"llm_id": llm_id}
res = update_memory(WebApiAuth, memory_ids[0], payload)
assert res["code"] == 0, res

8
uv.lock generated
View File

@ -3051,7 +3051,7 @@ wheels = [
[[package]]
name = "infinity-sdk"
version = "0.6.11"
version = "0.6.13"
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
dependencies = [
{ name = "datrie" },
@ -3068,9 +3068,9 @@ dependencies = [
{ name = "sqlglot", extra = ["rs"] },
{ name = "thrift" },
]
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/6d/294b4c8fb36f874c92576107fab22da6d64f567fd3e24a312d7bcba5f17a/infinity_sdk-0.6.11.tar.gz", hash = "sha256:f78acd5439c3837715ab308c49be04b416bdaa42a5f4fb840682639ee39d435f", size = 29518792, upload-time = "2025-12-08T05:58:35.167Z" }
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/de/56fdc0fa962d5a8e0aa68d16f5321b2d88d79fceb7d0d6cfdde338b65d05/infinity_sdk-0.6.13.tar.gz", hash = "sha256:faf7bc23de7fa549a3842753eddad54ae551ada9df4fff25421658a7fa6fa8c2", size = 29518902, upload-time = "2025-12-24T10:00:01.483Z" }
wheels = [
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/cf/d7d1bc584c8f7d1cd8f75c39067179b98ca4bcbe5a86c61e7dbc2b8e692d/infinity_sdk-0.6.11-py3-none-any.whl", hash = "sha256:ec5ac3e710f29db4b875d3e24e20e391ad64270a5e5d189295cc91c362af74d1", size = 29737403, upload-time = "2025-12-08T05:54:58.798Z" },
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f4/a0/8f1e134fdf4ca8bebac7b62caace1816953bb5ffc720d9f0004246c8c38d/infinity_sdk-0.6.13-py3-none-any.whl", hash = "sha256:c08a523d2c27e9a7e6e88be640970530b4661a67c3e9dc3e1aa89533a822fd78", size = 29737403, upload-time = "2025-12-24T09:56:16.93Z" },
]
[[package]]
@ -6229,7 +6229,7 @@ requires-dist = [
{ name = "grpcio-status", specifier = "==1.67.1" },
{ name = "html-text", specifier = "==0.6.2" },
{ name = "infinity-emb", specifier = ">=0.0.66,<0.0.67" },
{ name = "infinity-sdk", specifier = "==0.6.11" },
{ name = "infinity-sdk", specifier = "==0.6.13" },
{ name = "jira", specifier = "==3.10.5" },
{ name = "json-repair", specifier = "==0.35.0" },
{ name = "langfuse", specifier = ">=2.60.0" },

View File

@ -18,7 +18,9 @@ import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-reques
import { IModalProps } from '@/interfaces/common';
import { IParserConfig } from '@/interfaces/database/document';
import { IChangeParserConfigRequestBody } from '@/interfaces/request/document';
import { MetadataType } from '@/pages/dataset/components/metedata/hooks/use-manage-modal';
import {
AutoMetadata,
ChunkMethodItem,
EnableTocToggle,
ImageContextWindow,
@ -86,6 +88,7 @@ export function ChunkMethodDialog({
visible,
parserConfig,
loading,
documentId,
}: IProps) {
const { t } = useTranslation();
@ -142,6 +145,18 @@ export function ChunkMethodDialog({
pages: z
.array(z.object({ from: z.coerce.number(), to: z.coerce.number() }))
.optional(),
metadata: z
.array(
z
.object({
key: z.string().optional(),
description: z.string().optional(),
enum: z.array(z.string().optional()).optional(),
})
.optional(),
)
.optional(),
enable_metadata: z.boolean().optional(),
}),
})
.superRefine((data, ctx) => {
@ -373,6 +388,10 @@ export function ChunkMethodDialog({
)}
{showAutoKeywords(selectedTag) && (
<>
<AutoMetadata
type={MetadataType.SingleFileSetting}
otherData={{ documentId }}
/>
<AutoKeywordsFormField></AutoKeywordsFormField>
<AutoQuestionsFormField></AutoQuestionsFormField>
</>

View File

@ -36,9 +36,11 @@ export function useDefaultParserValues() {
// },
entity_types: [],
pages: [],
metadata: [],
enable_metadata: false,
};
return defaultParserValues;
return defaultParserValues as IParserConfig;
}, [t]);
return defaultParserValues;

View File

@ -35,6 +35,7 @@ import { cn } from '@/lib/utils';
import { t } from 'i18next';
import { Loader } from 'lucide-react';
import { MultiSelect, MultiSelectOptionType } from './ui/multi-select';
import { Switch } from './ui/switch';
// Field type enumeration
export enum FormFieldType {
@ -46,6 +47,7 @@ export enum FormFieldType {
Select = 'select',
MultiSelect = 'multi-select',
Checkbox = 'checkbox',
Switch = 'switch',
Tag = 'tag',
Custom = 'custom',
}
@ -154,6 +156,7 @@ export const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
}
break;
case FormFieldType.Checkbox:
case FormFieldType.Switch:
fieldSchema = z.boolean();
break;
case FormFieldType.Tag:
@ -193,6 +196,8 @@ export const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
if (
field.type !== FormFieldType.Number &&
field.type !== FormFieldType.Checkbox &&
field.type !== FormFieldType.Switch &&
field.type !== FormFieldType.Custom &&
field.type !== FormFieldType.Tag &&
field.required
) {
@ -289,7 +294,10 @@ const generateDefaultValues = <T extends FieldValues>(
const lastKey = keys[keys.length - 1];
if (field.defaultValue !== undefined) {
current[lastKey] = field.defaultValue;
} else if (field.type === FormFieldType.Checkbox) {
} else if (
field.type === FormFieldType.Checkbox ||
field.type === FormFieldType.Switch
) {
current[lastKey] = false;
} else if (field.type === FormFieldType.Tag) {
current[lastKey] = [];
@ -299,7 +307,10 @@ const generateDefaultValues = <T extends FieldValues>(
} else {
if (field.defaultValue !== undefined) {
defaultValues[field.name] = field.defaultValue;
} else if (field.type === FormFieldType.Checkbox) {
} else if (
field.type === FormFieldType.Checkbox ||
field.type === FormFieldType.Switch
) {
defaultValues[field.name] = false;
} else if (
field.type === FormFieldType.Tag ||
@ -502,6 +513,32 @@ export const RenderField = ({
)}
/>
);
case FormFieldType.Switch:
return (
<RAGFlowFormItem
{...field}
labelClassName={labelClassName || field.labelClassName}
>
{(fieldProps) => {
const finalFieldProps = field.onChange
? {
...fieldProps,
onChange: (checked: boolean) => {
fieldProps.onChange(checked);
field.onChange?.(checked);
},
}
: fieldProps;
return (
<Switch
checked={finalFieldProps.value as boolean}
onCheckedChange={(checked) => finalFieldProps.onChange(checked)}
disabled={field.disabled}
/>
);
}}
</RAGFlowFormItem>
);
case FormFieldType.Tag:
return (

View File

@ -15,6 +15,7 @@ import { Progress } from '@/components/ui/progress';
import { useControllableState } from '@/hooks/use-controllable-state';
import { cn, formatBytes } from '@/lib/utils';
import { useTranslation } from 'react-i18next';
import { Tooltip, TooltipContent, TooltipTrigger } from './ui/tooltip';
function isFileWithPreview(file: File): file is File & { preview: string } {
return 'preview' in file && typeof file.preview === 'string';
@ -58,10 +59,17 @@ function FileCard({ file, progress, onRemove }: FileCardProps) {
</div>
<div className="flex flex-col flex-1 gap-2 overflow-hidden">
<div className="flex flex-col gap-px">
<p className="line-clamp-1 text-sm font-medium text-foreground/80 text-ellipsis">
{file.name}
</p>
<p className="text-xs text-muted-foreground">
<Tooltip>
<TooltipTrigger asChild>
<p className=" w-fit line-clamp-1 text-sm font-medium text-foreground/80 text-ellipsis truncate max-w-[370px]">
{file.name}
</p>
</TooltipTrigger>
<TooltipContent className="border border-border-button">
{file.name}
</TooltipContent>
</Tooltip>
<p className="text-xs text-text-secondary">
{formatBytes(file.size)}
</p>
</div>
@ -311,7 +319,7 @@ export function FileUploader(props: FileUploaderProps) {
/>
</div>
<div className="flex flex-col gap-px">
<p className="font-medium text-text-secondary">
<p className="font-medium text-text-secondary ">
{title || t('knowledgeDetails.uploadTitle')}
</p>
<p className="text-sm text-text-disabled">

View File

@ -31,14 +31,16 @@ const handleCheckChange = ({
(value: string) => value !== item.id.toString(),
);
const newValue = {
...currentValue,
[parentId]: newParentValues,
};
const newValue = newParentValues?.length
? {
...currentValue,
[parentId]: newParentValues,
}
: { ...currentValue };
if (newValue[parentId].length === 0) {
delete newValue[parentId];
}
// if (newValue[parentId].length === 0) {
// delete newValue[parentId];
// }
return field.onChange(newValue);
} else {
@ -66,20 +68,31 @@ const FilterItem = memo(
}) => {
return (
<div
className={`flex items-center justify-between text-text-primary text-xs ${level > 0 ? 'ml-4' : ''}`}
className={`flex items-center justify-between text-text-primary text-xs ${level > 0 ? 'ml-1' : ''}`}
>
<FormItem className="flex flex-row space-x-3 space-y-0 items-center">
<FormItem className="flex flex-row space-x-3 space-y-0 items-center ">
<FormControl>
<Checkbox
checked={field.value?.includes(item.id.toString())}
onCheckedChange={(checked: boolean) =>
handleCheckChange({ checked, field, item })
}
/>
<div className="flex space-x-3">
<Checkbox
checked={field.value?.includes(item.id.toString())}
onCheckedChange={(checked: boolean) =>
handleCheckChange({ checked, field, item })
}
// className="hidden group-hover:block"
/>
<FormLabel
onClick={() =>
handleCheckChange({
checked: !field.value?.includes(item.id.toString()),
field,
item,
})
}
>
{item.label}
</FormLabel>
</div>
</FormControl>
<FormLabel onClick={(e) => e.stopPropagation()}>
{item.label}
</FormLabel>
</FormItem>
{item.count !== undefined && (
<span className="text-sm">{item.count}</span>
@ -107,11 +120,11 @@ export const FilterField = memo(
<FormField
key={item.id}
control={form.control}
name={parent.field as string}
name={parent.field?.toString() as string}
render={({ field }) => {
if (hasNestedList) {
return (
<div className={`flex flex-col gap-2 ${level > 0 ? 'ml-4' : ''}`}>
<div className={`flex flex-col gap-2 ${level > 0 ? 'ml-1' : ''}`}>
<div
className="flex items-center justify-between cursor-pointer"
onClick={() => {
@ -138,23 +151,6 @@ export const FilterField = memo(
}}
level={level + 1}
/>
// <FilterItem key={child.id} item={child} field={child.field} level={level+1} />
// <div
// className="flex flex-row space-x-3 space-y-0 items-center"
// key={child.id}
// >
// <FormControl>
// <Checkbox
// checked={field.value?.includes(child.id.toString())}
// onCheckedChange={(checked) =>
// handleCheckChange({ checked, field, item: child })
// }
// />
// </FormControl>
// <FormLabel onClick={(e) => e.stopPropagation()}>
// {child.label}
// </FormLabel>
// </div>
))}
</div>
);

View File

@ -11,8 +11,8 @@ import {
useMemo,
useState,
} from 'react';
import { useForm } from 'react-hook-form';
import { ZodArray, ZodString, z } from 'zod';
import { FieldPath, useForm } from 'react-hook-form';
import { z } from 'zod';
import { Button } from '@/components/ui/button';
@ -71,34 +71,37 @@ function CheckboxFormMultiple({
}, {});
}, [resolvedFilters]);
// const FormSchema = useMemo(() => {
// if (resolvedFilters.length === 0) {
// return z.object({});
// }
// return z.object(
// resolvedFilters.reduce<
// Record<
// string,
// ZodArray<ZodString, 'many'> | z.ZodObject<any> | z.ZodOptional<any>
// >
// >((pre, cur) => {
// const hasNested = cur.list?.some(
// (item) => item.list && item.list.length > 0,
// );
// if (hasNested) {
// pre[cur.field] = z
// .record(z.string(), z.array(z.string().optional()).optional())
// .optional();
// } else {
// pre[cur.field] = z.array(z.string().optional()).optional();
// }
// return pre;
// }, {}),
// );
// }, [resolvedFilters]);
const FormSchema = useMemo(() => {
if (resolvedFilters.length === 0) {
return z.object({});
}
return z.object(
resolvedFilters.reduce<
Record<
string,
ZodArray<ZodString, 'many'> | z.ZodObject<any> | z.ZodOptional<any>
>
>((pre, cur) => {
const hasNested = cur.list?.some(
(item) => item.list && item.list.length > 0,
);
if (hasNested) {
pre[cur.field] = z
.record(z.string(), z.array(z.string().optional()).optional())
.optional();
} else {
pre[cur.field] = z.array(z.string().optional()).optional();
}
return pre;
}, {}),
);
}, [resolvedFilters]);
return z.object({});
}, []);
const form = useForm<z.infer<typeof FormSchema>>({
resolver: resolvedFilters.length > 0 ? zodResolver(FormSchema) : undefined,
@ -178,7 +181,9 @@ function CheckboxFormMultiple({
<FormField
key={x.field}
control={form.control}
name={x.field}
name={
x.field.toString() as FieldPath<z.infer<typeof FormSchema>>
}
render={() => (
<FormItem className="space-y-4">
<div>
@ -186,19 +191,20 @@ function CheckboxFormMultiple({
{x.label}
</FormLabel>
</div>
{x.list.map((item) => {
return (
<FilterField
key={item.id}
item={{ ...item }}
parent={{
...x,
id: x.field,
// field: `${x.field}${item.field ? '.' + item.field : ''}`,
}}
/>
);
})}
{x.list?.length &&
x.list.map((item) => {
return (
<FilterField
key={item.id}
item={{ ...item }}
parent={{
...x,
id: x.field,
// field: `${x.field}${item.field ? '.' + item.field : ''}`,
}}
/>
);
})}
<FormMessage />
</FormItem>
)}

View File

@ -14,6 +14,7 @@ import {
IDocumentMetaRequestBody,
} from '@/interfaces/request/document';
import i18n from '@/locales/config';
import { EMPTY_METADATA_FIELD } from '@/pages/dataset/dataset/use-select-filters';
import kbService, { listDocument } from '@/services/knowledge-service';
import api, { api_host } from '@/utils/api';
import { buildChunkHighlights } from '@/utils/document-util';
@ -114,6 +115,20 @@ export const useFetchDocumentList = () => {
refetchInterval: isLoop ? 5000 : false,
enabled: !!knowledgeId || !!id,
queryFn: async () => {
let run = [] as any;
let returnEmptyMetadata = false;
if (filterValue.run && Array.isArray(filterValue.run)) {
run = [...(filterValue.run as string[])];
const returnEmptyMetadataIndex = run.findIndex(
(r: string) => r === EMPTY_METADATA_FIELD,
);
if (returnEmptyMetadataIndex > -1) {
returnEmptyMetadata = true;
run.splice(returnEmptyMetadataIndex, 1);
}
} else {
run = filterValue.run;
}
const ret = await listDocument(
{
kb_id: knowledgeId || id,
@ -123,7 +138,8 @@ export const useFetchDocumentList = () => {
},
{
suffix: filterValue.type as string[],
run_status: filterValue.run as string[],
run_status: run as string[],
return_empty_metadata: returnEmptyMetadata,
metadata: filterValue.metadata as Record<string, string[]>,
},
);

View File

@ -43,6 +43,18 @@ export interface IParserConfig {
task_page_size?: number;
raptor?: Raptor;
graphrag?: GraphRag;
image_context_window?: number;
mineru_parse_method?: 'auto' | 'txt' | 'ocr';
mineru_formula_enable?: boolean;
mineru_table_enable?: boolean;
mineru_lang?: string;
entity_types?: string[];
metadata?: Array<{
key?: string;
description?: string;
enum?: string[];
}>;
enable_metadata?: boolean;
}
interface Raptor {

View File

@ -33,5 +33,6 @@ export interface IFetchKnowledgeListRequestParams {
export interface IFetchDocumentListRequestBody {
suffix?: string[];
run_status?: string[];
return_empty_metadata?: boolean;
metadata?: Record<string, string[]>;
}

View File

@ -176,6 +176,15 @@ Procedural Memory: Learned skills, habits, and automated procedures.`,
},
knowledgeDetails: {
metadata: {
descriptionTip:
'Provide descriptions or examples to guide LLM extract values for this field. If left empty, it will rely on the field name.',
restrictTDefinedValuesTip:
'Enum Mode: Restricts LLM extraction to match preset values only. Define values below.',
valueExists:
'Value already exists. Confirm to merge duplicates and combine all associated files.',
fieldNameExists:
'Field name already exists. Confirm to merge duplicates and combine all associated files.',
fieldExists: 'Field already exists.',
fieldSetting: 'Field settings',
changesAffectNewParses: 'Changes affect new parses only.',
editMetadataForDataset: 'View and edit metadata for ',
@ -185,12 +194,25 @@ Procedural Memory: Learned skills, habits, and automated procedures.`,
manageMetadata: 'Manage metadata',
metadata: 'Metadata',
values: 'Values',
value: 'Value',
action: 'Action',
field: 'Field',
description: 'Description',
fieldName: 'Field name',
editMetadata: 'Edit metadata',
deleteWarn: 'This {{field}} will be removed from all associated files',
deleteManageFieldAllWarn:
'This field and all its corresponding values will be deleted from all associated files.',
deleteManageValueAllWarn:
'This value will be deleted from from all associated files.',
deleteManageFieldSingleWarn:
'This field and all its corresponding values will be deleted from this files.',
deleteManageValueSingleWarn:
'This value will be deleted from this files.',
deleteSettingFieldWarn: `This field will be deleted; existing metadata won't be affected.`,
deleteSettingValueWarn: `This value will be deleted; existing metadata won't be affected.`,
},
emptyMetadata: 'No metadata',
metadataField: 'Metadata field',
systemAttribute: 'System attribute',
localUpload: 'Local upload',
@ -334,9 +356,9 @@ Procedural Memory: Learned skills, habits, and automated procedures.`,
html4excel: 'Excel to HTML',
html4excelTip: `Use with the General chunking method. When disabled, spreadsheets (XLSX or XLS(Excel 97-2003)) in the knowledge base will be parsed into key-value pairs. When enabled, they will be parsed into HTML tables, splitting every 12 rows if the original table has more than 12 rows. See https://ragflow.io/docs/dev/enable_excel2html for details.`,
autoKeywords: 'Auto-keyword',
autoKeywordsTip: `Automatically extract N keywords for each chunk to increase their ranking for queries containing those keywords. Be aware that extra tokens will be consumed by the chat model specified in 'System model settings'. You can check or update the added keywords for a chunk from the chunk list. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
autoKeywordsTip: `Automatically extract N keywords for each chunk to increase their ranking for queries containing those keywords. Be aware that extra tokens will be consumed by the indexing model specified in 'Configuration'. You can check or update the added keywords for a chunk from the chunk list. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
autoQuestions: 'Auto-question',
autoQuestionsTip: `Automatically extract N questions for each chunk to increase their ranking for queries containing those questions. You can check or update the added questions for a chunk from the chunk list. This feature will not disrupt the chunking process if an error occurs, except that it may add an empty result to the original chunk. Be aware that extra tokens will be consumed by the LLM specified in 'System model settings'. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
autoQuestionsTip: `Automatically extract N questions for each chunk to increase their ranking for queries containing those questions. You can check or update the added questions for a chunk from the chunk list. This feature will not disrupt the chunking process if an error occurs, except that it may add an empty result to the original chunk. Be aware that extra tokens will be consumed by the indexing model specified in 'Configuration'. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
redo: 'Do you want to clear the existing {{chunkNum}} chunks?',
setMetaData: 'Set meta data',
pleaseInputJson: 'Please enter JSON',
@ -1607,6 +1629,8 @@ Example: Virtual Hosted Style`,
notEmpty: 'Not empty',
in: 'In',
notIn: 'Not in',
is: 'Is',
isNot: 'Is not',
},
switchLogicOperatorOptions: {
and: 'AND',
@ -1868,6 +1892,8 @@ This process aggregates variables from multiple branches into a single variable
beginInputTip:
'By defining input parameters, this content can be accessed by other components in subsequent processes.',
query: 'Query variables',
switchPromptMessage:
'The prompt words will change. Please confirm whether you want to discard the existing prompt words?',
queryRequired: 'Query is required',
queryTip: 'Select the variable you want to use',
agent: 'Agent',

View File

@ -168,6 +168,10 @@ export default {
},
knowledgeDetails: {
metadata: {
descriptionTip:
'提供描述或示例来指导大语言模型为此字段提取值。如果留空,将依赖字段名称。',
restrictTDefinedValuesTip:
'枚举模式:限制大语言模型仅提取预设值。在下方定义值。',
fieldSetting: '字段设置',
changesAffectNewParses: '更改仅影响新解析。',
editMetadataForDataset: '查看和编辑元数据于 ',
@ -177,12 +181,25 @@ export default {
manageMetadata: '管理元数据',
metadata: '元数据',
values: '值',
value: '值',
action: '操作',
field: '字段',
description: '描述',
fieldName: '字段名',
editMetadata: '编辑元数据',
valueExists: '值已存在。确认合并重复项并组合所有关联文件。',
fieldNameExists: '字段名已存在。确认合并重复项并组合所有关联文件。',
fieldExists: '字段名已存在。',
deleteWarn: '此 {{field}} 将从所有关联文件中移除',
deleteManageFieldAllWarn:
'此字段及其所有对应值将从所有关联的文件中删除。',
deleteManageValueAllWarn: '此值将从所有关联的文件中删除。',
deleteManageFieldSingleWarn: '此字段及其所有对应值将从此文件中删除。',
deleteManageValueSingleWarn: '此值将从此文件中删除。',
deleteSettingFieldWarn: `此字段将被删除;现有元数据不会受到影响。`,
deleteSettingValueWarn: `此值将被删除;现有元数据不会受到影响。`,
},
emptyMetadata: '无元数据',
localUpload: '本地上传',
fileSize: '文件大小',
fileType: '文件类型',
@ -311,9 +328,9 @@ export default {
html4excel: '表格转HTML',
html4excelTip: `与 General 切片方法配合使用。未开启状态下表格文件XLSX、XLSExcel 97-2003会按行解析为键值对。开启后表格文件会被解析为 HTML 表格。若原始表格超过 12 行,系统会自动按每 12 行拆分为多个 HTML 表格。欲了解更多详情,请参阅 https://ragflow.io/docs/dev/enable_excel2html。`,
autoKeywords: '自动关键词提取',
autoKeywordsTip: `自动为每个文本块中提取 N 个关键词,用以提升查询精度。请注意:该功能采用“系统模型设置”中设置的默认聊天模型提取关键词,因此也会产生更多 Token 消耗。另外,你也可以手动更新生成的关键词。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
autoKeywordsTip: `自动为每个文本块中提取 N 个关键词,用以提升查询精度。请注意:该功能采用在“配置”中指定的索引模型提取关键词,因此也会产生更多 Token 消耗。另外,你也可以手动更新生成的关键词。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
autoQuestions: '自动问题提取',
autoQuestionsTip: `利用“系统模型设置”中设置的 chat model 对知识库的每个文本块提取 N 个问题以提高其排名得分。请注意,开启后将消耗额外的 token。您可以在块列表中查看、编辑结果。如果自动问题提取发生错误不会妨碍整个分块过程只会将空结果添加到原始文本块。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
autoQuestionsTip: `利用在“配置”中指定的索引模型 对知识库的每个文本块提取 N 个问题以提高其排名得分。请注意,开启后将消耗额外的 token。您可以在块列表中查看、编辑结果。如果自动问题提取发生错误不会妨碍整个分块过程只会将空结果添加到原始文本块。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
redo: '是否清空已有 {{chunkNum}}个 chunk',
setMetaData: '设置元数据',
pleaseInputJson: '请输入JSON',
@ -1505,6 +1522,8 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
endWith: '结束是',
empty: '为空',
notEmpty: '不为空',
is: '是',
isNot: '不是',
},
switchLogicOperatorOptions: {
and: '与',

View File

@ -4,6 +4,7 @@ import { useSendAgentMessage } from './use-send-agent-message';
import { FileUploadProps } from '@/components/file-upload';
import { NextMessageInput } from '@/components/message-input/next';
import MarkdownContent from '@/components/next-markdown-content';
import MessageItem from '@/components/next-message-item';
import PdfSheet from '@/components/pdf-drawer';
import { useClickDrawer } from '@/components/pdf-drawer/hooks';
@ -102,8 +103,10 @@ function AgentChatBox() {
{message.role === MessageType.Assistant &&
derivedMessages.length - 1 !== i && (
<div>
<div>{message?.data?.tips}</div>
<MarkdownContent
content={message?.data?.tips}
loading={false}
></MarkdownContent>
<div>
{buildInputList(message)?.map((item) => item.value)}
</div>

View File

@ -206,6 +206,7 @@ export const initialSplitterValues = {
chunk_token_size: 512,
overlapped_percent: 0,
delimiters: [{ value: '\n' }],
image_table_context_window: 0,
};
export enum Hierarchy {

View File

@ -1,3 +1,4 @@
import MarkdownContent from '@/components/next-markdown-content';
import { ButtonLoading } from '@/components/ui/button';
import {
Form,
@ -234,7 +235,14 @@ const DebugContent = ({
return (
<>
<section>
{message?.data?.tips && <div className="mb-2">{message.data.tips}</div>}
{message?.data?.tips && (
<div className="mb-2">
<MarkdownContent
content={message?.data?.tips}
loading={false}
></MarkdownContent>
</div>
)}
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
{parameters.map((x, idx) => {

View File

@ -22,6 +22,7 @@ const outputList = buildOutputList(initialSplitterValues.outputs);
export const FormSchema = z.object({
chunk_token_size: z.number(),
image_table_context_window: z.number(),
delimiters: z.array(
z.object({
value: z.string().optional(),
@ -74,6 +75,13 @@ const SplitterForm = ({ node }: INextOperatorForm) => {
min={0}
label={t('flow.overlappedPercent')}
></SliderInputFormField>
<SliderInputFormField
name="image_table_context_window"
max={256}
min={0}
label={t('knowledgeConfiguration.imageTableContextWindow')}
tooltip={t('knowledgeConfiguration.imageTableContextWindowTip')}
></SliderInputFormField>
<section>
<span className="mb-2 inline-block">{t('flow.delimiters')}</span>
<div className="space-y-4">

View File

@ -466,7 +466,7 @@ const useGraphStore = create<RFState>()(
}
},
updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
const { updateNodeForm, edges } = get();
const { updateNodeForm, edges, getOperatorTypeFromId } = get();
if (sourceHandle) {
// A handle will connect to multiple downstream nodes
let currentHandleTargets = edges
@ -474,7 +474,8 @@ const useGraphStore = create<RFState>()(
(x) =>
x.source === source &&
x.sourceHandle === sourceHandle &&
typeof x.target === 'string',
typeof x.target === 'string' &&
getOperatorTypeFromId(x.target) !== Operator.Placeholder,
)
.map((x) => x.target);

View File

@ -289,10 +289,14 @@ function transformParserParams(params: ParserFormSchemaType) {
}
function transformSplitterParams(params: SplitterFormSchemaType) {
const { image_table_context_window, ...rest } = params;
const imageTableContextWindow = Number(image_table_context_window || 0);
return {
...params,
...rest,
overlapped_percent: Number(params.overlapped_percent) / 100,
delimiters: transformObjectArrayToPureArray(params.delimiters, 'value'),
table_context_size: imageTableContextWindow,
image_context_size: imageTableContextWindow,
// Unset children delimiters if this option is not enabled
children_delimiters: params.enable_children

View File

@ -1,11 +1,15 @@
import message from '@/components/ui/message';
import { useSetModalState } from '@/hooks/common-hooks';
import { useSetDocumentMeta } from '@/hooks/use-document-request';
import {
DocumentApiAction,
useSetDocumentMeta,
} from '@/hooks/use-document-request';
import kbService, {
getMetaDataService,
updateMetaData,
} from '@/services/knowledge-service';
import { useQuery } from '@tanstack/react-query';
import { useQuery, useQueryClient } from '@tanstack/react-query';
import { TFunction } from 'i18next';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useParams } from 'umi';
@ -16,12 +20,43 @@ import {
IMetaDataTableData,
MetadataOperations,
ShowManageMetadataModalProps,
} from './interface';
} from '../interface';
export enum MetadataType {
Manage = 1,
UpdateSingle = 2,
Setting = 3,
SingleFileSetting = 4,
}
export const MetadataDeleteMap = (
t: TFunction<'translation', undefined>,
): Record<
MetadataType,
{ title: string; warnFieldText: string; warnValueText: string }
> => {
return {
[MetadataType.Manage]: {
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
warnFieldText: t('knowledgeDetails.metadata.deleteManageFieldAllWarn'),
warnValueText: t('knowledgeDetails.metadata.deleteManageValueAllWarn'),
},
[MetadataType.Setting]: {
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
warnFieldText: t('knowledgeDetails.metadata.deleteSettingFieldWarn'),
warnValueText: t('knowledgeDetails.metadata.deleteSettingValueWarn'),
},
[MetadataType.UpdateSingle]: {
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
warnFieldText: t('knowledgeDetails.metadata.deleteManageFieldSingleWarn'),
warnValueText: t('knowledgeDetails.metadata.deleteManageValueSingleWarn'),
},
[MetadataType.SingleFileSetting]: {
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
warnFieldText: t('knowledgeDetails.metadata.deleteSettingFieldWarn'),
warnValueText: t('knowledgeDetails.metadata.deleteSettingValueWarn'),
},
};
};
export const util = {
changeToMetaDataTableData(data: IMetaDataReturnType): IMetaDataTableData[] {
return Object.entries(data).map(([key, value]) => {
@ -39,10 +74,21 @@ export const util = {
data: Record<string, string | string[]>,
): IMetaDataTableData[] {
return Object.entries(data).map(([key, value]) => {
let thisValue = [] as string[];
if (value && Array.isArray(value)) {
thisValue = value;
} else if (value && typeof value === 'string') {
thisValue = [value];
} else if (value && typeof value === 'object') {
thisValue = [JSON.stringify(value)];
} else if (value) {
thisValue = [value.toString()];
}
return {
field: key,
description: '',
values: value,
values: thisValue,
} as IMetaDataTableData;
});
},
@ -100,12 +146,42 @@ export const useMetadataOperations = () => {
}));
}, []);
// const addUpdateValue = useCallback(
// (key: string, value: string | string[]) => {
// setOperations((prev) => ({
// ...prev,
// updates: [...prev.updates, { key, value }],
// }));
// },
// [],
// );
const addUpdateValue = useCallback(
(key: string, value: string | string[]) => {
setOperations((prev) => ({
...prev,
updates: [...prev.updates, { key, value }],
}));
(key: string, originalValue: string, newValue: string) => {
setOperations((prev) => {
const existsIndex = prev.updates.findIndex(
(update) => update.key === key && update.match === originalValue,
);
if (existsIndex > -1) {
const updatedUpdates = [...prev.updates];
updatedUpdates[existsIndex] = {
key,
match: originalValue,
value: newValue,
};
return {
...prev,
updates: updatedUpdates,
};
}
return {
...prev,
updates: [
...prev.updates,
{ key, match: originalValue, value: newValue },
],
};
});
},
[],
);
@ -191,9 +267,14 @@ export const useManageMetaDataModal = (
const { data, loading } = useFetchMetaDataManageData(type);
const [tableData, setTableData] = useState<IMetaDataTableData[]>(metaData);
const { operations, addDeleteRow, addDeleteValue, addUpdateValue } =
useMetadataOperations();
const queryClient = useQueryClient();
const {
operations,
addDeleteRow,
addDeleteValue,
addUpdateValue,
resetOperations,
} = useMetadataOperations();
const { setDocumentMeta } = useSetDocumentMeta();
@ -259,11 +340,15 @@ export const useManageMetaDataModal = (
data: operations,
});
if (res.code === 0) {
queryClient.invalidateQueries({
queryKey: [DocumentApiAction.FetchDocumentList],
});
resetOperations();
message.success(t('message.operated'));
callback();
}
},
[operations, id, t],
[operations, id, t, queryClient, resetOperations],
);
const handleSaveUpdateSingle = useCallback(
@ -297,7 +382,26 @@ export const useManageMetaDataModal = (
return data;
},
[tableData, id],
[tableData, id, t],
);
const handleSaveSingleFileSettings = useCallback(
async (callback: () => void) => {
const data = util.tableDataToMetaDataSettingJSON(tableData);
if (otherData?.documentId) {
const { data: res } = await kbService.documentUpdateMetaData({
doc_id: otherData.documentId,
metadata: data,
});
if (res.code === 0) {
message.success(t('message.operated'));
callback?.();
}
}
return data;
},
[tableData, t, otherData],
);
const handleSave = useCallback(
@ -311,12 +415,20 @@ export const useManageMetaDataModal = (
break;
case MetadataType.Setting:
return handleSaveSettings(callback);
case MetadataType.SingleFileSetting:
return handleSaveSingleFileSettings(callback);
default:
handleSaveManage(callback);
break;
}
},
[handleSaveManage, type, handleSaveUpdateSingle, handleSaveSettings],
[
handleSaveManage,
type,
handleSaveUpdateSingle,
handleSaveSettings,
handleSaveSingleFileSettings,
],
);
return {
@ -371,11 +483,3 @@ export const useManageMetadata = () => {
config,
};
};
export const useManageValues = () => {
const [updateValues, setUpdateValues] = useState<{
field: string;
values: string[];
} | null>(null);
return { updateValues, setUpdateValues };
};

View File

@ -0,0 +1,208 @@
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { MetadataDeleteMap, MetadataType } from '../hooks/use-manage-modal';
import { IManageValuesProps, IMetaDataTableData } from '../interface';
export const useManageValues = (props: IManageValuesProps) => {
const {
data,
isShowValueSwitch,
hideModal,
onSave,
addUpdateValue,
addDeleteValue,
existsKeys,
type,
} = props;
const { t } = useTranslation();
const [metaData, setMetaData] = useState(data);
const [valueError, setValueError] = useState<Record<string, string>>({
field: '',
values: '',
});
const [deleteDialogContent, setDeleteDialogContent] = useState({
visible: false,
title: '',
name: '',
warnText: '',
onOk: () => {},
onCancel: () => {},
});
const hideDeleteModal = () => {
setDeleteDialogContent({
visible: false,
title: '',
name: '',
warnText: '',
onOk: () => {},
onCancel: () => {},
});
};
// Use functional update to avoid closure issues
const handleChange = useCallback(
(field: string, value: any) => {
if (field === 'field' && existsKeys.includes(value)) {
setValueError((prev) => {
return {
...prev,
field:
type === MetadataType.Setting
? t('knowledgeDetails.metadata.fieldExists')
: t('knowledgeDetails.metadata.fieldNameExists'),
};
});
} else if (field === 'field' && !existsKeys.includes(value)) {
setValueError((prev) => {
return {
...prev,
field: '',
};
});
}
setMetaData((prev) => ({
...prev,
[field]: value,
}));
},
[existsKeys, type, t],
);
// Maintain separate state for each input box
const [tempValues, setTempValues] = useState<string[]>([...data.values]);
useEffect(() => {
setTempValues([...data.values]);
setMetaData(data);
}, [data]);
const handleHideModal = useCallback(() => {
hideModal();
setMetaData({} as IMetaDataTableData);
}, [hideModal]);
const handleSave = useCallback(() => {
if (type === MetadataType.Setting && valueError.field) {
return;
}
if (!metaData.restrictDefinedValues && isShowValueSwitch) {
const newMetaData = { ...metaData, values: [] };
onSave(newMetaData);
} else {
onSave(metaData);
}
handleHideModal();
}, [metaData, onSave, handleHideModal, isShowValueSwitch, type, valueError]);
// Handle value changes, only update temporary state
const handleValueChange = useCallback(
(index: number, value: string) => {
setTempValues((prev) => {
if (prev.includes(value)) {
setValueError((prev) => {
return {
...prev,
values: t('knowledgeDetails.metadata.valueExists'),
};
});
} else {
setValueError((prev) => {
return {
...prev,
values: '',
};
});
}
const newValues = [...prev];
newValues[index] = value;
return newValues;
});
},
[t],
);
// Handle blur event, synchronize to main state
const handleValueBlur = useCallback(() => {
// addUpdateValue(metaData.field, [...new Set([...tempValues])]);
tempValues.forEach((newValue, index) => {
if (index < data.values.length) {
const originalValue = data.values[index];
if (originalValue !== newValue) {
addUpdateValue(metaData.field, originalValue, newValue);
}
} else {
if (newValue) {
addUpdateValue(metaData.field, '', newValue);
}
}
});
handleChange('values', [...new Set([...tempValues])]);
}, [handleChange, tempValues, metaData, data, addUpdateValue]);
// Handle delete operation
const handleDelete = useCallback(
(index: number) => {
setTempValues((prev) => {
const newTempValues = [...prev];
addDeleteValue(metaData.field, newTempValues[index]);
newTempValues.splice(index, 1);
return newTempValues;
});
// Synchronize to main state
setMetaData((prev) => {
const newMetaDataValues = [...prev.values];
newMetaDataValues.splice(index, 1);
return {
...prev,
values: newMetaDataValues,
};
});
},
[addDeleteValue, metaData],
);
const showDeleteModal = (item: string, callback: () => void) => {
setDeleteDialogContent({
visible: true,
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.value'),
name: item,
warnText: MetadataDeleteMap(t)[type as MetadataType].warnValueText,
onOk: () => {
hideDeleteModal();
callback();
},
onCancel: () => {
hideDeleteModal();
},
});
};
// Handle adding new value
const handleAddValue = useCallback(() => {
setTempValues((prev) => [...new Set([...prev, ''])]);
// Synchronize to main state
setMetaData((prev) => ({
...prev,
values: [...new Set([...prev.values, ''])],
}));
}, []);
return {
metaData,
tempValues,
valueError,
deleteDialogContent,
handleChange,
handleValueChange,
handleValueBlur,
handleDelete,
handleAddValue,
showDeleteModal,
handleSave,
handleHideModal,
};
};

View File

@ -39,6 +39,7 @@ export type IManageModalProps = {
export interface IManageValuesProps {
title: ReactNode;
existsKeys: string[];
visible: boolean;
isEditField?: boolean;
isAddValue?: boolean;
@ -46,9 +47,14 @@ export interface IManageValuesProps {
isShowValueSwitch?: boolean;
isVerticalShowValue?: boolean;
data: IMetaDataTableData;
type: MetadataType;
hideModal: () => void;
onSave: (data: IMetaDataTableData) => void;
addUpdateValue: (key: string, value: string | string[]) => void;
addUpdateValue: (
key: string,
originalValue: string,
newValue: string,
) => void;
addDeleteValue: (key: string, value: string) => void;
}
@ -59,7 +65,8 @@ interface DeleteOperation {
interface UpdateOperation {
key: string;
value: string | string[];
match: string;
value: string;
}
export interface MetadataOperations {

View File

@ -25,11 +25,16 @@ import {
useReactTable,
} from '@tanstack/react-table';
import { Plus, Settings, Trash2 } from 'lucide-react';
import { useCallback, useMemo, useState } from 'react';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { MetadataType, useManageMetaDataModal } from './hook';
import {
MetadataDeleteMap,
MetadataType,
useManageMetaDataModal,
} from './hooks/use-manage-modal';
import { IManageModalProps, IMetaDataTableData } from './interface';
import { ManageValuesModal } from './manage-values-modal';
export const ManageMetadataModal = (props: IManageModalProps) => {
const {
title,
@ -54,6 +59,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
values: [],
});
const [currentValueIndex, setCurrentValueIndex] = useState<number>(0);
const [deleteDialogContent, setDeleteDialogContent] = useState({
visible: false,
title: '',
@ -94,12 +100,12 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
description: '',
values: [],
});
// setCurrentValueIndex(tableData.length || 0);
setCurrentValueIndex(tableData.length || 0);
showManageValuesModal();
};
const handleEditValueRow = useCallback(
(data: IMetaDataTableData) => {
// setCurrentValueIndex(index);
(data: IMetaDataTableData, index: number) => {
setCurrentValueIndex(index);
setValueData(data);
showManageValuesModal();
},
@ -133,7 +139,8 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
const values = row.getValue('values') as Array<string>;
return (
<div className="flex items-center gap-1">
{values.length > 0 &&
{Array.isArray(values) &&
values.length > 0 &&
values
.filter((value: string, index: number) => index < 2)
?.map((value: string) => {
@ -153,10 +160,28 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
variant={'delete'}
className="p-0 bg-transparent"
onClick={() => {
handleDeleteSingleValue(
row.getValue('field'),
value,
);
setDeleteDialogContent({
visible: true,
title:
t('common.delete') +
' ' +
t('knowledgeDetails.metadata.value'),
name: value,
warnText:
MetadataDeleteMap(t)[
metadataType as MetadataType
].warnValueText,
onOk: () => {
hideDeleteModal();
handleDeleteSingleValue(
row.getValue('field'),
value,
);
},
onCancel: () => {
hideDeleteModal();
},
});
}}
>
<Trash2 />
@ -166,7 +191,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
</Button>
);
})}
{values.length > 2 && (
{Array.isArray(values) && values.length > 2 && (
<div className="text-text-secondary self-end">...</div>
)}
</div>
@ -185,7 +210,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
variant={'ghost'}
className="bg-transparent px-1 py-0"
onClick={() => {
handleEditValueRow(row.original);
handleEditValueRow(row.original, row.index);
}}
>
<Settings />
@ -197,11 +222,14 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
setDeleteDialogContent({
visible: true,
title:
t('common.delete') +
' ' +
t('knowledgeDetails.metadata.metadata'),
// t('common.delete') +
// ' ' +
// t('knowledgeDetails.metadata.metadata')
MetadataDeleteMap(t)[metadataType as MetadataType].title,
name: row.getValue('field'),
warnText: t('knowledgeDetails.metadata.deleteWarn'),
warnText:
MetadataDeleteMap(t)[metadataType as MetadataType]
.warnFieldText,
onOk: () => {
hideDeleteModal();
handleDeleteSingleRow(row.getValue('field'));
@ -240,15 +268,29 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
getFilteredRowModel: getFilteredRowModel(),
manualPagination: true,
});
const [shouldSave, setShouldSave] = useState(false);
const handleSaveValues = (data: IMetaDataTableData) => {
setTableData((prev) => {
//If the keys are the same, they need to be merged.
const fieldMap = new Map<string, any>();
let newData;
if (currentValueIndex >= prev.length) {
// Add operation
newData = [...prev, data];
} else {
// Edit operation
newData = prev.map((item, index) => {
if (index === currentValueIndex) {
return data;
}
return item;
});
}
prev.forEach((item) => {
// Deduplicate by field and merge values
const fieldMap = new Map<string, IMetaDataTableData>();
newData.forEach((item) => {
if (fieldMap.has(item.field)) {
const existingItem = fieldMap.get(item.field);
// Merge values if field exists
const existingItem = fieldMap.get(item.field)!;
const mergedValues = [
...new Set([...existingItem.values, ...item.values]),
];
@ -258,20 +300,26 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
}
});
if (fieldMap.has(data.field)) {
const existingItem = fieldMap.get(data.field);
const mergedValues = [
...new Set([...existingItem.values, ...data.values]),
];
fieldMap.set(data.field, { ...existingItem, values: mergedValues });
} else {
fieldMap.set(data.field, data);
}
return Array.from(fieldMap.values());
});
setShouldSave(true);
};
useEffect(() => {
if (shouldSave) {
const timer = setTimeout(() => {
handleSave({ callback: () => {} });
setShouldSave(false);
}, 0);
return () => clearTimeout(timer);
}
}, [tableData, shouldSave, handleSave]);
const existsKeys = useMemo(() => {
return tableData.map((item) => item.field);
}, [tableData]);
return (
<>
<Modal
@ -352,11 +400,14 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
<ManageValuesModal
title={
<div>
{metadataType === MetadataType.Setting
{metadataType === MetadataType.Setting ||
metadataType === MetadataType.SingleFileSetting
? t('knowledgeDetails.metadata.fieldSetting')
: t('knowledgeDetails.metadata.editMetadata')}
</div>
}
type={metadataType}
existsKeys={existsKeys}
visible={manageValuesVisible}
hideModal={hideManageValuesModal}
data={valueData}

View File

@ -1,13 +1,19 @@
import {
ConfirmDeleteDialog,
ConfirmDeleteDialogNode,
} from '@/components/confirm-delete-dialog';
import EditTag from '@/components/edit-tag';
import { Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Input } from '@/components/ui/input';
import { Modal } from '@/components/ui/modal/modal';
import { Switch } from '@/components/ui/switch';
import { Textarea } from '@/components/ui/textarea';
import { Plus, Trash2 } from 'lucide-react';
import { memo, useCallback, useEffect, useState } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { IManageValuesProps, IMetaDataTableData } from './interface';
import { useManageValues } from './hooks/use-manage-values-modal';
import { IManageValuesProps } from './interface';
// Create a separate input component, wrapped with memo to avoid unnecessary re-renders
const ValueInputItem = memo(
@ -52,102 +58,29 @@ const ValueInputItem = memo(
export const ManageValuesModal = (props: IManageValuesProps) => {
const {
title,
data,
isEditField,
visible,
isAddValue,
isShowDescription,
isShowValueSwitch,
isVerticalShowValue,
hideModal,
onSave,
addUpdateValue,
addDeleteValue,
} = props;
const [metaData, setMetaData] = useState(data);
const {
metaData,
tempValues,
valueError,
deleteDialogContent,
handleChange,
handleValueChange,
handleValueBlur,
handleDelete,
handleAddValue,
showDeleteModal,
handleSave,
handleHideModal,
} = useManageValues(props);
const { t } = useTranslation();
// Use functional update to avoid closure issues
const handleChange = useCallback((field: string, value: any) => {
setMetaData((prev) => ({
...prev,
[field]: value,
}));
}, []);
// Maintain separate state for each input box
const [tempValues, setTempValues] = useState<string[]>([...data.values]);
useEffect(() => {
setTempValues([...data.values]);
setMetaData(data);
}, [data]);
const handleHideModal = useCallback(() => {
hideModal();
setMetaData({} as IMetaDataTableData);
}, [hideModal]);
const handleSave = useCallback(() => {
if (!metaData.restrictDefinedValues && isShowValueSwitch) {
const newMetaData = { ...metaData, values: [] };
onSave(newMetaData);
} else {
onSave(metaData);
}
handleHideModal();
}, [metaData, onSave, handleHideModal, isShowValueSwitch]);
// Handle value changes, only update temporary state
const handleValueChange = useCallback((index: number, value: string) => {
setTempValues((prev) => {
const newValues = [...prev];
newValues[index] = value;
return newValues;
});
}, []);
// Handle blur event, synchronize to main state
const handleValueBlur = useCallback(() => {
addUpdateValue(metaData.field, [...new Set([...tempValues])]);
handleChange('values', [...new Set([...tempValues])]);
}, [handleChange, tempValues, metaData, addUpdateValue]);
// Handle delete operation
const handleDelete = useCallback(
(index: number) => {
setTempValues((prev) => {
const newTempValues = [...prev];
addDeleteValue(metaData.field, newTempValues[index]);
newTempValues.splice(index, 1);
return newTempValues;
});
// Synchronize to main state
setMetaData((prev) => {
const newMetaDataValues = [...prev.values];
newMetaDataValues.splice(index, 1);
return {
...prev,
values: newMetaDataValues,
};
});
},
[addDeleteValue, metaData],
);
// Handle adding new value
const handleAddValue = useCallback(() => {
setTempValues((prev) => [...new Set([...prev, ''])]);
// Synchronize to main state
setMetaData((prev) => ({
...prev,
values: [...new Set([...prev.values, ''])],
}));
}, []);
return (
<Modal
title={title}
@ -172,15 +105,24 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
<Input
value={metaData.field}
onChange={(e) => {
handleChange('field', e.target?.value || '');
const value = e.target?.value || '';
if (/^[a-zA-Z_]*$/.test(value)) {
handleChange('field', value);
}
}}
/>
<div className="text-state-error text-sm">{valueError.field}</div>
</div>
</div>
)}
{isShowDescription && (
<div className="flex flex-col gap-2">
<div>{t('knowledgeDetails.metadata.description')}</div>
<FormLabel
className="text-text-primary text-base"
tooltip={t('knowledgeDetails.metadata.descriptionTip')}
>
{t('knowledgeDetails.metadata.description')}
</FormLabel>
<div>
<Textarea
value={metaData.description}
@ -193,7 +135,12 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
)}
{isShowValueSwitch && (
<div className="flex flex-col gap-2">
<div>{t('knowledgeDetails.metadata.restrictDefinedValues')}</div>
<FormLabel
className="text-text-primary text-base"
tooltip={t('knowledgeDetails.metadata.restrictTDefinedValuesTip')}
>
{t('knowledgeDetails.metadata.restrictDefinedValues')}
</FormLabel>
<div>
<Switch
checked={metaData.restrictDefinedValues || false}
@ -230,7 +177,11 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
item={item}
index={index}
onValueChange={handleValueChange}
onDelete={handleDelete}
onDelete={(idx: number) => {
showDeleteModal(item, () => {
handleDelete(idx);
});
}}
onBlur={handleValueBlur}
/>
);
@ -240,11 +191,41 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
{!isVerticalShowValue && (
<EditTag
value={metaData.values}
onChange={(value) => handleChange('values', value)}
onChange={(value) => {
// find deleted value
const item = metaData.values.find(
(item) => !value.includes(item),
);
if (item) {
showDeleteModal(item, () => {
// handleDelete(idx);
handleChange('values', value);
});
} else {
handleChange('values', value);
}
}}
/>
)}
<div className="text-state-error text-sm">{valueError.values}</div>
</div>
)}
{deleteDialogContent.visible && (
<ConfirmDeleteDialog
open={deleteDialogContent.visible}
onCancel={deleteDialogContent.onCancel}
onOk={deleteDialogContent.onOk}
title={deleteDialogContent.title}
content={{
node: (
<ConfirmDeleteDialogNode
name={deleteDialogContent.name}
warnText={deleteDialogContent.warnText}
/>
),
}}
/>
)}
</div>
</Modal>
);

View File

@ -35,7 +35,8 @@ import {
MetadataType,
useManageMetadata,
util,
} from '../../components/metedata/hook';
} from '../../components/metedata/hooks/use-manage-modal';
import { IMetaDataReturnJSONSettings } from '../../components/metedata/interface';
import { ManageMetadataModal } from '../../components/metedata/manage-modal';
import {
useHandleKbEmbedding,
@ -359,7 +360,13 @@ export function OverlappedPercent() {
);
}
export function AutoMetadata() {
export function AutoMetadata({
type = MetadataType.Setting,
otherData,
}: {
type?: MetadataType;
otherData?: Record<string, any>;
}) {
// get metadata field
const form = useFormContext();
const {
@ -369,6 +376,7 @@ export function AutoMetadata() {
tableData,
config: metadataConfig,
} = useManageMetadata();
const autoMetadataField: FormFieldConfig = {
name: 'parser_config.enable_metadata',
label: t('knowledgeConfiguration.autoMetadata'),
@ -379,6 +387,7 @@ export function AutoMetadata() {
render: (fieldProps: ControllerRenderProps) => (
<div className="flex items-center justify-between">
<Button
type="button"
variant="ghost"
onClick={() => {
const metadata = form.getValues('parser_config.metadata');
@ -387,7 +396,8 @@ export function AutoMetadata() {
showManageMetadataModal({
metadata: tableMetaData,
isCanAdd: true,
type: MetadataType.Setting,
type: type,
record: otherData,
});
}}
>
@ -403,6 +413,10 @@ export function AutoMetadata() {
</div>
),
};
const handleSaveMetadata = (data?: IMetaDataReturnJSONSettings) => {
form.setValue('parser_config.metadata', data || []);
};
return (
<>
<RenderField field={autoMetadataField} />
@ -431,8 +445,8 @@ export function AutoMetadata() {
isShowDescription={true}
isShowValueSwitch={true}
isVerticalShowValue={false}
success={(data) => {
form.setValue('parser_config.metadata', data || []);
success={(data?: IMetaDataReturnJSONSettings) => {
handleSaveMetadata(data);
}}
/>
)}

View File

@ -96,7 +96,7 @@ export const formSchema = z
)
.optional(),
enable_metadata: z.boolean().optional(),
llm_id: z.string().optional(),
llm_id: z.string().min(1, { message: 'Indexing model is required' }),
})
.optional(),
pagerank: z.number(),

View File

@ -16,7 +16,10 @@ import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-reques
import { Pen, Upload } from 'lucide-react';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { MetadataType, useManageMetadata } from '../components/metedata/hook';
import {
MetadataType,
useManageMetadata,
} from '../components/metedata/hooks/use-manage-modal';
import { ManageMetadataModal } from '../components/metedata/manage-modal';
import { DatasetTable } from './dataset-table';
import Generate from './generate-button/generate';

View File

@ -16,7 +16,10 @@ import { formatDate } from '@/utils/date';
import { ColumnDef } from '@tanstack/table-core';
import { ArrowUpDown, MonitorUp } from 'lucide-react';
import { useTranslation } from 'react-i18next';
import { MetadataType, util } from '../components/metedata/hook';
import {
MetadataType,
util,
} from '../components/metedata/hooks/use-manage-modal';
import { ShowManageMetadataModalProps } from '../components/metedata/interface';
import { DatasetActionCell } from './dataset-action-cell';
import { ParsingStatusCell } from './parsing-status-cell';

View File

@ -1,8 +1,13 @@
import { FilterCollection } from '@/components/list-filter-bar/interface';
import {
FilterCollection,
FilterType,
} from '@/components/list-filter-bar/interface';
import { useTranslate } from '@/hooks/common-hooks';
import { useGetDocumentFilter } from '@/hooks/use-document-request';
import { useMemo } from 'react';
export const EMPTY_METADATA_FIELD = 'empty_metadata';
export function useSelectDatasetFilters() {
const { t } = useTranslate('knowledgeDetails');
const { filter, onOpenChange } = useGetDocumentFilter();
@ -17,34 +22,52 @@ export function useSelectDatasetFilters() {
}
}, [filter.suffix]);
const fileStatus = useMemo(() => {
let list = [] as FilterType[];
if (filter.run_status) {
return Object.keys(filter.run_status).map((x) => ({
list = Object.keys(filter.run_status).map((x) => ({
id: x,
label: t(`runningStatus${x}`),
count: filter.run_status[x as unknown as number],
}));
}
}, [filter.run_status, t]);
if (filter.metadata) {
const emptyMetadata = filter.metadata?.empty_metadata;
if (emptyMetadata) {
list.push({
id: EMPTY_METADATA_FIELD,
label: t('emptyMetadata'),
count: emptyMetadata.true,
});
}
}
return list;
}, [filter.run_status, filter.metadata, t]);
const metaDataList = useMemo(() => {
if (filter.metadata) {
return Object.keys(filter.metadata).map((x) => ({
id: x.toString(),
field: x.toString(),
label: x.toString(),
list: Object.keys(filter.metadata[x]).map((y) => ({
id: y.toString(),
field: y.toString(),
label: y.toString(),
value: [y],
count: filter.metadata[x][y],
})),
count: Object.keys(filter.metadata[x]).reduce(
(acc, cur) => acc + filter.metadata[x][cur],
0,
),
}));
const list = Object.keys(filter.metadata)
?.filter((m) => m !== EMPTY_METADATA_FIELD)
?.map((x) => {
return {
id: x.toString(),
field: x.toString(),
label: x.toString(),
list: Object.keys(filter.metadata[x]).map((y) => ({
id: y.toString(),
field: y.toString(),
label: y.toString(),
value: [y],
count: filter.metadata[x][y],
})),
count: Object.keys(filter.metadata[x]).reduce(
(acc, cur) => acc + filter.metadata[x][cur],
0,
),
};
});
return list;
}
}, [filter.metadata]);
const filters: FilterCollection[] = useMemo(() => {
return [
{ field: 'type', label: 'File Type', list: fileTypes },

View File

@ -1,6 +1,11 @@
import FileStatusBadge from '@/components/file-status-badge';
import { Button } from '@/components/ui/button';
import { Modal } from '@/components/ui/modal/modal';
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from '@/components/ui/tooltip';
import { RunningStatusMap } from '@/constants/knowledge';
import { useTranslate } from '@/hooks/common-hooks';
import React, { useMemo } from 'react';
@ -40,7 +45,14 @@ const InfoItem: React.FC<{
return (
<div className={`flex flex-col mb-4 ${className}`}>
<span className="text-text-secondary text-sm">{label}</span>
<span className="text-text-primary mt-1">{value}</span>
<Tooltip>
<TooltipTrigger asChild>
<span className="text-text-primary mt-1 truncate max-w-[200px]">
{value}
</span>
</TooltipTrigger>
<TooltipContent>{value}</TooltipContent>
</Tooltip>
</div>
);
};
@ -70,9 +82,7 @@ const ProcessLogModal: React.FC<ProcessLogModalProps> = ({
}) => {
const { t } = useTranslate('knowledgeDetails');
const blackKeyList = [''];
console.log('logInfo', initData);
const logInfo = useMemo(() => {
console.log('logInfo', initData);
return initData;
}, [initData]);

View File

@ -17,6 +17,7 @@ import {
} from '@/components/ui/form';
import { Input } from '@/components/ui/input';
import { FormLayout } from '@/constants/form';
import { useFetchTenantInfo } from '@/hooks/use-user-setting-request';
import { IModalProps } from '@/interfaces/common';
import { zodResolver } from '@hookform/resolvers/zod';
import { useEffect } from 'react';
@ -33,6 +34,7 @@ const FormId = 'dataset-creating-form';
export function InputForm({ onOk }: IModalProps<any>) {
const { t } = useTranslation();
const { data: tenantInfo } = useFetchTenantInfo();
const FormSchema = z
.object({
@ -80,7 +82,7 @@ export function InputForm({ onOk }: IModalProps<any>) {
name: '',
parseType: 1,
parser_id: '',
embd_id: '',
embd_id: tenantInfo?.embd_id,
},
});

View File

@ -48,10 +48,10 @@ const {
traceRaptor,
check_embedding,
kbUpdateMetaData,
documentUpdateMetaData,
} = api;
const methods = {
// 知识库管理
createKb: {
url: create_kb,
method: 'post',
@ -220,6 +220,10 @@ const methods = {
url: kbUpdateMetaData,
method: 'post',
},
documentUpdateMetaData: {
url: documentUpdateMetaData,
method: 'post',
},
// getMetaData: {
// url: getMetaData,
// method: 'get',
@ -263,7 +267,7 @@ export const documentFilter = (kb_id: string) =>
export const getMetaDataService = ({ kb_id }: { kb_id: string }) =>
request.post(api.getMetaData, { data: { kb_id } });
export const updateMetaData = ({ kb_id, data }: { kb_id: string; data: any }) =>
request.post(api.updateMetaData, { data: { kb_id, data } });
request.post(api.updateMetaData, { data: { kb_id, ...data } });
export const listDataPipelineLogDocument = (
params?: IFetchKnowledgeListRequestParams,

View File

@ -80,6 +80,7 @@ export default {
getMetaData: `${api_host}/document/metadata/summary`,
updateMetaData: `${api_host}/document/metadata/update`,
kbUpdateMetaData: `${api_host}/kb/update_metadata_setting`,
documentUpdateMetaData: `${api_host}/document/update_metadata_setting`,
// tags
listTag: (knowledgeId: string) => `${api_host}/kb/${knowledgeId}/tags`,