update document sdk (#2445)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
JobSmithManipulation
2024-09-18 11:08:19 +08:00
committed by GitHub
parent e7dd487779
commit 62cb5f1bac
5 changed files with 348 additions and 58 deletions

View File

@ -84,15 +84,28 @@ def upload(dataset_id, tenant_id):
@token_required
def docinfos(tenant_id):
req = request.args
if "id" not in req and "name" not in req:
return get_data_error_result(
retmsg="Id or name should be provided")
doc_id=None
if "id" in req:
doc_id = req["id"]
e, doc = DocumentService.get_by_id(doc_id)
return get_json_result(data=doc.to_json())
if "name" in req:
doc_name = req["name"]
doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
e, doc = DocumentService.get_by_id(doc_id)
return get_json_result(data=doc.to_json())
e, doc = DocumentService.get_by_id(doc_id)
#rename key's name
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "knowledgebase_id",
"token_num": "token_count",
}
renamed_doc = {}
for key, value in doc.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
return get_json_result(data=renamed_doc)
@manager.route('/save', methods=['POST'])
@ -246,7 +259,7 @@ def rename():
req["doc_id"], {"name": req["name"]}):
return get_data_error_result(
retmsg="Database error (Document rename)!")
informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
@ -259,7 +272,7 @@ def rename():
@manager.route("/<document_id>", methods=["GET"])
@token_required
def download_document(dataset_id, document_id):
def download_document(dataset_id, document_id,tenant_id):
try:
# Check whether there is this document
exist, document = DocumentService.get_by_id(document_id)
@ -313,7 +326,21 @@ def list_docs(dataset_id, tenant_id):
try:
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords)
return get_json_result(data={"total": tol, "docs": docs})
# rename key's name
renamed_doc_list = []
for doc in docs:
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "knowledgebase_id",
"token_num": "token_count",
}
renamed_doc = {}
for key, value in doc.items():
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
renamed_doc_list.append(renamed_doc)
return get_json_result(data={"total": tol, "docs": renamed_doc_list})
except Exception as e:
return server_error_response(e)
@ -436,6 +463,8 @@ def list_chunk(tenant_id):
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
origin_chunks=[]
for id in sres.ids:
d = {
"chunk_id": id,
@ -455,7 +484,21 @@ def list_chunk(tenant_id):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
res["chunks"].append(d)
origin_chunks.append(d)
##rename keys
for chunk in origin_chunks:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
}
renamed_chunk = {}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
res["chunks"].append(renamed_chunk)
return get_json_result(data=res)
except Exception as e:
if str(e).find("not_found") > 0:
@ -471,8 +514,9 @@ def create(tenant_id):
req = request.json
md5 = hashlib.md5()
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
chunk_id = md5.hexdigest()
d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
@ -503,20 +547,33 @@ def create(tenant_id):
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk": d})
# return get_json_result(data={"chunk_id": chunck_id})
d["chunk_id"] = chunk_id
#rename keys
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
"kb_id":"knowledge_base_id",
}
renamed_chunk = {}
for key, value in d.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
return get_json_result(data={"chunk": renamed_chunk})
# return get_json_result(data={"chunk_id": chunk_id})
except Exception as e:
return server_error_response(e)
@manager.route('/chunk/rm', methods=['POST'])
@token_required
@validate_request("chunk_ids", "doc_id")
def rm_chunk():
def rm_chunk(tenant_id):
req = request.json
try:
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
return get_data_error_result(retmsg="Index updating failure")
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@ -526,4 +583,126 @@ def rm_chunk():
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/chunk/set', methods=['POST'])
@token_required
@validate_request("doc_id", "chunk_id", "content_with_weight",
"important_kwd")
def set(tenant_id):
req = request.json
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
if "available_int" in req:
d["available_int"] = req["available_int"]
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
if len(arr) != 2:
return get_data_error_result(
retmsg="Q&A must be separated by TAB/ENTER key.")
q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
d = beAdoc(d, arr[0], arr[1], not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/retrieval_test', methods=['POST'])
@token_required
@validate_request("kb_id", "question")
def retrieval_test(tenant_id):
req = request.json
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
kb_id = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
try:
tenants = UserTenantService.query(user_id=tenant_id)
for kid in kb_id:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kid):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
##rename keys
renamed_chunks=[]
for chunk in ranks["chunks"]:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
}
rename_chunk={}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
rename_chunk[new_key] = value
renamed_chunks.append(rename_chunk)
ranks["chunks"] = renamed_chunks
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)