diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index e570debb2..b7cf58a20 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -16,6 +16,7 @@ import json import logging import random +import re from flask import request from flask_login import login_required, current_user @@ -847,8 +848,13 @@ def check_embedding(): "position_int": full_doc.get("position_int"), "top_int": full_doc.get("top_int"), "content_with_weight": full_doc.get("content_with_weight") or "", + "question_kwd": full_doc.get("question_kwd") or [] }) return out + + def _clean(s: str) -> str: + s = re.sub(r"]{0,12})?>", " ", s or "") + return s if s else "None" req = request.json kb_id = req.get("kb_id", "") embd_id = req.get("embd_id", "") @@ -861,8 +867,10 @@ def check_embedding(): results, eff_sims = [], [] for ck in samples: - txt = (ck.get("content_with_weight") or "").strip() - if not txt: + title = ck.get("doc_name") or "Title" + txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or "" + txt_in = _clean(txt_in) + if not txt_in: results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) continue @@ -871,8 +879,16 @@ def check_embedding(): continue try: - qv, _ = emb_mdl.encode_queries(txt) - sim = _cos_sim(qv, ck["vector"]) + v, _ = emb_mdl.encode([title, txt_in]) + sim_content = _cos_sim(v[1], ck["vector"]) + title_w = 0.1 + qv_mix = title_w * v[0] + (1 - title_w) * v[1] + sim_mix = _cos_sim(qv_mix, ck["vector"]) + sim = sim_content + mode = "content_only" + if sim_mix > sim: + sim = sim_mix + mode = "title+content" except Exception: return get_error_data_result(message="embedding failure") @@ -894,8 +910,9 @@ def check_embedding(): "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), + "match_mode": mode, } - if summary["avg_cos_sim"] > 0.99: + if summary["avg_cos_sim"] > 0.9: return get_json_result(data={"summary": summary, "results": results}) return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results}) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index a183bf0cf..370bd2a10 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -442,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): tk_count = 0 if len(tts) == len(cnts): vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) - tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0) + tts = np.tile(vts[0], (len(cnts), 1)) tk_count += c @timeout(60) @@ -465,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None): if not filename_embd_weight: filename_embd_weight = 0.1 title_w = float(filename_embd_weight) - vects = (title_w * tts + (1 - title_w) * - cnts) if len(tts) == len(cnts) else cnts + if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape: + vects = title_w * tts + (1 - title_w) * cnts + else: + vects = cnts assert len(vects) == len(docs) vector_size = 0