mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat:update check_embedding api (#11254)
### What problem does this PR solve? pr: #10854 change: update check_embedding api ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -16,6 +16,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
@ -847,8 +848,13 @@ def check_embedding():
|
|||||||
"position_int": full_doc.get("position_int"),
|
"position_int": full_doc.get("position_int"),
|
||||||
"top_int": full_doc.get("top_int"),
|
"top_int": full_doc.get("top_int"),
|
||||||
"content_with_weight": full_doc.get("content_with_weight") or "",
|
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||||
|
"question_kwd": full_doc.get("question_kwd") or []
|
||||||
})
|
})
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def _clean(s: str) -> str:
|
||||||
|
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||||
|
return s if s else "None"
|
||||||
req = request.json
|
req = request.json
|
||||||
kb_id = req.get("kb_id", "")
|
kb_id = req.get("kb_id", "")
|
||||||
embd_id = req.get("embd_id", "")
|
embd_id = req.get("embd_id", "")
|
||||||
@ -861,8 +867,10 @@ def check_embedding():
|
|||||||
|
|
||||||
results, eff_sims = [], []
|
results, eff_sims = [], []
|
||||||
for ck in samples:
|
for ck in samples:
|
||||||
txt = (ck.get("content_with_weight") or "").strip()
|
title = ck.get("doc_name") or "Title"
|
||||||
if not txt:
|
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
|
||||||
|
txt_in = _clean(txt_in)
|
||||||
|
if not txt_in:
|
||||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -871,8 +879,16 @@ def check_embedding():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qv, _ = emb_mdl.encode_queries(txt)
|
v, _ = emb_mdl.encode([title, txt_in])
|
||||||
sim = _cos_sim(qv, ck["vector"])
|
sim_content = _cos_sim(v[1], ck["vector"])
|
||||||
|
title_w = 0.1
|
||||||
|
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
|
||||||
|
sim_mix = _cos_sim(qv_mix, ck["vector"])
|
||||||
|
sim = sim_content
|
||||||
|
mode = "content_only"
|
||||||
|
if sim_mix > sim:
|
||||||
|
sim = sim_mix
|
||||||
|
mode = "title+content"
|
||||||
except Exception:
|
except Exception:
|
||||||
return get_error_data_result(message="embedding failure")
|
return get_error_data_result(message="embedding failure")
|
||||||
|
|
||||||
@ -894,8 +910,9 @@ def check_embedding():
|
|||||||
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||||
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||||
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||||
|
"match_mode": mode,
|
||||||
}
|
}
|
||||||
if summary["avg_cos_sim"] > 0.99:
|
if summary["avg_cos_sim"] > 0.9:
|
||||||
return get_json_result(data={"summary": summary, "results": results})
|
return get_json_result(data={"summary": summary, "results": results})
|
||||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||||
|
|
||||||
|
|||||||
@ -442,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
tk_count = 0
|
tk_count = 0
|
||||||
if len(tts) == len(cnts):
|
if len(tts) == len(cnts):
|
||||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||||
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
|
tts = np.tile(vts[0], (len(cnts), 1))
|
||||||
tk_count += c
|
tk_count += c
|
||||||
|
|
||||||
@timeout(60)
|
@timeout(60)
|
||||||
@ -465,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
if not filename_embd_weight:
|
if not filename_embd_weight:
|
||||||
filename_embd_weight = 0.1
|
filename_embd_weight = 0.1
|
||||||
title_w = float(filename_embd_weight)
|
title_w = float(filename_embd_weight)
|
||||||
vects = (title_w * tts + (1 - title_w) *
|
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
|
||||||
cnts) if len(tts) == len(cnts) else cnts
|
vects = title_w * tts + (1 - title_w) * cnts
|
||||||
|
else:
|
||||||
|
vects = cnts
|
||||||
|
|
||||||
assert len(vects) == len(docs)
|
assert len(vects) == len(docs)
|
||||||
vector_size = 0
|
vector_size = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user