mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-27 13:46:39 +08:00
Fix: add tokenized content (#12793)
### What problem does this PR solve? Add tokenized content es field to query zh message. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -31,7 +31,7 @@ from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMSer
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.services.system_settings_service import SystemSettingsService
|
||||
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
|
||||
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
@ -175,6 +175,7 @@ def init_web_data():
|
||||
add_graph_templates()
|
||||
init_message_id_sequence()
|
||||
init_memory_size_cache()
|
||||
fix_missing_tokenized_memory()
|
||||
logging.info("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
def init_table():
|
||||
|
||||
@ -306,6 +306,24 @@ def init_memory_size_cache():
|
||||
logging.info("Memory size cache init done.")
|
||||
|
||||
|
||||
def fix_missing_tokenized_memory():
|
||||
if settings.DOC_ENGINE != "elasticsearch":
|
||||
logging.info("Not using elasticsearch as doc engine, no need to fix missing tokenized memory.")
|
||||
return
|
||||
memory_list = MemoryService.get_all_memory()
|
||||
if not memory_list:
|
||||
logging.info("No memory found, no need to fix missing tokenized memory.")
|
||||
else:
|
||||
for m in memory_list:
|
||||
message_list = MessageService.get_missing_field_messages(m.id, m.tenant_id, "tokenized_content_ltks")
|
||||
for msg in message_list:
|
||||
# update content to refresh tokenized field
|
||||
MessageService.update_message({"message_id": msg["message_id"], "memory_id": m.id}, {"content": msg["content"]}, m.tenant_id, m.id)
|
||||
if message_list:
|
||||
logging.info(f"Fixed {len(message_list)} messages missing tokenized field in memory: {m.name}.")
|
||||
logging.info("Fix missing tokenized memory done.")
|
||||
|
||||
|
||||
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
|
||||
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
|
||||
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
|
||||
|
||||
@ -247,6 +247,21 @@ class MessageService:
|
||||
return ids_to_remove, current_size
|
||||
return ids_to_remove, current_size
|
||||
|
||||
@classmethod
|
||||
def get_missing_field_messages(cls, memory_id: str, uid: str, field_name: str):
|
||||
select_fields = ["message_id", "content"]
|
||||
_index_name = index_name(uid)
|
||||
res = settings.msgStoreConn.get_missing_field_message(
|
||||
select_fields=select_fields,
|
||||
index_name=_index_name,
|
||||
memory_id=memory_id,
|
||||
field_name=field_name
|
||||
)
|
||||
if not res:
|
||||
return []
|
||||
docs = settings.msgStoreConn.get_fields(res, select_fields)
|
||||
return list(docs.values())
|
||||
|
||||
@classmethod
|
||||
def get_by_message_id(cls, memory_id: str, message_id: int, uid: str):
|
||||
index = index_name(uid)
|
||||
|
||||
@ -27,6 +27,7 @@ from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExp
|
||||
from common.doc_store.es_conn_base import ESConnectionBase
|
||||
from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
@ -35,13 +36,15 @@ ATTEMPT_TIME = 2
|
||||
class ESConnection(ESConnectionBase):
|
||||
|
||||
@staticmethod
|
||||
def convert_field_name(field_name: str) -> str:
|
||||
def convert_field_name(field_name: str, use_tokenized_content=False) -> str:
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case "content":
|
||||
if use_tokenized_content:
|
||||
return "tokenized_content_ltks"
|
||||
return "content_ltks"
|
||||
case _:
|
||||
return field_name
|
||||
@ -69,6 +72,7 @@ class ESConnection(ESConnectionBase):
|
||||
"status_int": 1 if message["status"] else 0,
|
||||
"zone_id": message.get("zone_id", 0),
|
||||
"content_ltks": message["content"],
|
||||
"tokenized_content_ltks": fine_grained_tokenize(tokenize(message["content"])),
|
||||
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
|
||||
}
|
||||
return storage_doc
|
||||
@ -166,7 +170,7 @@ class ESConnection(ESConnectionBase):
|
||||
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
|
||||
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f, use_tokenized_content=True) for f in m.fields],
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
@ -286,6 +290,51 @@ class ESConnection(ESConnectionBase):
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512):
|
||||
if not self.index_exist(index_name):
|
||||
return None
|
||||
bool_query = Q("bool", must=[])
|
||||
bool_query.must.append(Q("term", memory_id=memory_id))
|
||||
bool_query.must_not.append(Q("exists", field=field_name))
|
||||
# from old to new
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("valid_at")
|
||||
# build search
|
||||
s = Search()
|
||||
s = s.query(bool_query)
|
||||
orders = list()
|
||||
for field, order in order_by.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
if field.endswith("_int") or field.endswith("_flt"):
|
||||
order_info = {"order": order, "unmapped_type": "float"}
|
||||
else:
|
||||
order_info = {"order": order, "unmapped_type": "text"}
|
||||
orders.append({field: order_info})
|
||||
s = s.sort(*orders)
|
||||
s = s[:limit]
|
||||
q = s.to_dict()
|
||||
# search
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except NotFoundError as e:
|
||||
self.logger.debug(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
@ -345,6 +394,8 @@ class ESConnection(ESConnectionBase):
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||
doc = copy.deepcopy(new_value)
|
||||
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
|
||||
if "content_ltks" in update_dict:
|
||||
update_dict["tokenized_content_ltks"] = fine_grained_tokenize(tokenize(update_dict["content_ltks"]))
|
||||
update_dict.pop("id", None)
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
condition_dict["memory_id"] = memory_id
|
||||
|
||||
@ -305,6 +305,36 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res
|
||||
|
||||
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512):
|
||||
condition = {"memory_id": memory_id, "must_not": {"exists": field_name}}
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("valid_at_flt")
|
||||
# query
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
column_name_list = [r[0] for r in table_instance.show_columns().rows()]
|
||||
output_fields = [self.convert_message_field_to_infinity(f, column_name_list) for f in select_fields]
|
||||
builder = table_instance.output(output_fields)
|
||||
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
builder.filter(filter_cond)
|
||||
order_by_expr_list = list()
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
order_field_name = self.convert_condition_and_order_field(order_field[0])
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field_name, SortType.Asc))
|
||||
else:
|
||||
order_by_expr_list.append((order_field_name, SortType.Desc))
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(0).limit(limit)
|
||||
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
|
||||
res = self.concat_dataframes(mem_res, output_fields)
|
||||
res.head(limit)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res
|
||||
|
||||
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
|
||||
Reference in New Issue
Block a user