diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b65c0cb95..b00e75ef1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -219,7 +219,7 @@ jobs: sleep 5 done source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log - + - name: Run http api tests against Elasticsearch run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" diff --git a/Dockerfile b/Dockerfile index f44b643bf..5f2c5f6cf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -192,6 +192,7 @@ COPY pyproject.toml uv.lock ./ COPY mcp mcp COPY plugin plugin COPY common common +COPY memory memory COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/entrypoint.sh ./ diff --git a/admin/server/routes.py b/admin/server/routes.py index 29a4bbd22..e83f3ff08 100644 --- a/admin/server/routes.py +++ b/admin/server/routes.py @@ -29,6 +29,11 @@ from common.versions import get_ragflow_version admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin') +@admin_bp.route('/ping', methods=['GET']) +def ping(): + return success_response('PONG') + + @admin_bp.route('/login', methods=['POST']) def login(): if not request.json: diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 4dcdba809..08fab6ed5 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService from common.metadata_utils import apply_meta_data_filter from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from api.db.services.memory_service import MemoryService +from api.db.joint_services.memory_message_service import query_message from common import settings from common.connection_utils import timeout from rag.app.tag import label_question -from rag.prompts.generator import cross_languages, kb_prompt +from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt class RetrievalParam(ToolParamBase): @@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase): self.top_n = 8 self.top_k = 1024 self.kb_ids = [] + self.memory_ids = [] self.kb_vars = [] self.rerank_id = "" self.empty_response = "" @@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase): class Retrieval(ToolBase, ABC): component_name = "Retrieval" - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) - async def _invoke_async(self, **kwargs): - if self.check_if_canceled("Retrieval processing"): - return - - if not kwargs.get("query"): - self.set_output("formalized_content", self._param.empty_response) - return - + async def _retrieve_kb(self, query_text: str): kb_ids: list[str] = [] for id in self._param.kb_ids: if id.find("@") < 0: @@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC): if self._param.rerank_id: rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) - vars = self.get_input_elements_from_text(kwargs["query"]) - vars = {k:o["value"] for k,o in vars.items()} - query = self.string_format(kwargs["query"], vars) + vars = self.get_input_elements_from_text(query_text) + vars = {k: o["value"] for k, o in vars.items()} + query = self.string_format(query_text, vars) - doc_ids=[] - if self._param.meta_data_filter!={}: + doc_ids = [] + if self._param.meta_data_filter != {}: metas = DocumentService.get_meta_by_kbs(kb_ids) def _resolve_manual_filter(flt: dict) -> dict: @@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC): if self._param.toc_enhance: chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) - cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) + cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], + chat_mdl, self._param.top_n) if self.check_if_canceled("Retrieval processing"): return if cks: kbinfos["chunks"] = cks - kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs]) + kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], + [kb.tenant_id for kb in kbs]) if self._param.use_kg: ck = settings.kg_retriever.retrieval(query, - [kb.tenant_id for kb in kbs], - kb_ids, - embd_mdl, - LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)) + [kb.tenant_id for kb in kbs], + kb_ids, + embd_mdl, + LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: @@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC): kbinfos = {"chunks": [], "doc_aggs": []} if self._param.use_kg and kbs: - ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) + ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, + LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: @@ -248,6 +246,54 @@ class Retrieval(ToolBase, ABC): return form_cnt + async def _retrieve_memory(self, query_text: str): + memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids] + memory_list = MemoryService.get_by_ids(memory_ids) + if not memory_list: + raise Exception("No memory is selected.") + + embd_names = list({memory.embd_id for memory in memory_list}) + assert len(embd_names) == 1, "Memory use different embedding models." + + vars = self.get_input_elements_from_text(query_text) + vars = {k: o["value"] for k, o in vars.items()} + query = self.string_format(query_text, vars) + # query message + message_list = query_message({"memory_id": memory_ids}, { + "query": query, + "similarity_threshold": self._param.similarity_threshold, + "keywords_similarity_weight": self._param.keywords_similarity_weight, + "top_n": self._param.top_n + }) + print(f"found {len(message_list)} messages.") + + if not message_list: + self.set_output("formalized_content", self._param.empty_response) + return + formated_content = "\n".join(memory_prompt(message_list, 200000)) + + # set formalized_content output + self.set_output("formalized_content", formated_content) + print(f"formated_content {formated_content}") + return formated_content + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) + async def _invoke_async(self, **kwargs): + if self.check_if_canceled("Retrieval processing"): + return + print(f"debug retrieval, query is {kwargs.get('query')}.", flush=True) + if not kwargs.get("query"): + self.set_output("formalized_content", self._param.empty_response) + return + + if self._param.kb_ids: + return await self._retrieve_kb(kwargs["query"]) + elif self._param.memory_ids: + return await self._retrieve_memory(kwargs["query"]) + else: + self.set_output("formalized_content", self._param.empty_response) + return + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): return asyncio.run(self._invoke_async(**kwargs)) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index ab6367eb5..1ff1c0d2a 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -192,7 +192,7 @@ async def rerun(): if 0 < doc["progress"] < 1: return get_data_error_result(message=f"`{doc['name']}` is processing...") - if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): + if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]): settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) doc["progress_msg"] = "" doc["chunk_num"] = 0 diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 137ec9ac1..723e909ec 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -564,7 +564,7 @@ async def run(): DocumentService.update_by_id(id, info) if req.get("delete", False): TaskService.filter_delete([Task.doc_id == id]) - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) if str(req["run"]) == TaskStatus.RUNNING.value: @@ -615,7 +615,7 @@ async def rename(): "title_tks": title_tks, "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), } - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): settings.docStoreConn.update( {"doc_id": req["doc_id"]}, es_body, @@ -696,7 +696,7 @@ async def change_parser(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(message="Tenant not found!") - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) return None diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 5e1d16c77..fff982563 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -39,9 +39,9 @@ from api.utils.api_utils import get_json_result from rag.nlp import search from api.constants import DATASET_NAME_LIMIT from rag.utils.redis_conn import REDIS_CONN -from rag.utils.doc_store_conn import OrderByExpr from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD from common import settings +from common.doc_store.doc_store_base import OrderByExpr from api.apps import login_required, current_user @@ -285,7 +285,7 @@ async def rm(): message="Database error (Knowledgebase removal)!") for kb in kbs: settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) - settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id) if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): settings.STORAGE_IMPL.remove_bucket(kb.id) return get_json_result(data=True) @@ -386,7 +386,7 @@ def knowledge_graph(kb_id): } obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): + if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id): return get_json_result(data=obj) sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) if not len(sres.ids): @@ -858,11 +858,11 @@ async def check_embedding(): index_nm = search.index_name(tenant_id) res0 = docStoreConn.search( - selectFields=[], highlightFields=[], + select_fields=[], highlight_fields=[], condition={"kb_id": kb_id, "available_int": 1}, - matchExprs=[], orderBy=OrderByExpr(), + match_expressions=[], order_by=OrderByExpr(), offset=0, limit=1, - indexNames=index_nm, knowledgebaseIds=[kb_id] + index_names=index_nm, knowledgebase_ids=[kb_id] ) total = docStoreConn.get_total(res0) if total <= 0: @@ -874,14 +874,14 @@ async def check_embedding(): for off in offsets: res1 = docStoreConn.search( - selectFields=list(base_fields), - highlightFields=[], + select_fields=list(base_fields), + highlight_fields=[], condition={"kb_id": kb_id, "available_int": 1}, - matchExprs=[], orderBy=OrderByExpr(), + match_expressions=[], order_by=OrderByExpr(), offset=off, limit=1, - indexNames=index_nm, knowledgebaseIds=[kb_id] + index_names=index_nm, knowledgebase_ids=[kb_id] ) - ids = docStoreConn.get_chunk_ids(res1) + ids = docStoreConn.get_doc_ids(res1) if not ids: continue diff --git a/api/apps/memories_app.py b/api/apps/memories_app.py index 9a5cae936..4882b9526 100644 --- a/api/apps/memories_app.py +++ b/api/apps/memories_app.py @@ -20,10 +20,12 @@ from api.apps import login_required, current_user from api.db import TenantPermission from api.db.services.memory_service import MemoryService from api.db.services.user_service import UserTenantService +from api.db.services.canvas_service import UserCanvasService from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \ not_allowed_parameters from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT +from memory.services.messages import MessageService from common.constants import MemoryType, RetCode, ForgettingPolicy @@ -57,7 +59,6 @@ async def create_memory(): if res: return get_json_result(message=True, data=format_ret_data_from_memory(memory)) - else: return get_json_result(message=memory, code=RetCode.SERVER_ERROR) @@ -124,7 +125,7 @@ async def update_memory(memory_id): return get_json_result(message=True, data=memory_dict) try: - MemoryService.update_memory(memory_id, to_update) + MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update) updated_memory = MemoryService.get_by_memory_id(memory_id) return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory)) @@ -133,7 +134,7 @@ async def update_memory(memory_id): return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) -@manager.route("/", methods=["DELETE"]) # noqa: F821 +@manager.route("/", methods=["DELETE"]) # noqa: F821 @login_required async def delete_memory(memory_id): memory = MemoryService.get_by_memory_id(memory_id) @@ -141,13 +142,14 @@ async def delete_memory(memory_id): return get_json_result(message=True, code=RetCode.NOT_FOUND) try: MemoryService.delete_memory(memory_id) + MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) return get_json_result(message=True) except Exception as e: logging.error(e) return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) -@manager.route("", methods=["GET"]) # noqa: F821 +@manager.route("", methods=["GET"]) # noqa: F821 @login_required async def list_memory(): args = request.args @@ -183,3 +185,26 @@ async def get_memory_config(memory_id): if not memory: return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") return get_json_result(message=True, data=format_ret_data_from_memory(memory)) + + +@manager.route("/", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_detail(memory_id): + args = request.args + agent_ids = args.getlist("agent_id") + keywords = args.get("keywords", "") + keywords = keywords.strip() + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 50)) + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + messages = MessageService.list_message( + memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) + agent_name_mapping = {} + if messages["message_list"]: + agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) + agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} + for message in messages["message_list"]: + message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") + return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True) diff --git a/api/apps/messages_app.py b/api/apps/messages_app.py new file mode 100644 index 000000000..c91831b9a --- /dev/null +++ b/api/apps/messages_app.py @@ -0,0 +1,169 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from quart import request +from api.apps import login_required +from api.db.services.memory_service import MemoryService +from common.time_utils import current_timestamp, timestamp_to_date + +from memory.services.messages import MessageService +from api.db.joint_services import memory_message_service +from api.db.joint_services.memory_message_service import query_message +from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result +from common.constants import RetCode + + +@manager.route("", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") +async def add_message(): + + req = await get_request_json() + memory_ids = req["memory_id"] + agent_id = req["agent_id"] + session_id = req["session_id"] + user_id = req["user_id"] if req.get("user_id") else "" + user_input = req["user_input"] + agent_response = req["agent_response"] + + res = [] + for memory_id in memory_ids: + success, msg = await memory_message_service.save_to_memory( + memory_id, + { + "user_id": user_id, + "agent_id": agent_id, + "session_id": session_id, + "user_input": user_input, + "agent_response": agent_response + } + ) + res.append({ + "memory_id": memory_id, + "success": success, + "message": msg + }) + + if all([r["success"] for r in res]): + return get_json_result(message="Successfully added to memories.") + + return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res) + + +@manager.route("/:", methods=["DELETE"]) # noqa: F821 +@login_required +async def forget_message(memory_id: str, message_id: int): + + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + forget_time = timestamp_to_date(current_timestamp()) + update_succeed = MessageService.update_message( + {"memory_id": memory_id, "message_id": int(message_id)}, + {"forget_at": forget_time}, + memory.tenant_id, memory_id) + if update_succeed: + return get_json_result(message=update_succeed) + else: + return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.") + + +@manager.route("/:", methods=["PUT"]) # noqa: F821 +@login_required +@validate_request("status") +async def update_message(memory_id: str, message_id: int): + req = await get_request_json() + status = req["status"] + if not isinstance(status, bool): + return get_error_argument_result("Status must be a boolean.") + + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id) + if update_succeed: + return get_json_result(message=update_succeed) + else: + return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") + + +@manager.route("/search", methods=["GET"]) # noqa: F821 +@login_required +async def search_message(): + args = request.args + print(args, flush=True) + empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)] + if empty_fields: + return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.") + + memory_ids = args.getlist("memory_id") + query = args.get("query") + similarity_threshold = float(args.get("similarity_threshold", 0.2)) + keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) + top_n = int(args.get("top_n", 5)) + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + + filter_dict = { + "memory_id": memory_ids, + "agent_id": agent_id, + "session_id": session_id + } + params = { + "query": query, + "similarity_threshold": similarity_threshold, + "keywords_similarity_weight": keywords_similarity_weight, + "top_n": top_n + } + res = query_message(filter_dict, params) + return get_json_result(message=True, data=res) + + +@manager.route("", methods=["GET"]) # noqa: F821 +@login_required +async def get_messages(): + args = request.args + memory_ids = args.getlist("memory_id") + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + limit = int(args.get("limit", 10)) + if not memory_ids: + return get_error_argument_result("memory_ids is required.") + memory_list = MemoryService.get_by_ids(memory_ids) + uids = [memory.tenant_id for memory in memory_list] + res = MessageService.get_recent_messages( + uids, + memory_ids, + agent_id, + session_id, + limit + ) + return get_json_result(message=True, data=res) + + +@manager.route("/:/content", methods=["GET"]) # noqa: F821 +@login_required +async def get_message_content(memory_id:str, message_id: int): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) + if res: + return get_json_result(message=True, data=res) + else: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.") diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index ff860324c..7d52c3fec 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -495,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id): } obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): + if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): return get_result(data=obj) sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) if not len(sres.ids): diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index d5859e892..c551bbfe1 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1080,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id): res["chunks"].append(final_chunk) _ = Chunk(**final_chunk) - elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): + elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) res["total"] = sres.total for id in sres.ids: diff --git a/api/db/init_data.py b/api/db/init_data.py index 1ebc306d3..77f676f09 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.user_service import TenantService, UserTenantService +from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache from common.constants import LLMType from common.file_utils import get_project_base_directory from common import settings @@ -169,6 +170,8 @@ def init_web_data(): # init_superuser() add_graph_templates() + init_message_id_sequence() + init_memory_size_cache() logging.info("init web data success:{}".format(time.time() - start_time)) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py new file mode 100644 index 000000000..970d76ab1 --- /dev/null +++ b/api/db/joint_services/memory_message_service.py @@ -0,0 +1,233 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from typing import List + +from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms +from common.constants import MemoryType, LLMType +from common.doc_store.doc_store_base import FusionExpr +from api.db.services.memory_service import MemoryService +from api.db.services.tenant_llm_service import TenantLLMService +from api.db.services.llm_service import LLMBundle +from api.utils.memory_utils import get_memory_type_human +from memory.services.messages import MessageService +from memory.services.query import MsgTextQuery, get_vector +from memory.utils.prompt_util import PromptAssembler +from memory.utils.msg_util import get_json_result_from_llm_response +from rag.utils.redis_conn import REDIS_CONN + + +async def save_to_memory(memory_id: str, message_dict: dict): + """ + :param memory_id: + :param message_dict: { + "user_id": str, + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str + } + """ + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return False, f"Memory '{memory_id}' not found." + + tenant_id = memory.tenant_id + extracted_content = await extract_by_llm( + tenant_id, + memory.llm_id, + {"temperature": memory.temperature}, + get_memory_type_human(memory.memory_type), + message_dict.get("user_input", ""), + message_dict.get("agent_response", "") + ) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract + raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory") + message_list = [{ + "message_id": raw_message_id, + "message_type": MemoryType.RAW.name.lower(), + "source_id": 0, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}", + "valid_at": timestamp_to_date(current_timestamp()), + "invalid_at": None, + "forget_at": None, + "status": True + }, *[{ + "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), + "message_type": content["message_type"], + "source_id": raw_message_id, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": content["content"], + "valid_at": content["valid_at"], + "invalid_at": content["invalid_at"] if content["invalid_at"] else None, + "forget_at": None, + "status": True + } for content in extracted_content]] + embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) + vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list]) + for idx, msg in enumerate(message_list): + msg["content_embed"] = vector_list[idx] + vector_dimension = len(vector_list[0]) + if not MessageService.has_index(tenant_id, memory_id): + created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension) + if not created: + return False, "Failed to create message index." + + new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list]) + current_memory_size = get_memory_size_cache(memory_id, tenant_id) + if new_msg_size + current_memory_size > memory.memory_size: + size_to_delete = current_memory_size + new_msg_size - memory.memory_size + if memory.forgetting_policy == "fifo": + message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete) + MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id) + decrease_memory_size_cache(memory_id, tenant_id, delete_size) + else: + return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete." + fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id) + if fail_cases: + return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases) + + increase_memory_size_cache(memory_id, tenant_id, new_msg_size) + return True, "Message saved successfully." + + +async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str, + agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]: + llm_type = TenantLLMService.llm_id2llm_type(llm_id) + if not llm_type: + raise RuntimeError(f"Unknown type of LLM '{llm_id}'") + if not system_prompt: + system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type}) + conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}" + conversation_time = timestamp_to_date(current_timestamp()) + user_prompts = [] + if user_prompt: + user_prompts.append({"role": "user", "content": user_prompt}) + user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"}) + else: + user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) + llm = LLMBundle(tenant_id, llm_type, llm_id) + res = await llm.async_chat(system_prompt, user_prompts, extract_conf) + res_json = get_json_result_from_llm_response(res) + return [{ + "content": extracted_content["content"], + "valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]), + "invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "", + "message_type": message_type + } for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list] + + +def query_message(filter_dict: dict, params: dict): + """ + :param filter_dict: { + "memory_id": List[str], + "agent_id": optional + "session_id": optional + } + :param params: { + "query": question str, + "similarity_threshold": float, + "keywords_similarity_weight": float, + "top_n": int + } + """ + memory_ids = filter_dict["memory_id"] + memory_list = MemoryService.get_by_ids(memory_ids) + if not memory_list: + return [] + + condition_dict = {k: v for k, v in filter_dict.items() if v} + uids = [memory.tenant_id for memory in memory_list] + + question = params["query"] + question = question.strip() + memory = memory_list[0] + embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) + match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"]) + match_text, _ = MsgTextQuery().question(question, min_match=0.3) + keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7) + fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])}) + + return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"]) + + +def init_message_id_sequence(): + message_id_redis_key = "id_generator:memory" + if REDIS_CONN.exist(message_id_redis_key): + current_max_id = REDIS_CONN.get(message_id_redis_key) + logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.") + else: + max_id = 1 + exist_memory_list = MemoryService.get_all_memory() + if not exist_memory_list: + REDIS_CONN.set(message_id_redis_key, max_id) + else: + max_id = MessageService.get_max_message_id( + uid_list=[m.tenant_id for m in exist_memory_list], + memory_ids=[m.id for m in exist_memory_list] + ) + REDIS_CONN.set(message_id_redis_key, max_id) + logging.info(f"Init message_id sequence done, current max id is {max_id}.") + + +def get_memory_size_cache(memory_id: str, uid: str): + redis_key = f"memory_{memory_id}" + if REDIS_CONN.exists(redis_key): + return REDIS_CONN.get(redis_key) + else: + memory_size_map = MessageService.calculate_memory_size( + [memory_id], + [uid] + ) + memory_size = memory_size_map.get(memory_id, 0) + set_memory_size_cache(memory_id, memory_size) + return memory_size + + +def set_memory_size_cache(memory_id: str, size: int): + redis_key = f"memory_{memory_id}" + return REDIS_CONN.set(redis_key, size) + + +def increase_memory_size_cache(memory_id: str, uid: str, size: int): + current_value = get_memory_size_cache(memory_id, uid) + return set_memory_size_cache(memory_id, current_value + size) + + +def decrease_memory_size_cache(memory_id: str, uid: str, size: int): + current_value = get_memory_size_cache(memory_id, uid) + return set_memory_size_cache(memory_id, max(current_value - size, 0)) + + +def init_memory_size_cache(): + memory_list = MemoryService.get_all_memory() + if not memory_list: + logging.info("No memory found, no need to init memory size.") + else: + memory_size_map = MessageService.calculate_memory_size( + memory_ids=[m.id for m in memory_list], + uid_list=[m.tenant_id for m in memory_list], + ) + for memory in memory_list: + memory_size = memory_size_map.get(memory.id, 0) + set_memory_size_cache(memory.id, memory_size) + logging.info("Memory size cache init done.") diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index 6ab7d8774..2e4dfeaab 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_service import TenantService, UserService, UserTenantService +from api.db.services.memory_service import MemoryService +from memory.services.messages import MessageService from rag.nlp import search from common.constants import ActiveEnum from common import settings @@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict: done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n" langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id) done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n" - # step1.3 delete own tenant + # step1.3 delete memory and messages + user_memory = MemoryService.get_by_tenant_id(tenant_id) + if user_memory: + for memory in user_memory: + if MessageService.has_index(tenant_id, memory.id): + MessageService.delete_index(tenant_id, memory.id) + done_msg += " Deleted memory index." + memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory]) + done_msg += f"Deleted {memory_delete_res} memory datasets." + # step1.4 delete own tenant tenant_delete_res = TenantService.delete_by_id(tenant_id) done_msg += f"- Deleted {tenant_delete_res} tenant.\n" # step2 delete user-tenant relation diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 57b4b5c2a..706ea1aac 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -123,6 +123,19 @@ class UserCanvasService(CommonService): logging.exception(e) return False, None + @classmethod + @DB.connection_context() + def get_basic_info_by_canvas_ids(cls, canvas_id): + fields = [ + cls.model.id, + cls.model.avatar, + cls.model.user_id, + cls.model.title, + cls.model.permission, + cls.model.canvas_category + ] + return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts() + @classmethod @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 7b3a253b7..7b8d222d4 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -38,7 +38,7 @@ from common.time_utils import current_timestamp, get_format_time from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import REDIS_CONN -from rag.utils.doc_store_conn import OrderByExpr +from common.doc_store.doc_store_base import OrderByExpr from common import settings @@ -345,7 +345,7 @@ class DocumentService(CommonService): chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id]) - chunk_ids = settings.docStoreConn.get_chunk_ids(chunks) + chunk_ids = settings.docStoreConn.get_doc_ids(chunks) if not chunk_ids: break all_chunk_ids.extend(chunk_ids) @@ -1230,8 +1230,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): d["q_%d_vec" % len(v)] = v for b in range(0, len(cks), es_bulk_size): if try_create_idx: - if not settings.docStoreConn.indexExist(idxnm, kb_id): - settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0])) + if not settings.docStoreConn.index_exist(idxnm, kb_id): + settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0])) try_create_idx = False settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py index bc071a66f..49b699950 100644 --- a/api/db/services/memory_service.py +++ b/api/db/services/memory_service.py @@ -15,7 +15,6 @@ # from typing import List -from api.apps import current_user from api.db.db_models import DB, Memory, User from api.db.services import duplicate_name from api.db.services.common_service import CommonService @@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type from api.constants import MEMORY_NAME_LIMIT from common.misc_utils import get_uuid from common.time_utils import get_format_time, current_timestamp +from memory.utils.prompt_util import PromptAssembler class MemoryService(CommonService): @@ -34,6 +34,17 @@ class MemoryService(CommonService): def get_by_memory_id(cls, memory_id: str): return cls.model.select().where(cls.model.id == memory_id).first() + @classmethod + @DB.connection_context() + def get_by_tenant_id(cls, tenant_id: str): + return cls.model.select().where(cls.model.tenant_id == tenant_id) + + @classmethod + @DB.connection_context() + def get_all_memory(cls): + memory_list = cls.model.select() + return list(memory_list) + @classmethod @DB.connection_context() def get_with_owner_name_by_id(cls, memory_id: str): @@ -53,7 +64,9 @@ class MemoryService(CommonService): cls.model.forgetting_policy, cls.model.temperature, cls.model.system_prompt, - cls.model.user_prompt + cls.model.user_prompt, + cls.model.create_date, + cls.model.create_time ] memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where( cls.model.id == memory_id @@ -72,7 +85,9 @@ class MemoryService(CommonService): cls.model.memory_type, cls.model.storage_type, cls.model.permissions, - cls.model.description + cls.model.description, + cls.model.create_time, + cls.model.create_date ] memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) if filter_dict.get("tenant_id"): @@ -110,6 +125,7 @@ class MemoryService(CommonService): "tenant_id": tenant_id, "embd_id": embd_id, "llm_id": llm_id, + "system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}), "create_time": current_timestamp(), "create_date": get_format_time(), "update_time": current_timestamp(), @@ -126,7 +142,7 @@ class MemoryService(CommonService): @classmethod @DB.connection_context() - def update_memory(cls, memory_id: str, update_dict: dict): + def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict): if not update_dict: return 0 if "temperature" in update_dict and isinstance(update_dict["temperature"], str): @@ -135,7 +151,7 @@ class MemoryService(CommonService): update_dict["name"] = duplicate_name( cls.query, name=update_dict["name"], - tenant_id=current_user.id + tenant_id=tenant_id ) update_dict.update({ "update_time": current_timestamp(), diff --git a/common/doc_store/__init__.py b/common/doc_store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rag/utils/doc_store_conn.py b/common/doc_store/doc_store_base.py similarity index 82% rename from rag/utils/doc_store_conn.py rename to common/doc_store/doc_store_base.py index 33f030011..fe6304f75 100644 --- a/rag/utils/doc_store_conn.py +++ b/common/doc_store/doc_store_base.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from abc import ABC, abstractmethod from dataclasses import dataclass import numpy as np @@ -22,7 +21,6 @@ DEFAULT_MATCH_VECTOR_TOPN = 10 DEFAULT_MATCH_SPARSE_TOPN = 10 VEC = list | np.ndarray - @dataclass class SparseVector: indices: list[int] @@ -55,14 +53,13 @@ class SparseVector: def __repr__(self): return str(self) - -class MatchTextExpr(ABC): +class MatchTextExpr: def __init__( self, fields: list[str], matching_text: str, topn: int, - extra_options: dict = dict(), + extra_options: dict | None = None, ): self.fields = fields self.matching_text = matching_text @@ -70,7 +67,7 @@ class MatchTextExpr(ABC): self.extra_options = extra_options -class MatchDenseExpr(ABC): +class MatchDenseExpr: def __init__( self, vector_column_name: str, @@ -78,7 +75,7 @@ class MatchDenseExpr(ABC): embedding_data_type: str, distance_type: str, topn: int = DEFAULT_MATCH_VECTOR_TOPN, - extra_options: dict = dict(), + extra_options: dict | None = None, ): self.vector_column_name = vector_column_name self.embedding_data = embedding_data @@ -88,7 +85,7 @@ class MatchDenseExpr(ABC): self.extra_options = extra_options -class MatchSparseExpr(ABC): +class MatchSparseExpr: def __init__( self, vector_column_name: str, @@ -104,7 +101,7 @@ class MatchSparseExpr(ABC): self.opt_params = opt_params -class MatchTensorExpr(ABC): +class MatchTensorExpr: def __init__( self, column_name: str, @@ -120,7 +117,7 @@ class MatchTensorExpr(ABC): self.extra_option = extra_option -class FusionExpr(ABC): +class FusionExpr: def __init__(self, method: str, topn: int, fusion_params: dict | None = None): self.method = method self.topn = topn @@ -129,7 +126,8 @@ class FusionExpr(ABC): MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr -class OrderByExpr(ABC): + +class OrderByExpr: def __init__(self): self.fields = list() def asc(self, field: str): @@ -141,13 +139,14 @@ class OrderByExpr(ABC): def fields(self): return self.fields + class DocStoreConnection(ABC): """ Database operations """ @abstractmethod - def dbType(self) -> str: + def db_type(self) -> str: """ Return the type of the database. """ @@ -165,21 +164,21 @@ class DocStoreConnection(ABC): """ @abstractmethod - def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): + def create_idx(self, index_name: str, dataset_id: str, vector_size: int): """ Create an index with given name """ raise NotImplementedError("Not implemented") @abstractmethod - def deleteIdx(self, indexName: str, knowledgebaseId: str): + def delete_idx(self, index_name: str, dataset_id: str): """ Delete an index with given name """ raise NotImplementedError("Not implemented") @abstractmethod - def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: + def index_exist(self, index_name: str, dataset_id: str) -> bool: """ Check if an index with given name exists """ @@ -191,16 +190,16 @@ class DocStoreConnection(ABC): @abstractmethod def search( - self, selectFields: list[str], - highlightFields: list[str], + self, select_fields: list[str], + highlight_fields: list[str], condition: dict, - matchExprs: list[MatchExpr], - orderBy: OrderByExpr, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, offset: int, limit: int, - indexNames: str|list[str], - knowledgebaseIds: list[str], - aggFields: list[str] = [], + index_names: str|list[str], + dataset_ids: list[str], + agg_fields: list[str] | None = None, rank_feature: dict | None = None ): """ @@ -209,28 +208,28 @@ class DocStoreConnection(ABC): raise NotImplementedError("Not implemented") @abstractmethod - def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: + def get(self, data_id: str, index_name: str, dataset_ids: list[str]) -> dict | None: """ Get single chunk with given id """ raise NotImplementedError("Not implemented") @abstractmethod - def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: + def insert(self, rows: list[dict], index_name: str, dataset_id: str = None) -> list[str]: """ Update or insert a bulk of rows """ raise NotImplementedError("Not implemented") @abstractmethod - def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: + def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool: """ Update rows with given conjunctive equivalent filtering condition """ raise NotImplementedError("Not implemented") @abstractmethod - def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: + def delete(self, condition: dict, index_name: str, dataset_id: str) -> int: """ Delete rows with given conjunctive equivalent filtering condition """ @@ -245,7 +244,7 @@ class DocStoreConnection(ABC): raise NotImplementedError("Not implemented") @abstractmethod - def get_chunk_ids(self, res): + def get_doc_ids(self, res): raise NotImplementedError("Not implemented") @abstractmethod @@ -253,18 +252,18 @@ class DocStoreConnection(ABC): raise NotImplementedError("Not implemented") @abstractmethod - def get_highlight(self, res, keywords: list[str], fieldnm: str): + def get_highlight(self, res, keywords: list[str], field_name: str): raise NotImplementedError("Not implemented") @abstractmethod - def get_aggregation(self, res, fieldnm: str): + def get_aggregation(self, res, field_name: str): raise NotImplementedError("Not implemented") """ SQL """ @abstractmethod - def sql(sql: str, fetch_size: int, format: str): + def sql(self, sql: str, fetch_size: int, format: str): """ Run the sql generated by text-to-sql """ diff --git a/common/doc_store/es_conn_base.py b/common/doc_store/es_conn_base.py new file mode 100644 index 000000000..ed8891a4d --- /dev/null +++ b/common/doc_store/es_conn_base.py @@ -0,0 +1,326 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import re +import json +import time +import os +from abc import abstractmethod + +from elasticsearch import Elasticsearch, NotFoundError +from elasticsearch_dsl import Index +from elastic_transport import ConnectionTimeout +from common.file_utils import get_project_base_directory +from common.misc_utils import convert_bytes +from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr +from rag.nlp import is_english, rag_tokenizer +from common import settings + +ATTEMPT_TIME = 2 + + +class ESConnectionBase(DocStoreConnection): + def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'): + self.logger = logging.getLogger(logger_name) + + self.info = {} + self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.") + for _ in range(ATTEMPT_TIME): + try: + if self._connect(): + break + except Exception as e: + self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.") + time.sleep(5) + + if not self.es.ping(): + msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s." + self.logger.error(msg) + raise Exception(msg) + v = self.info.get("version", {"number": "8.11.3"}) + v = v["number"].split(".")[0] + if int(v) < 8: + msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}" + self.logger.error(msg) + raise Exception(msg) + fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name) + if not os.path.exists(fp_mapping): + msg = f"Elasticsearch mapping file not found at {fp_mapping}" + self.logger.error(msg) + raise Exception(msg) + self.mapping = json.load(open(fp_mapping, "r")) + self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.") + + def _connect(self): + self.es = Elasticsearch( + settings.ES["hosts"].split(","), + basic_auth=(settings.ES["username"], settings.ES[ + "password"]) if "username" in settings.ES and "password" in settings.ES else None, + verify_certs= settings.ES.get("verify_certs", False), + timeout=600 ) + if self.es: + self.info = self.es.info() + return True + return False + + """ + Database operations + """ + + def db_type(self) -> str: + return "elasticsearch" + + def health(self) -> dict: + health_dict = dict(self.es.cluster.health()) + health_dict["type"] = "elasticsearch" + return health_dict + + def get_cluster_stats(self): + """ + curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats. + """ + raw_stats = self.es.cluster.stats() + self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}") + try: + res = { + 'cluster_name': raw_stats['cluster_name'], + 'status': raw_stats['status'] + } + indices_status = raw_stats['indices'] + res.update({ + 'indices': indices_status['count'], + 'indices_shards': indices_status['shards']['total'] + }) + doc_info = indices_status['docs'] + res.update({ + 'docs': doc_info['count'], + 'docs_deleted': doc_info['deleted'] + }) + store_info = indices_status['store'] + res.update({ + 'store_size': convert_bytes(store_info['size_in_bytes']), + 'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes']) + }) + mappings_info = indices_status['mappings'] + res.update({ + 'mappings_fields': mappings_info['total_field_count'], + 'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'], + 'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes']) + }) + node_info = raw_stats['nodes'] + res.update({ + 'nodes': node_info['count']['total'], + 'nodes_version': node_info['versions'], + 'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']), + 'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']), + 'os_mem_used_percent': node_info['os']['mem']['used_percent'], + 'jvm_versions': node_info['jvm']['versions'][0]['vm_version'], + 'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']), + 'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes']) + }) + return res + + except Exception as e: + self.logger.exception(f"ESConnection.get_cluster_stats: {e}") + return None + + """ + Table operations + """ + + def create_idx(self, index_name: str, dataset_id: str, vector_size: int): + if self.index_exist(index_name, dataset_id): + return True + try: + from elasticsearch.client import IndicesClient + return IndicesClient(self.es).create(index=index_name, + settings=self.mapping["settings"], + mappings=self.mapping["mappings"]) + except Exception: + self.logger.exception("ESConnection.createIndex error %s" % index_name) + + def delete_idx(self, index_name: str, dataset_id: str): + if len(dataset_id) > 0: + # The index need to be alive after any kb deletion since all kb under this tenant are in one index. + return + try: + self.es.indices.delete(index=index_name, allow_no_indices=True) + except NotFoundError: + pass + except Exception: + self.logger.exception("ESConnection.deleteIdx error %s" % index_name) + + def index_exist(self, index_name: str, dataset_id: str = None) -> bool: + s = Index(index_name, self.es) + for i in range(ATTEMPT_TIME): + try: + return s.exists() + except ConnectionTimeout: + self.logger.exception("ES request timeout") + time.sleep(3) + self._connect() + continue + except Exception as e: + self.logger.exception(e) + break + return False + + """ + CRUD operations + """ + + def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None: + for i in range(ATTEMPT_TIME): + try: + res = self.es.get(index=index_name, + id=doc_id, source=True, ) + if str(res.get("timed_out", "")).lower() == "true": + raise Exception("Es Timeout.") + doc = res["_source"] + doc["id"] = doc_id + return doc + except NotFoundError: + return None + except Exception as e: + self.logger.exception(f"ESConnection.get({doc_id}) got exception") + raise e + self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!") + raise Exception("ESConnection.get timeout.") + + @abstractmethod + def search( + self, select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + dataset_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None + ): + raise NotImplementedError("Not implemented") + + @abstractmethod + def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]: + raise NotImplementedError("Not implemented") + + @abstractmethod + def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool: + raise NotImplementedError("Not implemented") + + @abstractmethod + def delete(self, condition: dict, index_name: str, dataset_id: str) -> int: + raise NotImplementedError("Not implemented") + + """ + Helper functions for search result + """ + + def get_total(self, res): + if isinstance(res["hits"]["total"], type({})): + return res["hits"]["total"]["value"] + return res["hits"]["total"] + + def get_doc_ids(self, res): + return [d["_id"] for d in res["hits"]["hits"]] + + def _get_source(self, res): + rr = [] + for d in res["hits"]["hits"]: + d["_source"]["id"] = d["_id"] + d["_source"]["_score"] = d["_score"] + rr.append(d["_source"]) + return rr + + @abstractmethod + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: + raise NotImplementedError("Not implemented") + + def get_highlight(self, res, keywords: list[str], field_name: str): + ans = {} + for d in res["hits"]["hits"]: + highlights = d.get("highlight") + if not highlights: + continue + txt = "...".join([a for a in list(highlights.items())[0][1]]) + if not is_english(txt.split()): + ans[d["_id"]] = txt + continue + + txt = d["_source"][field_name] + txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) + txt_list = [] + for t in re.split(r"[.?!;\n]", txt): + for w in keywords: + t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1\2\3", t, + flags=re.IGNORECASE | re.MULTILINE) + if not re.search(r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE): + continue + txt_list.append(t) + ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]]) + + return ans + + def get_aggregation(self, res, field_name: str): + agg_field = "aggs_" + field_name + if "aggregations" not in res or agg_field not in res["aggregations"]: + return list() + buckets = res["aggregations"][agg_field]["buckets"] + return [(b["key"], b["doc_count"]) for b in buckets] + + """ + SQL + """ + + def sql(self, sql: str, fetch_size: int, format: str): + self.logger.debug(f"ESConnection.sql get sql: {sql}") + sql = re.sub(r"[ `]+", " ", sql) + sql = sql.replace("%", "") + replaces = [] + for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): + fld, v = r.group(1), r.group(3) + match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format( + fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v))) + replaces.append( + ("{}{}'{}'".format( + r.group(1), + r.group(2), + r.group(3)), + match)) + + for p, r in replaces: + sql = sql.replace(p, r, 1) + self.logger.debug(f"ESConnection.sql to es: {sql}") + + for i in range(ATTEMPT_TIME): + try: + res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, + request_timeout="2s") + return res + except ConnectionTimeout: + self.logger.exception("ES request timeout") + time.sleep(3) + self._connect() + continue + except Exception as e: + self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}") + raise Exception(f"SQL error: {e}\n\nSQL: {sql}") + self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!") + return None diff --git a/common/doc_store/infinity_conn_base.py b/common/doc_store/infinity_conn_base.py new file mode 100644 index 000000000..7c7c55d19 --- /dev/null +++ b/common/doc_store/infinity_conn_base.py @@ -0,0 +1,451 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import re +import json +import time +from abc import abstractmethod + +import infinity +from infinity.common import ConflictType +from infinity.index import IndexInfo, IndexType +from infinity.connection_pool import ConnectionPool +from infinity.errors import ErrorCode +import pandas as pd +from common.file_utils import get_project_base_directory +from rag.nlp import is_english +from common import settings +from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr + + +class InfinityConnectionBase(DocStoreConnection): + def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"): + self.dbName = settings.INFINITY.get("db_name", "default_db") + self.mapping_file_name = mapping_file_name + self.logger = logging.getLogger(logger_name) + infinity_uri = settings.INFINITY["uri"] + if ":" in infinity_uri: + host, port = infinity_uri.split(":") + infinity_uri = infinity.common.NetworkAddress(host, int(port)) + self.connPool = None + self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.") + for _ in range(24): + try: + conn_pool = ConnectionPool(infinity_uri, max_size=4) + inf_conn = conn_pool.get_conn() + res = inf_conn.show_current_node() + if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]: + self._migrate_db(inf_conn) + self.connPool = conn_pool + conn_pool.release_conn(inf_conn) + break + conn_pool.release_conn(inf_conn) + self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.") + time.sleep(5) + except Exception as e: + self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") + time.sleep(5) + if self.connPool is None: + msg = f"Infinity {infinity_uri} is unhealthy in 120s." + self.logger.error(msg) + raise Exception(msg) + self.logger.info(f"Infinity {infinity_uri} is healthy.") + + def _migrate_db(self, inf_conn): + inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) + fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name) + if not os.path.exists(fp_mapping): + raise Exception(f"Mapping file not found at {fp_mapping}") + schema = json.load(open(fp_mapping)) + table_names = inf_db.list_tables().table_names + for table_name in table_names: + inf_table = inf_db.get_table(table_name) + index_names = inf_table.list_indexes().index_names + if "q_vec_idx" not in index_names: + # Skip tables not created by me + continue + column_names = inf_table.show_columns()["name"] + column_names = set(column_names) + for field_name, field_info in schema.items(): + if field_name in column_names: + continue + res = inf_table.add_columns({field_name: field_info}) + assert res.error_code == infinity.ErrorCode.OK + self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}") + if field_info["type"] != "varchar" or "analyzer" not in field_info: + continue + analyzers = field_info["analyzer"] + if isinstance(analyzers, str): + analyzers = [analyzers] + for analyzer in analyzers: + inf_table.create_index( + f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}", + IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}), + ConflictType.Ignore, + ) + + """ + Dataframe and fields convert + """ + + @staticmethod + @abstractmethod + def field_keyword(field_name: str): + # judge keyword or not, such as "*_kwd" tag-like columns. + raise NotImplementedError("Not implemented") + + @abstractmethod + def convert_select_fields(self, output_fields: list[str]) -> list[str]: + # rm _kwd, _tks, _sm_tks, _with_weight suffix in field name. + raise NotImplementedError("Not implemented") + + @staticmethod + @abstractmethod + def convert_matching_field(field_weight_str: str) -> str: + # convert matching field to + raise NotImplementedError("Not implemented") + + @staticmethod + def list2str(lst: str | list, sep: str = " ") -> str: + if isinstance(lst, str): + return lst + return sep.join(lst) + + def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None: + assert "_id" not in condition + columns = {} + if table_instance: + for n, ty, de, _ in table_instance.show_columns().rows(): + columns[n] = (ty, de) + + def exists(cln): + nonlocal columns + assert cln in columns, f"'{cln}' should be in '{columns}'." + ty, de = columns[cln] + if ty.lower().find("cha"): + if not de: + de = "" + return f" {cln}!='{de}' " + return f"{cln}!={de}" + + cond = list() + for k, v in condition.items(): + if not isinstance(k, str) or not v: + continue + if self.field_keyword(k): + if isinstance(v, list): + inCond = list() + for item in v: + if isinstance(item, str): + item = item.replace("'", "''") + inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')") + if inCond: + strInCond = " or ".join(inCond) + strInCond = f"({strInCond})" + cond.append(strInCond) + else: + cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')") + elif isinstance(v, list): + inCond = list() + for item in v: + if isinstance(item, str): + item = item.replace("'", "''") + inCond.append(f"'{item}'") + else: + inCond.append(str(item)) + if inCond: + strInCond = ", ".join(inCond) + strInCond = f"{k} IN ({strInCond})" + cond.append(strInCond) + elif k == "must_not": + if isinstance(v, dict): + for kk, vv in v.items(): + if kk == "exists": + cond.append("NOT (%s)" % exists(vv)) + elif isinstance(v, str): + cond.append(f"{k}='{v}'") + elif k == "exists": + cond.append(exists(v)) + else: + cond.append(f"{k}={str(v)}") + return " AND ".join(cond) if cond else "1=1" + + @staticmethod + def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame: + df_list2 = [df for df in df_list if not df.empty] + if df_list2: + return pd.concat(df_list2, axis=0).reset_index(drop=True) + + schema = [] + for field_name in select_fields: + if field_name == "score()": # Workaround: fix schema is changed to score() + schema.append("SCORE") + elif field_name == "similarity()": # Workaround: fix schema is changed to similarity() + schema.append("SIMILARITY") + else: + schema.append(field_name) + return pd.DataFrame(columns=schema) + + """ + Database operations + """ + + def db_type(self) -> str: + return "infinity" + + def health(self) -> dict: + """ + Return the health status of the database. + """ + inf_conn = self.connPool.get_conn() + res = inf_conn.show_current_node() + self.connPool.release_conn(inf_conn) + res2 = { + "type": "infinity", + "status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red", + "error": res.error_msg, + } + return res2 + + """ + Table operations + """ + + def create_idx(self, index_name: str, dataset_id: str, vector_size: int): + table_name = f"{index_name}_{dataset_id}" + inf_conn = self.connPool.get_conn() + inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) + + fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name) + if not os.path.exists(fp_mapping): + raise Exception(f"Mapping file not found at {fp_mapping}") + schema = json.load(open(fp_mapping)) + vector_name = f"q_{vector_size}_vec" + schema[vector_name] = {"type": f"vector,{vector_size},float"} + inf_table = inf_db.create_table( + table_name, + schema, + ConflictType.Ignore, + ) + inf_table.create_index( + "q_vec_idx", + IndexInfo( + vector_name, + IndexType.Hnsw, + { + "M": "16", + "ef_construction": "50", + "metric": "cosine", + "encode": "lvq", + }, + ), + ConflictType.Ignore, + ) + for field_name, field_info in schema.items(): + if field_info["type"] != "varchar" or "analyzer" not in field_info: + continue + analyzers = field_info["analyzer"] + if isinstance(analyzers, str): + analyzers = [analyzers] + for analyzer in analyzers: + inf_table.create_index( + f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}", + IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}), + ConflictType.Ignore, + ) + self.connPool.release_conn(inf_conn) + self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}") + return True + + def delete_idx(self, index_name: str, dataset_id: str): + table_name = f"{index_name}_{dataset_id}" + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + db_instance.drop_table(table_name, ConflictType.Ignore) + self.connPool.release_conn(inf_conn) + self.logger.info(f"INFINITY dropped table {table_name}") + + def index_exist(self, index_name: str, dataset_id: str) -> bool: + table_name = f"{index_name}_{dataset_id}" + try: + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + _ = db_instance.get_table(table_name) + self.connPool.release_conn(inf_conn) + return True + except Exception as e: + self.logger.warning(f"INFINITY indexExist {str(e)}") + return False + + """ + CRUD operations + """ + + @abstractmethod + def search( + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + dataset_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, + ) -> tuple[pd.DataFrame, int]: + raise NotImplementedError("Not implemented") + + @abstractmethod + def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None: + raise NotImplementedError("Not implemented") + + @abstractmethod + def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]: + raise NotImplementedError("Not implemented") + + @abstractmethod + def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool: + raise NotImplementedError("Not implemented") + + def delete(self, condition: dict, index_name: str, dataset_id: str) -> int: + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + table_name = f"{index_name}_{dataset_id}" + try: + table_instance = db_instance.get_table(table_name) + except Exception: + self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.") + return 0 + filter = self.equivalent_condition_to_str(condition, table_instance) + self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.") + res = table_instance.delete(filter) + self.connPool.release_conn(inf_conn) + return res.deleted_rows + + """ + Helper functions for search result + """ + + def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: + if isinstance(res, tuple): + return res[1] + return len(res) + + def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: + if isinstance(res, tuple): + res = res[0] + return list(res["id"]) + + @abstractmethod + def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: + raise NotImplementedError("Not implemented") + + def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str): + if isinstance(res, tuple): + res = res[0] + ans = {} + num_rows = len(res) + column_id = res["id"] + if field_name not in res: + return {} + for i in range(num_rows): + id = column_id[i] + txt = res[field_name][i] + if re.search(r"[^<>]+", txt, flags=re.IGNORECASE | re.MULTILINE): + ans[id] = txt + continue + txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) + txt_list = [] + for t in re.split(r"[.?!;\n]", txt): + if is_english([t]): + for w in keywords: + t = re.sub( + r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), + r"\1\2\3", + t, + flags=re.IGNORECASE | re.MULTILINE, + ) + else: + for w in sorted(keywords, key=len, reverse=True): + t = re.sub( + re.escape(w), + f"{w}", + t, + flags=re.IGNORECASE | re.MULTILINE, + ) + if not re.search(r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE): + continue + txt_list.append(t) + if txt_list: + ans[id] = "...".join(txt_list) + else: + ans[id] = txt + return ans + + def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str): + """ + Manual aggregation for tag fields since Infinity doesn't provide native aggregation + """ + from collections import Counter + + # Extract DataFrame from result + if isinstance(res, tuple): + df, _ = res + else: + df = res + + if df.empty or field_name not in df.columns: + return [] + + # Aggregate tag counts + tag_counter = Counter() + + for value in df[field_name]: + if pd.isna(value) or not value: + continue + + # Handle different tag formats + if isinstance(value, str): + # Split by ### for tag_kwd field or comma for other formats + if field_name == "tag_kwd" and "###" in value: + tags = [tag.strip() for tag in value.split("###") if tag.strip()] + else: + # Try comma separation as fallback + tags = [tag.strip() for tag in value.split(",") if tag.strip()] + + for tag in tags: + if tag: # Only count non-empty tags + tag_counter[tag] += 1 + elif isinstance(value, list): + # Handle list format + for tag in value: + if tag and isinstance(tag, str): + tag_counter[tag.strip()] += 1 + + # Return as list of [tag, count] pairs, sorted by count descending + return [[tag, count] for tag, count in tag_counter.most_common()] + + """ + SQL + """ + + def sql(self, sql: str, fetch_size: int, format: str): + raise NotImplementedError("Not implemented") diff --git a/common/query_base.py b/common/query_base.py new file mode 100644 index 000000000..eae44514f --- /dev/null +++ b/common/query_base.py @@ -0,0 +1,72 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +from abc import ABC, abstractmethod + + +class QueryBase(ABC): + + @staticmethod + def is_chinese(line): + arr = re.split(r"[ \t]+", line) + if len(arr) <= 3: + return True + e = 0 + for t in arr: + if not re.match(r"[a-zA-Z]+$", t): + e += 1 + return e * 1.0 / len(arr) >= 0.7 + + @staticmethod + def sub_special_char(line): + return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip() + + @staticmethod + def rmWWW(txt): + patts = [ + ( + r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*", + "", + ), + (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), + ( + r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", + " ") + ] + otxt = txt + for r, p in patts: + txt = re.sub(r, p, txt, flags=re.IGNORECASE) + if not txt: + txt = otxt + return txt + + @staticmethod + def add_space_between_eng_zh(txt): + # (ENG/ENG+NUM) + ZH + txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt) + # ENG + ZH + txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt) + # ZH + (ENG/ENG+NUM) + txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt) + txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt) + return txt + + @abstractmethod + def question(self, text, tbl, min_match): + """ + Returns a query object based on the input text, table, and minimum match criteria. + """ + raise NotImplementedError("Not implemented") diff --git a/common/settings.py b/common/settings.py index a0385e716..88f6fc088 100644 --- a/common/settings.py +++ b/common/settings.py @@ -39,6 +39,9 @@ from rag.utils.oss_conn import RAGFlowOSS from rag.nlp import search +import memory.utils.es_conn as memory_es_conn +import memory.utils.infinity_conn as memory_infinity_conn + LLM = None LLM_FACTORY = None LLM_BASE_URL = None @@ -76,9 +79,11 @@ FEISHU_OAUTH = None OAUTH_CONFIG = None DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch') DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity") +MSG_ENGINE = DOC_ENGINE docStoreConn = None +msgStoreConn = None retriever = None kg_retriever = None @@ -256,6 +261,15 @@ def init_settings(): else: raise Exception(f"Not supported doc engine: {DOC_ENGINE}") + global MSG_ENGINE, msgStoreConn + MSG_ENGINE = DOC_ENGINE # use the same engine for message store + if MSG_ENGINE == "elasticsearch": + ES = get_base_config("es", {}) + msgStoreConn = memory_es_conn.ESConnection() + elif MSG_ENGINE == "infinity": + INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) + msgStoreConn = memory_infinity_conn.InfinityConnection() + global AZURE, S3, MINIO, OSS, GCS if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: AZURE = get_base_config("azure", {}) diff --git a/common/time_utils.py b/common/time_utils.py index a924b3405..f501674e8 100644 --- a/common/time_utils.py +++ b/common/time_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import datetime +import logging import time def current_timestamp(): @@ -123,4 +124,31 @@ def delta_seconds(date_string: str): 3600.0 # If current time is 2024-01-01 13:00:00 """ dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S") - return (datetime.datetime.now() - dt).total_seconds() \ No newline at end of file + return (datetime.datetime.now() - dt).total_seconds() + + +def format_iso_8601_to_ymd_hms(time_str: str) -> str: + """ + Convert ISO 8601 formatted string to "YYYY-MM-DD HH:MM:SS" format. + + Args: + time_str: ISO 8601 date string (e.g. "2024-01-01T12:00:00Z") + + Returns: + str: Date string in "YYYY-MM-DD HH:MM:SS" format + + Example: + >>> format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z") + '2024-01-01 12:00:00' + """ + from dateutil import parser + + try: + if parser.isoparse(time_str): + dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00")) + return dt.strftime("%Y-%m-%d %H:%M:%S") + else: + return time_str + except Exception as e: + logging.error(str(e)) + return time_str diff --git a/conf/message_infinity_mapping.json b/conf/message_infinity_mapping.json new file mode 100644 index 000000000..17e7307e8 --- /dev/null +++ b/conf/message_infinity_mapping.json @@ -0,0 +1,19 @@ +{ + "id": {"type": "varchar", "default": ""}, + "message_id": {"type": "integer", "default": 0}, + "message_type_kwd": {"type": "varchar", "default": ""}, + "source_id": {"type": "integer", "default": 0}, + "memory_id": {"type": "varchar", "default": ""}, + "user_id": {"type": "varchar", "default": ""}, + "agent_id": {"type": "varchar", "default": ""}, + "session_id": {"type": "varchar", "default": ""}, + "valid_at": {"type": "varchar", "default": ""}, + "valid_at_flt": {"type": "float", "default": 0.0}, + "invalid_at": {"type": "varchar", "default": ""}, + "invalid_at_flt": {"type": "float", "default": 0.0}, + "forget_at": {"type": "varchar", "default": ""}, + "forget_at_flt": {"type": "float", "default": 0.0}, + "status_int": {"type": "integer", "default": 1}, + "zone_id": {"type": "integer", "default": 0}, + "content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_ltks"} +} \ No newline at end of file diff --git a/graphrag/search.py b/graphrag/search.py index 860c58906..7bb46b6b9 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -24,11 +24,11 @@ from common.misc_utils import get_uuid from graphrag.query_analyze_prompt import PROMPTS from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation from common.token_utils import num_tokens_from_string -from rag.utils.doc_store_conn import OrderByExpr from rag.nlp.search import Dealer, index_name from common.float_utils import get_float from common import settings +from common.doc_store.doc_store_base import OrderByExpr class KGSearch(Dealer): diff --git a/graphrag/utils.py b/graphrag/utils.py index 7e3fec1a9..89dbfad75 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -26,9 +26,9 @@ from networkx.readwrite import json_graph from common.misc_utils import get_uuid from common.connection_utils import timeout from rag.nlp import rag_tokenizer, search -from rag.utils.doc_store_conn import OrderByExpr from rag.utils.redis_conn import REDIS_CONN from common import settings +from common.doc_store.doc_store_base import OrderByExpr GRAPH_FIELD_SEP = "" diff --git a/memory/__init__.py b/memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/memory/services/__init__.py b/memory/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/memory/services/messages.py b/memory/services/messages.py new file mode 100644 index 000000000..7c4ab717a --- /dev/null +++ b/memory/services/messages.py @@ -0,0 +1,240 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +from typing import List + +from common import settings +from common.doc_store.doc_store_base import OrderByExpr, MatchExpr + + +def index_name(uid: str): return f"memory_{uid}" + + +class MessageService: + + @classmethod + def has_index(cls, uid: str, memory_id: str): + index = index_name(uid) + return settings.msgStoreConn.index_exist(index, memory_id) + + @classmethod + def create_index(cls, uid: str, memory_id: str, vector_size: int): + index = index_name(uid) + return settings.msgStoreConn.create_idx(index, memory_id, vector_size) + + @classmethod + def delete_index(cls, uid: str, memory_id: str): + index = index_name(uid) + return settings.msgStoreConn.delete_idx(index, memory_id) + + @classmethod + def insert_message(cls, messages: List[dict], uid: str, memory_id: str): + index = index_name(uid) + [m.update({ + "id": f'{memory_id}_{m["message_id"]}', + "status": 1 if m["status"] else 0 + }) for m in messages] + return settings.msgStoreConn.insert(messages, index, memory_id) + + @classmethod + def update_message(cls, condition: dict, update_dict: dict, uid: str, memory_id: str): + index = index_name(uid) + if "status" in update_dict: + update_dict["status"] = 1 if update_dict["status"] else 0 + return settings.msgStoreConn.update(condition, update_dict, index, memory_id) + + @classmethod + def delete_message(cls, condition: dict, uid: str, memory_id: str): + index = index_name(uid) + return settings.msgStoreConn.delete(condition, index, memory_id) + + @classmethod + def list_message(cls, uid: str, memory_id: str, agent_ids: List[str]=None, keywords: str=None, page: int=1, page_size: int=50): + index = index_name(uid) + filter_dict = {} + if agent_ids: + filter_dict["agent_id"] = agent_ids + if keywords: + filter_dict["session_id"] = keywords + order_by = OrderByExpr() + order_by.desc("valid_at") + res = settings.msgStoreConn.search( + select_fields=[ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", + "invalid_at", "forget_at", "status" + ], + highlight_fields=[], + condition=filter_dict, + match_expressions=[], order_by=order_by, + offset=(page-1)*page_size, limit=page_size, + index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False + ) + total_count = settings.msgStoreConn.get_total(res) + doc_mapping = settings.msgStoreConn.get_fields(res, [ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", + "valid_at", "invalid_at", "forget_at", "status" + ]) + return { + "message_list": list(doc_mapping.values()), + "total_count": total_count + } + + @classmethod + def get_recent_messages(cls, uid_list: List[str], memory_ids: List[str], agent_id: str, session_id: str, limit: int): + index_names = [index_name(uid) for uid in uid_list] + condition_dict = { + "agent_id": agent_id, + "session_id": session_id + } + order_by = OrderByExpr() + order_by.desc("valid_at") + res = settings.msgStoreConn.search( + select_fields=[ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", + "invalid_at", "forget_at", "status", "content" + ], + highlight_fields=[], + condition=condition_dict, + match_expressions=[], order_by=order_by, + offset=0, limit=limit, + index_names=index_names, memory_ids=memory_ids, agg_fields=[] + ) + doc_mapping = settings.msgStoreConn.get_fields(res, [ + "message_id", "message_type", "source_id", "memory_id","user_id", "agent_id", "session_id", + "valid_at", "invalid_at", "forget_at", "status", "content" + ]) + return list(doc_mapping.values()) + + @classmethod + def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions:list[MatchExpr], top_n: int): + index_names = [index_name(uid) for uid in uid_list] + # filter only valid messages by default + if "status" not in condition_dict: + condition_dict["status"] = 1 + + order_by = OrderByExpr() + order_by.desc("valid_at") + res = settings.msgStoreConn.search( + select_fields=[ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", + "valid_at", + "invalid_at", "forget_at", "status", "content" + ], + highlight_fields=[], + condition=condition_dict, + match_expressions=match_expressions, + order_by=order_by, + offset=0, limit=top_n, + index_names=index_names, memory_ids=memory_ids, agg_fields=[] + ) + docs = settings.msgStoreConn.get_fields(res, [ + "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", + "invalid_at", "forget_at", "status", "content" + ]) + return list(docs.values()) + + @staticmethod + def calculate_message_size(message: dict): + return sys.getsizeof(message["content"]) + sys.getsizeof(message["content_embed"][0]) * len(message["content_embed"]) + + @classmethod + def calculate_memory_size(cls, memory_ids: List[str], uid_list: List[str]): + index_names = [index_name(uid) for uid in uid_list] + order_by = OrderByExpr() + order_by.desc("valid_at") + + res = settings.msgStoreConn.search( + select_fields=["memory_id", "content", "content_embed"], + highlight_fields=[], + condition={}, + match_expressions=[], + order_by=order_by, + offset=0, limit=2000*len(memory_ids), + index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False + ) + docs = settings.msgStoreConn.get_fields(res, ["memory_id", "content", "content_embed"]) + size_dict = {} + for doc in docs.values(): + if size_dict.get(doc["memory_id"]): + size_dict[doc["memory_id"]] += cls.calculate_message_size(doc) + else: + size_dict[doc["memory_id"]] = cls.calculate_message_size(doc) + return size_dict + + @classmethod + def pick_messages_to_delete_by_fifo(cls, memory_id: str, uid: str, size_to_delete: int): + select_fields = ["message_id", "content", "content_embed"] + _index_name = index_name(uid) + res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id) + message_list = settings.msgStoreConn.get_fields(res, select_fields) + current_size = 0 + ids_to_remove = [] + for message in message_list: + if current_size < size_to_delete: + current_size += cls.calculate_message_size(message) + ids_to_remove.append(message["message_id"]) + else: + return ids_to_remove, current_size + if current_size >= size_to_delete: + return ids_to_remove, current_size + + order_by = OrderByExpr() + order_by.asc("valid_at") + res = settings.msgStoreConn.search( + select_fields=["memory_id", "content", "content_embed"], + highlight_fields=[], + condition={}, + match_expressions=[], + order_by=order_by, + offset=0, limit=2000, + index_names=[_index_name], memory_ids=[memory_id], agg_fields=[] + ) + docs = settings.msgStoreConn.get_fields(res, select_fields) + for doc in docs.values(): + if current_size < size_to_delete: + current_size += cls.calculate_message_size(doc) + ids_to_remove.append(doc["memory_id"]) + else: + return ids_to_remove, current_size + return ids_to_remove, current_size + + @classmethod + def get_by_message_id(cls, memory_id: str, message_id: int, uid: str): + index = index_name(uid) + doc_id = f'{memory_id}_{message_id}' + return settings.msgStoreConn.get(doc_id, index, [memory_id]) + + @classmethod + def get_max_message_id(cls, uid_list: List[str], memory_ids: List[str]): + order_by = OrderByExpr() + order_by.desc("message_id") + index_names = [index_name(uid) for uid in uid_list] + res = settings.msgStoreConn.search( + select_fields=["message_id"], + highlight_fields=[], + condition={}, + match_expressions=[], + order_by=order_by, + offset=0, limit=1, + index_names=index_names, memory_ids=memory_ids, + agg_fields=[], hide_forgotten=False + ) + docs = settings.msgStoreConn.get_fields(res, ["message_id"]) + if not docs: + return 1 + else: + latest_msg = list(docs.values())[0] + return int(latest_msg["message_id"]) diff --git a/memory/services/query.py b/memory/services/query.py new file mode 100644 index 000000000..06f253f6b --- /dev/null +++ b/memory/services/query.py @@ -0,0 +1,185 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +import logging +import json +import numpy as np +from common.query_base import QueryBase +from common.doc_store.doc_store_base import MatchDenseExpr, MatchTextExpr +from common.float_utils import get_float +from rag.nlp import rag_tokenizer, term_weight, synonym + + +def get_vector(txt, emb_mdl, topk=10, similarity=0.1): + if isinstance(similarity, str) and len(similarity) > 0: + try: + similarity = float(similarity) + except Exception as e: + logging.warning(f"Convert similarity '{similarity}' to float failed: {e}. Using default 0.1") + similarity = 0.1 + qv, _ = emb_mdl.encode_queries(txt) + shape = np.array(qv).shape + if len(shape) > 1: + raise Exception( + f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).") + embedding_data = [get_float(v) for v in qv] + vector_column_name = f"q_{len(embedding_data)}_vec" + return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity}) + + +class MsgTextQuery(QueryBase): + + def __init__(self): + self.tw = term_weight.Dealer() + self.syn = synonym.Dealer() + self.query_fields = [ + "content" + ] + + def question(self, txt, tbl="messages", min_match: float=0.6): + original_query = txt + txt = MsgTextQuery.add_space_between_eng_zh(txt) + txt = re.sub( + r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+", + " ", + rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())), + ).strip() + otxt = txt + txt = MsgTextQuery.rmWWW(txt) + + if not self.is_chinese(txt): + txt = self.rmWWW(txt) + tks = rag_tokenizer.tokenize(txt).split() + keywords = [t for t in tks if t] + tks_w = self.tw.weights(tks, preprocess=False) + tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w] + tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk] + tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] + tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()] + syns = [] + for tk, w in tks_w[:256]: + syn = self.syn.lookup(tk) + syn = rag_tokenizer.tokenize(" ".join(syn)).split() + keywords.extend(syn) + syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] + syns.append(" ".join(syn)) + + q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if + tk and not re.match(r"[.^+\(\)-]", tk)] + for i in range(1, len(tks_w)): + left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip() + if not left or not right: + continue + q.append( + '"%s %s"^%.4f' + % ( + tks_w[i - 1][0], + tks_w[i][0], + max(tks_w[i - 1][1], tks_w[i][1]) * 2, + ) + ) + if not q: + q.append(txt) + query = " ".join(q) + return MatchTextExpr( + self.query_fields, query, 100, {"original_query": original_query} + ), keywords + + def need_fine_grained_tokenize(tk): + if len(tk) < 3: + return False + if re.match(r"[0-9a-z\.\+#_\*-]+$", tk): + return False + return True + + txt = self.rmWWW(txt) + qs, keywords = [], [] + for tt in self.tw.split(txt)[:256]: # .split(): + if not tt: + continue + keywords.append(tt) + twts = self.tw.weights([tt]) + syns = self.syn.lookup(tt) + if syns and len(keywords) < 32: + keywords.extend(syns) + logging.debug(json.dumps(twts, ensure_ascii=False)) + tms = [] + for tk, w in sorted(twts, key=lambda x: x[1] * -1): + sm = ( + rag_tokenizer.fine_grained_tokenize(tk).split() + if need_fine_grained_tokenize(tk) + else [] + ) + sm = [ + re.sub( + r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", + "", + m, + ) + for m in sm + ] + sm = [self.sub_special_char(m) for m in sm if len(m) > 1] + sm = [m for m in sm if len(m) > 1] + + if len(keywords) < 32: + keywords.append(re.sub(r"[ \\\"']+", "", tk)) + keywords.extend(sm) + + tk_syns = self.syn.lookup(tk) + tk_syns = [self.sub_special_char(s) for s in tk_syns] + if len(keywords) < 32: + keywords.extend([s for s in tk_syns if s]) + tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] + tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] + + if len(keywords) >= 32: + break + + tk = self.sub_special_char(tk) + if tk.find(" ") > 0: + tk = '"%s"' % tk + if tk_syns: + tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) + if sm: + tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm)) + if tk.strip(): + tms.append((tk, w)) + + tms = " ".join([f"({t})^{w}" for t, w in tms]) + + if len(twts) > 1: + tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt) + + syns = " OR ".join( + [ + '"%s"' + % rag_tokenizer.tokenize(self.sub_special_char(s)) + for s in syns + ] + ) + if syns and tms: + tms = f"({tms})^5 OR ({syns})^0.7" + + qs.append(tms) + + if qs: + query = " OR ".join([f"({t})" for t in qs if t]) + if not query: + query = otxt + return MatchTextExpr( + self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query} + ), keywords + return None, keywords \ No newline at end of file diff --git a/memory/utils/__init__.py b/memory/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/memory/utils/es_conn.py b/memory/utils/es_conn.py new file mode 100644 index 000000000..ec3cd48a6 --- /dev/null +++ b/memory/utils/es_conn.py @@ -0,0 +1,494 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +import json +import time + +import copy +from elasticsearch import NotFoundError +from elasticsearch_dsl import UpdateByQuery, Q, Search +from elastic_transport import ConnectionTimeout +from common.decorator import singleton +from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr +from common.doc_store.es_conn_base import ESConnectionBase +from common.float_utils import get_float +from common.constants import PAGERANK_FLD, TAG_FLD + +ATTEMPT_TIME = 2 + + +@singleton +class ESConnection(ESConnectionBase): + + @staticmethod + def convert_field_name(field_name: str) -> str: + match field_name: + case "message_type": + return "message_type_kwd" + case "status": + return "status_int" + case "content": + return "content_ltks" + case _: + return field_name + + @staticmethod + def map_message_to_es_fields(message: dict) -> dict: + """ + Map message dictionary fields to Elasticsearch document/Infinity fields. + + :param message: A dictionary containing message details. + :return: A dictionary formatted for Elasticsearch/Infinity indexing. + """ + storage_doc = { + "id": message.get("id"), + "message_id": message["message_id"], + "message_type_kwd": message["message_type"], + "source_id": message["source_id"], + "memory_id": message["memory_id"], + "user_id": message["user_id"], + "agent_id": message["agent_id"], + "session_id": message["session_id"], + "valid_at": message["valid_at"], + "invalid_at": message["invalid_at"], + "forget_at": message["forget_at"], + "status_int": 1 if message["status"] else 0, + "zone_id": message.get("zone_id", 0), + "content_ltks": message["content"], + f"q_{len(message['content_embed'])}_vec": message["content_embed"], + } + return storage_doc + + @staticmethod + def get_message_from_es_doc(doc: dict) -> dict: + """ + Convert an Elasticsearch/Infinity document back to a message dictionary. + + :param doc: A dictionary representing the Elasticsearch/Infinity document. + :return: A dictionary formatted as a message. + """ + embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None) + message = { + "message_id": doc["message_id"], + "message_type": doc["message_type_kwd"], + "source_id": doc["source_id"] if doc["source_id"] else None, + "memory_id": doc["memory_id"], + "user_id": doc.get("user_id", ""), + "agent_id": doc["agent_id"], + "session_id": doc["session_id"], + "zone_id": doc.get("zone_id", 0), + "valid_at": doc["valid_at"], + "invalid_at": doc.get("invalid_at", "-"), + "forget_at": doc.get("forget_at", "-"), + "status": bool(int(doc["status_int"])), + "content": doc.get("content_ltks", ""), + "content_embed": doc.get(embd_field_name, []) if embd_field_name else [], + } + if doc.get("id"): + message["id"] = doc["id"] + return message + + """ + CRUD operations + """ + + def search( + self, select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + memory_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, + hide_forgotten: bool = True + ): + """ + Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html + """ + if isinstance(index_names, str): + index_names = index_names.split(",") + assert isinstance(index_names, list) and len(index_names) > 0 + assert "_id" not in condition + bool_query = Q("bool", must=[], must_not=[]) + if hide_forgotten: + # filter not forget + bool_query.must_not.append(Q("exists", field="forget_at")) + + condition["memory_id"] = memory_ids + for k, v in condition.items(): + if k == "session_id" and v: + bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True})) + continue + if not v: + continue + if isinstance(v, list): + bool_query.filter.append(Q("terms", **{k: v})) + elif isinstance(v, str) or isinstance(v, int): + bool_query.filter.append(Q("term", **{k: v})) + else: + raise Exception( + f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") + s = Search() + vector_similarity_weight = 0.5 + for m in match_expressions: + if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: + assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1], + MatchDenseExpr) and isinstance( + match_expressions[2], FusionExpr) + weights = m.fusion_params["weights"] + vector_similarity_weight = get_float(weights.split(",")[1]) + for m in match_expressions: + if isinstance(m, MatchTextExpr): + minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) + if isinstance(minimum_should_match, float): + minimum_should_match = str(int(minimum_should_match * 100)) + "%" + bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields], + type="best_fields", query=m.matching_text, + minimum_should_match=minimum_should_match, + boost=1)) + bool_query.boost = 1.0 - vector_similarity_weight + + elif isinstance(m, MatchDenseExpr): + assert (bool_query is not None) + similarity = 0.0 + if "similarity" in m.extra_options: + similarity = m.extra_options["similarity"] + s = s.knn(self.convert_field_name(m.vector_column_name), + m.topn, + m.topn * 2, + query_vector=list(m.embedding_data), + filter=bool_query.to_dict(), + similarity=similarity, + ) + + if bool_query and rank_feature: + for fld, sc in rank_feature.items(): + if fld != PAGERANK_FLD: + fld = f"{TAG_FLD}.{fld}" + bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) + + if bool_query: + s = s.query(bool_query) + for field in highlight_fields: + s = s.highlight(field) + + if order_by: + orders = list() + for field, order in order_by.fields: + order = "asc" if order == 0 else "desc" + if field.endswith("_int") or field.endswith("_flt"): + order_info = {"order": order, "unmapped_type": "float"} + else: + order_info = {"order": order, "unmapped_type": "text"} + orders.append({field: order_info}) + s = s.sort(*orders) + + if agg_fields: + for fld in agg_fields: + s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) + + if limit > 0: + s = s[offset:offset + limit] + q = s.to_dict() + self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q)) + + for i in range(ATTEMPT_TIME): + try: + #print(json.dumps(q, ensure_ascii=False)) + res = self.es.search(index=index_names, + body=q, + timeout="600s", + # search_type="dfs_query_then_fetch", + track_total_hits=True, + _source=True) + if str(res.get("timed_out", "")).lower() == "true": + raise Exception("Es Timeout.") + self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res)) + return res + except ConnectionTimeout: + self.logger.exception("ES request timeout") + self._connect() + continue + except Exception as e: + self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e)) + raise e + + self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") + raise Exception("ESConnection.search timeout.") + + def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000): + bool_query = Q("bool", must_not=[]) + bool_query.must_not.append(Q("term", forget_at=None)) + bool_query.filter.append(Q("term", memory_id=memory_id)) + # from old to new + order_by = OrderByExpr() + order_by.asc("forget_at") + # build search + s = Search() + s = s.query(bool_query) + s = s.sort(order_by) + s = s[:limit] + q = s.to_dict() + # search + for i in range(ATTEMPT_TIME): + try: + res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True) + if str(res.get("timed_out", "")).lower() == "true": + raise Exception("Es Timeout.") + self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res)) + return res + except ConnectionTimeout: + self.logger.exception("ES request timeout") + self._connect() + continue + except Exception as e: + self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e)) + raise e + + self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") + raise Exception("ESConnection.search timeout.") + + def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None: + for i in range(ATTEMPT_TIME): + try: + res = self.es.get(index=index_name, + id=doc_id, source=True, ) + if str(res.get("timed_out", "")).lower() == "true": + raise Exception("Es Timeout.") + message = res["_source"] + message["id"] = doc_id + return self.get_message_from_es_doc(message) + except NotFoundError: + return None + except Exception as e: + self.logger.exception(f"ESConnection.get({doc_id}) got exception") + raise e + self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!") + raise Exception("ESConnection.get timeout.") + + def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]: + # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html + operations = [] + for d in documents: + assert "_id" not in d + assert "id" in d + d_copy_raw = copy.deepcopy(d) + d_copy = self.map_message_to_es_fields(d_copy_raw) + d_copy["memory_id"] = memory_id + meta_id = d_copy.pop("id", "") + operations.append( + {"index": {"_index": index_name, "_id": meta_id}}) + operations.append(d_copy) + res = [] + for _ in range(ATTEMPT_TIME): + try: + res = [] + r = self.es.bulk(index=index_name, operations=operations, + refresh=False, timeout="60s") + if re.search(r"False", str(r["errors"]), re.IGNORECASE): + return res + + for item in r["items"]: + for action in ["create", "delete", "index", "update"]: + if action in item and "error" in item[action]: + res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) + return res + except ConnectionTimeout: + self.logger.exception("ES request timeout") + time.sleep(3) + self._connect() + continue + except Exception as e: + res.append(str(e)) + self.logger.warning("ESConnection.insert got exception: " + str(e)) + + return res + + def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool: + doc = copy.deepcopy(new_value) + update_dict = {self.convert_field_name(k): v for k, v in doc.items()} + update_dict.pop("id", None) + condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} + condition_dict["memory_id"] = memory_id + if "id" in condition_dict and isinstance(condition_dict["id"], str): + # update specific single document + message_id = condition_dict["id"] + for i in range(ATTEMPT_TIME): + for k in update_dict.keys(): + if "feas" != k.split("_")[-1]: + continue + try: + self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");") + except Exception: + self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") + try: + self.es.update(index=index_name, id=message_id, doc=update_dict) + return True + except Exception as e: + self.logger.exception( + f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) + break + return False + + # update unspecific maybe-multiple documents + bool_query = Q("bool") + for k, v in condition_dict.items(): + if not isinstance(k, str) or not v: + continue + if k == "exists": + bool_query.filter.append(Q("exists", field=v)) + continue + if isinstance(v, list): + bool_query.filter.append(Q("terms", **{k: v})) + elif isinstance(v, str) or isinstance(v, int): + bool_query.filter.append(Q("term", **{k: v})) + else: + raise Exception( + f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") + scripts = [] + params = {} + for k, v in update_dict.items(): + if k == "remove": + if isinstance(v, str): + scripts.append(f"ctx._source.remove('{v}');") + if isinstance(v, dict): + for kk, vv in v.items(): + scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);") + params[f"p_{kk}"] = vv + continue + if k == "add": + if isinstance(v, dict): + for kk, vv in v.items(): + scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});") + params[f"pp_{kk}"] = vv.strip() + continue + if (not isinstance(k, str) or not v) and k != "status_int": + continue + if isinstance(v, str): + v = re.sub(r"(['\n\r]|\\.)", " ", v) + params[f"pp_{k}"] = v + scripts.append(f"ctx._source.{k}=params.pp_{k};") + elif isinstance(v, int) or isinstance(v, float): + scripts.append(f"ctx._source.{k}={v};") + elif isinstance(v, list): + scripts.append(f"ctx._source.{k}=params.pp_{k};") + params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False) + else: + raise Exception( + f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") + ubq = UpdateByQuery( + index=index_name).using( + self.es).query(bool_query) + ubq = ubq.script(source="".join(scripts), params=params) + ubq = ubq.params(refresh=True) + ubq = ubq.params(slices=5) + ubq = ubq.params(conflicts="proceed") + for _ in range(ATTEMPT_TIME): + try: + _ = ubq.execute() + return True + except ConnectionTimeout: + self.logger.exception("ES request timeout") + time.sleep(3) + self._connect() + continue + except Exception as e: + self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts)) + break + return False + + def delete(self, condition: dict, index_name: str, memory_id: str) -> int: + assert "_id" not in condition + condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} + condition_dict["memory_id"] = memory_id + if "id" in condition_dict: + message_ids = condition_dict["id"] + if not isinstance(message_ids, list): + message_ids = [message_ids] + if not message_ids: # when message_ids is empty, delete all + qry = Q("match_all") + else: + qry = Q("ids", values=message_ids) + else: + qry = Q("bool") + for k, v in condition_dict.items(): + if k == "exists": + qry.filter.append(Q("exists", field=v)) + + elif k == "must_not": + if isinstance(v, dict): + for kk, vv in v.items(): + if kk == "exists": + qry.must_not.append(Q("exists", field=vv)) + + elif isinstance(v, list): + qry.must.append(Q("terms", **{k: v})) + elif isinstance(v, str) or isinstance(v, int): + qry.must.append(Q("term", **{k: v})) + else: + raise Exception("Condition value must be int, str or list.") + self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict())) + for _ in range(ATTEMPT_TIME): + try: + res = self.es.delete_by_query( + index=index_name, + body=Search().query(qry).to_dict(), + refresh=True) + return res["deleted"] + except ConnectionTimeout: + self.logger.exception("ES request timeout") + time.sleep(3) + self._connect() + continue + except Exception as e: + self.logger.warning("ESConnection.delete got exception: " + str(e)) + if re.search(r"(not_found)", str(e), re.IGNORECASE): + return 0 + return 0 + + """ + Helper functions for search result + """ + + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: + res_fields = {} + if not fields: + return {} + for doc in self._get_source(res): + message = self.get_message_from_es_doc(doc) + m = {} + for n, v in message.items(): + if n not in fields: + continue + if isinstance(v, list): + m[n] = v + continue + if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)): + m[n] = v + continue + if not isinstance(v, str): + m[n] = str(v) + else: + m[n] = v + + if m: + res_fields[doc["id"]] = m + return res_fields diff --git a/memory/utils/infinity_conn.py b/memory/utils/infinity_conn.py new file mode 100644 index 000000000..7bd351935 --- /dev/null +++ b/memory/utils/infinity_conn.py @@ -0,0 +1,467 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +import json +import copy +from infinity.common import InfinityException, SortType +from infinity.errors import ErrorCode + +from common.decorator import singleton +import pandas as pd +from common.constants import PAGERANK_FLD, TAG_FLD +from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr +from common.doc_store.infinity_conn_base import InfinityConnectionBase +from common.time_utils import date_string_to_timestamp + + +@singleton +class InfinityConnection(InfinityConnectionBase): + def __init__(self): + super().__init__() + self.mapping_file_name = "message_infinity_mapping.json" + + """ + Dataframe and fields convert + """ + + @staticmethod + def field_keyword(field_name: str): + # no keywords right now + return False + + @staticmethod + def convert_message_field_to_infinity(field_name: str): + match field_name: + case "message_type": + return "message_type_kwd" + case "status": + return "status_int" + case _: + return field_name + + @staticmethod + def convert_infinity_field_to_message(field_name: str): + if field_name.startswith("message_type"): + return "message_type" + if field_name.startswith("status"): + return "status" + if re.match(r"q_\d+_vec", field_name): + return "content_embed" + return field_name + + def convert_select_fields(self, output_fields: list[str]) -> list[str]: + return list({self.convert_message_field_to_infinity(f) for f in output_fields}) + + @staticmethod + def convert_matching_field(field_weight_str: str) -> str: + tokens = field_weight_str.split("^") + field = tokens[0] + if field == "content": + field = "content@ft_contentm_rag_fine" + tokens[0] = field + return "^".join(tokens) + + @staticmethod + def convert_condition_and_order_field(field_name: str): + match field_name: + case "message_type": + return "message_type_kwd" + case "status": + return "status_int" + case "valid_at": + return "valid_at_flt" + case "invalid_at": + return "invalid_at_flt" + case "forget_at": + return "forget_at_flt" + case _: + return field_name + + """ + CRUD operations + """ + + def search( + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + memory_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, + hide_forgotten: bool = True, + ) -> tuple[pd.DataFrame, int]: + """ + BUG: Infinity returns empty for a highlight field if the query string doesn't use that field. + """ + if isinstance(index_names, str): + index_names = index_names.split(",") + assert isinstance(index_names, list) and len(index_names) > 0 + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + df_list = list() + table_list = list() + if hide_forgotten: + condition.update({"must_not": {"exists": "forget_at_flt"}}) + output = select_fields.copy() + output = self.convert_select_fields(output) + if agg_fields is None: + agg_fields = [] + for essential_field in ["id"] + agg_fields: + if essential_field not in output: + output.append(essential_field) + score_func = "" + score_column = "" + for matchExpr in match_expressions: + if isinstance(matchExpr, MatchTextExpr): + score_func = "score()" + score_column = "SCORE" + break + if not score_func: + for matchExpr in match_expressions: + if isinstance(matchExpr, MatchDenseExpr): + score_func = "similarity()" + score_column = "SIMILARITY" + break + if match_expressions: + if score_func not in output: + output.append(score_func) + if PAGERANK_FLD not in output: + output.append(PAGERANK_FLD) + output = [f for f in output if f != "_score"] + if limit <= 0: + # ElasticSearch default limit is 10000 + limit = 10000 + + # Prepare expressions common to all tables + filter_cond = None + filter_fulltext = "" + if condition: + condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()} + table_found = False + for indexName in index_names: + for mem_id in memory_ids: + table_name = f"{indexName}_{mem_id}" + try: + filter_cond = self.equivalent_condition_to_str(condition_dict, db_instance.get_table(table_name)) + table_found = True + break + except Exception: + pass + if table_found: + break + if not table_found: + self.logger.error(f"No valid tables found for indexNames {index_names} and memoryIds {memory_ids}") + return pd.DataFrame(), 0 + + for matchExpr in match_expressions: + if isinstance(matchExpr, MatchTextExpr): + if filter_cond and "filter" not in matchExpr.extra_options: + matchExpr.extra_options.update({"filter": filter_cond}) + matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields] + fields = ",".join(matchExpr.fields) + filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')" + if filter_cond: + filter_fulltext = f"({filter_cond}) AND {filter_fulltext}" + minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0) + if isinstance(minimum_should_match, float): + str_minimum_should_match = str(int(minimum_should_match * 100)) + "%" + matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match + + # Add rank_feature support + if rank_feature and "rank_features" not in matchExpr.extra_options: + # Convert rank_feature dict to Infinity's rank_features string format + # Format: "field^feature_name^weight,field^feature_name^weight" + rank_features_list = [] + for feature_name, weight in rank_feature.items(): + # Use TAG_FLD as the field containing rank features + rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}") + if rank_features_list: + matchExpr.extra_options["rank_features"] = ",".join(rank_features_list) + + for k, v in matchExpr.extra_options.items(): + if not isinstance(v, str): + matchExpr.extra_options[k] = str(v) + self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}") + elif isinstance(matchExpr, MatchDenseExpr): + if filter_fulltext and "filter" not in matchExpr.extra_options: + matchExpr.extra_options.update({"filter": filter_fulltext}) + for k, v in matchExpr.extra_options.items(): + if not isinstance(v, str): + matchExpr.extra_options[k] = str(v) + similarity = matchExpr.extra_options.get("similarity") + if similarity: + matchExpr.extra_options["threshold"] = similarity + del matchExpr.extra_options["similarity"] + self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}") + elif isinstance(matchExpr, FusionExpr): + self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}") + + order_by_expr_list = list() + if order_by.fields: + for order_field in order_by.fields: + order_field_name = self.convert_condition_and_order_field(order_field[0]) + if order_field[1] == 0: + order_by_expr_list.append((order_field_name, SortType.Asc)) + else: + order_by_expr_list.append((order_field_name, SortType.Desc)) + + total_hits_count = 0 + # Scatter search tables and gather the results + for indexName in index_names: + for memory_id in memory_ids: + table_name = f"{indexName}_{memory_id}" + try: + table_instance = db_instance.get_table(table_name) + except Exception: + continue + table_list.append(table_name) + builder = table_instance.output(output) + if len(match_expressions) > 0: + for matchExpr in match_expressions: + if isinstance(matchExpr, MatchTextExpr): + fields = ",".join(matchExpr.fields) + builder = builder.match_text( + fields, + matchExpr.matching_text, + matchExpr.topn, + matchExpr.extra_options.copy(), + ) + elif isinstance(matchExpr, MatchDenseExpr): + builder = builder.match_dense( + matchExpr.vector_column_name, + matchExpr.embedding_data, + matchExpr.embedding_data_type, + matchExpr.distance_type, + matchExpr.topn, + matchExpr.extra_options.copy(), + ) + elif isinstance(matchExpr, FusionExpr): + builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params) + else: + if filter_cond and len(filter_cond) > 0: + builder.filter(filter_cond) + if order_by.fields: + builder.sort(order_by_expr_list) + builder.offset(offset).limit(limit) + mem_res, extra_result = builder.option({"total_hits_count": True}).to_df() + if extra_result: + total_hits_count += int(extra_result["total_hits_count"]) + self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(mem_res)}") + df_list.append(mem_res) + self.connPool.release_conn(inf_conn) + res = self.concat_dataframes(df_list, output) + if match_expressions: + res["_score"] = res[score_column] + res[PAGERANK_FLD] + res = res.sort_values(by="_score", ascending=False).reset_index(drop=True) + res = res.head(limit) + self.logger.debug(f"INFINITY search final result: {str(res)}") + return res, total_hits_count + + def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000): + condition = {"memory_id": memory_id, "exists": "forget_at_flt"} + order_by = OrderByExpr() + order_by.asc("forget_at_flt") + # query + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + table_name = f"{index_name}_{memory_id}" + table_instance = db_instance.get_table(table_name) + output_fields = [self.convert_message_field_to_infinity(f) for f in select_fields] + builder = table_instance.output(output_fields) + filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name)) + builder.filter(filter_cond) + order_by_expr_list = list() + if order_by.fields: + for order_field in order_by.fields: + order_field_name = self.convert_condition_and_order_field(order_field[0]) + if order_field[1] == 0: + order_by_expr_list.append((order_field_name, SortType.Asc)) + else: + order_by_expr_list.append((order_field_name, SortType.Desc)) + builder.sort(order_by_expr_list) + builder.offset(0).limit(limit) + mem_res, _ = builder.option({"total_hits_count": True}).to_df() + res = self.concat_dataframes(mem_res, output_fields) + res.head(limit) + self.connPool.release_conn(inf_conn) + return res + + def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None: + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + df_list = list() + assert isinstance(memory_ids, list) + table_list = list() + for memoryId in memory_ids: + table_name = f"{index_name}_{memoryId}" + table_list.append(table_name) + try: + table_instance = db_instance.get_table(table_name) + except Exception: + self.logger.warning(f"Table not found: {table_name}, this memory isn't created in Infinity. Maybe it is created in other document engine.") + continue + mem_res, _ = table_instance.output(["*"]).filter(f"id = '{message_id}'").to_df() + self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(mem_res)}") + df_list.append(mem_res) + self.connPool.release_conn(inf_conn) + res = self.concat_dataframes(df_list, ["id"]) + fields = set(res.columns.tolist()) + res_fields = self.get_fields(res, list(fields)) + return res_fields.get(message_id, None) + + def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]: + if not documents: + return [] + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + table_name = f"{index_name}_{memory_id}" + vector_size = int(len(documents[0]["content_embed"])) + try: + table_instance = db_instance.get_table(table_name) + except InfinityException as e: + # src/common/status.cppm, kTableNotExist = 3022 + if e.error_code != ErrorCode.TABLE_NOT_EXIST: + raise + if vector_size == 0: + raise ValueError("Cannot infer vector size from documents") + self.create_idx(index_name, memory_id, vector_size) + table_instance = db_instance.get_table(table_name) + + # embedding fields can't have a default value.... + embedding_columns = [] + table_columns = table_instance.show_columns().rows() + for n, ty, _, _ in table_columns: + r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty) + if not r: + continue + embedding_columns.append((n, int(r.group(1)))) + + docs = copy.deepcopy(documents) + for d in docs: + assert "_id" not in d + assert "id" in d + for k, v in list(d.items()): + field_name = self.convert_message_field_to_infinity(k) + if field_name in ["valid_at", "invalid_at", "forget_at"]: + d[f"{field_name}_flt"] = date_string_to_timestamp(v) if v else 0 + if v is None: + d[field_name] = "" + elif self.field_keyword(k): + if isinstance(v, list): + d[k] = "###".join(v) + else: + d[k] = v + elif k == "memory_id": + if isinstance(d[k], list): + d[k] = d[k][0] # since d[k] is a list, but we need a str + elif field_name == "content_embed": + d[f"q_{vector_size}_vec"] = d["content_embed"] + d.pop("content_embed") + else: + d[field_name] = v + if k != field_name: + d.pop(k) + + for n, vs in embedding_columns: + if n in d: + continue + d[n] = [0] * vs + ids = ["'{}'".format(d["id"]) for d in docs] + str_ids = ", ".join(ids) + str_filter = f"id IN ({str_ids})" + table_instance.delete(str_filter) + table_instance.insert(docs) + self.connPool.release_conn(inf_conn) + self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.") + return [] + + def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool: + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + table_name = f"{index_name}_{memory_id}" + table_instance = db_instance.get_table(table_name) + + columns = {} + if table_instance: + for n, ty, de, _ in table_instance.show_columns().rows(): + columns[n] = (ty, de) + condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()} + filter = self.equivalent_condition_to_str(condition_dict, table_instance) + update_dict = {self.convert_message_field_to_infinity(k): v for k, v in new_value.items()} + date_floats = {} + for k, v in update_dict.items(): + if k in ["valid_at", "invalid_at", "forget_at"]: + date_floats[f"{k}_flt"] = date_string_to_timestamp(v) if v else 0 + elif self.field_keyword(k): + if isinstance(v, list): + update_dict[k] = "###".join(v) + else: + update_dict[k] = v + elif k == "memory_id": + if isinstance(update_dict[k], list): + update_dict[k] = update_dict[k][0] # since d[k] is a list, but we need a str + else: + update_dict[k] = v + if date_floats: + update_dict.update(date_floats) + + self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.") + table_instance.update(filter, update_dict) + self.connPool.release_conn(inf_conn) + return True + + """ + Helper functions for search result + """ + + def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: + if isinstance(res, tuple): + res = res[0] + if not fields: + return {} + fields_all = fields.copy() + fields_all.append("id") + fields_all = {self.convert_message_field_to_infinity(f) for f in fields_all} + + column_map = {col.lower(): col for col in res.columns} + matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map} + none_columns = [col for col in fields_all if col.lower() not in column_map] + + res2 = res[matched_columns.keys()] + res2 = res2.rename(columns=matched_columns) + res2.drop_duplicates(subset=["id"], inplace=True) + + for column in list(res2.columns): + k = column.lower() + if self.field_keyword(k): + res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd]) + else: + pass + for column in ["content"]: + if column in res2: + del res2[column] + for column in none_columns: + res2[column] = None + + res_dict = res2.set_index("id").to_dict(orient="index") + return {_id: {self.convert_infinity_field_to_message(k): v for k, v in doc.items()} for _id, doc in res_dict.items()} diff --git a/memory/utils/msg_util.py b/memory/utils/msg_util.py new file mode 100644 index 000000000..71a9f6b76 --- /dev/null +++ b/memory/utils/msg_util.py @@ -0,0 +1,37 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json + + +def get_json_result_from_llm_response(response_str: str) -> dict: + """ + Parse the LLM response string to extract JSON content. + The function looks for the first and last curly braces to identify the JSON part. + If parsing fails, it returns an empty dictionary. + + :param response_str: The response string from the LLM. + :return: A dictionary parsed from the JSON content in the response. + """ + try: + clean_str = response_str.strip() + if clean_str.startswith('```json'): + clean_str = clean_str[7:] # Remove the starting ```json + if clean_str.endswith('```'): + clean_str = clean_str[:-3] # Remove the ending ``` + + return json.loads(clean_str.strip()) + except (ValueError, json.JSONDecodeError): + return {} diff --git a/memory/utils/prompt_util.py b/memory/utils/prompt_util.py new file mode 100644 index 000000000..e46e1be6a --- /dev/null +++ b/memory/utils/prompt_util.py @@ -0,0 +1,201 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional, List + +from common.constants import MemoryType +from common.time_utils import current_timestamp + +class PromptAssembler: + + SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist** +You are an expert at analyzing conversations to extract structured memory. + +{type_specific_instructions} + + +**OUTPUT REQUIREMENTS:** +1. Output MUST be valid JSON +2. Follow the specified output format exactly +3. Each extracted item MUST have: content, valid_at, invalid_at +4. Timestamps in {timestamp_format} format +5. Only extract memory types specified above +6. Maximum {max_items} items per type +""" + + TYPE_INSTRUCTIONS = { + MemoryType.SEMANTIC.name.lower(): """ + **EXTRACT SEMANTIC KNOWLEDGE:** + - Universal facts, definitions, concepts, relationships + - Time-invariant, generally true information + - Examples: "The capital of France is Paris", "Water boils at 100°C" + + **Timestamp Rules for Semantic Knowledge:** + - valid_at: When the fact became true (e.g., law enactment, discovery) + - invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true + - Default: valid_at = conversation time, invalid_at = "" for timeless facts + """, + + MemoryType.EPISODIC.name.lower(): """ + **EXTRACT EPISODIC KNOWLEDGE:** + - Specific experiences, events, personal stories + - Time-bound, person-specific, contextual + - Examples: "Yesterday I fixed the bug", "User reported issue last week" + + **Timestamp Rules for Episodic Knowledge:** + - valid_at: Event start/occurrence time + - invalid_at: Event end time or empty if instantaneous + - Extract explicit times: "at 3 PM", "last Monday", "from X to Y" + """, + + MemoryType.PROCEDURAL.name.lower(): """ + **EXTRACT PROCEDURAL KNOWLEDGE:** + - Processes, methods, step-by-step instructions + - Goal-oriented, actionable, often includes conditions + - Examples: "To reset password, click...", "Debugging steps: 1)..." + + **Timestamp Rules for Procedural Knowledge:** + - valid_at: When procedure becomes valid/effective + - invalid_at: When it expires/becomes obsolete or empty if current + - For version-specific: use release dates + - For best practices: invalid_at = "" + """ + } + + OUTPUT_TEMPLATES = { + MemoryType.SEMANTIC.name.lower(): """ + "semantic": [ + { + "content": "Clear factual statement", + "valid_at": "timestamp or empty", + "invalid_at": "timestamp or empty" + } + ] + """, + + MemoryType.EPISODIC.name.lower(): """ + "episodic": [ + { + "content": "Narrative event description", + "valid_at": "event start timestamp", + "invalid_at": "event end timestamp or empty" + } + ] + """, + + MemoryType.PROCEDURAL.name.lower(): """ + "procedural": [ + { + "content": "Actionable instructions", + "valid_at": "procedure effective timestamp", + "invalid_at": "procedure expiration timestamp or empty" + } + ] + """ + } + + BASE_USER_PROMPT = """ +**CONVERSATION:** +{conversation} + +**CONVERSATION TIME:** {conversation_time} +**CURRENT TIME:** {current_time} +""" + + @classmethod + def assemble_system_prompt(cls, config: dict) -> str: + types_to_extract = cls._get_types_to_extract(config["memory_type"]) + + type_instructions = cls._generate_type_instructions(types_to_extract) + + output_format = cls._generate_output_format(types_to_extract) + + full_prompt = cls.SYSTEM_BASE_TEMPLATE.format( + type_specific_instructions=type_instructions, + timestamp_format=config.get("timestamp_format", "ISO 8601"), + max_items=config.get("max_items_per_type", 5) + ) + + full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n" + + examples = cls._generate_examples(types_to_extract) + if examples: + full_prompt += f"\n**EXAMPLES:**\n{examples}\n" + + return full_prompt + + @staticmethod + def _get_types_to_extract(requested_types: List[str]) -> List[str]: + types = set() + for rt in requested_types: + if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower(): + types.add(rt) + return list(types) + + @classmethod + def _generate_type_instructions(cls, types_to_extract: List[str]) -> str: + target_types = set(types_to_extract) + instructions = [cls.TYPE_INSTRUCTIONS[mt] for mt in target_types] + return "\n".join(instructions) + + @classmethod + def _generate_output_format(cls, types_to_extract: List[str]) -> str: + target_types = set(types_to_extract) + output_parts = [cls.OUTPUT_TEMPLATES[mt] for mt in target_types] + return ",\n".join(output_parts) + + @staticmethod + def _generate_examples(types_to_extract: list[str]) -> str: + examples = [] + + if MemoryType.SEMANTIC.name.lower() in types_to_extract: + examples.append(""" + **Semantic Example:** + Input: "Python lists are mutable and support various operations." + Output: {"semantic": [{"content": "Python lists are mutable data structures", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]} + """) + + if MemoryType.EPISODIC.name.lower() in types_to_extract: + examples.append(""" + **Episodic Example:** + Input: "I deployed the new feature yesterday afternoon." + Output: {"episodic": [{"content": "User deployed new feature", "valid_at": "2024-01-14T14:00:00", "invalid_at": "2024-01-14T18:00:00"}]} + """) + + if MemoryType.PROCEDURAL.name.lower() in types_to_extract: + examples.append(""" + **Procedural Example:** + Input: "To debug API errors: 1) Check logs 2) Verify endpoints 3) Test connectivity." + Output: {"procedural": [{"content": "API error debugging: 1. Check logs 2. Verify endpoints 3. Test connectivity", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]} + """) + + return "\n".join(examples) + + @classmethod + def assemble_user_prompt( + cls, + conversation: str, + conversation_time: Optional[str] = None, + current_time: Optional[str] = None + ) -> str: + return cls.BASE_USER_PROMPT.format( + conversation=conversation, + conversation_time=conversation_time or "Not specified", + current_time=current_time or current_timestamp(), + ) + + @classmethod + def get_raw_user_prompt(cls): + return cls.BASE_USER_PROMPT diff --git a/rag/benchmark.py b/rag/benchmark.py index 100fc32d5..c19785db3 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -77,9 +77,9 @@ class Benchmark: def init_index(self, vector_size: int): if self.initialized_index: return - if settings.docStoreConn.indexExist(self.index_name, self.kb_id): - settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) - settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) + if settings.docStoreConn.index_exist(self.index_name, self.kb_id): + settings.docStoreConn.delete_idx(self.index_name, self.kb_id) + settings.docStoreConn.create_idx(self.index_name, self.kb_id, vector_size) self.initialized_index = True def ms_marco_index(self, file_path, index_name): diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 2c203f521..1cb2f4071 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -19,11 +19,12 @@ import json import re from collections import defaultdict -from rag.utils.doc_store_conn import MatchTextExpr +from common.query_base import QueryBase +from common.doc_store.doc_store_base import MatchTextExpr from rag.nlp import rag_tokenizer, term_weight, synonym -class FulltextQueryer: +class FulltextQueryer(QueryBase): def __init__(self): self.tw = term_weight.Dealer() self.syn = synonym.Dealer() @@ -37,64 +38,19 @@ class FulltextQueryer: "content_sm_ltks", ] - @staticmethod - def sub_special_char(line): - return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip() - - @staticmethod - def is_chinese(line): - arr = re.split(r"[ \t]+", line) - if len(arr) <= 3: - return True - e = 0 - for t in arr: - if not re.match(r"[a-zA-Z]+$", t): - e += 1 - return e * 1.0 / len(arr) >= 0.7 - - @staticmethod - def rmWWW(txt): - patts = [ - ( - r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*", - "", - ), - (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), - ( - r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", - " ") - ] - otxt = txt - for r, p in patts: - txt = re.sub(r, p, txt, flags=re.IGNORECASE) - if not txt: - txt = otxt - return txt - - @staticmethod - def add_space_between_eng_zh(txt): - # (ENG/ENG+NUM) + ZH - txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt) - # ENG + ZH - txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt) - # ZH + (ENG/ENG+NUM) - txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt) - txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt) - return txt - def question(self, txt, tbl="qa", min_match: float = 0.6): original_query = txt - txt = FulltextQueryer.add_space_between_eng_zh(txt) + txt = self.add_space_between_eng_zh(txt) txt = re.sub( r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+", " ", rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())), ).strip() otxt = txt - txt = FulltextQueryer.rmWWW(txt) + txt = self.rmWWW(txt) if not self.is_chinese(txt): - txt = FulltextQueryer.rmWWW(txt) + txt = self.rmWWW(txt) tks = rag_tokenizer.tokenize(txt).split() keywords = [t for t in tks if t] tks_w = self.tw.weights(tks, preprocess=False) @@ -138,7 +94,7 @@ class FulltextQueryer: return False return True - txt = FulltextQueryer.rmWWW(txt) + txt = self.rmWWW(txt) qs, keywords = [], [] for tt in self.tw.split(txt)[:256]: # .split(): if not tt: @@ -164,7 +120,7 @@ class FulltextQueryer: ) for m in sm ] - sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1] + sm = [self.sub_special_char(m) for m in sm if len(m) > 1] sm = [m for m in sm if len(m) > 1] if len(keywords) < 32: @@ -172,7 +128,7 @@ class FulltextQueryer: keywords.extend(sm) tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns] + tk_syns = [self.sub_special_char(s) for s in tk_syns] if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] @@ -181,7 +137,7 @@ class FulltextQueryer: if len(keywords) >= 32: break - tk = FulltextQueryer.sub_special_char(tk) + tk = self.sub_special_char(tk) if tk.find(" ") > 0: tk = '"%s"' % tk if tk_syns: @@ -199,7 +155,7 @@ class FulltextQueryer: syns = " OR ".join( [ '"%s"' - % rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s)) + % rag_tokenizer.tokenize(self.sub_special_char(s)) for s in syns ] ) @@ -264,10 +220,10 @@ class FulltextQueryer: keywords = [f'"{k.strip()}"' for k in keywords] for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]: tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns] + tk_syns = [self.sub_special_char(s) for s in tk_syns] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] - tk = FulltextQueryer.sub_special_char(tk) + tk = self.sub_special_char(tk) if tk.find(" ") > 0: tk = '"%s"' % tk if tk_syns: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index d2129e77f..988887e55 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -24,7 +24,7 @@ from dataclasses import dataclass from rag.prompts.generator import relevant_chunks_with_toc from rag.nlp import rag_tokenizer, query import numpy as np -from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr +from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByExpr, DocStoreConnection from common.string_utils import remove_redundant_spaces from common.float_utils import get_float from common.constants import PAGERANK_FLD, TAG_FLD @@ -155,7 +155,7 @@ class Dealer: kwds.add(kk) logging.debug(f"TOTAL: {total}") - ids = self.dataStore.get_chunk_ids(res) + ids = self.dataStore.get_doc_ids(res) keywords = list(kwds) highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight") aggs = self.dataStore.get_aggregation(res, "docnm_kwd") @@ -545,7 +545,7 @@ class Dealer: return res def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000): - if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]): + if not self.dataStore.index_exist(index_name(tenant_id), kb_ids[0]): return [] res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) return self.dataStore.get_aggregation(res, "tag_kwd") diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 5f93f112a..d11cc2ce5 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -136,6 +136,19 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False): return knowledges +def memory_prompt(message_list, max_tokens): + used_token_count = 0 + content_list = [] + for message in message_list: + current_content_tokens = num_tokens_from_string(message["content"]) + if used_token_count + current_content_tokens > max_tokens * 0.97: + logging.warning(f"Not all the retrieval into prompt: {len(content_list)}/{len(message_list)}") + break + content_list.append(message["content"]) + used_token_count += current_content_tokens + return content_list + + CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt") CITATION_PLUS_TEMPLATE = load_prompt("citation_plus") CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt") diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 68817a66b..84fdf968f 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -506,7 +506,7 @@ def build_TOC(task, docs, progress_callback): def init_kb(row, vector_size: int): idxnm = search.index_name(row["tenant_id"]) - return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) + return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size) async def embedding(docs, mdl, parser_config=None, callback=None): diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index cca3fc7c7..c991ac2a8 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -14,194 +14,92 @@ # limitations under the License. # -import logging import re import json import time -import os import copy -from elasticsearch import Elasticsearch, NotFoundError -from elasticsearch_dsl import UpdateByQuery, Q, Search, Index +from elasticsearch_dsl import UpdateByQuery, Q, Search from elastic_transport import ConnectionTimeout from common.decorator import singleton -from common.file_utils import get_project_base_directory -from common.misc_utils import convert_bytes -from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ - FusionExpr -from rag.nlp import is_english, rag_tokenizer +from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr +from common.doc_store.es_conn_base import ESConnectionBase from common.float_utils import get_float -from common import settings from common.constants import PAGERANK_FLD, TAG_FLD ATTEMPT_TIME = 2 -logger = logging.getLogger('ragflow.es_conn') - @singleton -class ESConnection(DocStoreConnection): - def __init__(self): - self.info = {} - logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.") - for _ in range(ATTEMPT_TIME): - try: - if self._connect(): - break - except Exception as e: - logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.") - time.sleep(5) - - if not self.es.ping(): - msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s." - logger.error(msg) - raise Exception(msg) - v = self.info.get("version", {"number": "8.11.3"}) - v = v["number"].split(".")[0] - if int(v) < 8: - msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}" - logger.error(msg) - raise Exception(msg) - fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json") - if not os.path.exists(fp_mapping): - msg = f"Elasticsearch mapping file not found at {fp_mapping}" - logger.error(msg) - raise Exception(msg) - self.mapping = json.load(open(fp_mapping, "r")) - logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.") - - def _connect(self): - self.es = Elasticsearch( - settings.ES["hosts"].split(","), - basic_auth=(settings.ES["username"], settings.ES[ - "password"]) if "username" in settings.ES and "password" in settings.ES else None, - verify_certs= settings.ES.get("verify_certs", False), - timeout=600 ) - if self.es: - self.info = self.es.info() - return True - return False - - """ - Database operations - """ - - def dbType(self) -> str: - return "elasticsearch" - - def health(self) -> dict: - health_dict = dict(self.es.cluster.health()) - health_dict["type"] = "elasticsearch" - return health_dict - - """ - Table operations - """ - - def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): - if self.indexExist(indexName, knowledgebaseId): - return True - try: - from elasticsearch.client import IndicesClient - return IndicesClient(self.es).create(index=indexName, - settings=self.mapping["settings"], - mappings=self.mapping["mappings"]) - except Exception: - logger.exception("ESConnection.createIndex error %s" % (indexName)) - - def deleteIdx(self, indexName: str, knowledgebaseId: str): - if len(knowledgebaseId) > 0: - # The index need to be alive after any kb deletion since all kb under this tenant are in one index. - return - try: - self.es.indices.delete(index=indexName, allow_no_indices=True) - except NotFoundError: - pass - except Exception: - logger.exception("ESConnection.deleteIdx error %s" % (indexName)) - - def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool: - s = Index(indexName, self.es) - for i in range(ATTEMPT_TIME): - try: - return s.exists() - except ConnectionTimeout: - logger.exception("ES request timeout") - time.sleep(3) - self._connect() - continue - except Exception as e: - logger.exception(e) - break - return False +class ESConnection(ESConnectionBase): """ CRUD operations """ def search( - self, selectFields: list[str], - highlightFields: list[str], + self, select_fields: list[str], + highlight_fields: list[str], condition: dict, - matchExprs: list[MatchExpr], - orderBy: OrderByExpr, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, offset: int, limit: int, - indexNames: str | list[str], - knowledgebaseIds: list[str], - aggFields: list[str] = [], + index_names: str | list[str], + knowledgebase_ids: list[str], + agg_fields: list[str] | None = None, rank_feature: dict | None = None ): """ Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html """ - if isinstance(indexNames, str): - indexNames = indexNames.split(",") - assert isinstance(indexNames, list) and len(indexNames) > 0 + if isinstance(index_names, str): + index_names = index_names.split(",") + assert isinstance(index_names, list) and len(index_names) > 0 assert "_id" not in condition - bqry = Q("bool", must=[]) - condition["kb_id"] = knowledgebaseIds + bool_query = Q("bool", must=[]) + condition["kb_id"] = knowledgebase_ids for k, v in condition.items(): if k == "available_int": if v == 0: - bqry.filter.append(Q("range", available_int={"lt": 1})) + bool_query.filter.append(Q("range", available_int={"lt": 1})) else: - bqry.filter.append( + bool_query.filter.append( Q("bool", must_not=Q("range", available_int={"lt": 1}))) continue if not v: continue if isinstance(v, list): - bqry.filter.append(Q("terms", **{k: v})) + bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): - bqry.filter.append(Q("term", **{k: v})) + bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 - for m in matchExprs: + for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: - assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], - MatchDenseExpr) and isinstance( - matchExprs[2], FusionExpr) + assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1], + MatchDenseExpr) and isinstance( + match_expressions[2], FusionExpr) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) - for m in matchExprs: + for m in match_expressions: if isinstance(m, MatchTextExpr): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" - bqry.must.append(Q("query_string", fields=m.fields, + bool_query.must.append(Q("query_string", fields=m.fields, type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) - bqry.boost = 1.0 - vector_similarity_weight + bool_query.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): - assert (bqry is not None) + assert (bool_query is not None) similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] @@ -209,24 +107,24 @@ class ESConnection(DocStoreConnection): m.topn, m.topn * 2, query_vector=list(m.embedding_data), - filter=bqry.to_dict(), + filter=bool_query.to_dict(), similarity=similarity, ) - if bqry and rank_feature: + if bool_query and rank_feature: for fld, sc in rank_feature.items(): if fld != PAGERANK_FLD: fld = f"{TAG_FLD}.{fld}" - bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) + bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) - if bqry: - s = s.query(bqry) - for field in highlightFields: + if bool_query: + s = s.query(bool_query) + for field in highlight_fields: s = s.highlight(field) - if orderBy: + if order_by: orders = list() - for field, order in orderBy.fields: + for field, order in order_by.fields: order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: order_info = {"order": order, "unmapped_type": "float", @@ -237,19 +135,19 @@ class ESConnection(DocStoreConnection): order_info = {"order": order, "unmapped_type": "text"} orders.append({field: order_info}) s = s.sort(*orders) - - for fld in aggFields: - s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) + if agg_fields: + for fld in agg_fields: + s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) if limit > 0: s = s[offset:offset + limit] q = s.to_dict() - logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q)) + self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q)) for i in range(ATTEMPT_TIME): try: #print(json.dumps(q, ensure_ascii=False)) - res = self.es.search(index=indexNames, + res = self.es.search(index=index_names, body=q, timeout="600s", # search_type="dfs_query_then_fetch", @@ -257,55 +155,37 @@ class ESConnection(DocStoreConnection): _source=True) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") - logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res)) + self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res)) return res except ConnectionTimeout: - logger.exception("ES request timeout") + self.logger.exception("ES request timeout") self._connect() continue except Exception as e: - logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e)) + self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e)) raise e - logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") + self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") - def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: - for i in range(ATTEMPT_TIME): - try: - res = self.es.get(index=(indexName), - id=chunkId, source=True, ) - if str(res.get("timed_out", "")).lower() == "true": - raise Exception("Es Timeout.") - chunk = res["_source"] - chunk["id"] = chunkId - return chunk - except NotFoundError: - return None - except Exception as e: - logger.exception(f"ESConnection.get({chunkId}) got exception") - raise e - logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!") - raise Exception("ESConnection.get timeout.") - - def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: + def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html operations = [] for d in documents: assert "_id" not in d assert "id" in d d_copy = copy.deepcopy(d) - d_copy["kb_id"] = knowledgebaseId + d_copy["kb_id"] = knowledgebase_id meta_id = d_copy.pop("id", "") operations.append( - {"index": {"_index": indexName, "_id": meta_id}}) + {"index": {"_index": index_name, "_id": meta_id}}) operations.append(d_copy) res = [] for _ in range(ATTEMPT_TIME): try: res = [] - r = self.es.bulk(index=(indexName), operations=operations, + r = self.es.bulk(index=index_name, operations=operations, refresh=False, timeout="60s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res @@ -316,58 +196,58 @@ class ESConnection(DocStoreConnection): res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) return res except ConnectionTimeout: - logger.exception("ES request timeout") + self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: res.append(str(e)) - logger.warning("ESConnection.insert got exception: " + str(e)) + self.logger.warning("ESConnection.insert got exception: " + str(e)) return res - def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: - doc = copy.deepcopy(newValue) + def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool: + doc = copy.deepcopy(new_value) doc.pop("id", None) - condition["kb_id"] = knowledgebaseId + condition["kb_id"] = knowledgebase_id if "id" in condition and isinstance(condition["id"], str): # update specific single document - chunkId = condition["id"] + chunk_id = condition["id"] for i in range(ATTEMPT_TIME): for k in doc.keys(): if "feas" != k.split("_")[-1]: continue try: - self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");") + self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");") except Exception: - logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") + self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") try: - self.es.update(index=indexName, id=chunkId, doc=doc) + self.es.update(index=index_name, id=chunk_id, doc=doc) return True except Exception as e: - logger.exception( - f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e)) + self.logger.exception( + f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) break return False # update unspecific maybe-multiple documents - bqry = Q("bool") + bool_query = Q("bool") for k, v in condition.items(): if not isinstance(k, str) or not v: continue if k == "exists": - bqry.filter.append(Q("exists", field=v)) + bool_query.filter.append(Q("exists", field=v)) continue if isinstance(v, list): - bqry.filter.append(Q("terms", **{k: v})) + bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): - bqry.filter.append(Q("term", **{k: v})) + bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") scripts = [] params = {} - for k, v in newValue.items(): + for k, v in new_value.items(): if k == "remove": if isinstance(v, str): scripts.append(f"ctx._source.remove('{v}');") @@ -397,8 +277,8 @@ class ESConnection(DocStoreConnection): raise Exception( f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") ubq = UpdateByQuery( - index=indexName).using( - self.es).query(bqry) + index=index_name).using( + self.es).query(bool_query) ubq = ubq.script(source="".join(scripts), params=params) ubq = ubq.params(refresh=True) ubq = ubq.params(slices=5) @@ -409,19 +289,18 @@ class ESConnection(DocStoreConnection): _ = ubq.execute() return True except ConnectionTimeout: - logger.exception("ES request timeout") + self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: - logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts)) + self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts)) break return False - def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: - qry = None + def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int: assert "_id" not in condition - condition["kb_id"] = knowledgebaseId + condition["kb_id"] = knowledgebase_id if "id" in condition: chunk_ids = condition["id"] if not isinstance(chunk_ids, list): @@ -448,21 +327,21 @@ class ESConnection(DocStoreConnection): qry.must.append(Q("term", **{k: v})) else: raise Exception("Condition value must be int, str or list.") - logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict())) + self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict())) for _ in range(ATTEMPT_TIME): try: res = self.es.delete_by_query( - index=indexName, + index=index_name, body=Search().query(qry).to_dict(), refresh=True) return res["deleted"] except ConnectionTimeout: - logger.exception("ES request timeout") + self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: - logger.warning("ESConnection.delete got exception: " + str(e)) + self.logger.warning("ESConnection.delete got exception: " + str(e)) if re.search(r"(not_found)", str(e), re.IGNORECASE): return 0 return 0 @@ -471,27 +350,11 @@ class ESConnection(DocStoreConnection): Helper functions for search result """ - def get_total(self, res): - if isinstance(res["hits"]["total"], type({})): - return res["hits"]["total"]["value"] - return res["hits"]["total"] - - def get_chunk_ids(self, res): - return [d["_id"] for d in res["hits"]["hits"]] - - def __getSource(self, res): - rr = [] - for d in res["hits"]["hits"]: - d["_source"]["id"] = d["_id"] - d["_source"]["_score"] = d["_score"] - rr.append(d["_source"]) - return rr - def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} - for d in self.__getSource(res): + for d in self._get_source(res): m = {n: d.get(n) for n in fields if d.get(n) is not None} for n, v in m.items(): if isinstance(v, list): @@ -508,124 +371,3 @@ class ESConnection(DocStoreConnection): if m: res_fields[d["id"]] = m return res_fields - - def get_highlight(self, res, keywords: list[str], fieldnm: str): - ans = {} - for d in res["hits"]["hits"]: - hlts = d.get("highlight") - if not hlts: - continue - txt = "...".join([a for a in list(hlts.items())[0][1]]) - if not is_english(txt.split()): - ans[d["_id"]] = txt - continue - - txt = d["_source"][fieldnm] - txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) - txts = [] - for t in re.split(r"[.?!;\n]", txt): - for w in keywords: - t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1\2\3", t, - flags=re.IGNORECASE | re.MULTILINE) - if not re.search(r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE): - continue - txts.append(t) - ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]]) - - return ans - - def get_aggregation(self, res, fieldnm: str): - agg_field = "aggs_" + fieldnm - if "aggregations" not in res or agg_field not in res["aggregations"]: - return list() - bkts = res["aggregations"][agg_field]["buckets"] - return [(b["key"], b["doc_count"]) for b in bkts] - - """ - SQL - """ - - def sql(self, sql: str, fetch_size: int, format: str): - logger.debug(f"ESConnection.sql get sql: {sql}") - sql = re.sub(r"[ `]+", " ", sql) - sql = sql.replace("%", "") - replaces = [] - for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): - fld, v = r.group(1), r.group(3) - match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format( - fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v))) - replaces.append( - ("{}{}'{}'".format( - r.group(1), - r.group(2), - r.group(3)), - match)) - - for p, r in replaces: - sql = sql.replace(p, r, 1) - logger.debug(f"ESConnection.sql to es: {sql}") - - for i in range(ATTEMPT_TIME): - try: - res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, - request_timeout="2s") - return res - except ConnectionTimeout: - logger.exception("ES request timeout") - time.sleep(3) - self._connect() - continue - except Exception as e: - logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}") - raise Exception(f"SQL error: {e}\n\nSQL: {sql}") - logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!") - return None - - def get_cluster_stats(self): - """ - curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats. - """ - raw_stats = self.es.cluster.stats() - logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}") - try: - res = { - 'cluster_name': raw_stats['cluster_name'], - 'status': raw_stats['status'] - } - indices_status = raw_stats['indices'] - res.update({ - 'indices': indices_status['count'], - 'indices_shards': indices_status['shards']['total'] - }) - doc_info = indices_status['docs'] - res.update({ - 'docs': doc_info['count'], - 'docs_deleted': doc_info['deleted'] - }) - store_info = indices_status['store'] - res.update({ - 'store_size': convert_bytes(store_info['size_in_bytes']), - 'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes']) - }) - mappings_info = indices_status['mappings'] - res.update({ - 'mappings_fields': mappings_info['total_field_count'], - 'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'], - 'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes']) - }) - node_info = raw_stats['nodes'] - res.update({ - 'nodes': node_info['count']['total'], - 'nodes_version': node_info['versions'], - 'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']), - 'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']), - 'os_mem_used_percent': node_info['os']['mem']['used_percent'], - 'jvm_versions': node_info['jvm']['versions'][0]['vm_version'], - 'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']), - 'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes']) - }) - return res - - except Exception as e: - logger.exception(f"ESConnection.get_cluster_stats: {e}") - return None diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 1a0edd418..8805a754b 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -14,365 +14,125 @@ # limitations under the License. # -import logging -import os import re import json -import time import copy -import infinity -from infinity.common import ConflictType, InfinityException, SortType -from infinity.index import IndexInfo, IndexType -from infinity.connection_pool import ConnectionPool +from infinity.common import InfinityException, SortType from infinity.errors import ErrorCode from common.decorator import singleton import pandas as pd -from common.file_utils import get_project_base_directory -from rag.nlp import is_english from common.constants import PAGERANK_FLD, TAG_FLD -from common import settings -from rag.utils.doc_store_conn import ( - DocStoreConnection, - MatchExpr, - MatchTextExpr, - MatchDenseExpr, - FusionExpr, - OrderByExpr, -) - -logger = logging.getLogger("ragflow.infinity_conn") - - -def field_keyword(field_name: str): - # Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like. - if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]): - return True - return False - -def convert_select_fields(output_fields: list[str]) -> list[str]: - for i, field in enumerate(output_fields): - if field in ["docnm_kwd", "title_tks", "title_sm_tks"]: - output_fields[i] = "docnm" - elif field in ["important_kwd", "important_tks"]: - output_fields[i] = "important_keywords" - elif field in ["question_kwd", "question_tks"]: - output_fields[i] = "questions" - elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]: - output_fields[i] = "content" - elif field in ["authors_tks", "authors_sm_tks"]: - output_fields[i] = "authors" - return list(set(output_fields)) - -def convert_matching_field(field_weightstr: str) -> str: - tokens = field_weightstr.split("^") - field = tokens[0] - if field == "docnm_kwd" or field == "title_tks": - field = "docnm@ft_docnm_rag_coarse" - elif field == "title_sm_tks": - field = "docnm@ft_docnm_rag_fine" - elif field == "important_kwd": - field = "important_keywords@ft_important_keywords_rag_coarse" - elif field == "important_tks": - field = "important_keywords@ft_important_keywords_rag_fine" - elif field == "question_kwd": - field = "questions@ft_questions_rag_coarse" - elif field == "question_tks": - field = "questions@ft_questions_rag_fine" - elif field == "content_with_weight" or field == "content_ltks": - field = "content@ft_content_rag_coarse" - elif field == "content_sm_ltks": - field = "content@ft_content_rag_fine" - elif field == "authors_tks": - field = "authors@ft_authors_rag_coarse" - elif field == "authors_sm_tks": - field = "authors@ft_authors_rag_fine" - tokens[0] = field - return "^".join(tokens) - -def list2str(lst: str|list, sep: str = " ") -> str: - if isinstance(lst, str): - return lst - return sep.join(lst) - - -def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None: - assert "_id" not in condition - clmns = {} - if table_instance: - for n, ty, de, _ in table_instance.show_columns().rows(): - clmns[n] = (ty, de) - - def exists(cln): - nonlocal clmns - assert cln in clmns, f"'{cln}' should be in '{clmns}'." - ty, de = clmns[cln] - if ty.lower().find("cha"): - if not de: - de = "" - return f" {cln}!='{de}' " - return f"{cln}!={de}" - - cond = list() - for k, v in condition.items(): - if not isinstance(k, str) or not v: - continue - if field_keyword(k): - if isinstance(v, list): - inCond = list() - for item in v: - if isinstance(item, str): - item = item.replace("'", "''") - inCond.append(f"filter_fulltext('{convert_matching_field(k)}', '{item}')") - if inCond: - strInCond = " or ".join(inCond) - strInCond = f"({strInCond})" - cond.append(strInCond) - else: - cond.append(f"filter_fulltext('{convert_matching_field(k)}', '{v}')") - elif isinstance(v, list): - inCond = list() - for item in v: - if isinstance(item, str): - item = item.replace("'", "''") - inCond.append(f"'{item}'") - else: - inCond.append(str(item)) - if inCond: - strInCond = ", ".join(inCond) - strInCond = f"{k} IN ({strInCond})" - cond.append(strInCond) - elif k == "must_not": - if isinstance(v, dict): - for kk, vv in v.items(): - if kk == "exists": - cond.append("NOT (%s)" % exists(vv)) - elif isinstance(v, str): - cond.append(f"{k}='{v}'") - elif k == "exists": - cond.append(exists(v)) - else: - cond.append(f"{k}={str(v)}") - return " AND ".join(cond) if cond else "1=1" - - -def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame: - df_list2 = [df for df in df_list if not df.empty] - if df_list2: - return pd.concat(df_list2, axis=0).reset_index(drop=True) - - schema = [] - for field_name in selectFields: - if field_name == "score()": # Workaround: fix schema is changed to score() - schema.append("SCORE") - elif field_name == "similarity()": # Workaround: fix schema is changed to similarity() - schema.append("SIMILARITY") - else: - schema.append(field_name) - return pd.DataFrame(columns=schema) +from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr +from common.doc_store.infinity_conn_base import InfinityConnectionBase @singleton -class InfinityConnection(DocStoreConnection): - def __init__(self): - self.dbName = settings.INFINITY.get("db_name", "default_db") - infinity_uri = settings.INFINITY["uri"] - if ":" in infinity_uri: - host, port = infinity_uri.split(":") - infinity_uri = infinity.common.NetworkAddress(host, int(port)) - self.connPool = None - logger.info(f"Use Infinity {infinity_uri} as the doc engine.") - for _ in range(24): - try: - connPool = ConnectionPool(infinity_uri, max_size=4) - inf_conn = connPool.get_conn() - res = inf_conn.show_current_node() - if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]: - self._migrate_db(inf_conn) - self.connPool = connPool - connPool.release_conn(inf_conn) - break - connPool.release_conn(inf_conn) - logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.") - time.sleep(5) - except Exception as e: - logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") - time.sleep(5) - if self.connPool is None: - msg = f"Infinity {infinity_uri} is unhealthy in 120s." - logger.error(msg) - raise Exception(msg) - logger.info(f"Infinity {infinity_uri} is healthy.") - - def _migrate_db(self, inf_conn): - inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) - fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json") - if not os.path.exists(fp_mapping): - raise Exception(f"Mapping file not found at {fp_mapping}") - schema = json.load(open(fp_mapping)) - table_names = inf_db.list_tables().table_names - for table_name in table_names: - inf_table = inf_db.get_table(table_name) - index_names = inf_table.list_indexes().index_names - if "q_vec_idx" not in index_names: - # Skip tables not created by me - continue - column_names = inf_table.show_columns()["name"] - column_names = set(column_names) - for field_name, field_info in schema.items(): - if field_name in column_names: - continue - res = inf_table.add_columns({field_name: field_info}) - assert res.error_code == infinity.ErrorCode.OK - logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}") - if field_info["type"] != "varchar" or "analyzer" not in field_info: - continue - analyzers = field_info["analyzer"] - if isinstance(analyzers, str): - analyzers = [analyzers] - for analyzer in analyzers: - inf_table.create_index( - f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}", - IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}), - ConflictType.Ignore, - ) +class InfinityConnection(InfinityConnectionBase): """ - Database operations + Dataframe and fields convert """ - def dbType(self) -> str: - return "infinity" - - def health(self) -> dict: - """ - Return the health status of the database. - """ - inf_conn = self.connPool.get_conn() - res = inf_conn.show_current_node() - self.connPool.release_conn(inf_conn) - res2 = { - "type": "infinity", - "status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red", - "error": res.error_msg, - } - return res2 - - """ - Table operations - """ - - def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): - table_name = f"{indexName}_{knowledgebaseId}" - inf_conn = self.connPool.get_conn() - inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) - - fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json") - if not os.path.exists(fp_mapping): - raise Exception(f"Mapping file not found at {fp_mapping}") - schema = json.load(open(fp_mapping)) - vector_name = f"q_{vectorSize}_vec" - schema[vector_name] = {"type": f"vector,{vectorSize},float"} - inf_table = inf_db.create_table( - table_name, - schema, - ConflictType.Ignore, - ) - inf_table.create_index( - "q_vec_idx", - IndexInfo( - vector_name, - IndexType.Hnsw, - { - "M": "16", - "ef_construction": "50", - "metric": "cosine", - "encode": "lvq", - }, - ), - ConflictType.Ignore, - ) - for field_name, field_info in schema.items(): - if field_info["type"] != "varchar" or "analyzer" not in field_info: - continue - analyzers = field_info["analyzer"] - if isinstance(analyzers, str): - analyzers = [analyzers] - for analyzer in analyzers: - inf_table.create_index( - f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}", - IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}), - ConflictType.Ignore, - ) - self.connPool.release_conn(inf_conn) - logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}") - - def deleteIdx(self, indexName: str, knowledgebaseId: str): - table_name = f"{indexName}_{knowledgebaseId}" - inf_conn = self.connPool.get_conn() - db_instance = inf_conn.get_database(self.dbName) - db_instance.drop_table(table_name, ConflictType.Ignore) - self.connPool.release_conn(inf_conn) - logger.info(f"INFINITY dropped table {table_name}") - - def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: - table_name = f"{indexName}_{knowledgebaseId}" - try: - inf_conn = self.connPool.get_conn() - db_instance = inf_conn.get_database(self.dbName) - _ = db_instance.get_table(table_name) - self.connPool.release_conn(inf_conn) + @staticmethod + def field_keyword(field_name: str): + # Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like. + if field_name == "source_id" or ( + field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", + "question_kwd"]): return True - except Exception as e: - logger.warning(f"INFINITY indexExist {str(e)}") return False + def convert_select_fields(self, output_fields: list[str]) -> list[str]: + for i, field in enumerate(output_fields): + if field in ["docnm_kwd", "title_tks", "title_sm_tks"]: + output_fields[i] = "docnm" + elif field in ["important_kwd", "important_tks"]: + output_fields[i] = "important_keywords" + elif field in ["question_kwd", "question_tks"]: + output_fields[i] = "questions" + elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]: + output_fields[i] = "content" + elif field in ["authors_tks", "authors_sm_tks"]: + output_fields[i] = "authors" + return list(set(output_fields)) + + @staticmethod + def convert_matching_field(field_weight_str: str) -> str: + tokens = field_weight_str.split("^") + field = tokens[0] + if field == "docnm_kwd" or field == "title_tks": + field = "docnm@ft_docnm_rag_coarse" + elif field == "title_sm_tks": + field = "docnm@ft_docnm_rag_fine" + elif field == "important_kwd": + field = "important_keywords@ft_important_keywords_rag_coarse" + elif field == "important_tks": + field = "important_keywords@ft_important_keywords_rag_fine" + elif field == "question_kwd": + field = "questions@ft_questions_rag_coarse" + elif field == "question_tks": + field = "questions@ft_questions_rag_fine" + elif field == "content_with_weight" or field == "content_ltks": + field = "content@ft_content_rag_coarse" + elif field == "content_sm_ltks": + field = "content@ft_content_rag_fine" + elif field == "authors_tks": + field = "authors@ft_authors_rag_coarse" + elif field == "authors_sm_tks": + field = "authors@ft_authors_rag_fine" + tokens[0] = field + return "^".join(tokens) + + """ CRUD operations """ def search( self, - selectFields: list[str], - highlightFields: list[str], + select_fields: list[str], + highlight_fields: list[str], condition: dict, - matchExprs: list[MatchExpr], - orderBy: OrderByExpr, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, offset: int, limit: int, - indexNames: str | list[str], - knowledgebaseIds: list[str], - aggFields: list[str] = [], + index_names: str | list[str], + knowledgebase_ids: list[str], + agg_fields: list[str] | None = None, rank_feature: dict | None = None, ) -> tuple[pd.DataFrame, int]: """ BUG: Infinity returns empty for a highlight field if the query string doesn't use that field. """ - if isinstance(indexNames, str): - indexNames = indexNames.split(",") - assert isinstance(indexNames, list) and len(indexNames) > 0 + if isinstance(index_names, str): + index_names = index_names.split(",") + assert isinstance(index_names, list) and len(index_names) > 0 inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) df_list = list() table_list = list() - output = selectFields.copy() - output = convert_select_fields(output) - for essential_field in ["id"] + aggFields: + output = select_fields.copy() + output = self.convert_select_fields(output) + if agg_fields is None: + agg_fields = [] + for essential_field in ["id"] + agg_fields: if essential_field not in output: output.append(essential_field) score_func = "" score_column = "" - for matchExpr in matchExprs: + for matchExpr in match_expressions: if isinstance(matchExpr, MatchTextExpr): score_func = "score()" score_column = "SCORE" break if not score_func: - for matchExpr in matchExprs: + for matchExpr in match_expressions: if isinstance(matchExpr, MatchDenseExpr): score_func = "similarity()" score_column = "SIMILARITY" break - if matchExprs: + if match_expressions: if score_func not in output: output.append(score_func) if PAGERANK_FLD not in output: @@ -387,11 +147,11 @@ class InfinityConnection(DocStoreConnection): filter_fulltext = "" if condition: table_found = False - for indexName in indexNames: - for kb_id in knowledgebaseIds: + for indexName in index_names: + for kb_id in knowledgebase_ids: table_name = f"{indexName}_{kb_id}" try: - filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name)) + filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name)) table_found = True break except Exception: @@ -399,14 +159,14 @@ class InfinityConnection(DocStoreConnection): if table_found: break if not table_found: - logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}") + self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}") return pd.DataFrame(), 0 - for matchExpr in matchExprs: + for matchExpr in match_expressions: if isinstance(matchExpr, MatchTextExpr): if filter_cond and "filter" not in matchExpr.extra_options: matchExpr.extra_options.update({"filter": filter_cond}) - matchExpr.fields = [convert_matching_field(field) for field in matchExpr.fields] + matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields] fields = ",".join(matchExpr.fields) filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')" if filter_cond: @@ -430,7 +190,7 @@ class InfinityConnection(DocStoreConnection): for k, v in matchExpr.extra_options.items(): if not isinstance(v, str): matchExpr.extra_options[k] = str(v) - logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}") + self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}") elif isinstance(matchExpr, MatchDenseExpr): if filter_fulltext and "filter" not in matchExpr.extra_options: matchExpr.extra_options.update({"filter": filter_fulltext}) @@ -441,16 +201,16 @@ class InfinityConnection(DocStoreConnection): if similarity: matchExpr.extra_options["threshold"] = similarity del matchExpr.extra_options["similarity"] - logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}") + self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}") elif isinstance(matchExpr, FusionExpr): if matchExpr.method == "weighted_sum": # The default is "minmax" which gives a zero score for the last doc. matchExpr.fusion_params["normalize"] = "atan" - logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}") + self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}") order_by_expr_list = list() - if orderBy.fields: - for order_field in orderBy.fields: + if order_by.fields: + for order_field in order_by.fields: if order_field[1] == 0: order_by_expr_list.append((order_field[0], SortType.Asc)) else: @@ -458,8 +218,8 @@ class InfinityConnection(DocStoreConnection): total_hits_count = 0 # Scatter search tables and gather the results - for indexName in indexNames: - for knowledgebaseId in knowledgebaseIds: + for indexName in index_names: + for knowledgebaseId in knowledgebase_ids: table_name = f"{indexName}_{knowledgebaseId}" try: table_instance = db_instance.get_table(table_name) @@ -467,8 +227,8 @@ class InfinityConnection(DocStoreConnection): continue table_list.append(table_name) builder = table_instance.output(output) - if len(matchExprs) > 0: - for matchExpr in matchExprs: + if len(match_expressions) > 0: + for matchExpr in match_expressions: if isinstance(matchExpr, MatchTextExpr): fields = ",".join(matchExpr.fields) builder = builder.match_text( @@ -491,53 +251,52 @@ class InfinityConnection(DocStoreConnection): else: if filter_cond and len(filter_cond) > 0: builder.filter(filter_cond) - if orderBy.fields: + if order_by.fields: builder.sort(order_by_expr_list) builder.offset(offset).limit(limit) kb_res, extra_result = builder.option({"total_hits_count": True}).to_df() if extra_result: total_hits_count += int(extra_result["total_hits_count"]) - logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}") + self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}") df_list.append(kb_res) self.connPool.release_conn(inf_conn) - res = concat_dataframes(df_list, output) - if matchExprs: + res = self.concat_dataframes(df_list, output) + if match_expressions: res["_score"] = res[score_column] + res[PAGERANK_FLD] res = res.sort_values(by="_score", ascending=False).reset_index(drop=True) res = res.head(limit) - logger.debug(f"INFINITY search final result: {str(res)}") + self.logger.debug(f"INFINITY search final result: {str(res)}") return res, total_hits_count - def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: + def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) df_list = list() - assert isinstance(knowledgebaseIds, list) + assert isinstance(knowledgebase_ids, list) table_list = list() - for knowledgebaseId in knowledgebaseIds: - table_name = f"{indexName}_{knowledgebaseId}" + for knowledgebaseId in knowledgebase_ids: + table_name = f"{index_name}_{knowledgebaseId}" table_list.append(table_name) - table_instance = None try: table_instance = db_instance.get_table(table_name) except Exception: - logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.") + self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.") continue - kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df() - logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") + kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df() + self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") df_list.append(kb_res) self.connPool.release_conn(inf_conn) - res = concat_dataframes(df_list, ["id"]) + res = self.concat_dataframes(df_list, ["id"]) fields = set(res.columns.tolist()) for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]: fields.add(field) res_fields = self.get_fields(res, list(fields)) - return res_fields.get(chunkId, None) + return res_fields.get(chunk_id, None) - def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: + def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) - table_name = f"{indexName}_{knowledgebaseId}" + table_name = f"{index_name}_{knowledgebase_id}" try: table_instance = db_instance.get_table(table_name) except InfinityException as e: @@ -553,7 +312,7 @@ class InfinityConnection(DocStoreConnection): break if vector_size == 0: raise ValueError("Cannot infer vector size from documents") - self.createIdx(indexName, knowledgebaseId, vector_size) + self.create_idx(index_name, knowledgebase_id, vector_size) table_instance = db_instance.get_table(table_name) # embedding fields can't have a default value.... @@ -574,12 +333,12 @@ class InfinityConnection(DocStoreConnection): d["docnm"] = v elif k == "title_kwd": if not d.get("docnm_kwd"): - d["docnm"] = list2str(v) + d["docnm"] = self.list2str(v) elif k == "title_sm_tks": if not d.get("docnm_kwd"): - d["docnm"] = list2str(v) + d["docnm"] = self.list2str(v) elif k == "important_kwd": - d["important_keywords"] = list2str(v) + d["important_keywords"] = self.list2str(v) elif k == "important_tks": if not d.get("important_kwd"): d["important_keywords"] = v @@ -597,11 +356,11 @@ class InfinityConnection(DocStoreConnection): if not d.get("authors_tks"): d["authors"] = v elif k == "question_kwd": - d["questions"] = list2str(v, "\n") + d["questions"] = self.list2str(v, "\n") elif k == "question_tks": if not d.get("question_kwd"): - d["questions"] = list2str(v) - elif field_keyword(k): + d["questions"] = self.list2str(v) + elif self.field_keyword(k): if isinstance(v, list): d[k] = "###".join(v) else: @@ -637,15 +396,15 @@ class InfinityConnection(DocStoreConnection): # logger.info(f"InfinityConnection.insert {json.dumps(documents)}") table_instance.insert(docs) self.connPool.release_conn(inf_conn) - logger.debug(f"INFINITY inserted into {table_name} {str_ids}.") + self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.") return [] - def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: + def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool: # if 'position_int' in newValue: # logger.info(f"update position_int: {newValue['position_int']}") inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) - table_name = f"{indexName}_{knowledgebaseId}" + table_name = f"{index_name}_{knowledgebase_id}" table_instance = db_instance.get_table(table_name) # if "exists" in condition: # del condition["exists"] @@ -654,57 +413,57 @@ class InfinityConnection(DocStoreConnection): if table_instance: for n, ty, de, _ in table_instance.show_columns().rows(): clmns[n] = (ty, de) - filter = equivalent_condition_to_str(condition, table_instance) + filter = self.equivalent_condition_to_str(condition, table_instance) removeValue = {} - for k, v in list(newValue.items()): + for k, v in list(new_value.items()): if k == "docnm_kwd": - newValue["docnm"] = list2str(v) + new_value["docnm"] = self.list2str(v) elif k == "title_kwd": - if not newValue.get("docnm_kwd"): - newValue["docnm"] = list2str(v) + if not new_value.get("docnm_kwd"): + new_value["docnm"] = self.list2str(v) elif k == "title_sm_tks": - if not newValue.get("docnm_kwd"): - newValue["docnm"] = v + if not new_value.get("docnm_kwd"): + new_value["docnm"] = v elif k == "important_kwd": - newValue["important_keywords"] = list2str(v) + new_value["important_keywords"] = self.list2str(v) elif k == "important_tks": - if not newValue.get("important_kwd"): - newValue["important_keywords"] = v + if not new_value.get("important_kwd"): + new_value["important_keywords"] = v elif k == "content_with_weight": - newValue["content"] = v + new_value["content"] = v elif k == "content_ltks": - if not newValue.get("content_with_weight"): - newValue["content"] = v + if not new_value.get("content_with_weight"): + new_value["content"] = v elif k == "content_sm_ltks": - if not newValue.get("content_with_weight"): - newValue["content"] = v + if not new_value.get("content_with_weight"): + new_value["content"] = v elif k == "authors_tks": - newValue["authors"] = v + new_value["authors"] = v elif k == "authors_sm_tks": - if not newValue.get("authors_tks"): - newValue["authors"] = v + if not new_value.get("authors_tks"): + new_value["authors"] = v elif k == "question_kwd": - newValue["questions"] = "\n".join(v) + new_value["questions"] = "\n".join(v) elif k == "question_tks": - if not newValue.get("question_kwd"): - newValue["questions"] = list2str(v) - elif field_keyword(k): + if not new_value.get("question_kwd"): + new_value["questions"] = self.list2str(v) + elif self.field_keyword(k): if isinstance(v, list): - newValue[k] = "###".join(v) + new_value[k] = "###".join(v) else: - newValue[k] = v + new_value[k] = v elif re.search(r"_feas$", k): - newValue[k] = json.dumps(v) + new_value[k] = json.dumps(v) elif k == "kb_id": - if isinstance(newValue[k], list): - newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str + if isinstance(new_value[k], list): + new_value[k] = new_value[k][0] # since d[k] is a list, but we need a str elif k == "position_int": assert isinstance(v, list) arr = [num for row in v for num in row] - newValue[k] = "_".join(f"{num:08x}" for num in arr) + new_value[k] = "_".join(f"{num:08x}" for num in arr) elif k in ["page_num_int", "top_int"]: assert isinstance(v, list) - newValue[k] = "_".join(f"{num:08x}" for num in v) + new_value[k] = "_".join(f"{num:08x}" for num in v) elif k == "remove": if isinstance(v, str): assert v in clmns, f"'{v}' should be in '{clmns}'." @@ -712,22 +471,22 @@ class InfinityConnection(DocStoreConnection): if ty.lower().find("cha"): if not de: de = "" - newValue[v] = de + new_value[v] = de else: for kk, vv in v.items(): removeValue[kk] = vv - del newValue[k] + del new_value[k] else: - newValue[k] = v + new_value[k] = v for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]: - if k in newValue: - del newValue[k] + if k in new_value: + del new_value[k] remove_opt = {} # "[k,new_value]": [id_to_update, ...] if removeValue: col_to_remove = list(removeValue.keys()) row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df() - logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}") + self.logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}") row_to_opt = self.get_fields(row_to_opt, col_to_remove) for id, old_v in row_to_opt.items(): for k, remove_v in removeValue.items(): @@ -740,78 +499,53 @@ class InfinityConnection(DocStoreConnection): else: remove_opt[kv_key].append(id) - logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.") + self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.") for update_kv, ids in remove_opt.items(): k, v = json.loads(update_kv) table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)}) - table_instance.update(filter, newValue) + table_instance.update(filter, new_value) self.connPool.release_conn(inf_conn) return True - def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: - inf_conn = self.connPool.get_conn() - db_instance = inf_conn.get_database(self.dbName) - table_name = f"{indexName}_{knowledgebaseId}" - try: - table_instance = db_instance.get_table(table_name) - except Exception: - logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.") - return 0 - filter = equivalent_condition_to_str(condition, table_instance) - logger.debug(f"INFINITY delete table {table_name}, filter {filter}.") - res = table_instance.delete(filter) - self.connPool.release_conn(inf_conn) - return res.deleted_rows - """ Helper functions for search result """ - def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: - if isinstance(res, tuple): - return res[1] - return len(res) - - def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: - if isinstance(res, tuple): - res = res[0] - return list(res["id"]) - def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: if isinstance(res, tuple): res = res[0] if not fields: return {} - fieldsAll = fields.copy() - fieldsAll.append("id") - fieldsAll = set(fieldsAll) + fields_all = fields.copy() + fields_all.append("id") + fields_all = set(fields_all) if "docnm" in res.columns: for field in ["docnm_kwd", "title_tks", "title_sm_tks"]: - if field in fieldsAll: + if field in fields_all: res[field] = res["docnm"] if "important_keywords" in res.columns: - if "important_kwd" in fieldsAll: + if "important_kwd" in fields_all: res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split()) - if "important_tks" in fieldsAll: + if "important_tks" in fields_all: res["important_tks"] = res["important_keywords"] if "questions" in res.columns: - if "question_kwd" in fieldsAll: + if "question_kwd" in fields_all: res["question_kwd"] = res["questions"].apply(lambda v: v.splitlines()) - if "question_tks" in fieldsAll: + if "question_tks" in fields_all: res["question_tks"] = res["questions"] if "content" in res.columns: for field in ["content_with_weight", "content_ltks", "content_sm_ltks"]: - if field in fieldsAll: + if field in fields_all: res[field] = res["content"] if "authors" in res.columns: for field in ["authors_tks", "authors_sm_tks"]: - if field in fieldsAll: + if field in fields_all: res[field] = res["authors"] column_map = {col.lower(): col for col in res.columns} - matched_columns = {column_map[col.lower()]: col for col in fieldsAll if col.lower() in column_map} - none_columns = [col for col in fieldsAll if col.lower() not in column_map] + matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map} + none_columns = [col for col in fields_all if col.lower() not in column_map] res2 = res[matched_columns.keys()] res2 = res2.rename(columns=matched_columns) @@ -819,7 +553,7 @@ class InfinityConnection(DocStoreConnection): for column in list(res2.columns): k = column.lower() - if field_keyword(k): + if self.field_keyword(k): res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd]) elif re.search(r"_feas$", k): res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {}) @@ -844,95 +578,3 @@ class InfinityConnection(DocStoreConnection): res2[column] = None return res2.set_index("id").to_dict(orient="index") - - def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): - if isinstance(res, tuple): - res = res[0] - ans = {} - num_rows = len(res) - column_id = res["id"] - if fieldnm not in res: - return {} - for i in range(num_rows): - id = column_id[i] - txt = res[fieldnm][i] - if re.search(r"[^<>]+", txt, flags=re.IGNORECASE | re.MULTILINE): - ans[id] = txt - continue - txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) - txts = [] - for t in re.split(r"[.?!;\n]", txt): - if is_english([t]): - for w in keywords: - t = re.sub( - r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), - r"\1\2\3", - t, - flags=re.IGNORECASE | re.MULTILINE, - ) - else: - for w in sorted(keywords, key=len, reverse=True): - t = re.sub( - re.escape(w), - f"{w}", - t, - flags=re.IGNORECASE | re.MULTILINE, - ) - if not re.search(r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE): - continue - txts.append(t) - if txts: - ans[id] = "...".join(txts) - else: - ans[id] = txt - return ans - - def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): - """ - Manual aggregation for tag fields since Infinity doesn't provide native aggregation - """ - from collections import Counter - - # Extract DataFrame from result - if isinstance(res, tuple): - df, _ = res - else: - df = res - - if df.empty or fieldnm not in df.columns: - return [] - - # Aggregate tag counts - tag_counter = Counter() - - for value in df[fieldnm]: - if pd.isna(value) or not value: - continue - - # Handle different tag formats - if isinstance(value, str): - # Split by ### for tag_kwd field or comma for other formats - if fieldnm == "tag_kwd" and "###" in value: - tags = [tag.strip() for tag in value.split("###") if tag.strip()] - else: - # Try comma separation as fallback - tags = [tag.strip() for tag in value.split(",") if tag.strip()] - - for tag in tags: - if tag: # Only count non-empty tags - tag_counter[tag] += 1 - elif isinstance(value, list): - # Handle list format - for tag in value: - if tag and isinstance(tag, str): - tag_counter[tag.strip()] += 1 - - # Return as list of [tag, count] pairs, sorted by count descending - return [[tag, count] for tag, count in tag_counter.most_common()] - - """ - SQL - """ - - def sql(sql: str, fetch_size: int, format: str): - raise NotImplementedError("Not implemented") diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 3c00be421..0786e9140 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -37,9 +37,8 @@ from common import settings from common.constants import PAGERANK_FLD, TAG_FLD from common.decorator import singleton from common.float_utils import get_float +from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from rag.nlp import rag_tokenizer -from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \ - MatchDenseExpr ATTEMPT_TIME = 2 OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000")) @@ -497,7 +496,7 @@ class OBConnection(DocStoreConnection): Database operations """ - def dbType(self) -> str: + def db_type(self) -> str: return "oceanbase" def health(self) -> dict: @@ -553,7 +552,7 @@ class OBConnection(DocStoreConnection): Table operations """ - def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): + def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int): vector_field_name = f"q_{vectorSize}_vec" vector_index_name = f"{vector_field_name}_idx" @@ -604,7 +603,7 @@ class OBConnection(DocStoreConnection): # always refresh metadata to make sure it contains the latest table structure self.client.refresh_metadata([indexName]) - def deleteIdx(self, indexName: str, knowledgebaseId: str): + def delete_idx(self, indexName: str, knowledgebaseId: str): if len(knowledgebaseId) > 0: # The index need to be alive after any kb deletion since all kb under this tenant are in one index. return @@ -615,7 +614,7 @@ class OBConnection(DocStoreConnection): except Exception as e: raise Exception(f"OBConnection.deleteIndex error: {str(e)}") - def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool: + def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool: return self._check_table_exists_cached(indexName) def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: @@ -1500,7 +1499,7 @@ class OBConnection(DocStoreConnection): def get_total(self, res) -> int: return res.total - def get_chunk_ids(self, res) -> list[str]: + def get_doc_ids(self, res) -> list[str]: return [row["id"] for row in res.chunks] def get_fields(self, res, fields: list[str]) -> dict[str, dict]: diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index 2df1d65ee..2e828be6e 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -26,8 +26,7 @@ from opensearchpy import UpdateByQuery, Q, Search, Index from opensearchpy import ConnectionTimeout from common.decorator import singleton from common.file_utils import get_project_base_directory -from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ - FusionExpr +from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr from rag.nlp import is_english, rag_tokenizer from common.constants import PAGERANK_FLD, TAG_FLD from common import settings @@ -79,7 +78,7 @@ class OSConnection(DocStoreConnection): Database operations """ - def dbType(self) -> str: + def db_type(self) -> str: return "opensearch" def health(self) -> dict: @@ -91,8 +90,8 @@ class OSConnection(DocStoreConnection): Table operations """ - def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): - if self.indexExist(indexName, knowledgebaseId): + def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int): + if self.index_exist(indexName, knowledgebaseId): return True try: from opensearchpy.client import IndicesClient @@ -101,7 +100,7 @@ class OSConnection(DocStoreConnection): except Exception: logger.exception("OSConnection.createIndex error %s" % (indexName)) - def deleteIdx(self, indexName: str, knowledgebaseId: str): + def delete_idx(self, indexName: str, knowledgebaseId: str): if len(knowledgebaseId) > 0: # The index need to be alive after any kb deletion since all kb under this tenant are in one index. return @@ -112,7 +111,7 @@ class OSConnection(DocStoreConnection): except Exception: logger.exception("OSConnection.deleteIdx error %s" % (indexName)) - def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool: + def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool: s = Index(indexName, self.os) for i in range(ATTEMPT_TIME): try: @@ -460,7 +459,7 @@ class OSConnection(DocStoreConnection): return res["hits"]["total"]["value"] return res["hits"]["total"] - def get_chunk_ids(self, res): + def get_doc_ids(self, res): return [d["_id"] for d in res["hits"]["hits"]] def __getSource(self, res): diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index d7f0dcd9d..8dec64fe0 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -272,6 +272,49 @@ class RedisDB: self.__open__() return None + def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int: + redis_key = f"{key_prefix}:{namespace}" + + try: + # Use pipeline for atomicity + pipe = self.REDIS.pipeline() + + # Check if key exists + pipe.exists(redis_key) + + # Get/Increment + if ensure_minimum is not None: + # Ensure minimum value + pipe.get(redis_key) + results = pipe.execute() + + if results[0] == 0: # Key doesn't exist + start_id = max(1, ensure_minimum) + pipe.set(redis_key, start_id) + pipe.execute() + return start_id + else: + current = int(results[1]) + if current < ensure_minimum: + pipe.set(redis_key, ensure_minimum) + pipe.execute() + return ensure_minimum + + # Increment operation + next_id = self.REDIS.incrby(redis_key, increment) + + # If it's the first time, set a reasonable initial value + if next_id == increment: + self.REDIS.set(redis_key, 1 + increment) + return 1 + increment + + return next_id + + except Exception as e: + logging.warning("RedisDB.generate_auto_increment_id got exception: " + str(e)) + self.__open__() + return -1 + def transaction(self, key, value, exp=3600): try: pipeline = self.REDIS.pipeline(transaction=True) diff --git a/test/testcases/test_web_api/test_memory_app/conftest.py b/test/testcases/test_web_api/test_memory_app/conftest.py index 11c7c2a10..7fdd78f53 100644 --- a/test/testcases/test_web_api/test_memory_app/conftest.py +++ b/test/testcases/test_web_api/test_memory_app/conftest.py @@ -32,8 +32,8 @@ def add_memory_func(request, WebApiAuth): payload = { "name": f"test_memory_{i}", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), - "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", - "llm_id": "ZHIPU-AI@glm-4-flash" + "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", + "llm_id": "glm-4-flash@ZHIPU-AI" } res = create_memory(WebApiAuth, payload) memory_ids.append(res["data"]["id"]) diff --git a/test/testcases/test_web_api/test_memory_app/test_create_memory.py b/test/testcases/test_web_api/test_memory_app/test_create_memory.py index d91500bc9..e21c98859 100644 --- a/test/testcases/test_web_api/test_memory_app/test_create_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_create_memory.py @@ -49,8 +49,8 @@ class TestMemoryCreate: payload = { "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), - "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", - "llm_id": "ZHIPU-AI@glm-4-flash" + "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", + "llm_id": "glm-4-flash@ZHIPU-AI" } res = create_memory(WebApiAuth, payload) assert res["code"] == 0, res @@ -72,8 +72,8 @@ class TestMemoryCreate: payload = { "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), - "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", - "llm_id": "ZHIPU-AI@glm-4-flash" + "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", + "llm_id": "glm-4-flash@ZHIPU-AI" } res = create_memory(WebApiAuth, payload) assert res["message"] == expected_message, res @@ -84,8 +84,8 @@ class TestMemoryCreate: payload = { "name": name, "memory_type": ["something"], - "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", - "llm_id": "ZHIPU-AI@glm-4-flash" + "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", + "llm_id": "glm-4-flash@ZHIPU-AI" } res = create_memory(WebApiAuth, payload) assert res["message"] == f"Memory type '{ {'something'} }' is not supported.", res @@ -96,8 +96,8 @@ class TestMemoryCreate: payload = { "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), - "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", - "llm_id": "ZHIPU-AI@glm-4-flash" + "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", + "llm_id": "glm-4-flash@ZHIPU-AI" } res1 = create_memory(WebApiAuth, payload) assert res1["code"] == 0, res1 diff --git a/test/testcases/test_web_api/test_memory_app/test_update_memory.py b/test/testcases/test_web_api/test_memory_app/test_update_memory.py index 4def9d8b1..a801fa994 100644 --- a/test/testcases/test_web_api/test_memory_app/test_update_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_update_memory.py @@ -101,7 +101,7 @@ class TestMemoryUpdate: @pytest.mark.p1 def test_llm(self, WebApiAuth, add_memory_func): memory_ids = add_memory_func - llm_id = "ZHIPU-AI@glm-4" + llm_id = "glm-4@ZHIPU-AI" payload = {"llm_id": llm_id} res = update_memory(WebApiAuth, memory_ids[0], payload) assert res["code"] == 0, res