From 5674d762f71cfd5ab4e0303ca60fc287eea549bf Mon Sep 17 00:00:00 2001 From: buua436 <66937541+buua436@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:06:16 +0800 Subject: [PATCH] Feat:check embedding model api (#10854) ### What problem does this PR solve? change: Randomly sample `check_num` chunks from knowledge base `kb_id`, re-embed them using `embd_id`, and compare with stored vectors via cosine similarity. If `avg_cos_sim > 0.99`, return success (`code=0`); otherwise return business failure (`code=10`). url: `/v1/kb/check_embedding` Request Body: ``` { "kb_id": "", "embd_id": "BAAI/bge-m3@SILICONFLOW", "check_num": 5 } ``` Success Response: ``` { "code": 0, "message": "success", "data": { "summary": { "avg_cos_sim": 0.999999, "sampled": 5, "valid": 5, "max_cos_sim":0.999999,"min_cos_sim":0.999999,"model":"BAAI/bge-m3@SILICONFLOW" }, "results": [ ... ] } } ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/kb_app.py | 143 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index b889b95ee..8f177ddb8 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -15,11 +15,15 @@ # import json import logging +import random from flask import request from flask_login import login_required, current_user +import numpy as np +from api.db import LLMType from api.db.services import duplicate_name +from api.db.services.llm_service import LLMBundle from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService @@ -38,6 +42,7 @@ from api.constants import DATASET_NAME_LIMIT from rag.settings import PAGERANK_FLD from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL +from rag.utils.doc_store_conn import OrderByExpr @manager.route('/create', methods=['post']) # noqa: F821 @@ -788,3 +793,141 @@ def delete_kb_task(): return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}") return get_json_result(data=True) + +@manager.route("/check_embedding", methods=["post"]) # noqa: F821 +@login_required +def check_embedding(): + + def _guess_vec_field(src: dict) -> str | None: + for k in src or {}: + if k.endswith("_vec"): + return k + return None + + def _as_float_vec(v): + if v is None: + return [] + if isinstance(v, str): + return [float(x) for x in v.split("\t") if x != ""] + if isinstance(v, (list, tuple, np.ndarray)): + return [float(x) for x in v] + return [] + + def _to_1d(x): + a = np.asarray(x, dtype=np.float32) + return a.reshape(-1) + + def _cos_sim(a, b, eps=1e-12): + a = _to_1d(a) + b = _to_1d(b) + na = np.linalg.norm(a) + nb = np.linalg.norm(b) + if na < eps or nb < eps: + return 0.0 + return float(np.dot(a, b) / (na * nb)) + + def sample_random_chunks_with_vectors( + docStoreConn, + tenant_id: str, + kb_id: str, + n: int = 5, + base_fields=("docnm_kwd","doc_id","content_with_weight","page_num_int","position_int","top_int"), + ): + index_nm = search.index_name(tenant_id) + + res0 = docStoreConn.search( + selectFields=[], highlightFields=[], + condition={"kb_id": kb_id, "available_int": 1}, + matchExprs=[], orderBy=OrderByExpr(), + offset=0, limit=1, + indexNames=index_nm, knowledgebaseIds=[kb_id] + ) + total = docStoreConn.getTotal(res0) + if total <= 0: + return [] + + n = min(n, total) + offsets = sorted(random.sample(range(total), n)) + out = [] + + for off in offsets: + res1 = docStoreConn.search( + selectFields=list(base_fields), + highlightFields=[], + condition={"kb_id": kb_id, "available_int": 1}, + matchExprs=[], orderBy=OrderByExpr(), + offset=off, limit=1, + indexNames=index_nm, knowledgebaseIds=[kb_id] + ) + ids = docStoreConn.getChunkIds(res1) + if not ids: + continue + + cid = ids[0] + full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {} + vec_field = _guess_vec_field(full_doc) + vec = _as_float_vec(full_doc.get(vec_field)) + + out.append({ + "chunk_id": cid, + "kb_id": kb_id, + "doc_id": full_doc.get("doc_id"), + "doc_name": full_doc.get("docnm_kwd"), + "vector_field": vec_field, + "vector_dim": len(vec), + "vector": vec, + "page_num_int": full_doc.get("page_num_int"), + "position_int": full_doc.get("position_int"), + "top_int": full_doc.get("top_int"), + "content_with_weight": full_doc.get("content_with_weight") or "", + }) + return out + req = request.json + kb_id = req.get("kb_id", "") + embd_id = req.get("embd_id", "") + n = int(req.get("check_num", 5)) + _, kb = KnowledgebaseService.get_by_id(kb_id) + tenant_id = kb.tenant_id + + emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) + samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) + + results, eff_sims = [], [] + for ck in samples: + txt = (ck.get("content_with_weight") or "").strip() + if not txt: + results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) + continue + + if not ck.get("vector"): + results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"}) + continue + + try: + qv, _ = emb_mdl.encode_queries(txt) + sim = _cos_sim(qv, ck["vector"]) + except Exception: + return get_error_data_result(message="embedding failure") + + eff_sims.append(sim) + results.append({ + "chunk_id": ck["chunk_id"], + "doc_id": ck["doc_id"], + "doc_name": ck["doc_name"], + "vector_field": ck["vector_field"], + "vector_dim": ck["vector_dim"], + "cos_sim": round(sim, 6), + }) + + summary = { + "kb_id": kb_id, + "model": embd_id, + "sampled": len(samples), + "valid": len(eff_sims), + "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), + "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), + "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), + } + if summary["avg_cos_sim"] > 0.99: + return get_json_result(data={"summary": summary, "results": results}) + return get_json_result(code=settings.RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})