mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat:check embedding model api (#10854)
### What problem does this PR solve?
change:
Randomly sample `check_num` chunks from knowledge base `kb_id`, re-embed
them using `embd_id`, and compare with stored vectors via cosine
similarity. If `avg_cos_sim > 0.99`, return success (`code=0`);
otherwise return business failure (`code=10`).
url:
`/v1/kb/check_embedding`
Request Body:
```
{
"kb_id": "<dataset_id>",
"embd_id": "BAAI/bge-m3@SILICONFLOW",
"check_num": 5
}
```
Success Response:
```
{
"code": 0,
"message": "success",
"data": {
"summary": { "avg_cos_sim": 0.999999, "sampled": 5, "valid": 5, "max_cos_sim":0.999999,"min_cos_sim":0.999999,"model":"BAAI/bge-m3@SILICONFLOW" },
"results": [ ... ]
}
}
```
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -15,11 +15,15 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from api.db import LLMType
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
@ -38,6 +42,7 @@ from api.constants import DATASET_NAME_LIMIT
|
|||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
from rag.utils.doc_store_conn import OrderByExpr
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/create', methods=['post']) # noqa: F821
|
@manager.route('/create', methods=['post']) # noqa: F821
|
||||||
@ -788,3 +793,141 @@ def delete_kb_task():
|
|||||||
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def check_embedding():
|
||||||
|
|
||||||
|
def _guess_vec_field(src: dict) -> str | None:
|
||||||
|
for k in src or {}:
|
||||||
|
if k.endswith("_vec"):
|
||||||
|
return k
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _as_float_vec(v):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
return [float(x) for x in v.split("\t") if x != ""]
|
||||||
|
if isinstance(v, (list, tuple, np.ndarray)):
|
||||||
|
return [float(x) for x in v]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _to_1d(x):
|
||||||
|
a = np.asarray(x, dtype=np.float32)
|
||||||
|
return a.reshape(-1)
|
||||||
|
|
||||||
|
def _cos_sim(a, b, eps=1e-12):
|
||||||
|
a = _to_1d(a)
|
||||||
|
b = _to_1d(b)
|
||||||
|
na = np.linalg.norm(a)
|
||||||
|
nb = np.linalg.norm(b)
|
||||||
|
if na < eps or nb < eps:
|
||||||
|
return 0.0
|
||||||
|
return float(np.dot(a, b) / (na * nb))
|
||||||
|
|
||||||
|
def sample_random_chunks_with_vectors(
|
||||||
|
docStoreConn,
|
||||||
|
tenant_id: str,
|
||||||
|
kb_id: str,
|
||||||
|
n: int = 5,
|
||||||
|
base_fields=("docnm_kwd","doc_id","content_with_weight","page_num_int","position_int","top_int"),
|
||||||
|
):
|
||||||
|
index_nm = search.index_name(tenant_id)
|
||||||
|
|
||||||
|
res0 = docStoreConn.search(
|
||||||
|
selectFields=[], highlightFields=[],
|
||||||
|
condition={"kb_id": kb_id, "available_int": 1},
|
||||||
|
matchExprs=[], orderBy=OrderByExpr(),
|
||||||
|
offset=0, limit=1,
|
||||||
|
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||||
|
)
|
||||||
|
total = docStoreConn.getTotal(res0)
|
||||||
|
if total <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
n = min(n, total)
|
||||||
|
offsets = sorted(random.sample(range(total), n))
|
||||||
|
out = []
|
||||||
|
|
||||||
|
for off in offsets:
|
||||||
|
res1 = docStoreConn.search(
|
||||||
|
selectFields=list(base_fields),
|
||||||
|
highlightFields=[],
|
||||||
|
condition={"kb_id": kb_id, "available_int": 1},
|
||||||
|
matchExprs=[], orderBy=OrderByExpr(),
|
||||||
|
offset=off, limit=1,
|
||||||
|
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||||
|
)
|
||||||
|
ids = docStoreConn.getChunkIds(res1)
|
||||||
|
if not ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cid = ids[0]
|
||||||
|
full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {}
|
||||||
|
vec_field = _guess_vec_field(full_doc)
|
||||||
|
vec = _as_float_vec(full_doc.get(vec_field))
|
||||||
|
|
||||||
|
out.append({
|
||||||
|
"chunk_id": cid,
|
||||||
|
"kb_id": kb_id,
|
||||||
|
"doc_id": full_doc.get("doc_id"),
|
||||||
|
"doc_name": full_doc.get("docnm_kwd"),
|
||||||
|
"vector_field": vec_field,
|
||||||
|
"vector_dim": len(vec),
|
||||||
|
"vector": vec,
|
||||||
|
"page_num_int": full_doc.get("page_num_int"),
|
||||||
|
"position_int": full_doc.get("position_int"),
|
||||||
|
"top_int": full_doc.get("top_int"),
|
||||||
|
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
req = request.json
|
||||||
|
kb_id = req.get("kb_id", "")
|
||||||
|
embd_id = req.get("embd_id", "")
|
||||||
|
n = int(req.get("check_num", 5))
|
||||||
|
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
tenant_id = kb.tenant_id
|
||||||
|
|
||||||
|
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||||
|
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||||
|
|
||||||
|
results, eff_sims = [], []
|
||||||
|
for ck in samples:
|
||||||
|
txt = (ck.get("content_with_weight") or "").strip()
|
||||||
|
if not txt:
|
||||||
|
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not ck.get("vector"):
|
||||||
|
results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
qv, _ = emb_mdl.encode_queries(txt)
|
||||||
|
sim = _cos_sim(qv, ck["vector"])
|
||||||
|
except Exception:
|
||||||
|
return get_error_data_result(message="embedding failure")
|
||||||
|
|
||||||
|
eff_sims.append(sim)
|
||||||
|
results.append({
|
||||||
|
"chunk_id": ck["chunk_id"],
|
||||||
|
"doc_id": ck["doc_id"],
|
||||||
|
"doc_name": ck["doc_name"],
|
||||||
|
"vector_field": ck["vector_field"],
|
||||||
|
"vector_dim": ck["vector_dim"],
|
||||||
|
"cos_sim": round(sim, 6),
|
||||||
|
})
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"kb_id": kb_id,
|
||||||
|
"model": embd_id,
|
||||||
|
"sampled": len(samples),
|
||||||
|
"valid": len(eff_sims),
|
||||||
|
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||||
|
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||||
|
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||||
|
}
|
||||||
|
if summary["avg_cos_sim"] > 0.99:
|
||||||
|
return get_json_result(data={"summary": summary, "results": results})
|
||||||
|
return get_json_result(code=settings.RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||||
|
|||||||
Reference in New Issue
Block a user