Fix: add tokenized content (#12793)

### What problem does this PR solve?

Add tokenized content es field to query zh message.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Lynn
2026-01-23 16:56:03 +08:00
committed by GitHub
parent 11470906cf
commit f3923452df
5 changed files with 118 additions and 3 deletions

View File

@ -31,7 +31,7 @@ from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMSer
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.system_settings_service import SystemSettingsService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@ -175,6 +175,7 @@ def init_web_data():
add_graph_templates()
init_message_id_sequence()
init_memory_size_cache()
fix_missing_tokenized_memory()
logging.info("init web data success:{}".format(time.time() - start_time))
def init_table():

View File

@ -306,6 +306,24 @@ def init_memory_size_cache():
logging.info("Memory size cache init done.")
def fix_missing_tokenized_memory():
if settings.DOC_ENGINE != "elasticsearch":
logging.info("Not using elasticsearch as doc engine, no need to fix missing tokenized memory.")
return
memory_list = MemoryService.get_all_memory()
if not memory_list:
logging.info("No memory found, no need to fix missing tokenized memory.")
else:
for m in memory_list:
message_list = MessageService.get_missing_field_messages(m.id, m.tenant_id, "tokenized_content_ltks")
for msg in message_list:
# update content to refresh tokenized field
MessageService.update_message({"message_id": msg["message_id"], "memory_id": m.id}, {"content": msg["content"]}, m.tenant_id, m.id)
if message_list:
logging.info(f"Fixed {len(message_list)} messages missing tokenized field in memory: {m.name}.")
logging.info("Fix missing tokenized memory done.")
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})

View File

@ -247,6 +247,21 @@ class MessageService:
return ids_to_remove, current_size
return ids_to_remove, current_size
@classmethod
def get_missing_field_messages(cls, memory_id: str, uid: str, field_name: str):
select_fields = ["message_id", "content"]
_index_name = index_name(uid)
res = settings.msgStoreConn.get_missing_field_message(
select_fields=select_fields,
index_name=_index_name,
memory_id=memory_id,
field_name=field_name
)
if not res:
return []
docs = settings.msgStoreConn.get_fields(res, select_fields)
return list(docs.values())
@classmethod
def get_by_message_id(cls, memory_id: str, message_id: int, uid: str):
index = index_name(uid)

View File

@ -27,6 +27,7 @@ from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExp
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize
ATTEMPT_TIME = 2
@ -35,13 +36,15 @@ ATTEMPT_TIME = 2
class ESConnection(ESConnectionBase):
@staticmethod
def convert_field_name(field_name: str) -> str:
def convert_field_name(field_name: str, use_tokenized_content=False) -> str:
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "content":
if use_tokenized_content:
return "tokenized_content_ltks"
return "content_ltks"
case _:
return field_name
@ -69,6 +72,7 @@ class ESConnection(ESConnectionBase):
"status_int": 1 if message["status"] else 0,
"zone_id": message.get("zone_id", 0),
"content_ltks": message["content"],
"tokenized_content_ltks": fine_grained_tokenize(tokenize(message["content"])),
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
}
return storage_doc
@ -166,7 +170,7 @@ class ESConnection(ESConnectionBase):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f, use_tokenized_content=True) for f in m.fields],
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
@ -286,6 +290,51 @@ class ESConnection(ESConnectionBase):
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512):
if not self.index_exist(index_name):
return None
bool_query = Q("bool", must=[])
bool_query.must.append(Q("term", memory_id=memory_id))
bool_query.must_not.append(Q("exists", field=field_name))
# from old to new
order_by = OrderByExpr()
order_by.asc("valid_at")
# build search
s = Search()
s = s.query(bool_query)
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
s = s[:limit]
q = s.to_dict()
# search
for i in range(ATTEMPT_TIME):
try:
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except NotFoundError as e:
self.logger.debug(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
return None
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
@ -345,6 +394,8 @@ class ESConnection(ESConnectionBase):
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
doc = copy.deepcopy(new_value)
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
if "content_ltks" in update_dict:
update_dict["tokenized_content_ltks"] = fine_grained_tokenize(tokenize(update_dict["content_ltks"]))
update_dict.pop("id", None)
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id

View File

@ -305,6 +305,36 @@ class InfinityConnection(InfinityConnectionBase):
self.connPool.release_conn(inf_conn)
return res
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512):
condition = {"memory_id": memory_id, "must_not": {"exists": field_name}}
order_by = OrderByExpr()
order_by.asc("valid_at_flt")
# query
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
column_name_list = [r[0] for r in table_instance.show_columns().rows()]
output_fields = [self.convert_message_field_to_infinity(f, column_name_list) for f in select_fields]
builder = table_instance.output(output_fields)
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
builder.filter(filter_cond)
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
builder.sort(order_by_expr_list)
builder.offset(0).limit(limit)
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
res = self.concat_dataframes(mem_res, output_fields)
res.head(limit)
self.connPool.release_conn(inf_conn)
return res
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)