mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-30 23:26:36 +08:00
Fix: add tokenized content (#12793)
### What problem does this PR solve? Add tokenized content es field to query zh message. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -27,6 +27,7 @@ from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExp
|
||||
from common.doc_store.es_conn_base import ESConnectionBase
|
||||
from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
@ -35,13 +36,15 @@ ATTEMPT_TIME = 2
|
||||
class ESConnection(ESConnectionBase):
|
||||
|
||||
@staticmethod
|
||||
def convert_field_name(field_name: str) -> str:
|
||||
def convert_field_name(field_name: str, use_tokenized_content=False) -> str:
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case "content":
|
||||
if use_tokenized_content:
|
||||
return "tokenized_content_ltks"
|
||||
return "content_ltks"
|
||||
case _:
|
||||
return field_name
|
||||
@ -69,6 +72,7 @@ class ESConnection(ESConnectionBase):
|
||||
"status_int": 1 if message["status"] else 0,
|
||||
"zone_id": message.get("zone_id", 0),
|
||||
"content_ltks": message["content"],
|
||||
"tokenized_content_ltks": fine_grained_tokenize(tokenize(message["content"])),
|
||||
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
|
||||
}
|
||||
return storage_doc
|
||||
@ -166,7 +170,7 @@ class ESConnection(ESConnectionBase):
|
||||
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
|
||||
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f, use_tokenized_content=True) for f in m.fields],
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
@ -286,6 +290,51 @@ class ESConnection(ESConnectionBase):
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512):
|
||||
if not self.index_exist(index_name):
|
||||
return None
|
||||
bool_query = Q("bool", must=[])
|
||||
bool_query.must.append(Q("term", memory_id=memory_id))
|
||||
bool_query.must_not.append(Q("exists", field=field_name))
|
||||
# from old to new
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("valid_at")
|
||||
# build search
|
||||
s = Search()
|
||||
s = s.query(bool_query)
|
||||
orders = list()
|
||||
for field, order in order_by.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
if field.endswith("_int") or field.endswith("_flt"):
|
||||
order_info = {"order": order, "unmapped_type": "float"}
|
||||
else:
|
||||
order_info = {"order": order, "unmapped_type": "text"}
|
||||
orders.append({field: order_info})
|
||||
s = s.sort(*orders)
|
||||
s = s[:limit]
|
||||
q = s.to_dict()
|
||||
# search
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except NotFoundError as e:
|
||||
self.logger.debug(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
@ -345,6 +394,8 @@ class ESConnection(ESConnectionBase):
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||
doc = copy.deepcopy(new_value)
|
||||
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
|
||||
if "content_ltks" in update_dict:
|
||||
update_dict["tokenized_content_ltks"] = fine_grained_tokenize(tokenize(update_dict["content_ltks"]))
|
||||
update_dict.pop("id", None)
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
condition_dict["memory_id"] = memory_id
|
||||
|
||||
@ -305,6 +305,36 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res
|
||||
|
||||
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512):
|
||||
condition = {"memory_id": memory_id, "must_not": {"exists": field_name}}
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("valid_at_flt")
|
||||
# query
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
column_name_list = [r[0] for r in table_instance.show_columns().rows()]
|
||||
output_fields = [self.convert_message_field_to_infinity(f, column_name_list) for f in select_fields]
|
||||
builder = table_instance.output(output_fields)
|
||||
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
builder.filter(filter_cond)
|
||||
order_by_expr_list = list()
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
order_field_name = self.convert_condition_and_order_field(order_field[0])
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field_name, SortType.Asc))
|
||||
else:
|
||||
order_by_expr_list.append((order_field_name, SortType.Desc))
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(0).limit(limit)
|
||||
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
|
||||
res = self.concat_dataframes(mem_res, output_fields)
|
||||
res.head(limit)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res
|
||||
|
||||
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
|
||||
Reference in New Issue
Block a user