mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat:check embedding model api (#10854)
### What problem does this PR solve?
change:
Randomly sample `check_num` chunks from knowledge base `kb_id`, re-embed
them using `embd_id`, and compare with stored vectors via cosine
similarity. If `avg_cos_sim > 0.99`, return success (`code=0`);
otherwise return business failure (`code=10`).
url:
`/v1/kb/check_embedding`
Request Body:
```
{
"kb_id": "<dataset_id>",
"embd_id": "BAAI/bge-m3@SILICONFLOW",
"check_num": 5
}
```
Success Response:
```
{
"code": 0,
"message": "success",
"data": {
"summary": { "avg_cos_sim": 0.999999, "sampled": 5, "valid": 5, "max_cos_sim":0.999999,"min_cos_sim":0.999999,"model":"BAAI/bge-m3@SILICONFLOW" },
"results": [ ... ]
}
}
```
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -15,11 +15,15 @@
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
import numpy as np
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
@ -38,6 +42,7 @@ from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@ -788,3 +793,141 @@ def delete_kb_task():
|
||||
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
def check_embedding():
|
||||
|
||||
def _guess_vec_field(src: dict) -> str | None:
|
||||
for k in src or {}:
|
||||
if k.endswith("_vec"):
|
||||
return k
|
||||
return None
|
||||
|
||||
def _as_float_vec(v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [float(x) for x in v.split("\t") if x != ""]
|
||||
if isinstance(v, (list, tuple, np.ndarray)):
|
||||
return [float(x) for x in v]
|
||||
return []
|
||||
|
||||
def _to_1d(x):
|
||||
a = np.asarray(x, dtype=np.float32)
|
||||
return a.reshape(-1)
|
||||
|
||||
def _cos_sim(a, b, eps=1e-12):
|
||||
a = _to_1d(a)
|
||||
b = _to_1d(b)
|
||||
na = np.linalg.norm(a)
|
||||
nb = np.linalg.norm(b)
|
||||
if na < eps or nb < eps:
|
||||
return 0.0
|
||||
return float(np.dot(a, b) / (na * nb))
|
||||
|
||||
def sample_random_chunks_with_vectors(
|
||||
docStoreConn,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
n: int = 5,
|
||||
base_fields=("docnm_kwd","doc_id","content_with_weight","page_num_int","position_int","top_int"),
|
||||
):
|
||||
index_nm = search.index_name(tenant_id)
|
||||
|
||||
res0 = docStoreConn.search(
|
||||
selectFields=[], highlightFields=[],
|
||||
condition={"kb_id": kb_id, "available_int": 1},
|
||||
matchExprs=[], orderBy=OrderByExpr(),
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
total = docStoreConn.getTotal(res0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
n = min(n, total)
|
||||
offsets = sorted(random.sample(range(total), n))
|
||||
out = []
|
||||
|
||||
for off in offsets:
|
||||
res1 = docStoreConn.search(
|
||||
selectFields=list(base_fields),
|
||||
highlightFields=[],
|
||||
condition={"kb_id": kb_id, "available_int": 1},
|
||||
matchExprs=[], orderBy=OrderByExpr(),
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.getChunkIds(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
cid = ids[0]
|
||||
full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {}
|
||||
vec_field = _guess_vec_field(full_doc)
|
||||
vec = _as_float_vec(full_doc.get(vec_field))
|
||||
|
||||
out.append({
|
||||
"chunk_id": cid,
|
||||
"kb_id": kb_id,
|
||||
"doc_id": full_doc.get("doc_id"),
|
||||
"doc_name": full_doc.get("docnm_kwd"),
|
||||
"vector_field": vec_field,
|
||||
"vector_dim": len(vec),
|
||||
"vector": vec,
|
||||
"page_num_int": full_doc.get("page_num_int"),
|
||||
"position_int": full_doc.get("position_int"),
|
||||
"top_int": full_doc.get("top_int"),
|
||||
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||
})
|
||||
return out
|
||||
req = request.json
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
n = int(req.get("check_num", 5))
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
tenant_id = kb.tenant_id
|
||||
|
||||
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
txt = (ck.get("content_with_weight") or "").strip()
|
||||
if not txt:
|
||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||
continue
|
||||
|
||||
if not ck.get("vector"):
|
||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"})
|
||||
continue
|
||||
|
||||
try:
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
sim = _cos_sim(qv, ck["vector"])
|
||||
except Exception:
|
||||
return get_error_data_result(message="embedding failure")
|
||||
|
||||
eff_sims.append(sim)
|
||||
results.append({
|
||||
"chunk_id": ck["chunk_id"],
|
||||
"doc_id": ck["doc_id"],
|
||||
"doc_name": ck["doc_name"],
|
||||
"vector_field": ck["vector_field"],
|
||||
"vector_dim": ck["vector_dim"],
|
||||
"cos_sim": round(sim, 6),
|
||||
})
|
||||
|
||||
summary = {
|
||||
"kb_id": kb_id,
|
||||
"model": embd_id,
|
||||
"sampled": len(samples),
|
||||
"valid": len(eff_sims),
|
||||
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||
}
|
||||
if summary["avg_cos_sim"] > 0.99:
|
||||
return get_json_result(data={"summary": summary, "results": results})
|
||||
return get_json_result(code=settings.RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||
|
||||
Reference in New Issue
Block a user