mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
update document sdk (#2445)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e7dd487779
commit
62cb5f1bac
@ -84,15 +84,28 @@ def upload(dataset_id, tenant_id):
|
||||
@token_required
|
||||
def docinfos(tenant_id):
|
||||
req = request.args
|
||||
if "id" not in req and "name" not in req:
|
||||
return get_data_error_result(
|
||||
retmsg="Id or name should be provided")
|
||||
doc_id=None
|
||||
if "id" in req:
|
||||
doc_id = req["id"]
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
return get_json_result(data=doc.to_json())
|
||||
if "name" in req:
|
||||
doc_name = req["name"]
|
||||
doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
return get_json_result(data=doc.to_json())
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
#rename key's name
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"kb_id": "knowledgebase_id",
|
||||
"token_num": "token_count",
|
||||
}
|
||||
renamed_doc = {}
|
||||
for key, value in doc.to_dict().items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_doc[new_key] = value
|
||||
|
||||
return get_json_result(data=renamed_doc)
|
||||
|
||||
|
||||
@manager.route('/save', methods=['POST'])
|
||||
@ -246,7 +259,7 @@ def rename():
|
||||
req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document rename)!")
|
||||
|
||||
|
||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
@ -259,7 +272,7 @@ def rename():
|
||||
|
||||
@manager.route("/<document_id>", methods=["GET"])
|
||||
@token_required
|
||||
def download_document(dataset_id, document_id):
|
||||
def download_document(dataset_id, document_id,tenant_id):
|
||||
try:
|
||||
# Check whether there is this document
|
||||
exist, document = DocumentService.get_by_id(document_id)
|
||||
@ -313,7 +326,21 @@ def list_docs(dataset_id, tenant_id):
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
# rename key's name
|
||||
renamed_doc_list = []
|
||||
for doc in docs:
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"kb_id": "knowledgebase_id",
|
||||
"token_num": "token_count",
|
||||
}
|
||||
renamed_doc = {}
|
||||
for key, value in doc.items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_doc[new_key] = value
|
||||
renamed_doc_list.append(renamed_doc)
|
||||
return get_json_result(data={"total": tol, "docs": renamed_doc_list})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -436,6 +463,8 @@ def list_chunk(tenant_id):
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
|
||||
origin_chunks=[]
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
@ -455,7 +484,21 @@ def list_chunk(tenant_id):
|
||||
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
||||
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
||||
d["positions"] = poss
|
||||
res["chunks"].append(d)
|
||||
|
||||
origin_chunks.append(d)
|
||||
##rename keys
|
||||
for chunk in origin_chunks:
|
||||
key_mapping = {
|
||||
"chunk_id": "id",
|
||||
"content_with_weight": "content",
|
||||
"doc_id": "document_id",
|
||||
"important_kwd": "important_keywords",
|
||||
}
|
||||
renamed_chunk = {}
|
||||
for key, value in chunk.items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_chunk[new_key] = value
|
||||
res["chunks"].append(renamed_chunk)
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
@ -471,8 +514,9 @@ def create(tenant_id):
|
||||
req = request.json
|
||||
md5 = hashlib.md5()
|
||||
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
|
||||
chunck_id = md5.hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
|
||||
chunk_id = md5.hexdigest()
|
||||
d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["important_kwd"] = req.get("important_kwd", [])
|
||||
@ -503,20 +547,33 @@ def create(tenant_id):
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk": d})
|
||||
# return get_json_result(data={"chunk_id": chunck_id})
|
||||
d["chunk_id"] = chunk_id
|
||||
#rename keys
|
||||
key_mapping = {
|
||||
"chunk_id": "id",
|
||||
"content_with_weight": "content",
|
||||
"doc_id": "document_id",
|
||||
"important_kwd": "important_keywords",
|
||||
"kb_id":"knowledge_base_id",
|
||||
}
|
||||
renamed_chunk = {}
|
||||
for key, value in d.items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_chunk[new_key] = value
|
||||
|
||||
return get_json_result(data={"chunk": renamed_chunk})
|
||||
# return get_json_result(data={"chunk_id": chunk_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/chunk/rm', methods=['POST'])
|
||||
@token_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm_chunk():
|
||||
def rm_chunk(tenant_id):
|
||||
req = request.json
|
||||
try:
|
||||
if not ELASTICSEARCH.deleteByQuery(
|
||||
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
||||
Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
|
||||
return get_data_error_result(retmsg="Index updating failure")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
@ -526,4 +583,126 @@ def rm_chunk():
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/chunk/set', methods=['POST'])
|
||||
@token_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight",
|
||||
"important_kwd")
|
||||
def set(tenant_id):
|
||||
req = request.json
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
|
||||
if "available_int" in req:
|
||||
d["available_int"] = req["available_int"]
|
||||
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(
|
||||
retmsg="Q&A must be separated by TAB/ENTER key.")
|
||||
q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
|
||||
d = beAdoc(d, arr[0], arr[1], not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/retrieval_test', methods=['POST'])
|
||||
@token_required
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test(tenant_id):
|
||||
req = request.json
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
kb_id = req["kb_id"]
|
||||
if isinstance(kb_id, str): kb_id = [kb_id]
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for kid in kb_id:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kid):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Knowledgebase not found!")
|
||||
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
|
||||
##rename keys
|
||||
renamed_chunks=[]
|
||||
for chunk in ranks["chunks"]:
|
||||
key_mapping = {
|
||||
"chunk_id": "id",
|
||||
"content_with_weight": "content",
|
||||
"doc_id": "document_id",
|
||||
"important_kwd": "important_keywords",
|
||||
}
|
||||
rename_chunk={}
|
||||
for key, value in chunk.items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
rename_chunk[new_key] = value
|
||||
renamed_chunks.append(rename_chunk)
|
||||
ranks["chunks"] = renamed_chunks
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
Reference in New Issue
Block a user