mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
refine admin initialization (#75)
This commit is contained in:
@ -230,7 +230,7 @@ class HuParser:
|
||||
b["H_right"] = headers[ii]["x1"]
|
||||
b["H"] = ii
|
||||
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
|
||||
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
||||
if ii is not None:
|
||||
b["C"] = ii
|
||||
b["C_left"] = clmns[ii]["x0"]
|
||||
|
||||
@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer):
|
||||
super().__init__(self.labels, domain,
|
||||
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
||||
def __is_garbage(b):
|
||||
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
||||
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
||||
|
||||
@ -2,7 +2,6 @@ import copy
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddle
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
@ -215,7 +214,7 @@ class DBPostProcess(object):
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
if not isinstance(pred, np.ndarray):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
preds = preds[-1]
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
if not isinstance(preds, np.ndarray):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
|
||||
@ -259,6 +259,18 @@ class Recognizer(object):
|
||||
|
||||
return max_overlaped_i
|
||||
|
||||
@staticmethod
|
||||
def find_horizontally_tightest_fit(box, boxes):
|
||||
if not boxes:
|
||||
return
|
||||
min_dis, min_i = 1000000, None
|
||||
for i,b in enumerate(boxes):
|
||||
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
||||
if dis < min_dis:
|
||||
min_i = i
|
||||
min_dis = dis
|
||||
return min_i
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
||||
if not boxes:
|
||||
|
||||
@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr):
|
||||
clmns = sorted([r for r in tb_cpns if re.match(
|
||||
r"table column$", r["label"])], key=lambda x: x["x0"])
|
||||
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
||||
|
||||
for b in boxes:
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
||||
if ii is not None:
|
||||
@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr):
|
||||
b["H_right"] = headers[ii]["x1"]
|
||||
b["H"] = ii
|
||||
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
|
||||
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
||||
if ii is not None:
|
||||
b["C"] = ii
|
||||
b["C_left"] = clmns[ii]["x0"]
|
||||
@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr):
|
||||
b["H_left"] = spans[ii]["x0"]
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
html = """
|
||||
<html>
|
||||
<head>
|
||||
|
||||
@ -14,7 +14,6 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer):
|
||||
super().__init__(self.labels, "tsr",
|
||||
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||
|
||||
def __call__(self, images, thr=0.5):
|
||||
def __call__(self, images, thr=0.2):
|
||||
tbls = super().__call__(images, thr)
|
||||
res = []
|
||||
# align left&right for rows, align top&bottom for columns
|
||||
@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer):
|
||||
"row") > 0 or b["label"].find("header") > 0]
|
||||
if not left:
|
||||
continue
|
||||
left = np.median(left) if len(left) > 4 else np.min(left)
|
||||
right = np.median(right) if len(right) > 4 else np.max(right)
|
||||
left = np.mean(left) if len(left) > 4 else np.min(left)
|
||||
right = np.mean(right) if len(right) > 4 else np.max(right)
|
||||
for b in lts:
|
||||
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
|
||||
if b["x0"] > left:
|
||||
@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer):
|
||||
i = 0
|
||||
while i < len(boxes):
|
||||
if TableStructureRecognizer.is_caption(boxes[i]):
|
||||
if is_english: cap + " "
|
||||
cap += boxes[i]["text"]
|
||||
boxes.pop(i)
|
||||
i -= 1
|
||||
@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer):
|
||||
for i in range(clmno):
|
||||
if not tbl[r][i]:
|
||||
continue
|
||||
txt = "".join([a["text"].strip() for a in tbl[r][i]])
|
||||
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
|
||||
headers[r][i] = txt
|
||||
hdrset.add(txt)
|
||||
if all([not t for t in headers[r]]):
|
||||
|
||||
Reference in New Issue
Block a user