mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor code (#8341)
### What problem does this PR solve? 1. rename var 2. update if statement ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -133,7 +133,7 @@ class Recognizer:
|
||||
|
||||
@staticmethod
|
||||
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
|
||||
def notOverlapped(a, b):
|
||||
def not_overlapped(a, b):
|
||||
return any([a["x1"] < b["x0"],
|
||||
a["x0"] > b["x1"],
|
||||
a["bottom"] < b["top"],
|
||||
@ -144,7 +144,7 @@ class Recognizer:
|
||||
j = i + 1
|
||||
while j < min(i + far, len(layouts)) \
|
||||
and (layouts[i].get("type", "") != layouts[j].get("type", "")
|
||||
or notOverlapped(layouts[i], layouts[j])):
|
||||
or not_overlapped(layouts[i], layouts[j])):
|
||||
j += 1
|
||||
if j >= min(i + far, len(layouts)):
|
||||
i += 1
|
||||
@ -163,9 +163,9 @@ class Recognizer:
|
||||
|
||||
area_i, area_i_1 = 0, 0
|
||||
for b in boxes:
|
||||
if not notOverlapped(b, layouts[i]):
|
||||
if not not_overlapped(b, layouts[i]):
|
||||
area_i += Recognizer.overlapped_area(b, layouts[i], False)
|
||||
if not notOverlapped(b, layouts[j]):
|
||||
if not not_overlapped(b, layouts[j]):
|
||||
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
|
||||
|
||||
if area_i > area_i_1:
|
||||
@ -408,18 +408,18 @@ class Recognizer:
|
||||
|
||||
def __call__(self, image_list, thr=0.7, batch_size=16):
|
||||
res = []
|
||||
imgs = []
|
||||
images = []
|
||||
for i in range(len(image_list)):
|
||||
if not isinstance(image_list[i], np.ndarray):
|
||||
imgs.append(np.array(image_list[i]))
|
||||
images.append(np.array(image_list[i]))
|
||||
else:
|
||||
imgs.append(image_list[i])
|
||||
images.append(image_list[i])
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
||||
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||
for i in range(batch_loop_cnt):
|
||||
start_index = i * batch_size
|
||||
end_index = min((i + 1) * batch_size, len(imgs))
|
||||
batch_image_list = imgs[start_index:end_index]
|
||||
end_index = min((i + 1) * batch_size, len(images))
|
||||
batch_image_list = images[start_index:end_index]
|
||||
inputs = self.preprocess(batch_image_list)
|
||||
logging.debug("preprocess")
|
||||
for ins in inputs:
|
||||
|
||||
@ -92,13 +92,13 @@ class TestDocumentsParseStop:
|
||||
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message
|
||||
else:
|
||||
if expected_code == 0:
|
||||
completed_document_ids = list(set(document_ids) - set(payload["document_ids"]))
|
||||
condition(get_http_api_auth, dataset_id, completed_document_ids)
|
||||
validate_document_parse_cancel(get_http_api_auth, dataset_id, payload["document_ids"])
|
||||
validate_document_parse_done(get_http_api_auth, dataset_id, completed_document_ids)
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -173,10 +173,10 @@ class TestDocumentsUpdated:
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
if chunk_method != "":
|
||||
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
if chunk_method == "":
|
||||
assert res["data"]["docs"][0]["chunk_method"] == "naive"
|
||||
else:
|
||||
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -532,10 +532,7 @@ class TestUpdateDocumentParserConfig:
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
if parser_config != {}:
|
||||
for k, v in parser_config.items():
|
||||
assert res["data"]["docs"][0]["parser_config"][k] == v
|
||||
else:
|
||||
if parser_config == {}:
|
||||
assert res["data"]["docs"][0]["parser_config"] == {
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": r"\n",
|
||||
@ -543,5 +540,8 @@ class TestUpdateDocumentParserConfig:
|
||||
"layout_recognize": "DeepDOC",
|
||||
"raptor": {"use_raptor": False},
|
||||
}
|
||||
else:
|
||||
for k, v in parser_config.items():
|
||||
assert res["data"]["docs"][0]["parser_config"][k] == v
|
||||
if expected_code != 0 or expected_message:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -162,10 +162,10 @@ class TestSessionsWithChatAssistantList:
|
||||
res = list_session_with_chat_assistants(get_http_api_auth, chat_assistant_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["name"] != "session_with_chat_assistant_1":
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
if params["name"] == "session_with_chat_assistant_1":
|
||||
assert res["data"][0]["name"] == params["name"]
|
||||
else:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -189,10 +189,10 @@ class TestSessionsWithChatAssistantList:
|
||||
res = list_session_with_chat_assistants(get_http_api_auth, chat_assistant_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["id"] != session_ids[0]:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
if params["id"] == session_ids[0]:
|
||||
assert res["data"][0]["id"] == params["id"]
|
||||
else:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@ -174,10 +174,10 @@ class TestDocumentsUpdated:
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documents(HttpApiAuth, dataset_id, {"id": document_ids[0]})
|
||||
if chunk_method != "":
|
||||
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
if chunk_method == "":
|
||||
assert res["data"]["docs"][0]["chunk_method"] == "naive"
|
||||
else:
|
||||
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -533,10 +533,7 @@ class TestUpdateDocumentParserConfig:
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documents(HttpApiAuth, dataset_id, {"id": document_ids[0]})
|
||||
if parser_config != {}:
|
||||
for k, v in parser_config.items():
|
||||
assert res["data"]["docs"][0]["parser_config"][k] == v
|
||||
else:
|
||||
if parser_config == {}:
|
||||
assert res["data"]["docs"][0]["parser_config"] == {
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": r"\n",
|
||||
@ -544,5 +541,8 @@ class TestUpdateDocumentParserConfig:
|
||||
"layout_recognize": "DeepDOC",
|
||||
"raptor": {"use_raptor": False},
|
||||
}
|
||||
else:
|
||||
for k, v in parser_config.items():
|
||||
assert res["data"]["docs"][0]["parser_config"][k] == v
|
||||
if expected_code != 0 or expected_message:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -163,10 +163,10 @@ class TestSessionsWithChatAssistantList:
|
||||
res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["name"] != "session_with_chat_assistant_1":
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
if params["name"] == "session_with_chat_assistant_1":
|
||||
assert res["data"][0]["name"] == params["name"]
|
||||
else:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -190,10 +190,10 @@ class TestSessionsWithChatAssistantList:
|
||||
res = list_session_with_chat_assistants(HttpApiAuth, chat_assistant_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["id"] != session_ids[0]:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
if params["id"] == session_ids[0]:
|
||||
assert res["data"][0]["id"] == params["id"]
|
||||
else:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@ -126,10 +126,10 @@ class TestSessionsWithChatAssistantList:
|
||||
assert expected_message in str(excinfo.value)
|
||||
else:
|
||||
sessions = chat_assistant.list_sessions(**params)
|
||||
if params["name"] != "session_with_chat_assistant_1":
|
||||
assert len(sessions) == expected_num
|
||||
else:
|
||||
if params["name"] == "session_with_chat_assistant_1":
|
||||
assert sessions[0].name == params["name"]
|
||||
else:
|
||||
assert len(sessions) == expected_num
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
@ -154,10 +154,10 @@ class TestSessionsWithChatAssistantList:
|
||||
assert expected_message in str(excinfo.value)
|
||||
else:
|
||||
list_sessions = chat_assistant.list_sessions(**params)
|
||||
if "id" in params and params["id"] != sessions[0].id:
|
||||
assert len(list_sessions) == expected_num
|
||||
else:
|
||||
if "id" in params and params["id"] == sessions[0].id:
|
||||
assert list_sessions[0].id == params["id"]
|
||||
else:
|
||||
assert len(list_sessions) == expected_num
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user