Feat:check embedding model api (#10854)

### What problem does this PR solve?
change:
Randomly sample `check_num` chunks from knowledge base `kb_id`, re-embed
them using `embd_id`, and compare with stored vectors via cosine
similarity. If `avg_cos_sim > 0.99`, return success (`code=0`);
otherwise return business failure (`code=10`).

url:
`/v1/kb/check_embedding`

Request Body:
```
{
  "kb_id": "<dataset_id>",
  "embd_id": "BAAI/bge-m3@SILICONFLOW",
  "check_num": 5
}

```
Success Response:
```
{
  "code": 0,
  "message": "success",
  "data": {
    "summary": { "avg_cos_sim": 0.999999, "sampled": 5, "valid": 5, "max_cos_sim":0.999999,"min_cos_sim":0.999999,"model":"BAAI/bge-m3@SILICONFLOW" },
    "results": [ ... ]
  }
}
```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
buua436
2025-10-30 19:06:16 +08:00
committed by GitHub
parent fa38aed01b
commit 5674d762f7

View File

@ -15,11 +15,15 @@
# #
import json import json
import logging import logging
import random
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
import numpy as np
from api.db import LLMType
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.llm_service import LLMBundle
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
@ -38,6 +42,7 @@ from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr
@manager.route('/create', methods=['post']) # noqa: F821 @manager.route('/create', methods=['post']) # noqa: F821
@ -788,3 +793,141 @@ def delete_kb_task():
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}") return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
return get_json_result(data=True) return get_json_result(data=True)
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
@login_required
def check_embedding():
def _guess_vec_field(src: dict) -> str | None:
for k in src or {}:
if k.endswith("_vec"):
return k
return None
def _as_float_vec(v):
if v is None:
return []
if isinstance(v, str):
return [float(x) for x in v.split("\t") if x != ""]
if isinstance(v, (list, tuple, np.ndarray)):
return [float(x) for x in v]
return []
def _to_1d(x):
a = np.asarray(x, dtype=np.float32)
return a.reshape(-1)
def _cos_sim(a, b, eps=1e-12):
a = _to_1d(a)
b = _to_1d(b)
na = np.linalg.norm(a)
nb = np.linalg.norm(b)
if na < eps or nb < eps:
return 0.0
return float(np.dot(a, b) / (na * nb))
def sample_random_chunks_with_vectors(
docStoreConn,
tenant_id: str,
kb_id: str,
n: int = 5,
base_fields=("docnm_kwd","doc_id","content_with_weight","page_num_int","position_int","top_int"),
):
index_nm = search.index_name(tenant_id)
res0 = docStoreConn.search(
selectFields=[], highlightFields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
total = docStoreConn.getTotal(res0)
if total <= 0:
return []
n = min(n, total)
offsets = sorted(random.sample(range(total), n))
out = []
for off in offsets:
res1 = docStoreConn.search(
selectFields=list(base_fields),
highlightFields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
ids = docStoreConn.getChunkIds(res1)
if not ids:
continue
cid = ids[0]
full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {}
vec_field = _guess_vec_field(full_doc)
vec = _as_float_vec(full_doc.get(vec_field))
out.append({
"chunk_id": cid,
"kb_id": kb_id,
"doc_id": full_doc.get("doc_id"),
"doc_name": full_doc.get("docnm_kwd"),
"vector_field": vec_field,
"vector_dim": len(vec),
"vector": vec,
"page_num_int": full_doc.get("page_num_int"),
"position_int": full_doc.get("position_int"),
"top_int": full_doc.get("top_int"),
"content_with_weight": full_doc.get("content_with_weight") or "",
})
return out
req = request.json
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
n = int(req.get("check_num", 5))
_, kb = KnowledgebaseService.get_by_id(kb_id)
tenant_id = kb.tenant_id
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
results, eff_sims = [], []
for ck in samples:
txt = (ck.get("content_with_weight") or "").strip()
if not txt:
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
continue
if not ck.get("vector"):
results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"})
continue
try:
qv, _ = emb_mdl.encode_queries(txt)
sim = _cos_sim(qv, ck["vector"])
except Exception:
return get_error_data_result(message="embedding failure")
eff_sims.append(sim)
results.append({
"chunk_id": ck["chunk_id"],
"doc_id": ck["doc_id"],
"doc_name": ck["doc_name"],
"vector_field": ck["vector_field"],
"vector_dim": ck["vector_dim"],
"cos_sim": round(sim, 6),
})
summary = {
"kb_id": kb_id,
"model": embd_id,
"sampled": len(samples),
"valid": len(eff_sims),
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
}
if summary["avg_cos_sim"] > 0.99:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=settings.RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})