Add test for API (#3134)

### What problem does this PR solve?

Add test for API

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
This commit is contained in:
liuhua
2024-11-01 22:59:17 +08:00
committed by GitHub
parent 7eafccf78a
commit 44ad9a6cd7
10 changed files with 292 additions and 355 deletions

View File

@ -194,8 +194,11 @@ def list_docs(dataset_id, tenant_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}. ")
id = request.args.get("id")
name = request.args.get("name")
if not DocumentService.query(id=id,kb_id=dataset_id):
return get_error_data_result(retmsg=f"You don't own the document {id}.")
if not DocumentService.query(name=name,kb_id=dataset_id):
return get_error_data_result(retmsg=f"You don't own the document {name}.")
offset = int(request.args.get("offset", 1))
keywords = request.args.get("keywords","")
limit = int(request.args.get("limit", 1024))
@ -204,7 +207,7 @@ def list_docs(dataset_id, tenant_id):
desc = False
else:
desc = True
docs, tol = DocumentService.get_list(dataset_id, offset, limit, orderby, desc, keywords, id)
docs, tol = DocumentService.get_list(dataset_id, offset, limit, orderby, desc, keywords, id,name)
# rename key's name
renamed_doc_list = []
@ -321,8 +324,8 @@ def stop_parsing(tenant_id,dataset_id):
doc = DocumentService.query(id=id, kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {id}.")
if doc[0].progress == 100.0 or doc[0].progress == 0.0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 100")
if int(doc[0].progress) == 1 or int(doc[0].progress) == 0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
info = {"run": "2", "progress": 0,"chunk_num":0}
DocumentService.update_by_id(id, info)
ELASTICSEARCH.deleteByQuery(
@ -414,9 +417,9 @@ def list_chunks(tenant_id,dataset_id,document_id):
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
if renamed_chunk["available"] == "0":
if renamed_chunk["available"] == 0:
renamed_chunk["available"] = False
if renamed_chunk["available"] == "1":
if renamed_chunk["available"] == 1:
renamed_chunk["available"] = True
res["chunks"].append(renamed_chunk)
return get_result(data=res)
@ -464,6 +467,7 @@ def add_chunk(tenant_id,dataset_id,document_id):
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
d["chunk_id"] = chunk_id
d["kb_id"]=doc.kb_id
# rename keys
key_mapping = {
"chunk_id": "id",
@ -581,10 +585,10 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id):
def retrieval_test(tenant_id):
req = request.json
if not req.get("dataset_ids"):
return get_error_data_result("`datasets` is required.")
return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"]
if not isinstance(kb_ids,list):
return get_error_data_result("`datasets` should be a list")
return get_error_data_result("`dataset_ids` should be a list")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
for id in kb_ids:
if not KnowledgebaseService.query(id=id,tenant_id=tenant_id):

View File

@ -52,11 +52,15 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id):
docs =cls.model.select().where(cls.model.kb_id==kb_id)
orderby, desc, keywords, id, name):
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if id:
docs = docs.where(
cls.model.id== id )
cls.model.id == id)
if name:
docs = docs.where(
cls.model.name == name
)
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
@ -70,7 +74,6 @@ class DocumentService(CommonService):
count = docs.count()
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
@ -162,26 +165,27 @@ class DocumentService(CommonService):
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value)\
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \
.order_by(cls.model.update_time.asc())
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
return list(docs.dicts())
@classmethod
@ -196,12 +200,12 @@ class DocumentService(CommonService):
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
@ -214,13 +218,13 @@ class DocumentService(CommonService):
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def clear_chunk_num(cls, doc_id):
@ -229,10 +233,10 @@ class DocumentService(CommonService):
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num-1
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
return num
@ -243,8 +247,8 @@ class DocumentService(CommonService):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
@ -270,8 +274,8 @@ class DocumentService(CommonService):
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
@ -284,7 +288,7 @@ class DocumentService(CommonService):
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
@ -296,13 +300,13 @@ class DocumentService(CommonService):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["embd_id"]
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
@ -338,6 +342,7 @@ class DocumentService(CommonService):
dfs_update(old[k], v)
else:
old[k] = v
dfs_update(d.parser_config, config)
cls.update_by_id(id, {"parser_config": d.parser_config})
@ -372,7 +377,7 @@ class DocumentService(CommonService):
finished = True
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run#TaskStatus.RUNNING.value
status = doc.run # TaskStatus.RUNNING.value
for t in tsks:
if 0 <= t.progress < 1:
finished = False
@ -386,9 +391,10 @@ class DocumentService(CommonService):
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
" raptor") < 0:
queue_raptor_tasks(d)
prg = 0.98 * len(tsks)/(len(tsks)+1)
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ RAPTOR -------")
else:
status = TaskStatus.DONE.value
@ -414,7 +420,6 @@ class DocumentService(CommonService):
return len(cls.model.select(cls.model.id).where(
cls.model.kb_id == kb_id).dicts())
@classmethod
@DB.connection_context()
def do_cancel(cls, doc_id):
@ -579,4 +584,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d,_ in files]
return [d["id"] for d, _ in files]