mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
fix bug about fetching knowledge graph (#3394)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -13,7 +13,8 @@ from rag import settings
|
||||
from rag.utils import singleton
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
import polars as pl
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
|
||||
|
||||
@ -26,7 +27,8 @@ class ESConnection(DocStoreConnection):
|
||||
try:
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs=False,
|
||||
timeout=600
|
||||
)
|
||||
@ -57,6 +59,7 @@ class ESConnection(DocStoreConnection):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
@ -66,6 +69,7 @@ class ESConnection(DocStoreConnection):
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.indexExist(indexName, knowledgebaseId):
|
||||
return True
|
||||
@ -97,7 +101,10 @@ class ESConnection(DocStoreConnection):
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
||||
|
||||
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
||||
"""
|
||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||
"""
|
||||
@ -109,8 +116,10 @@ class ESConnection(DocStoreConnection):
|
||||
bqry = None
|
||||
vector_similarity_weight = 0.5
|
||||
for m in matchExprs:
|
||||
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
|
||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
matchExprs[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = float(weights.split(",")[1])
|
||||
for m in matchExprs:
|
||||
@ -119,36 +128,41 @@ class ESConnection(DocStoreConnection):
|
||||
if "minimum_should_match" in m.extra_options:
|
||||
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
|
||||
bqry = Q("bool",
|
||||
must=Q("query_string", fields=m.fields,
|
||||
must=Q("query_string", fields=m.fields,
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match = minimum_should_match,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1),
|
||||
boost = 1.0 - vector_similarity_weight,
|
||||
)
|
||||
if condition:
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
boost=1.0 - vector_similarity_weight,
|
||||
)
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
assert(bqry is not None)
|
||||
assert (bqry is not None)
|
||||
similarity = 0.0
|
||||
if "similarity" in m.extra_options:
|
||||
similarity = m.extra_options["similarity"]
|
||||
s = s.knn(m.vector_column_name,
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector = list(m.embedding_data),
|
||||
filter = bqry.to_dict(),
|
||||
similarity = similarity,
|
||||
)
|
||||
if matchExprs:
|
||||
s.query = bqry
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bqry.to_dict(),
|
||||
similarity=similarity,
|
||||
)
|
||||
|
||||
if condition:
|
||||
if not bqry:
|
||||
bqry = Q("bool", must=[])
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
|
||||
if bqry:
|
||||
s = s.query(bqry)
|
||||
for field in highlightFields:
|
||||
s = s.highlight(field)
|
||||
|
||||
@ -157,12 +171,13 @@ class ESConnection(DocStoreConnection):
|
||||
for field, order in orderBy.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
orders.append({field: {"order": order, "unmapped_type": "float",
|
||||
"mode": "avg", "numeric_type": "double"}})
|
||||
"mode": "avg", "numeric_type": "double"}})
|
||||
s = s.sort(*orders)
|
||||
|
||||
if limit > 0:
|
||||
s = s[offset:limit]
|
||||
q = s.to_dict()
|
||||
print(json.dumps(q), flush=True)
|
||||
# logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
||||
|
||||
for i in range(3):
|
||||
@ -189,7 +204,7 @@ class ESConnection(DocStoreConnection):
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.get(index=(indexName),
|
||||
id=chunkId, source=True,)
|
||||
id=chunkId, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
if not res.get("found"):
|
||||
@ -222,7 +237,7 @@ class ESConnection(DocStoreConnection):
|
||||
for _ in range(100):
|
||||
try:
|
||||
r = self.es.bulk(index=(indexName), operations=operations,
|
||||
refresh=False, timeout="600s")
|
||||
refresh=False, timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
@ -249,7 +264,8 @@ class ESConnection(DocStoreConnection):
|
||||
self.es.update(index=indexName, id=chunkId, doc=doc)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
|
||||
logger.exception(
|
||||
f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
else:
|
||||
@ -263,7 +279,8 @@ class ESConnection(DocStoreConnection):
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
scripts = []
|
||||
for k, v in newValue.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
@ -273,7 +290,8 @@ class ESConnection(DocStoreConnection):
|
||||
elif isinstance(v, int):
|
||||
scripts.append(f"ctx._source.{k} = {v}")
|
||||
else:
|
||||
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
raise Exception(
|
||||
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
ubq = UpdateByQuery(
|
||||
index=indexName).using(
|
||||
self.es).query(bqry)
|
||||
@ -313,7 +331,7 @@ class ESConnection(DocStoreConnection):
|
||||
try:
|
||||
res = self.es.delete_by_query(
|
||||
index=indexName,
|
||||
body = Search().query(qry).to_dict(),
|
||||
body=Search().query(qry).to_dict(),
|
||||
refresh=True)
|
||||
return res["deleted"]
|
||||
except Exception as e:
|
||||
@ -325,10 +343,10 @@ class ESConnection(DocStoreConnection):
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
@ -376,12 +394,13 @@ class ESConnection(DocStoreConnection):
|
||||
continue
|
||||
|
||||
txt = d["_source"][fieldnm]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
||||
flags=re.IGNORECASE | re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||
@ -395,10 +414,10 @@ class ESConnection(DocStoreConnection):
|
||||
bkts = res["aggregations"][agg_field]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
logger.info(f"ESConnection.sql get sql: {sql}")
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
@ -413,7 +432,7 @@ class ESConnection(DocStoreConnection):
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
match))
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
@ -421,7 +440,8 @@ class ESConnection(DocStoreConnection):
|
||||
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
||||
request_timeout="2s")
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ESConnection.sql timeout [Q]: " + sql)
|
||||
|
||||
Reference in New Issue
Block a user