Refactor Document API (#2833)

### What problem does this PR solve?

Refactor Document API

### Type of change


- [x] Refactoring

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua
2024-10-14 20:03:33 +08:00
committed by GitHub
parent df223eddf3
commit 6329427ad5
11 changed files with 393 additions and 418 deletions

View File

@ -243,7 +243,7 @@ def list(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc") == "False":
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True

View File

@ -107,11 +107,6 @@ def update(tenant_id,dataset_id):
if req["tenant_id"] != tenant_id:
return get_error_data_result(
retmsg="Can't change tenant_id.")
if "embedding_model" in req:
if req["embedding_model"] != t.embd_id:
return get_error_data_result(
retmsg="Can't change embedding_model.")
req.pop("embedding_model")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if "chunk_count" in req:
if req["chunk_count"] != kb.chunk_num:
@ -128,6 +123,11 @@ def update(tenant_id,dataset_id):
return get_error_data_result(
retmsg="If chunk count is not 0, parse method is not changable.")
req['parser_id'] = req.pop('parse_method')
if "embedding_model" in req:
if kb.chunk_num != 0 and req['parse_method'] != kb.parser_id:
return get_error_data_result(
retmsg="If chunk count is not 0, parse method is not changable.")
req['embd_id'] = req.pop('embedding_model')
if "name" in req:
req["name"] = req["name"].strip()
if req["name"].lower() != kb.name.lower() \
@ -150,7 +150,7 @@ def list(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc") == "False":
if request.args.get("desc") == "False" or request.args.get("desc") == "false" :
desc = False
else:
desc = True

View File

@ -8,6 +8,7 @@ from botocore.docs.method import document_model_driven_method
from flask import request
from flask_login import login_required, current_user
from elasticsearch_dsl import Q
from pygments import highlight
from sphinx.addnodes import document
from rag.app.qa import rmPrefix, beAdoc
@ -158,7 +159,7 @@ def download(tenant_id, dataset_id, document_id):
return get_error_data_result(retmsg=f'You do not own the dataset {dataset_id}.')
doc = DocumentService.query(kb_id=dataset_id, id=document_id)
if not doc:
return get_error_data_result(retmsg=f'The dataset not own the document {doc.id}.')
return get_error_data_result(retmsg=f'The dataset not own the document {document_id}.')
# The process of downloading
doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
@ -294,7 +295,7 @@ def stop_parsing(tenant_id,dataset_id):
return get_result()
@manager.route('/dataset/{dataset_id}/document/{document_id}/chunk', methods=['GET'])
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['GET'])
@token_required
def list_chunk(tenant_id,dataset_id,document_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
@ -361,7 +362,7 @@ def list_chunk(tenant_id,dataset_id,document_id):
return server_error_response(e)
@manager.route('/dataset/{dataset_id}/document/{document_id}/chunk', methods=['POST'])
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST'])
@token_required
def create(tenant_id,dataset_id,document_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
@ -369,6 +370,7 @@ def create(tenant_id,dataset_id,document_id):
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
doc = doc[0]
req = request.json
if not req.get("content"):
return get_error_data_result(retmsg="`content` is required")
@ -418,7 +420,7 @@ def create(tenant_id,dataset_id,document_id):
# return get_result(data={"chunk_id": chunk_id})
@manager.route('dataset/{dataset_id}/document/{document_id}/chunk', methods=['DELETE'])
@manager.route('dataset/<dataset_id>/document/<document_id>/chunk', methods=['DELETE'])
@token_required
def rm_chunk(tenant_id,dataset_id,document_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
@ -426,9 +428,16 @@ def rm_chunk(tenant_id,dataset_id,document_id):
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
doc = doc[0]
req = request.json
if not req.get("chunk_ids"):
return get_error_data_result("`chunk_ids` is required")
for chunk_id in req.get("chunk_ids"):
res = ELASTICSEARCH.get(
chunk_id, search.index_name(
tenant_id))
if not res.get("found"):
return server_error_response(f"Chunk {chunk_id} not found")
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
return get_error_data_result(retmsg="Index updating failure")
@ -439,25 +448,26 @@ def rm_chunk(tenant_id,dataset_id,document_id):
@manager.route('/dataset/{dataset_id}/document/{document_id}/chunk/{chunk_id}', methods=['PUT'])
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT'])
@token_required
def set(tenant_id,dataset_id,document_id,chunk_id):
res = ELASTICSEARCH.get(
chunk_id, search.index_name(
tenant_id))
if not res.get("found"):
return get_error_data_result(f"Chunk {chunk_id} not found")
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
req = request.json
if not req.get("content"):
return get_error_data_result("`content` is required")
if not req.get("important_keywords"):
return get_error_data_result("`important_keywords` is required")
d = {
"id": chunk_id,
"content_with_weight": req["content"]}
"content_with_weight": req.get("content",res.get["content_with_weight"])}
d["content_ltks"] = rag_tokenizer.tokenize(req["content"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req["important_keywords"]
d["important_kwd"] = req.get("important_keywords",[])
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"]))
if "available" in req:
d["available_int"] = req["available"]
@ -488,23 +498,27 @@ def set(tenant_id,dataset_id,document_id,chunk_id):
@token_required
def retrieval_test(tenant_id):
req = request.args
if not req.get("datasets"):
req_json = request.json
if not req_json.get("datasets"):
return get_error_data_result("`datasets` is required.")
for id in req.get("datasets"):
for id in req_json.get("datasets"):
if not KnowledgebaseService.query(id=id,tenant_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.")
if not req.get("question"):
if "question" not in req_json:
return get_error_data_result("`question` is required.")
page = int(req.get("offset", 1))
size = int(req.get("limit", 30))
question = req["question"]
kb_id = req["datasets"]
question = req_json["question"]
kb_id = req_json["datasets"]
if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = req.get("documents", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
doc_ids = req_json.get("documents", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
if req.get("highlight")=="False" or req.get("highlight")=="false":
highlight = False
else:
highlight = True
try:
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
if not e:
@ -524,7 +538,7 @@ def retrieval_test(tenant_id):
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
doc_ids, rerank_mdl=rerank_mdl, highlight=highlight)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
@ -543,11 +557,11 @@ def retrieval_test(tenant_id):
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
rename_chunk[new_key] = value
renamed_chunks.append(rename_chunk)
renamed_chunks.append(rename_chunk)
ranks["chunks"] = renamed_chunks
return get_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_result(retmsg=f'No chunk found! Check the chunk status please!',
return get_result(retmsg=f'No chunk found! Check the chunk statu s please!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)

View File

@ -163,7 +163,7 @@ def list(chat_id,tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc") == "False":
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True