mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
remove unused codes, seperate layout detection out as a new api. Add new rag methed 'table' (#55)
This commit is contained in:
@ -1,7 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import fitz
|
||||
import requests
|
||||
import xgboost as xgb
|
||||
from io import BytesIO
|
||||
import torch
|
||||
@ -10,13 +13,14 @@ import pdfplumber
|
||||
import logging
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from api.db import ParserType
|
||||
from rag.nlp import huqie
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from rag.cv.table_recognize import TableTransformer
|
||||
from rag.cv.ppdetection import PPDet
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
logging.getLogger("pdfminer").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@ -25,8 +29,10 @@ class HuParser:
|
||||
from paddleocr import PaddleOCR
|
||||
logging.getLogger("ppocr").setLevel(logging.ERROR)
|
||||
self.ocr = PaddleOCR(use_angle_cls=False, lang="ch")
|
||||
self.layouter = PPDet("/data/newpeak/medical-gpt/res/ppdet")
|
||||
self.tbl_det = PPDet("/data/newpeak/medical-gpt/res/ppdet.tbl")
|
||||
if not hasattr(self, "model_speciess"):
|
||||
self.model_speciess = ParserType.GENERAL.value
|
||||
self.layouter = partial(self.__remote_call, self.model_speciess)
|
||||
self.tbl_det = partial(self.__remote_call, "table_component")
|
||||
|
||||
self.updown_cnt_mdl = xgb.Booster()
|
||||
if torch.cuda.is_available():
|
||||
@ -45,6 +51,38 @@ class HuParser:
|
||||
|
||||
"""
|
||||
|
||||
def __remote_call(self, species, images, thr=0.7):
|
||||
url = os.environ.get("INFINIFLOW_SERVER")
|
||||
if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
|
||||
token = os.environ.get("INFINIFLOW_TOKEN")
|
||||
if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
|
||||
|
||||
def convert_image_to_bytes(PILimage):
|
||||
image = BytesIO()
|
||||
PILimage.save(image, format='png')
|
||||
image.seek(0)
|
||||
return image.getvalue()
|
||||
|
||||
images = [convert_image_to_bytes(img) for img in images]
|
||||
|
||||
def remote_call():
|
||||
nonlocal images, thr
|
||||
res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr},
|
||||
headers={"Authorization": token}, timeout=len(images) * 10)
|
||||
res = res.json()
|
||||
if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
|
||||
return res["data"]
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
return remote_call()
|
||||
except RuntimeError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.error("layout_predict:"+str(e))
|
||||
return remote_call()
|
||||
|
||||
|
||||
def __char_width(self, c):
|
||||
return (c["x1"] - c["x0"]) // len(c["text"])
|
||||
|
||||
@ -344,7 +382,7 @@ class HuParser:
|
||||
return layouts
|
||||
|
||||
def __table_paddle(self, images):
|
||||
tbls = self.tbl_det([np.array(img) for img in images], thr=0.5)
|
||||
tbls = self.tbl_det(images, thr=0.5)
|
||||
res = []
|
||||
# align left&right for rows, align top&bottom for columns
|
||||
for tbl in tbls:
|
||||
@ -522,7 +560,7 @@ class HuParser:
|
||||
assert len(self.page_images) == len(self.boxes)
|
||||
# Tag layout type
|
||||
boxes = []
|
||||
layouts = self.layouter([np.array(img) for img in self.page_images])
|
||||
layouts = self.layouter(self.page_images)
|
||||
assert len(self.page_images) == len(layouts)
|
||||
for pn, lts in enumerate(layouts):
|
||||
bxs = self.boxes[pn]
|
||||
@ -1705,7 +1743,8 @@ class HuParser:
|
||||
self.__ocr_paddle(i + 1, img, chars, zoomin)
|
||||
|
||||
if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
|
||||
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices([b for bxs in self.boxes for b in bxs], k=30)]))
|
||||
bxes = [b for bxs in self.boxes for b in bxs]
|
||||
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
|
||||
|
||||
logging.info("Is it English:", self.is_english)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user