mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add support for the Ascend layout recognizer (#10105)
### What problem does this PR solve? Supports Ascend layout recognizer. Use the environment variable `LAYOUT_RECOGNIZER_TYPE=ascend` to enable the Ascend layout recognizer, and `ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID=n` (for example, n=0) to specify the Ascend device ID. Ensure that you have installed the [ais tools](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench) properly. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
@ -45,28 +47,22 @@ class LayoutRecognizer(Recognizer):
|
||||
|
||||
def __init__(self, domain):
|
||||
try:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc")
|
||||
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
|
||||
self.garbage_layouts = ["footer", "header", "reference"]
|
||||
self.client = None
|
||||
if os.environ.get("TENSORRT_DLA_SVR"):
|
||||
from deepdoc.vision.dla_cli import DLAClient
|
||||
|
||||
self.client = DLAClient(os.environ["TENSORRT_DLA_SVR"])
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||
def __is_garbage(b):
|
||||
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
||||
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
||||
"\\(cid *: *[0-9]+ *\\)"
|
||||
]
|
||||
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
|
||||
return any([re.search(p, b["text"]) for p in patt])
|
||||
|
||||
if self.client:
|
||||
@ -82,18 +78,23 @@ class LayoutRecognizer(Recognizer):
|
||||
page_layout = []
|
||||
for pn, lts in enumerate(layouts):
|
||||
bxs = ocr_res[pn]
|
||||
lts = [{"type": b["type"],
|
||||
lts = [
|
||||
{
|
||||
"type": b["type"],
|
||||
"score": float(b["score"]),
|
||||
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
|
||||
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
||||
"x0": b["bbox"][0] / scale_factor,
|
||||
"x1": b["bbox"][2] / scale_factor,
|
||||
"top": b["bbox"][1] / scale_factor,
|
||||
"bottom": b["bbox"][-1] / scale_factor,
|
||||
"page_number": pn,
|
||||
} for b in lts if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts]
|
||||
lts = self.sort_Y_firstly(lts, np.mean(
|
||||
[lt["bottom"] - lt["top"] for lt in lts]) / 2)
|
||||
}
|
||||
for b in lts
|
||||
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts
|
||||
]
|
||||
lts = self.sort_Y_firstly(lts, np.mean([lt["bottom"] - lt["top"] for lt in lts]) / 2)
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
# Tag layout type, layouts are ready
|
||||
def findLayout(ty):
|
||||
nonlocal bxs, lts, self
|
||||
lts_ = [lt for lt in lts if lt["type"] == ty]
|
||||
@ -106,21 +107,17 @@ class LayoutRecognizer(Recognizer):
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
ii = self.find_overlapped_with_threshold(bxs[i], lts_,
|
||||
thr=0.4)
|
||||
if ii is None: # belong to nothing
|
||||
ii = self.find_overlapped_with_threshold(bxs[i], lts_, thr=0.4)
|
||||
if ii is None:
|
||||
bxs[i]["layout_type"] = ""
|
||||
i += 1
|
||||
continue
|
||||
lts_[ii]["visited"] = True
|
||||
keep_feats = [
|
||||
lts_[
|
||||
ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
||||
lts_[
|
||||
ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
||||
lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
||||
lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
||||
]
|
||||
if drop and lts_[
|
||||
ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||
if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||
if lts_[ii]["type"] not in garbages:
|
||||
garbages[lts_[ii]["type"]] = []
|
||||
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
||||
@ -128,17 +125,14 @@ class LayoutRecognizer(Recognizer):
|
||||
continue
|
||||
|
||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
|
||||
ii]["type"] != "equation" else "figure"
|
||||
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"] != "equation" else "figure"
|
||||
i += 1
|
||||
|
||||
for lt in ["footer", "header", "reference", "figure caption",
|
||||
"table caption", "title", "table", "text", "figure", "equation"]:
|
||||
for lt in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
|
||||
findLayout(lt)
|
||||
|
||||
# add box to figure layouts which has not text box
|
||||
for i, lt in enumerate(
|
||||
[lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
||||
for i, lt in enumerate([lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
||||
if lt.get("visited"):
|
||||
continue
|
||||
lt = deepcopy(lt)
|
||||
@ -206,13 +200,11 @@ class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||
) # add border
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img[np.newaxis, :, :, :].astype(np.float32)
|
||||
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1]/ww, shape[0]/hh, dw, dh]})
|
||||
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1] / ww, shape[0] / hh, dw, dh]})
|
||||
|
||||
return inputs
|
||||
|
||||
@ -230,8 +222,7 @@ class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
||||
boxes[:, 2] -= inputs["scale_factor"][2]
|
||||
boxes[:, 1] -= inputs["scale_factor"][3]
|
||||
boxes[:, 3] -= inputs["scale_factor"][3]
|
||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0],
|
||||
inputs["scale_factor"][1]])
|
||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
|
||||
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
||||
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
@ -243,8 +234,223 @@ class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
||||
class_keep_boxes = nms(class_boxes, class_scores, 0.45)
|
||||
indices.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return [{
|
||||
"type": self.label_list[class_ids[i]].lower(),
|
||||
"bbox": [float(t) for t in boxes[i].tolist()],
|
||||
"score": float(scores[i])
|
||||
} for i in indices]
|
||||
return [{"type": self.label_list[class_ids[i]].lower(), "bbox": [float(t) for t in boxes[i].tolist()], "score": float(scores[i])} for i in indices]
|
||||
|
||||
|
||||
class AscendLayoutRecognizer(Recognizer):
|
||||
labels = [
|
||||
"title",
|
||||
"Text",
|
||||
"Reference",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Table caption",
|
||||
"Equation",
|
||||
"Figure caption",
|
||||
]
|
||||
|
||||
def __init__(self, domain):
|
||||
from ais_bench.infer.interface import InferSession
|
||||
|
||||
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||
model_file_path = os.path.join(model_dir, domain + ".om")
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError(f"Model file not found: {model_file_path}")
|
||||
|
||||
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
|
||||
self.session = InferSession(device_id=device_id, model_path=model_file_path)
|
||||
self.input_shape = self.session.get_inputs()[0].shape[2:4] # H,W
|
||||
self.garbage_layouts = ["footer", "header", "reference"]
|
||||
|
||||
def preprocess(self, image_list):
|
||||
inputs = []
|
||||
H, W = self.input_shape
|
||||
for img in image_list:
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||
|
||||
r = min(H / h, W / w)
|
||||
new_unpad = (int(round(w * r)), int(round(h * r)))
|
||||
dw, dh = (W - new_unpad[0]) / 2.0, (H - new_unpad[1]) / 2.0
|
||||
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)[np.newaxis, :, :, :].astype(np.float32)
|
||||
|
||||
inputs.append(
|
||||
{
|
||||
"image": img,
|
||||
"scale_factor": [w / new_unpad[0], h / new_unpad[1]],
|
||||
"pad": [dw, dh],
|
||||
"orig_shape": [h, w],
|
||||
}
|
||||
)
|
||||
return inputs
|
||||
|
||||
def postprocess(self, boxes, inputs, thr=0.25):
|
||||
arr = np.squeeze(boxes)
|
||||
if arr.ndim == 1:
|
||||
arr = arr.reshape(1, -1)
|
||||
|
||||
results = []
|
||||
if arr.shape[1] == 6:
|
||||
# [x1,y1,x2,y2,score,cls]
|
||||
m = arr[:, 4] >= thr
|
||||
arr = arr[m]
|
||||
if arr.size == 0:
|
||||
return []
|
||||
xyxy = arr[:, :4].astype(np.float32)
|
||||
scores = arr[:, 4].astype(np.float32)
|
||||
cls_ids = arr[:, 5].astype(np.int32)
|
||||
|
||||
if "pad" in inputs:
|
||||
dw, dh = inputs["pad"]
|
||||
sx, sy = inputs["scale_factor"]
|
||||
xyxy[:, [0, 2]] -= dw
|
||||
xyxy[:, [1, 3]] -= dh
|
||||
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
|
||||
else:
|
||||
# backup
|
||||
sx, sy = inputs["scale_factor"]
|
||||
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
|
||||
|
||||
keep_indices = []
|
||||
for c in np.unique(cls_ids):
|
||||
idx = np.where(cls_ids == c)[0]
|
||||
k = nms(xyxy[idx], scores[idx], 0.45)
|
||||
keep_indices.extend(idx[k])
|
||||
|
||||
for i in keep_indices:
|
||||
cid = int(cls_ids[i])
|
||||
if 0 <= cid < len(self.labels):
|
||||
results.append({"type": self.labels[cid].lower(), "bbox": [float(t) for t in xyxy[i].tolist()], "score": float(scores[i])})
|
||||
return results
|
||||
|
||||
raise ValueError(f"Unexpected output shape: {arr.shape}")
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
assert len(image_list) == len(ocr_res)
|
||||
|
||||
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
|
||||
layouts_all_pages = [] # list of list[{"type","score","bbox":[x1,y1,x2,y2]}]
|
||||
|
||||
conf_thr = max(thr, 0.08)
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||
for bi in range(batch_loop_cnt):
|
||||
s = bi * batch_size
|
||||
e = min((bi + 1) * batch_size, len(images))
|
||||
batch_images = images[s:e]
|
||||
|
||||
inputs_list = self.preprocess(batch_images)
|
||||
logging.debug("preprocess done")
|
||||
|
||||
for ins in inputs_list:
|
||||
feeds = [ins["image"]]
|
||||
out_list = self.session.infer(feeds=feeds, mode="static")
|
||||
|
||||
for out in out_list:
|
||||
lts = self.postprocess(out, ins, conf_thr)
|
||||
|
||||
page_lts = []
|
||||
for b in lts:
|
||||
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts:
|
||||
x0, y0, x1, y1 = b["bbox"]
|
||||
page_lts.append(
|
||||
{
|
||||
"type": b["type"],
|
||||
"score": float(b["score"]),
|
||||
"x0": float(x0) / scale_factor,
|
||||
"x1": float(x1) / scale_factor,
|
||||
"top": float(y0) / scale_factor,
|
||||
"bottom": float(y1) / scale_factor,
|
||||
"page_number": len(layouts_all_pages),
|
||||
}
|
||||
)
|
||||
layouts_all_pages.append(page_lts)
|
||||
|
||||
def _is_garbage_text(box):
|
||||
patt = [r"^•+$", r"^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", r"^http://[^ ]{12,}", r"\(cid *: *[0-9]+ *\)"]
|
||||
return any(re.search(p, box.get("text", "")) for p in patt)
|
||||
|
||||
boxes_out = []
|
||||
page_layout = []
|
||||
garbages = {}
|
||||
|
||||
for pn, lts in enumerate(layouts_all_pages):
|
||||
if lts:
|
||||
avg_h = np.mean([lt["bottom"] - lt["top"] for lt in lts])
|
||||
lts = self.sort_Y_firstly(lts, avg_h / 2 if avg_h > 0 else 0)
|
||||
|
||||
bxs = ocr_res[pn]
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
def _tag_layout(ty):
|
||||
nonlocal bxs, lts
|
||||
lts_of_ty = [lt for lt in lts if lt["type"] == ty]
|
||||
i = 0
|
||||
while i < len(bxs):
|
||||
if bxs[i].get("layout_type"):
|
||||
i += 1
|
||||
continue
|
||||
if _is_garbage_text(bxs[i]):
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
ii = self.find_overlapped_with_threshold(bxs[i], lts_of_ty, thr=0.4)
|
||||
if ii is None:
|
||||
bxs[i]["layout_type"] = ""
|
||||
i += 1
|
||||
continue
|
||||
|
||||
lts_of_ty[ii]["visited"] = True
|
||||
|
||||
keep_feats = [
|
||||
lts_of_ty[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].shape[0] * 0.9 / scale_factor,
|
||||
lts_of_ty[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].shape[0] * 0.1 / scale_factor,
|
||||
]
|
||||
if drop and lts_of_ty[ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||
garbages.setdefault(lts_of_ty[ii]["type"], []).append(bxs[i].get("text", ""))
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||
bxs[i]["layout_type"] = lts_of_ty[ii]["type"] if lts_of_ty[ii]["type"] != "equation" else "figure"
|
||||
i += 1
|
||||
|
||||
for ty in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
|
||||
_tag_layout(ty)
|
||||
|
||||
figs = [lt for lt in lts if lt["type"] in ["figure", "equation"]]
|
||||
for i, lt in enumerate(figs):
|
||||
if lt.get("visited"):
|
||||
continue
|
||||
lt = deepcopy(lt)
|
||||
lt.pop("type", None)
|
||||
lt["text"] = ""
|
||||
lt["layout_type"] = "figure"
|
||||
lt["layoutno"] = f"figure-{i}"
|
||||
bxs.append(lt)
|
||||
|
||||
boxes_out.extend(bxs)
|
||||
|
||||
garbag_set = set()
|
||||
for k, lst in garbages.items():
|
||||
cnt = Counter(lst)
|
||||
for g, c in cnt.items():
|
||||
if c > 1:
|
||||
garbag_set.add(g)
|
||||
|
||||
ocr_res_new = [b for b in boxes_out if b["text"].strip() not in garbag_set]
|
||||
return ocr_res_new, page_layout
|
||||
|
||||
Reference in New Issue
Block a user