diff --git a/api/db/init_data.py b/api/db/init_data.py index cb2feb748..49a094eb3 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -31,7 +31,7 @@ from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMSer from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.user_service import TenantService, UserTenantService from api.db.services.system_settings_service import SystemSettingsService -from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache +from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory from common.constants import LLMType from common.file_utils import get_project_base_directory from common import settings @@ -175,6 +175,7 @@ def init_web_data(): add_graph_templates() init_message_id_sequence() init_memory_size_cache() + fix_missing_tokenized_memory() logging.info("init web data success:{}".format(time.time() - start_time)) def init_table(): diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 490a16ac2..8f6621247 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -306,6 +306,24 @@ def init_memory_size_cache(): logging.info("Memory size cache init done.") +def fix_missing_tokenized_memory(): + if settings.DOC_ENGINE != "elasticsearch": + logging.info("Not using elasticsearch as doc engine, no need to fix missing tokenized memory.") + return + memory_list = MemoryService.get_all_memory() + if not memory_list: + logging.info("No memory found, no need to fix missing tokenized memory.") + else: + for m in memory_list: + message_list = MessageService.get_missing_field_messages(m.id, m.tenant_id, "tokenized_content_ltks") + for msg in message_list: + # update content to refresh tokenized field + MessageService.update_message({"message_id": msg["message_id"], "memory_id": m.id}, {"content": msg["content"]}, m.tenant_id, m.id) + if message_list: + logging.info(f"Fixed {len(message_list)} messages missing tokenized field in memory: {m.name}.") + logging.info("Fix missing tokenized memory done.") + + def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]): memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type) return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list}) diff --git a/memory/services/messages.py b/memory/services/messages.py index fe855905c..9c85c458c 100644 --- a/memory/services/messages.py +++ b/memory/services/messages.py @@ -247,6 +247,21 @@ class MessageService: return ids_to_remove, current_size return ids_to_remove, current_size + @classmethod + def get_missing_field_messages(cls, memory_id: str, uid: str, field_name: str): + select_fields = ["message_id", "content"] + _index_name = index_name(uid) + res = settings.msgStoreConn.get_missing_field_message( + select_fields=select_fields, + index_name=_index_name, + memory_id=memory_id, + field_name=field_name + ) + if not res: + return [] + docs = settings.msgStoreConn.get_fields(res, select_fields) + return list(docs.values()) + @classmethod def get_by_message_id(cls, memory_id: str, message_id: int, uid: str): index = index_name(uid) diff --git a/memory/utils/es_conn.py b/memory/utils/es_conn.py index 77d16dc41..afa06a169 100644 --- a/memory/utils/es_conn.py +++ b/memory/utils/es_conn.py @@ -27,6 +27,7 @@ from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExp from common.doc_store.es_conn_base import ESConnectionBase from common.float_utils import get_float from common.constants import PAGERANK_FLD, TAG_FLD +from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize ATTEMPT_TIME = 2 @@ -35,13 +36,15 @@ ATTEMPT_TIME = 2 class ESConnection(ESConnectionBase): @staticmethod - def convert_field_name(field_name: str) -> str: + def convert_field_name(field_name: str, use_tokenized_content=False) -> str: match field_name: case "message_type": return "message_type_kwd" case "status": return "status_int" case "content": + if use_tokenized_content: + return "tokenized_content_ltks" return "content_ltks" case _: return field_name @@ -69,6 +72,7 @@ class ESConnection(ESConnectionBase): "status_int": 1 if message["status"] else 0, "zone_id": message.get("zone_id", 0), "content_ltks": message["content"], + "tokenized_content_ltks": fine_grained_tokenize(tokenize(message["content"])), f"q_{len(message['content_embed'])}_vec": message["content_embed"], } return storage_doc @@ -166,7 +170,7 @@ class ESConnection(ESConnectionBase): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" - bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields], + bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f, use_tokenized_content=True) for f in m.fields], type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) @@ -286,6 +290,51 @@ class ESConnection(ESConnectionBase): self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") + def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512): + if not self.index_exist(index_name): + return None + bool_query = Q("bool", must=[]) + bool_query.must.append(Q("term", memory_id=memory_id)) + bool_query.must_not.append(Q("exists", field=field_name)) + # from old to new + order_by = OrderByExpr() + order_by.asc("valid_at") + # build search + s = Search() + s = s.query(bool_query) + orders = list() + for field, order in order_by.fields: + order = "asc" if order == 0 else "desc" + if field.endswith("_int") or field.endswith("_flt"): + order_info = {"order": order, "unmapped_type": "float"} + else: + order_info = {"order": order, "unmapped_type": "text"} + orders.append({field: order_info}) + s = s.sort(*orders) + s = s[:limit] + q = s.to_dict() + # search + for i in range(ATTEMPT_TIME): + try: + res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True) + if str(res.get("timed_out", "")).lower() == "true": + raise Exception("Es Timeout.") + self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res)) + return res + except ConnectionTimeout: + self.logger.exception("ES request timeout") + self._connect() + continue + except NotFoundError as e: + self.logger.debug(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e)) + return None + except Exception as e: + self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e)) + raise e + + self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") + raise Exception("ESConnection.search timeout.") + def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None: for i in range(ATTEMPT_TIME): try: @@ -345,6 +394,8 @@ class ESConnection(ESConnectionBase): def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool: doc = copy.deepcopy(new_value) update_dict = {self.convert_field_name(k): v for k, v in doc.items()} + if "content_ltks" in update_dict: + update_dict["tokenized_content_ltks"] = fine_grained_tokenize(tokenize(update_dict["content_ltks"])) update_dict.pop("id", None) condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} condition_dict["memory_id"] = memory_id diff --git a/memory/utils/infinity_conn.py b/memory/utils/infinity_conn.py index 932655a1d..c7998542e 100644 --- a/memory/utils/infinity_conn.py +++ b/memory/utils/infinity_conn.py @@ -305,6 +305,36 @@ class InfinityConnection(InfinityConnectionBase): self.connPool.release_conn(inf_conn) return res + def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512): + condition = {"memory_id": memory_id, "must_not": {"exists": field_name}} + order_by = OrderByExpr() + order_by.asc("valid_at_flt") + # query + inf_conn = self.connPool.get_conn() + db_instance = inf_conn.get_database(self.dbName) + table_name = f"{index_name}_{memory_id}" + table_instance = db_instance.get_table(table_name) + column_name_list = [r[0] for r in table_instance.show_columns().rows()] + output_fields = [self.convert_message_field_to_infinity(f, column_name_list) for f in select_fields] + builder = table_instance.output(output_fields) + filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name)) + builder.filter(filter_cond) + order_by_expr_list = list() + if order_by.fields: + for order_field in order_by.fields: + order_field_name = self.convert_condition_and_order_field(order_field[0]) + if order_field[1] == 0: + order_by_expr_list.append((order_field_name, SortType.Asc)) + else: + order_by_expr_list.append((order_field_name, SortType.Desc)) + builder.sort(order_by_expr_list) + builder.offset(0).limit(limit) + mem_res, _ = builder.option({"total_hits_count": True}).to_df() + res = self.concat_dataframes(mem_res, output_fields) + res.head(limit) + self.connPool.release_conn(inf_conn) + return res + def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName)