Refactor Chunk API (#2855)

### What problem does this PR solve?

Refactor Chunk API
#2846
### Type of change


- [x] Refactoring

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
liuhua
2024-10-16 18:41:24 +08:00
committed by GitHub
parent b9fa00f341
commit dab92ac1e8
11 changed files with 760 additions and 791 deletions

View File

@ -119,13 +119,11 @@ def update_doc(tenant_id, dataset_id, document_id):
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req["name"]})
if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"])
if "parser_method" in req:
if doc.parser_id.lower() == req["parser_method"].lower():
if "parser_config" in req:
if req["parser_config"] == doc.parser_config:
return get_result(retcode=RetCode.SUCCESS)
else:
return get_result(retcode=RetCode.SUCCESS)
return get_result()
if doc.type == FileType.VISUAL or re.search(
r"\.(ppt|pptx|pages)$", doc.name):
@ -146,8 +144,6 @@ def update_doc(tenant_id, dataset_id, document_id):
return get_error_data_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"])
return get_result()
@ -258,6 +254,8 @@ def parse(tenant_id,dataset_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
req = request.json
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
for id in req["document_ids"]:
if not DocumentService.query(id=id,kb_id=dataset_id):
return get_error_data_result(retmsg=f"You don't own the document {id}.")
@ -283,9 +281,14 @@ def stop_parsing(tenant_id,dataset_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
req = request.json
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
for id in req["document_ids"]:
if not DocumentService.query(id=id,kb_id=dataset_id):
doc = DocumentService.query(id=id, kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {id}.")
if doc[0].progress == 100.0 or doc[0].progress == 0.0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 100")
info = {"run": "2", "progress": 0}
DocumentService.update_by_id(id, info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
@ -297,7 +300,7 @@ def stop_parsing(tenant_id,dataset_id):
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['GET'])
@token_required
def list_chunk(tenant_id,dataset_id,document_id):
def list_chunks(tenant_id,dataset_id,document_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
doc=DocumentService.query(id=document_id, kb_id=dataset_id)
@ -309,57 +312,58 @@ def list_chunk(tenant_id,dataset_id,document_id):
page = int(req.get("offset", 1))
size = int(req.get("limit", 30))
question = req.get("keywords", "")
try:
query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
origin_chunks = []
sign = 0
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
id].get(
"content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t")
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
origin_chunks = []
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
id].get(
"content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t")
}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
origin_chunks.append(d)
if req.get("id"):
if req.get("id") == id:
origin_chunks.clear()
origin_chunks.append(d)
sign = 1
break
if req.get("id"):
if sign == 0:
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
for chunk in origin_chunks:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
"img_id": "image_id",
}
renamed_chunk = {}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
res["chunks"].append(renamed_chunk)
return get_result(data=res)
origin_chunks.append(d)
##rename keys
for chunk in origin_chunks:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
"img_id": "image_id",
}
renamed_chunk = {}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
res["chunks"].append(renamed_chunk)
return get_result(data=res)
except Exception as e:
if str(e).find("not_found") > 0:
return get_result(retmsg=f'No chunk found!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST'])
@ -374,6 +378,9 @@ def create(tenant_id,dataset_id,document_id):
req = request.json
if not req.get("content"):
return get_error_data_result(retmsg="`content` is required")
if "important_keywords" in req:
if type(req["important_keywords"]) != list:
return get_error_data_result("`important_keywords` is required to be a list")
md5 = hashlib.md5()
md5.update((req["content"] + document_id).encode("utf-8"))
@ -381,8 +388,8 @@ def create(tenant_id,dataset_id,document_id):
d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content"]),
"content_with_weight": req["content"]}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
d["important_kwd"] = req.get("important_keywords", [])
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", [])))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
d["kb_id"] = [doc.kb_id]
@ -432,12 +439,12 @@ def rm_chunk(tenant_id,dataset_id,document_id):
req = request.json
if not req.get("chunk_ids"):
return get_error_data_result("`chunk_ids` is required")
query = {
"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
for chunk_id in req.get("chunk_ids"):
res = ELASTICSEARCH.get(
chunk_id, search.index_name(
tenant_id))
if not res.get("found"):
return server_error_response(f"Chunk {chunk_id} not found")
if chunk_id not in sres.ids:
return get_error_data_result(f"Chunk {chunk_id} not found")
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
return get_error_data_result(retmsg="Index updating failure")
@ -451,24 +458,36 @@ def rm_chunk(tenant_id,dataset_id,document_id):
@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT'])
@token_required
def set(tenant_id,dataset_id,document_id,chunk_id):
res = ELASTICSEARCH.get(
try:
res = ELASTICSEARCH.get(
chunk_id, search.index_name(
tenant_id))
if not res.get("found"):
return get_error_data_result(f"Chunk {chunk_id} not found")
except Exception as e:
return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
doc = doc[0]
query = {
"doc_ids": [document_id], "page": 1, "size": 1024, "question": "", "sort": True
}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
if chunk_id not in sres.ids:
return get_error_data_result(f"You don't own the chunk {chunk_id}")
req = request.json
content=res["_source"].get("content_with_weight")
d = {
"id": chunk_id,
"content_with_weight": req.get("content",res.get["content_with_weight"])}
d["content_ltks"] = rag_tokenizer.tokenize(req["content"])
"content_with_weight": req.get("content",content)}
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_keywords",[])
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"]))
if "important_keywords" in req:
if type(req["important_keywords"]) != list:
return get_error_data_result("`important_keywords` is required to be a list")
d["important_kwd"] = req.get("important_keywords")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"]))
if "available" in req:
d["available_int"] = req["available"]
embd_id = DocumentService.get_embd_id(document_id)
@ -478,7 +497,7 @@ def set(tenant_id,dataset_id,document_id,chunk_id):
arr = [
t for t in re.split(
r"[\n\t]",
req["content"]) if len(t) > 1]
d["content_with_weight"]) if len(t) > 1]
if len(arr) != 2:
return get_error_data_result(
retmsg="Q&A must be separated by TAB/ENTER key.")
@ -486,7 +505,7 @@ def set(tenant_id,dataset_id,document_id,chunk_id):
d = beAdoc(d, arr[0], arr[1], not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content"]])
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
@ -505,7 +524,7 @@ def retrieval_test(tenant_id):
for id in kb_id:
if not KnowledgebaseService.query(id=id,tenant_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.")
if "question" not in req_json:
if "question" not in req:
return get_error_data_result("`question` is required.")
page = int(req.get("offset", 1))
size = int(req.get("limit", 30))

View File

@ -24,10 +24,9 @@ from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result
from api.utils.api_utils import get_result, token_required
@manager.route('/chat/<chat_id>/session', methods=['POST'])
@token_required
def create(tenant_id, chat_id):
def create(tenant_id,chat_id):
req = request.json
req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
@ -51,14 +50,13 @@ def create(tenant_id, chat_id):
del conv["reference"]
return get_result(data=conv)
@manager.route('/chat/<chat_id>/session/<session_id>', methods=['PUT'])
@token_required
def update(tenant_id, chat_id, session_id):
def update(tenant_id,chat_id,session_id):
req = request.json
req["dialog_id"] = chat_id
conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
if not conv:
return get_error_data_result(retmsg="Session does not exist")
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -74,16 +72,30 @@ def update(tenant_id, chat_id, session_id):
return get_result()
@manager.route('/chat/<chat_id>/session/<session_id>/completion', methods=['POST'])
@manager.route('/chat/<chat_id>/completion', methods=['POST'])
@token_required
def completion(tenant_id, chat_id, session_id):
def completion(tenant_id,chat_id):
req = request.json
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"}
# ]}
if not req.get("session_id"):
conv = {
"id": get_uuid(),
"dialog_id": chat_id,
"name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": "Hi! I am your assistantcan I help you?"}]
}
if not conv.get("name"):
return get_error_data_result(retmsg="Name can not be empty.")
ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"])
session_id=conv.id
else:
session_id = req.get("session_id")
if not req.get("question"):
return get_error_data_result(retmsg="Please input your question.")
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
conv = ConversationService.query(id=session_id,dialog_id=chat_id)
if not conv:
return get_error_data_result(retmsg="Session does not exist")
conv = conv[0]
@ -117,17 +129,18 @@ def completion(tenant_id, chat_id, session_id):
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id
ans["session_id"]=session_id
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, **req):
fillin_conv(ans)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
"data": {"answer": "**ERROR**: " + str(e),"reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
@ -148,15 +161,14 @@ def completion(tenant_id, chat_id, session_id):
break
return get_result(data=answer)
@manager.route('/chat/<chat_id>/session', methods=['GET'])
@token_required
def list(chat_id, tenant_id):
def list(chat_id,tenant_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(retmsg=f"You don't own the assistant {chat_id}.")
id = request.args.get("id")
name = request.args.get("name")
session = ConversationService.query(id=id, name=name, dialog_id=chat_id)
session = ConversationService.query(id=id,name=name,dialog_id=chat_id)
if not session:
return get_error_data_result(retmsg="The session doesn't exist")
page_number = int(request.args.get("page", 1))
@ -166,7 +178,7 @@ def list(chat_id, tenant_id):
desc = False
else:
desc = True
convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name)
convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
if not convs:
return get_result(data=[])
for conv in convs:
@ -201,17 +213,16 @@ def list(chat_id, tenant_id):
del conv["reference"]
return get_result(data=convs)
@manager.route('/chat/<chat_id>/session', methods=["DELETE"])
@token_required
def delete(tenant_id, chat_id):
def delete(tenant_id,chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(retmsg="You don't own the chat")
ids = request.json.get("ids")
if not ids:
return get_error_data_result(retmsg="`ids` is required in deleting operation")
for id in ids:
conv = ConversationService.query(id=id, dialog_id=chat_id)
conv = ConversationService.query(id=id,dialog_id=chat_id)
if not conv:
return get_error_data_result(retmsg="The chat doesn't own the session")
ConversationService.delete_by_id(id)