mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Format file format from Windows/dos to Unix (#1949)
### What problem does this PR solve? Related source file is in Windows/DOS format, they are format to Unix format. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -1,122 +1,122 @@
|
||||
English | [简体中文](./README_zh.md)
|
||||
|
||||
# *Deep*Doc
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Vision](#2)
|
||||
- [3. Parser](#3)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
|
||||
an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
|
||||
There are 2 parts in *Deep*Doc so far: vision and parser.
|
||||
You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py -h
|
||||
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './ocr_outputs'
|
||||
```
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py -h
|
||||
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './layouts_outputs'
|
||||
--threshold THRESHOLD
|
||||
A threshold to filter out detections. Default: 0.5
|
||||
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
|
||||
```
|
||||
|
||||
Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
|
||||
```bash
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Vision
|
||||
|
||||
We use vision information to resolve problems as human being.
|
||||
- OCR. Since a lot of documents presented as images or at least be able to transform to image,
|
||||
OCR is a very essential and fundamental or even universal solution for text extraction.
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
|
||||
txt files which contain the OCR text.
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
|
||||
</div>
|
||||
|
||||
- Layout recognition. Documents from different domain may have various layouts,
|
||||
like, newspaper, magazine, book and résumé are distinct in terms of layout.
|
||||
Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not,
|
||||
or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption.
|
||||
We have 10 basic layout components which covers most cases:
|
||||
- Text
|
||||
- Title
|
||||
- Figure
|
||||
- Figure caption
|
||||
- Table
|
||||
- Table caption
|
||||
- Header
|
||||
- Footer
|
||||
- Reference
|
||||
- Equation
|
||||
|
||||
Have a try on the following command to see the layout detection results.
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
||||
</div>
|
||||
|
||||
- Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
|
||||
And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
|
||||
Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
|
||||
We have five labels for TSR task:
|
||||
- Column
|
||||
- Row
|
||||
- Column header
|
||||
- Projected row header
|
||||
- Spanning cell
|
||||
|
||||
Have a try on the following command to see the layout detection results.
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
|
||||
</div>
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Parser
|
||||
|
||||
Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser.
|
||||
The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes:
|
||||
- Text chunks with their own positions in PDF(page number and rectangular positions).
|
||||
- Tables with cropped image from the PDF, and contents which has already translated into natural language sentences.
|
||||
- Figures with caption and text in the figures.
|
||||
|
||||
### Résumé
|
||||
|
||||
The résumé is a very complicated kind of document. A résumé which is composed of unstructured text
|
||||
with various layouts could be resolved into structured data composed of nearly a hundred of fields.
|
||||
We haven't opened the parser yet, as we open the processing method after parsing procedure.
|
||||
|
||||
English | [简体中文](./README_zh.md)
|
||||
|
||||
# *Deep*Doc
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Vision](#2)
|
||||
- [3. Parser](#3)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
|
||||
an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
|
||||
There are 2 parts in *Deep*Doc so far: vision and parser.
|
||||
You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py -h
|
||||
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './ocr_outputs'
|
||||
```
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py -h
|
||||
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './layouts_outputs'
|
||||
--threshold THRESHOLD
|
||||
A threshold to filter out detections. Default: 0.5
|
||||
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
|
||||
```
|
||||
|
||||
Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
|
||||
```bash
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Vision
|
||||
|
||||
We use vision information to resolve problems as human being.
|
||||
- OCR. Since a lot of documents presented as images or at least be able to transform to image,
|
||||
OCR is a very essential and fundamental or even universal solution for text extraction.
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
|
||||
txt files which contain the OCR text.
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
|
||||
</div>
|
||||
|
||||
- Layout recognition. Documents from different domain may have various layouts,
|
||||
like, newspaper, magazine, book and résumé are distinct in terms of layout.
|
||||
Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not,
|
||||
or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption.
|
||||
We have 10 basic layout components which covers most cases:
|
||||
- Text
|
||||
- Title
|
||||
- Figure
|
||||
- Figure caption
|
||||
- Table
|
||||
- Table caption
|
||||
- Header
|
||||
- Footer
|
||||
- Reference
|
||||
- Equation
|
||||
|
||||
Have a try on the following command to see the layout detection results.
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
||||
</div>
|
||||
|
||||
- Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
|
||||
And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
|
||||
Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
|
||||
We have five labels for TSR task:
|
||||
- Column
|
||||
- Row
|
||||
- Column header
|
||||
- Projected row header
|
||||
- Spanning cell
|
||||
|
||||
Have a try on the following command to see the layout detection results.
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
|
||||
</div>
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Parser
|
||||
|
||||
Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser.
|
||||
The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes:
|
||||
- Text chunks with their own positions in PDF(page number and rectangular positions).
|
||||
- Tables with cropped image from the PDF, and contents which has already translated into natural language sentences.
|
||||
- Figures with caption and text in the figures.
|
||||
|
||||
### Résumé
|
||||
|
||||
The résumé is a very complicated kind of document. A résumé which is composed of unstructured text
|
||||
with various layouts could be resolved into structured data composed of nearly a hundred of fields.
|
||||
We haven't opened the parser yet, as we open the processing method after parsing procedure.
|
||||
|
||||
|
||||
@ -1,61 +1,61 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from io import BytesIO
|
||||
from pptx import Presentation
|
||||
|
||||
|
||||
class RAGFlowPptParser(object):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __extract(self, shape):
|
||||
if shape.shape_type == 19:
|
||||
tb = shape.table
|
||||
rows = []
|
||||
for i in range(1, len(tb.rows)):
|
||||
rows.append("; ".join([tb.cell(
|
||||
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
|
||||
return "\n".join(rows)
|
||||
|
||||
if shape.has_text_frame:
|
||||
return shape.text_frame.text
|
||||
|
||||
if shape.shape_type == 6:
|
||||
texts = []
|
||||
for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
|
||||
t = self.__extract(p)
|
||||
if t:
|
||||
texts.append(t)
|
||||
return "\n".join(texts)
|
||||
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
ppt = Presentation(fnm) if isinstance(
|
||||
fnm, str) else Presentation(
|
||||
BytesIO(fnm))
|
||||
txts = []
|
||||
self.total_page = len(ppt.slides)
|
||||
for i, slide in enumerate(ppt.slides):
|
||||
if i < from_page:
|
||||
continue
|
||||
if i >= to_page:
|
||||
break
|
||||
texts = []
|
||||
for shape in sorted(
|
||||
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left)):
|
||||
txt = self.__extract(shape)
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
txts.append("\n".join(texts))
|
||||
|
||||
return txts
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from io import BytesIO
|
||||
from pptx import Presentation
|
||||
|
||||
|
||||
class RAGFlowPptParser(object):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __extract(self, shape):
|
||||
if shape.shape_type == 19:
|
||||
tb = shape.table
|
||||
rows = []
|
||||
for i in range(1, len(tb.rows)):
|
||||
rows.append("; ".join([tb.cell(
|
||||
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
|
||||
return "\n".join(rows)
|
||||
|
||||
if shape.has_text_frame:
|
||||
return shape.text_frame.text
|
||||
|
||||
if shape.shape_type == 6:
|
||||
texts = []
|
||||
for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
|
||||
t = self.__extract(p)
|
||||
if t:
|
||||
texts.append(t)
|
||||
return "\n".join(texts)
|
||||
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
ppt = Presentation(fnm) if isinstance(
|
||||
fnm, str) else Presentation(
|
||||
BytesIO(fnm))
|
||||
txts = []
|
||||
self.total_page = len(ppt.slides)
|
||||
for i, slide in enumerate(ppt.slides):
|
||||
if i < from_page:
|
||||
continue
|
||||
if i >= to_page:
|
||||
break
|
||||
texts = []
|
||||
for shape in sorted(
|
||||
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left)):
|
||||
txt = self.__extract(shape)
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
txts.append("\n".join(texts))
|
||||
|
||||
return txts
|
||||
|
||||
@ -1,65 +1,65 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
def refactor(cv):
|
||||
for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
|
||||
if n in cv and cv[n] is not None: del cv[n]
|
||||
cv["is_deleted"] = 0
|
||||
if "basic" not in cv: cv["basic"] = {}
|
||||
if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
|
||||
|
||||
for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
|
||||
if n not in cv or cv[n] is None: continue
|
||||
if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
|
||||
if type(cv[n]) != type([]):
|
||||
del cv[n]
|
||||
continue
|
||||
vv = []
|
||||
for v in cv[n]:
|
||||
if "external" in v and v["external"] is not None: del v["external"]
|
||||
vv.append(v)
|
||||
cv[n] = {str(i): vv[i] for i in range(len(vv))}
|
||||
|
||||
basics = [
|
||||
("basic_salary_month", "salary_month"),
|
||||
("expect_annual_salary_from", "expect_annual_salary"),
|
||||
]
|
||||
for n, t in basics:
|
||||
if cv["basic"].get(n):
|
||||
cv["basic"][t] = cv["basic"][n]
|
||||
del cv["basic"][n]
|
||||
|
||||
work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
|
||||
edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
|
||||
|
||||
if work:
|
||||
cv["basic"]["work_start_time"] = work[0].get("start_time", "")
|
||||
cv["basic"]["management_experience"] = 'Y' if any(
|
||||
[w.get("management_experience", '') == 'Y' for w in work]) else 'N'
|
||||
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
|
||||
|
||||
for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
|
||||
"corporation_type", "scale", "corporation_name"]:
|
||||
cv["basic"][n] = work[-1].get(n, "")
|
||||
|
||||
if edu:
|
||||
for n in ["school_name", "discipline_name"]:
|
||||
if n in edu[-1]: cv["basic"][n] = edu[-1][n]
|
||||
|
||||
cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if "contact" not in cv: cv["contact"] = {}
|
||||
if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
def refactor(cv):
|
||||
for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
|
||||
if n in cv and cv[n] is not None: del cv[n]
|
||||
cv["is_deleted"] = 0
|
||||
if "basic" not in cv: cv["basic"] = {}
|
||||
if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
|
||||
|
||||
for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
|
||||
if n not in cv or cv[n] is None: continue
|
||||
if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
|
||||
if type(cv[n]) != type([]):
|
||||
del cv[n]
|
||||
continue
|
||||
vv = []
|
||||
for v in cv[n]:
|
||||
if "external" in v and v["external"] is not None: del v["external"]
|
||||
vv.append(v)
|
||||
cv[n] = {str(i): vv[i] for i in range(len(vv))}
|
||||
|
||||
basics = [
|
||||
("basic_salary_month", "salary_month"),
|
||||
("expect_annual_salary_from", "expect_annual_salary"),
|
||||
]
|
||||
for n, t in basics:
|
||||
if cv["basic"].get(n):
|
||||
cv["basic"][t] = cv["basic"][n]
|
||||
del cv["basic"][n]
|
||||
|
||||
work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
|
||||
edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
|
||||
|
||||
if work:
|
||||
cv["basic"]["work_start_time"] = work[0].get("start_time", "")
|
||||
cv["basic"]["management_experience"] = 'Y' if any(
|
||||
[w.get("management_experience", '') == 'Y' for w in work]) else 'N'
|
||||
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
|
||||
|
||||
for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
|
||||
"corporation_type", "scale", "corporation_name"]:
|
||||
cv["basic"][n] = work[-1].get(n, "")
|
||||
|
||||
if edu:
|
||||
for n in ["school_name", "discipline_name"]:
|
||||
if n in edu[-1]: cv["basic"][n] = edu[-1][n]
|
||||
|
||||
cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if "contact" not in cv: cv["contact"] = {}
|
||||
if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
|
||||
return cv
|
||||
@ -1,4 +1,4 @@
|
||||
清华大学,2,985,清华
|
||||
清华大学,2,985,清华
|
||||
清华大学,2,985,Tsinghua University
|
||||
清华大学,2,985,THU
|
||||
北京大学,1,985,北大
|
||||
|
||||
|
@ -1,186 +1,186 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
from deepdoc.parser.resume.entities import degrees, regions, industries
|
||||
|
||||
FIELDS = [
|
||||
"address STRING",
|
||||
"annual_salary int",
|
||||
"annual_salary_from int",
|
||||
"annual_salary_to int",
|
||||
"birth STRING",
|
||||
"card STRING",
|
||||
"certificate_obj string",
|
||||
"city STRING",
|
||||
"corporation_id int",
|
||||
"corporation_name STRING",
|
||||
"corporation_type STRING",
|
||||
"degree STRING",
|
||||
"discipline_name STRING",
|
||||
"education_obj string",
|
||||
"email STRING",
|
||||
"expect_annual_salary int",
|
||||
"expect_city_names string",
|
||||
"expect_industry_name STRING",
|
||||
"expect_position_name STRING",
|
||||
"expect_salary_from int",
|
||||
"expect_salary_to int",
|
||||
"expect_type STRING",
|
||||
"gender STRING",
|
||||
"industry_name STRING",
|
||||
"industry_names STRING",
|
||||
"is_deleted STRING",
|
||||
"is_fertility STRING",
|
||||
"is_house STRING",
|
||||
"is_management_experience STRING",
|
||||
"is_marital STRING",
|
||||
"is_oversea STRING",
|
||||
"language_obj string",
|
||||
"name STRING",
|
||||
"nation STRING",
|
||||
"phone STRING",
|
||||
"political_status STRING",
|
||||
"position_name STRING",
|
||||
"project_obj string",
|
||||
"responsibilities string",
|
||||
"salary_month int",
|
||||
"scale STRING",
|
||||
"school_name STRING",
|
||||
"self_remark string",
|
||||
"skill_obj string",
|
||||
"title_name STRING",
|
||||
"tob_resume_id STRING",
|
||||
"updated_at Timestamp",
|
||||
"wechat STRING",
|
||||
"work_obj string",
|
||||
"work_experience int",
|
||||
"work_start_time BIGINT"
|
||||
]
|
||||
|
||||
def refactor(df):
|
||||
def deal_obj(obj, k, kk):
|
||||
if not isinstance(obj, type({})):
|
||||
return ""
|
||||
obj = obj.get(k, {})
|
||||
if not isinstance(obj, type({})):
|
||||
return ""
|
||||
return obj.get(kk, "")
|
||||
|
||||
def loadjson(line):
|
||||
try:
|
||||
return json.loads(line)
|
||||
except Exception as e:
|
||||
pass
|
||||
return {}
|
||||
|
||||
df["obj"] = df["resume_content"].map(lambda x: loadjson(x))
|
||||
df.fillna("", inplace=True)
|
||||
|
||||
clms = ["tob_resume_id", "updated_at"]
|
||||
|
||||
def extract(nms, cc=None):
|
||||
nonlocal clms
|
||||
clms.extend(nms)
|
||||
for c in nms:
|
||||
if cc:
|
||||
df[c] = df["obj"].map(lambda x: deal_obj(x, cc, c))
|
||||
else:
|
||||
df[c] = df["obj"].map(
|
||||
lambda x: json.dumps(
|
||||
x.get(
|
||||
c,
|
||||
{}),
|
||||
ensure_ascii=False) if isinstance(
|
||||
x,
|
||||
type(
|
||||
{})) and (
|
||||
isinstance(
|
||||
x.get(c),
|
||||
type(
|
||||
{})) or not x.get(c)) else str(x).replace(
|
||||
"None",
|
||||
""))
|
||||
|
||||
extract(["education", "work", "certificate", "project", "language",
|
||||
"skill"])
|
||||
extract(["wechat", "phone", "is_deleted",
|
||||
"name", "tel", "email"], "contact")
|
||||
extract(["nation", "expect_industry_name", "salary_month",
|
||||
"industry_ids", "is_house", "birth", "annual_salary_from",
|
||||
"annual_salary_to", "card",
|
||||
"expect_salary_to", "expect_salary_from",
|
||||
"expect_position_name", "gender", "city",
|
||||
"is_fertility", "expect_city_names",
|
||||
"political_status", "title_name", "expect_annual_salary",
|
||||
"industry_name", "address", "position_name", "school_name",
|
||||
"corporation_id",
|
||||
"is_oversea", "responsibilities",
|
||||
"work_start_time", "degree", "management_experience",
|
||||
"expect_type", "corporation_type", "scale", "corporation_name",
|
||||
"self_remark", "annual_salary", "work_experience",
|
||||
"discipline_name", "marital", "updated_at"], "basic")
|
||||
|
||||
df["degree"] = df["degree"].map(lambda x: degrees.get_name(x))
|
||||
df["address"] = df["address"].map(lambda x: " ".join(regions.get_names(x)))
|
||||
df["industry_names"] = df["industry_ids"].map(lambda x: " ".join([" ".join(industries.get_names(i)) for i in
|
||||
str(x).split(",")]))
|
||||
clms.append("industry_names")
|
||||
|
||||
def arr2str(a):
|
||||
if not a:
|
||||
return ""
|
||||
if isinstance(a, list):
|
||||
a = " ".join([str(i) for i in a])
|
||||
return str(a).replace(",", " ")
|
||||
|
||||
df["expect_industry_name"] = df["expect_industry_name"].map(
|
||||
lambda x: arr2str(x))
|
||||
df["gender"] = df["gender"].map(
|
||||
lambda x: "男" if x == 'M' else (
|
||||
"女" if x == 'F' else ""))
|
||||
for c in ["is_fertility", "is_oversea", "is_house",
|
||||
"management_experience", "marital"]:
|
||||
df[c] = df[c].map(
|
||||
lambda x: '是' if x == 'Y' else (
|
||||
'否' if x == 'N' else ""))
|
||||
df["is_management_experience"] = df["management_experience"]
|
||||
df["is_marital"] = df["marital"]
|
||||
clms.extend(["is_management_experience", "is_marital"])
|
||||
|
||||
df.fillna("", inplace=True)
|
||||
for i in range(len(df)):
|
||||
if not df.loc[i, "phone"].strip() and df.loc[i, "tel"].strip():
|
||||
df.loc[i, "phone"] = df.loc[i, "tel"].strip()
|
||||
|
||||
for n in ["industry_ids", "management_experience", "marital", "tel"]:
|
||||
for i in range(len(clms)):
|
||||
if clms[i] == n:
|
||||
del clms[i]
|
||||
break
|
||||
|
||||
clms = list(set(clms))
|
||||
|
||||
df = df.reindex(sorted(clms), axis=1)
|
||||
#print(json.dumps(list(df.columns.values)), "LLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLL")
|
||||
for c in clms:
|
||||
df[c] = df[c].map(
|
||||
lambda s: str(s).replace(
|
||||
"\t",
|
||||
" ").replace(
|
||||
"\n",
|
||||
"\\n").replace(
|
||||
"\r",
|
||||
"\\n"))
|
||||
# print(df.values.tolist())
|
||||
return dict(zip([n.split(" ")[0] for n in FIELDS], df.values.tolist()[0]))
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
from deepdoc.parser.resume.entities import degrees, regions, industries
|
||||
|
||||
FIELDS = [
|
||||
"address STRING",
|
||||
"annual_salary int",
|
||||
"annual_salary_from int",
|
||||
"annual_salary_to int",
|
||||
"birth STRING",
|
||||
"card STRING",
|
||||
"certificate_obj string",
|
||||
"city STRING",
|
||||
"corporation_id int",
|
||||
"corporation_name STRING",
|
||||
"corporation_type STRING",
|
||||
"degree STRING",
|
||||
"discipline_name STRING",
|
||||
"education_obj string",
|
||||
"email STRING",
|
||||
"expect_annual_salary int",
|
||||
"expect_city_names string",
|
||||
"expect_industry_name STRING",
|
||||
"expect_position_name STRING",
|
||||
"expect_salary_from int",
|
||||
"expect_salary_to int",
|
||||
"expect_type STRING",
|
||||
"gender STRING",
|
||||
"industry_name STRING",
|
||||
"industry_names STRING",
|
||||
"is_deleted STRING",
|
||||
"is_fertility STRING",
|
||||
"is_house STRING",
|
||||
"is_management_experience STRING",
|
||||
"is_marital STRING",
|
||||
"is_oversea STRING",
|
||||
"language_obj string",
|
||||
"name STRING",
|
||||
"nation STRING",
|
||||
"phone STRING",
|
||||
"political_status STRING",
|
||||
"position_name STRING",
|
||||
"project_obj string",
|
||||
"responsibilities string",
|
||||
"salary_month int",
|
||||
"scale STRING",
|
||||
"school_name STRING",
|
||||
"self_remark string",
|
||||
"skill_obj string",
|
||||
"title_name STRING",
|
||||
"tob_resume_id STRING",
|
||||
"updated_at Timestamp",
|
||||
"wechat STRING",
|
||||
"work_obj string",
|
||||
"work_experience int",
|
||||
"work_start_time BIGINT"
|
||||
]
|
||||
|
||||
def refactor(df):
|
||||
def deal_obj(obj, k, kk):
|
||||
if not isinstance(obj, type({})):
|
||||
return ""
|
||||
obj = obj.get(k, {})
|
||||
if not isinstance(obj, type({})):
|
||||
return ""
|
||||
return obj.get(kk, "")
|
||||
|
||||
def loadjson(line):
|
||||
try:
|
||||
return json.loads(line)
|
||||
except Exception as e:
|
||||
pass
|
||||
return {}
|
||||
|
||||
df["obj"] = df["resume_content"].map(lambda x: loadjson(x))
|
||||
df.fillna("", inplace=True)
|
||||
|
||||
clms = ["tob_resume_id", "updated_at"]
|
||||
|
||||
def extract(nms, cc=None):
|
||||
nonlocal clms
|
||||
clms.extend(nms)
|
||||
for c in nms:
|
||||
if cc:
|
||||
df[c] = df["obj"].map(lambda x: deal_obj(x, cc, c))
|
||||
else:
|
||||
df[c] = df["obj"].map(
|
||||
lambda x: json.dumps(
|
||||
x.get(
|
||||
c,
|
||||
{}),
|
||||
ensure_ascii=False) if isinstance(
|
||||
x,
|
||||
type(
|
||||
{})) and (
|
||||
isinstance(
|
||||
x.get(c),
|
||||
type(
|
||||
{})) or not x.get(c)) else str(x).replace(
|
||||
"None",
|
||||
""))
|
||||
|
||||
extract(["education", "work", "certificate", "project", "language",
|
||||
"skill"])
|
||||
extract(["wechat", "phone", "is_deleted",
|
||||
"name", "tel", "email"], "contact")
|
||||
extract(["nation", "expect_industry_name", "salary_month",
|
||||
"industry_ids", "is_house", "birth", "annual_salary_from",
|
||||
"annual_salary_to", "card",
|
||||
"expect_salary_to", "expect_salary_from",
|
||||
"expect_position_name", "gender", "city",
|
||||
"is_fertility", "expect_city_names",
|
||||
"political_status", "title_name", "expect_annual_salary",
|
||||
"industry_name", "address", "position_name", "school_name",
|
||||
"corporation_id",
|
||||
"is_oversea", "responsibilities",
|
||||
"work_start_time", "degree", "management_experience",
|
||||
"expect_type", "corporation_type", "scale", "corporation_name",
|
||||
"self_remark", "annual_salary", "work_experience",
|
||||
"discipline_name", "marital", "updated_at"], "basic")
|
||||
|
||||
df["degree"] = df["degree"].map(lambda x: degrees.get_name(x))
|
||||
df["address"] = df["address"].map(lambda x: " ".join(regions.get_names(x)))
|
||||
df["industry_names"] = df["industry_ids"].map(lambda x: " ".join([" ".join(industries.get_names(i)) for i in
|
||||
str(x).split(",")]))
|
||||
clms.append("industry_names")
|
||||
|
||||
def arr2str(a):
|
||||
if not a:
|
||||
return ""
|
||||
if isinstance(a, list):
|
||||
a = " ".join([str(i) for i in a])
|
||||
return str(a).replace(",", " ")
|
||||
|
||||
df["expect_industry_name"] = df["expect_industry_name"].map(
|
||||
lambda x: arr2str(x))
|
||||
df["gender"] = df["gender"].map(
|
||||
lambda x: "男" if x == 'M' else (
|
||||
"女" if x == 'F' else ""))
|
||||
for c in ["is_fertility", "is_oversea", "is_house",
|
||||
"management_experience", "marital"]:
|
||||
df[c] = df[c].map(
|
||||
lambda x: '是' if x == 'Y' else (
|
||||
'否' if x == 'N' else ""))
|
||||
df["is_management_experience"] = df["management_experience"]
|
||||
df["is_marital"] = df["marital"]
|
||||
clms.extend(["is_management_experience", "is_marital"])
|
||||
|
||||
df.fillna("", inplace=True)
|
||||
for i in range(len(df)):
|
||||
if not df.loc[i, "phone"].strip() and df.loc[i, "tel"].strip():
|
||||
df.loc[i, "phone"] = df.loc[i, "tel"].strip()
|
||||
|
||||
for n in ["industry_ids", "management_experience", "marital", "tel"]:
|
||||
for i in range(len(clms)):
|
||||
if clms[i] == n:
|
||||
del clms[i]
|
||||
break
|
||||
|
||||
clms = list(set(clms))
|
||||
|
||||
df = df.reindex(sorted(clms), axis=1)
|
||||
#print(json.dumps(list(df.columns.values)), "LLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLL")
|
||||
for c in clms:
|
||||
df[c] = df[c].map(
|
||||
lambda s: str(s).replace(
|
||||
"\t",
|
||||
" ").replace(
|
||||
"\n",
|
||||
"\\n").replace(
|
||||
"\r",
|
||||
"\\n"))
|
||||
# print(df.values.tolist())
|
||||
return dict(zip([n.split(" ")[0] for n in FIELDS], df.values.tolist()[0]))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,61 +1,61 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pdfplumber
|
||||
|
||||
from .ocr import OCR
|
||||
from .recognizer import Recognizer
|
||||
from .layout_recognizer import LayoutRecognizer
|
||||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
|
||||
|
||||
def init_in_out(args):
|
||||
from PIL import Image
|
||||
import os
|
||||
import traceback
|
||||
from api.utils.file_utils import traversal_files
|
||||
images = []
|
||||
outputs = []
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
def pdf_pages(fnm, zoomin=3):
|
||||
nonlocal outputs, images
|
||||
pdf = pdfplumber.open(fnm)
|
||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(pdf.pages)]
|
||||
|
||||
for i, page in enumerate(images):
|
||||
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
||||
|
||||
def images_and_outputs(fnm):
|
||||
nonlocal outputs, images
|
||||
if fnm.split(".")[-1].lower() == "pdf":
|
||||
pdf_pages(fnm)
|
||||
return
|
||||
try:
|
||||
images.append(Image.open(fnm))
|
||||
outputs.append(os.path.split(fnm)[-1])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
if os.path.isdir(args.inputs):
|
||||
for fnm in traversal_files(args.inputs):
|
||||
images_and_outputs(fnm)
|
||||
else:
|
||||
images_and_outputs(args.inputs)
|
||||
|
||||
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pdfplumber
|
||||
|
||||
from .ocr import OCR
|
||||
from .recognizer import Recognizer
|
||||
from .layout_recognizer import LayoutRecognizer
|
||||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
|
||||
|
||||
def init_in_out(args):
|
||||
from PIL import Image
|
||||
import os
|
||||
import traceback
|
||||
from api.utils.file_utils import traversal_files
|
||||
images = []
|
||||
outputs = []
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
def pdf_pages(fnm, zoomin=3):
|
||||
nonlocal outputs, images
|
||||
pdf = pdfplumber.open(fnm)
|
||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(pdf.pages)]
|
||||
|
||||
for i, page in enumerate(images):
|
||||
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
||||
|
||||
def images_and_outputs(fnm):
|
||||
nonlocal outputs, images
|
||||
if fnm.split(".")[-1].lower() == "pdf":
|
||||
pdf_pages(fnm)
|
||||
return
|
||||
try:
|
||||
images.append(Image.open(fnm))
|
||||
outputs.append(os.path.split(fnm)[-1])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
if os.path.isdir(args.inputs):
|
||||
for fnm in traversal_files(args.inputs):
|
||||
images_and_outputs(fnm)
|
||||
else:
|
||||
images_and_outputs(args.inputs)
|
||||
|
||||
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
|
||||
|
||||
return images, outputs
|
||||
@ -1,151 +1,151 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import Recognizer
|
||||
|
||||
|
||||
class LayoutRecognizer(Recognizer):
|
||||
labels = [
|
||||
"_background_",
|
||||
"Text",
|
||||
"Title",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Header",
|
||||
"Footer",
|
||||
"Reference",
|
||||
"Equation",
|
||||
]
|
||||
|
||||
def __init__(self, domain):
|
||||
try:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc")
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
|
||||
self.garbage_layouts = ["footer", "header", "reference"]
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3,
|
||||
thr=0.2, batch_size=16, drop=True):
|
||||
def __is_garbage(b):
|
||||
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
||||
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
||||
"(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
|
||||
"\\(cid *: *[0-9]+ *\\)"
|
||||
]
|
||||
return any([re.search(p, b["text"]) for p in patt])
|
||||
|
||||
layouts = super().__call__(image_list, thr, batch_size)
|
||||
# save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
|
||||
assert len(image_list) == len(ocr_res)
|
||||
# Tag layout type
|
||||
boxes = []
|
||||
assert len(image_list) == len(layouts)
|
||||
garbages = {}
|
||||
page_layout = []
|
||||
for pn, lts in enumerate(layouts):
|
||||
bxs = ocr_res[pn]
|
||||
lts = [{"type": b["type"],
|
||||
"score": float(b["score"]),
|
||||
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
|
||||
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
||||
"page_number": pn,
|
||||
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
|
||||
lts = self.sort_Y_firstly(lts, np.mean(
|
||||
[l["bottom"] - l["top"] for l in lts]) / 2)
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
# Tag layout type, layouts are ready
|
||||
def findLayout(ty):
|
||||
nonlocal bxs, lts, self
|
||||
lts_ = [lt for lt in lts if lt["type"] == ty]
|
||||
i = 0
|
||||
while i < len(bxs):
|
||||
if bxs[i].get("layout_type"):
|
||||
i += 1
|
||||
continue
|
||||
if __is_garbage(bxs[i]):
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
ii = self.find_overlapped_with_threashold(bxs[i], lts_,
|
||||
thr=0.4)
|
||||
if ii is None: # belong to nothing
|
||||
bxs[i]["layout_type"] = ""
|
||||
i += 1
|
||||
continue
|
||||
lts_[ii]["visited"] = True
|
||||
keep_feats = [
|
||||
lts_[
|
||||
ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
||||
lts_[
|
||||
ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
||||
]
|
||||
if drop and lts_[
|
||||
ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||
if lts_[ii]["type"] not in garbages:
|
||||
garbages[lts_[ii]["type"]] = []
|
||||
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
|
||||
ii]["type"] != "equation" else "figure"
|
||||
i += 1
|
||||
|
||||
for lt in ["footer", "header", "reference", "figure caption",
|
||||
"table caption", "title", "table", "text", "figure", "equation"]:
|
||||
findLayout(lt)
|
||||
|
||||
# add box to figure layouts which has not text box
|
||||
for i, lt in enumerate(
|
||||
[lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
||||
if lt.get("visited"):
|
||||
continue
|
||||
lt = deepcopy(lt)
|
||||
del lt["type"]
|
||||
lt["text"] = ""
|
||||
lt["layout_type"] = "figure"
|
||||
lt["layoutno"] = f"figure-{i}"
|
||||
bxs.append(lt)
|
||||
|
||||
boxes.extend(bxs)
|
||||
|
||||
ocr_res = boxes
|
||||
|
||||
garbag_set = set()
|
||||
for k in garbages.keys():
|
||||
garbages[k] = Counter(garbages[k])
|
||||
for g, c in garbages[k].items():
|
||||
if c > 1:
|
||||
garbag_set.add(g)
|
||||
|
||||
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
||||
return ocr_res, page_layout
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import Recognizer
|
||||
|
||||
|
||||
class LayoutRecognizer(Recognizer):
|
||||
labels = [
|
||||
"_background_",
|
||||
"Text",
|
||||
"Title",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Header",
|
||||
"Footer",
|
||||
"Reference",
|
||||
"Equation",
|
||||
]
|
||||
|
||||
def __init__(self, domain):
|
||||
try:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc")
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
|
||||
self.garbage_layouts = ["footer", "header", "reference"]
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3,
|
||||
thr=0.2, batch_size=16, drop=True):
|
||||
def __is_garbage(b):
|
||||
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
||||
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
||||
"(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
|
||||
"\\(cid *: *[0-9]+ *\\)"
|
||||
]
|
||||
return any([re.search(p, b["text"]) for p in patt])
|
||||
|
||||
layouts = super().__call__(image_list, thr, batch_size)
|
||||
# save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
|
||||
assert len(image_list) == len(ocr_res)
|
||||
# Tag layout type
|
||||
boxes = []
|
||||
assert len(image_list) == len(layouts)
|
||||
garbages = {}
|
||||
page_layout = []
|
||||
for pn, lts in enumerate(layouts):
|
||||
bxs = ocr_res[pn]
|
||||
lts = [{"type": b["type"],
|
||||
"score": float(b["score"]),
|
||||
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
|
||||
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
||||
"page_number": pn,
|
||||
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
|
||||
lts = self.sort_Y_firstly(lts, np.mean(
|
||||
[l["bottom"] - l["top"] for l in lts]) / 2)
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
# Tag layout type, layouts are ready
|
||||
def findLayout(ty):
|
||||
nonlocal bxs, lts, self
|
||||
lts_ = [lt for lt in lts if lt["type"] == ty]
|
||||
i = 0
|
||||
while i < len(bxs):
|
||||
if bxs[i].get("layout_type"):
|
||||
i += 1
|
||||
continue
|
||||
if __is_garbage(bxs[i]):
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
ii = self.find_overlapped_with_threashold(bxs[i], lts_,
|
||||
thr=0.4)
|
||||
if ii is None: # belong to nothing
|
||||
bxs[i]["layout_type"] = ""
|
||||
i += 1
|
||||
continue
|
||||
lts_[ii]["visited"] = True
|
||||
keep_feats = [
|
||||
lts_[
|
||||
ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
||||
lts_[
|
||||
ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
||||
]
|
||||
if drop and lts_[
|
||||
ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||
if lts_[ii]["type"] not in garbages:
|
||||
garbages[lts_[ii]["type"]] = []
|
||||
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
|
||||
ii]["type"] != "equation" else "figure"
|
||||
i += 1
|
||||
|
||||
for lt in ["footer", "header", "reference", "figure caption",
|
||||
"table caption", "title", "table", "text", "figure", "equation"]:
|
||||
findLayout(lt)
|
||||
|
||||
# add box to figure layouts which has not text box
|
||||
for i, lt in enumerate(
|
||||
[lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
||||
if lt.get("visited"):
|
||||
continue
|
||||
lt = deepcopy(lt)
|
||||
del lt["type"]
|
||||
lt["text"] = ""
|
||||
lt["layout_type"] = "figure"
|
||||
lt["layoutno"] = f"figure-{i}"
|
||||
bxs.append(lt)
|
||||
|
||||
boxes.extend(bxs)
|
||||
|
||||
ocr_res = boxes
|
||||
|
||||
garbag_set = set()
|
||||
for k in garbages.keys():
|
||||
garbages[k] = Counter(garbages[k])
|
||||
for g, c in garbages[k].items():
|
||||
if c > 1:
|
||||
garbag_set.add(g)
|
||||
|
||||
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
||||
return ocr_res, page_layout
|
||||
|
||||
13244
deepdoc/vision/ocr.res
13244
deepdoc/vision/ocr.res
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,366 +1,366 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = ['DBPostProcess', 'CTCLabelDecode']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
if module_name == "None":
|
||||
return
|
||||
if global_config is not None:
|
||||
config.update(global_config)
|
||||
assert module_name in support_dict, Exception(
|
||||
'post process only support {}'.format(support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
||||
|
||||
class DBPostProcess(object):
|
||||
"""
|
||||
The post process for Differentiable Binarization (DB).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
box_type='quad',
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
self.box_type = box_type
|
||||
assert score_mode in [
|
||||
"slow", "fast"
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
||||
self.dilation_kernel = None if not use_dilation else np.array(
|
||||
[[1, 1], [1, 1]])
|
||||
|
||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours[:self.max_candidates]:
|
||||
epsilon = 0.002 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
continue
|
||||
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
if points.shape[0] > 2:
|
||||
box = self.unclip(points, self.unclip_ratio)
|
||||
if len(box) > 1:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
box = box.reshape(-1, 2)
|
||||
|
||||
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box)
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(outs) == 3:
|
||||
img, contours, _ = outs[0], outs[1], outs[2]
|
||||
elif len(outs) == 2:
|
||||
contours, _ = outs[0], outs[1]
|
||||
|
||||
num_contours = min(len(contours), self.max_candidates)
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for index in range(num_contours):
|
||||
contour = contours[index]
|
||||
points, sside = self.get_mini_boxes(contour)
|
||||
if sside < self.min_size:
|
||||
continue
|
||||
points = np.array(points)
|
||||
if self.score_mode == "fast":
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
else:
|
||||
score = self.box_score_slow(pred, contour)
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
||||
box, sside = self.get_mini_boxes(box)
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
box = np.array(box)
|
||||
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.astype("int32"))
|
||||
scores.append(score)
|
||||
return np.array(boxes, dtype="int32"), scores
|
||||
|
||||
def unclip(self, box, unclip_ratio):
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = np.array(offset.Execute(distance))
|
||||
return expanded
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [
|
||||
points[index_1], points[index_2], points[index_3], points[index_4]
|
||||
]
|
||||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
'''
|
||||
box_score_fast: use bbox mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def box_score_slow(self, bitmap, contour):
|
||||
'''
|
||||
box_score_slow: use polyon mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
contour = contour.copy()
|
||||
contour = np.reshape(contour, (-1, 2))
|
||||
|
||||
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
|
||||
contour[:, 0] = contour[:, 0] - xmin
|
||||
contour[:, 1] = contour[:, 1] - ymin
|
||||
|
||||
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if not isinstance(pred, np.ndarray):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
if self.dilation_kernel is not None:
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
if self.box_type == 'poly':
|
||||
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
||||
mask, src_w, src_h)
|
||||
elif self.box_type == 'quad':
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||
src_w, src_h)
|
||||
else:
|
||||
raise ValueError(
|
||||
"box_type can only be one of ['quad', 'poly']")
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.reverse = False
|
||||
self.character_str = []
|
||||
|
||||
if character_dict_path is None:
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
if 'arabic' in character_dict_path:
|
||||
self.reverse = True
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def pred_reverse(self, pred):
|
||||
pred_re = []
|
||||
c_current = ''
|
||||
for c in pred:
|
||||
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
pred_re.append(c)
|
||||
c_current = ''
|
||||
else:
|
||||
c_current += c
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
|
||||
return ''.join(pred_re[::-1])
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||
batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [
|
||||
self.character[text_id]
|
||||
for text_id in text_index[batch_idx][selection]
|
||||
]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = ''.join(char_list)
|
||||
|
||||
if self.reverse: # for arabic rec
|
||||
text = self.pred_reverse(text)
|
||||
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(CTCLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
preds = preds[-1]
|
||||
if not isinstance(preds, np.ndarray):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = ['DBPostProcess', 'CTCLabelDecode']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
if module_name == "None":
|
||||
return
|
||||
if global_config is not None:
|
||||
config.update(global_config)
|
||||
assert module_name in support_dict, Exception(
|
||||
'post process only support {}'.format(support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
||||
|
||||
class DBPostProcess(object):
|
||||
"""
|
||||
The post process for Differentiable Binarization (DB).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
box_type='quad',
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
self.box_type = box_type
|
||||
assert score_mode in [
|
||||
"slow", "fast"
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
||||
self.dilation_kernel = None if not use_dilation else np.array(
|
||||
[[1, 1], [1, 1]])
|
||||
|
||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours[:self.max_candidates]:
|
||||
epsilon = 0.002 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
continue
|
||||
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
if points.shape[0] > 2:
|
||||
box = self.unclip(points, self.unclip_ratio)
|
||||
if len(box) > 1:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
box = box.reshape(-1, 2)
|
||||
|
||||
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box)
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(outs) == 3:
|
||||
img, contours, _ = outs[0], outs[1], outs[2]
|
||||
elif len(outs) == 2:
|
||||
contours, _ = outs[0], outs[1]
|
||||
|
||||
num_contours = min(len(contours), self.max_candidates)
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for index in range(num_contours):
|
||||
contour = contours[index]
|
||||
points, sside = self.get_mini_boxes(contour)
|
||||
if sside < self.min_size:
|
||||
continue
|
||||
points = np.array(points)
|
||||
if self.score_mode == "fast":
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
else:
|
||||
score = self.box_score_slow(pred, contour)
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
||||
box, sside = self.get_mini_boxes(box)
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
box = np.array(box)
|
||||
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.astype("int32"))
|
||||
scores.append(score)
|
||||
return np.array(boxes, dtype="int32"), scores
|
||||
|
||||
def unclip(self, box, unclip_ratio):
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = np.array(offset.Execute(distance))
|
||||
return expanded
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [
|
||||
points[index_1], points[index_2], points[index_3], points[index_4]
|
||||
]
|
||||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
'''
|
||||
box_score_fast: use bbox mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def box_score_slow(self, bitmap, contour):
|
||||
'''
|
||||
box_score_slow: use polyon mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
contour = contour.copy()
|
||||
contour = np.reshape(contour, (-1, 2))
|
||||
|
||||
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
|
||||
contour[:, 0] = contour[:, 0] - xmin
|
||||
contour[:, 1] = contour[:, 1] - ymin
|
||||
|
||||
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if not isinstance(pred, np.ndarray):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
if self.dilation_kernel is not None:
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
if self.box_type == 'poly':
|
||||
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
||||
mask, src_w, src_h)
|
||||
elif self.box_type == 'quad':
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||
src_w, src_h)
|
||||
else:
|
||||
raise ValueError(
|
||||
"box_type can only be one of ['quad', 'poly']")
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.reverse = False
|
||||
self.character_str = []
|
||||
|
||||
if character_dict_path is None:
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
if 'arabic' in character_dict_path:
|
||||
self.reverse = True
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def pred_reverse(self, pred):
|
||||
pred_re = []
|
||||
c_current = ''
|
||||
for c in pred:
|
||||
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
pred_re.append(c)
|
||||
c_current = ''
|
||||
else:
|
||||
c_current += c
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
|
||||
return ''.join(pred_re[::-1])
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||
batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [
|
||||
self.character[text_id]
|
||||
for text_id in text_index[batch_idx][selection]
|
||||
]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = ''.join(char_list)
|
||||
|
||||
if self.reverse: # for arabic rec
|
||||
text = self.pred_reverse(text)
|
||||
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(CTCLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
preds = preds[-1]
|
||||
if not isinstance(preds, np.ndarray):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
||||
|
||||
@ -1,452 +1,452 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .operators import *
|
||||
|
||||
|
||||
class Recognizer(object):
|
||||
def __init__(self, label_list, task_name, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc")
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
if not os.path.exists(model_file_path):
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
else:
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(
|
||||
model_file_path))
|
||||
if False and ort.get_device() == "GPU":
|
||||
options = ort.SessionOptions()
|
||||
options.enable_cpu_mem_arena = False
|
||||
self.ort_sess = ort.InferenceSession(model_file_path, options=options, providers=[('CUDAExecutionProvider')])
|
||||
else:
|
||||
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
|
||||
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
|
||||
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
|
||||
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
|
||||
self.label_list = label_list
|
||||
|
||||
@staticmethod
|
||||
def sort_Y_firstly(arr, threashold):
|
||||
# sort using y1 first and then x1
|
||||
arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
|
||||
and arr[j + 1]["x0"] < arr[j]["x0"]:
|
||||
tmp = deepcopy(arr[j])
|
||||
arr[j] = deepcopy(arr[j + 1])
|
||||
arr[j + 1] = deepcopy(tmp)
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_X_firstly(arr, threashold, copy=True):
|
||||
# sort using y1 first and then x1
|
||||
arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
|
||||
and arr[j + 1]["top"] < arr[j]["top"]:
|
||||
tmp = deepcopy(arr[j]) if copy else arr[j]
|
||||
arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
|
||||
arr[j + 1] = deepcopy(tmp) if copy else tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_C_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
||||
arr = Recognizer.sort_X_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if "C" not in arr[j] or "C" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["C"] < arr[j]["C"] \
|
||||
or (
|
||||
arr[j + 1]["C"] == arr[j]["C"]
|
||||
and arr[j + 1]["top"] < arr[j]["top"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
|
||||
|
||||
@staticmethod
|
||||
def sort_R_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
||||
arr = Recognizer.sort_Y_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if "R" not in arr[j] or "R" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["R"] < arr[j]["R"] \
|
||||
or (
|
||||
arr[j + 1]["R"] == arr[j]["R"]
|
||||
and arr[j + 1]["x0"] < arr[j]["x0"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def overlapped_area(a, b, ratio=True):
|
||||
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
|
||||
if b["x0"] > x1 or b["x1"] < x0:
|
||||
return 0
|
||||
if b["bottom"] < tp or b["top"] > btm:
|
||||
return 0
|
||||
x0_ = max(b["x0"], x0)
|
||||
x1_ = min(b["x1"], x1)
|
||||
assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
tp_ = max(b["top"], tp)
|
||||
btm_ = min(b["bottom"], btm)
|
||||
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
||||
x0 != 0 and btm - tp != 0 else 0
|
||||
if ov > 0 and ratio:
|
||||
ov /= (x1 - x0) * (btm - tp)
|
||||
return ov
|
||||
|
||||
@staticmethod
|
||||
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
|
||||
def notOverlapped(a, b):
|
||||
return any([a["x1"] < b["x0"],
|
||||
a["x0"] > b["x1"],
|
||||
a["bottom"] < b["top"],
|
||||
a["top"] > b["bottom"]])
|
||||
|
||||
i = 0
|
||||
while i + 1 < len(layouts):
|
||||
j = i + 1
|
||||
while j < min(i + far, len(layouts)) \
|
||||
and (layouts[i].get("type", "") != layouts[j].get("type", "")
|
||||
or notOverlapped(layouts[i], layouts[j])):
|
||||
j += 1
|
||||
if j >= min(i + far, len(layouts)):
|
||||
i += 1
|
||||
continue
|
||||
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
|
||||
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if layouts[i].get("score") and layouts[j].get("score"):
|
||||
if layouts[i]["score"] > layouts[j]["score"]:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
continue
|
||||
|
||||
area_i, area_i_1 = 0, 0
|
||||
for b in boxes:
|
||||
if not notOverlapped(b, layouts[i]):
|
||||
area_i += Recognizer.overlapped_area(b, layouts[i], False)
|
||||
if not notOverlapped(b, layouts[j]):
|
||||
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
|
||||
|
||||
if area_i > area_i_1:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
|
||||
return layouts
|
||||
|
||||
def create_inputs(self, imgs, im_info):
|
||||
"""generate input for different model type
|
||||
Args:
|
||||
imgs (list(numpy)): list of images (np.ndarray)
|
||||
im_info (list(dict)): list of image info
|
||||
Returns:
|
||||
inputs (dict): input of model
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
im_shape = []
|
||||
scale_factor = []
|
||||
if len(imgs) == 1:
|
||||
inputs['image'] = np.array((imgs[0],)).astype('float32')
|
||||
inputs['im_shape'] = np.array(
|
||||
(im_info[0]['im_shape'],)).astype('float32')
|
||||
inputs['scale_factor'] = np.array(
|
||||
(im_info[0]['scale_factor'],)).astype('float32')
|
||||
return inputs
|
||||
|
||||
for e in im_info:
|
||||
im_shape.append(np.array((e['im_shape'],)).astype('float32'))
|
||||
scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
|
||||
|
||||
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
|
||||
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
|
||||
|
||||
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
|
||||
max_shape_h = max([e[0] for e in imgs_shape])
|
||||
max_shape_w = max([e[1] for e in imgs_shape])
|
||||
padding_imgs = []
|
||||
for img in imgs:
|
||||
im_c, im_h, im_w = img.shape[:]
|
||||
padding_im = np.zeros(
|
||||
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = img
|
||||
padding_imgs.append(padding_im)
|
||||
inputs['image'] = np.stack(padding_imgs, axis=0)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped(box, boxes_sorted_by_y, naive=False):
|
||||
if not boxes_sorted_by_y:
|
||||
return
|
||||
bxs = boxes_sorted_by_y
|
||||
s, e, ii = 0, len(bxs), 0
|
||||
while s < e and not naive:
|
||||
ii = (e + s) // 2
|
||||
pv = bxs[ii]
|
||||
if box["bottom"] < pv["top"]:
|
||||
e = ii
|
||||
continue
|
||||
if box["top"] > pv["bottom"]:
|
||||
s = ii + 1
|
||||
continue
|
||||
break
|
||||
while s < ii:
|
||||
if box["top"] > bxs[s]["bottom"]:
|
||||
s += 1
|
||||
break
|
||||
while e - 1 > ii:
|
||||
if box["bottom"] < bxs[e - 1]["top"]:
|
||||
e -= 1
|
||||
break
|
||||
|
||||
max_overlaped_i, max_overlaped = None, 0
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(bxs[i], box)
|
||||
if ov <= max_overlaped:
|
||||
continue
|
||||
max_overlaped_i = i
|
||||
max_overlaped = ov
|
||||
|
||||
return max_overlaped_i
|
||||
|
||||
@staticmethod
|
||||
def find_horizontally_tightest_fit(box, boxes):
|
||||
if not boxes:
|
||||
return
|
||||
min_dis, min_i = 1000000, None
|
||||
for i,b in enumerate(boxes):
|
||||
if box.get("layoutno", "0") != b.get("layoutno", "0"): continue
|
||||
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
||||
if dis < min_dis:
|
||||
min_i = i
|
||||
min_dis = dis
|
||||
return min_i
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
||||
if not boxes:
|
||||
return
|
||||
max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0
|
||||
s, e = 0, len(boxes)
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(box, boxes[i])
|
||||
_ov = Recognizer.overlapped_area(boxes[i], box)
|
||||
if (ov, _ov) < (max_overlapped, _max_overlapped):
|
||||
continue
|
||||
max_overlapped_i = i
|
||||
max_overlapped = ov
|
||||
_max_overlapped = _ov
|
||||
|
||||
return max_overlapped_i
|
||||
|
||||
def preprocess(self, image_list):
|
||||
inputs = []
|
||||
if "scale_factor" in self.input_names:
|
||||
preprocess_ops = []
|
||||
for op_info in [
|
||||
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
|
||||
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
|
||||
{'type': 'Permute'},
|
||||
{'stride': 32, 'type': 'PadStride'}
|
||||
]:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop('type')
|
||||
preprocess_ops.append(eval(op_type)(**new_op_info))
|
||||
|
||||
for im_path in image_list:
|
||||
im, im_info = preprocess(im_path, preprocess_ops)
|
||||
inputs.append({"image": np.array((im,)).astype('float32'),
|
||||
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
|
||||
else:
|
||||
hh, ww = self.input_shape
|
||||
for img in image_list:
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
|
||||
# Scale input pixel values to 0 to 1
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img[np.newaxis, :, :, :].astype(np.float32)
|
||||
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
|
||||
return inputs
|
||||
|
||||
def postprocess(self, boxes, inputs, thr):
|
||||
if "scale_factor" in self.input_names:
|
||||
bb = []
|
||||
for b in boxes:
|
||||
clsid, bbox, score = int(b[0]), b[2:], b[1]
|
||||
if score < thr:
|
||||
continue
|
||||
if clsid >= len(self.label_list):
|
||||
continue
|
||||
bb.append({
|
||||
"type": self.label_list[clsid].lower(),
|
||||
"bbox": [float(t) for t in bbox.tolist()],
|
||||
"score": float(score)
|
||||
})
|
||||
return bb
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# [x, y, w, h] to [x1, y1, x2, y2]
|
||||
y = np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2
|
||||
y[:, 2] = x[:, 0] + x[:, 2] / 2
|
||||
y[:, 3] = x[:, 1] + x[:, 3] / 2
|
||||
return y
|
||||
|
||||
def compute_iou(box, boxes):
|
||||
# Compute xmin, ymin, xmax, ymax for both boxes
|
||||
xmin = np.maximum(box[0], boxes[:, 0])
|
||||
ymin = np.maximum(box[1], boxes[:, 1])
|
||||
xmax = np.minimum(box[2], boxes[:, 2])
|
||||
ymax = np.minimum(box[3], boxes[:, 3])
|
||||
|
||||
# Compute intersection area
|
||||
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
|
||||
|
||||
# Compute union area
|
||||
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
||||
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
union_area = box_area + boxes_area - intersection_area
|
||||
|
||||
# Compute IoU
|
||||
iou = intersection_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
def iou_filter(boxes, scores, iou_threshold):
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
keep_boxes = []
|
||||
while sorted_indices.size > 0:
|
||||
# Pick the last box
|
||||
box_id = sorted_indices[0]
|
||||
keep_boxes.append(box_id)
|
||||
|
||||
# Compute IoU of the picked box with the rest
|
||||
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
|
||||
|
||||
# Remove boxes with IoU over the threshold
|
||||
keep_indices = np.where(ious < iou_threshold)[0]
|
||||
|
||||
# print(keep_indices.shape, sorted_indices.shape)
|
||||
sorted_indices = sorted_indices[keep_indices + 1]
|
||||
|
||||
return keep_boxes
|
||||
|
||||
boxes = np.squeeze(boxes).T
|
||||
# Filter out object confidence scores below threshold
|
||||
scores = np.max(boxes[:, 4:], axis=1)
|
||||
boxes = boxes[scores > thr, :]
|
||||
scores = scores[scores > thr]
|
||||
if len(boxes) == 0: return []
|
||||
|
||||
# Get the class with the highest confidence
|
||||
class_ids = np.argmax(boxes[:, 4:], axis=1)
|
||||
boxes = boxes[:, :4]
|
||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
|
||||
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
||||
boxes = xywh2xyxy(boxes)
|
||||
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
indices = []
|
||||
for class_id in unique_class_ids:
|
||||
class_indices = np.where(class_ids == class_id)[0]
|
||||
class_boxes = boxes[class_indices, :]
|
||||
class_scores = scores[class_indices]
|
||||
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
|
||||
indices.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return [{
|
||||
"type": self.label_list[class_ids[i]].lower(),
|
||||
"bbox": [float(t) for t in boxes[i].tolist()],
|
||||
"score": float(scores[i])
|
||||
} for i in indices]
|
||||
|
||||
def __call__(self, image_list, thr=0.7, batch_size=16):
|
||||
res = []
|
||||
imgs = []
|
||||
for i in range(len(image_list)):
|
||||
if not isinstance(image_list[i], np.ndarray):
|
||||
imgs.append(np.array(image_list[i]))
|
||||
else: imgs.append(image_list[i])
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
||||
for i in range(batch_loop_cnt):
|
||||
start_index = i * batch_size
|
||||
end_index = min((i + 1) * batch_size, len(imgs))
|
||||
batch_image_list = imgs[start_index:end_index]
|
||||
inputs = self.preprocess(batch_image_list)
|
||||
print("preprocess")
|
||||
for ins in inputs:
|
||||
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr)
|
||||
res.append(bb)
|
||||
|
||||
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .operators import *
|
||||
|
||||
|
||||
class Recognizer(object):
|
||||
def __init__(self, label_list, task_name, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc")
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
if not os.path.exists(model_file_path):
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
else:
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(
|
||||
model_file_path))
|
||||
if False and ort.get_device() == "GPU":
|
||||
options = ort.SessionOptions()
|
||||
options.enable_cpu_mem_arena = False
|
||||
self.ort_sess = ort.InferenceSession(model_file_path, options=options, providers=[('CUDAExecutionProvider')])
|
||||
else:
|
||||
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
|
||||
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
|
||||
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
|
||||
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
|
||||
self.label_list = label_list
|
||||
|
||||
@staticmethod
|
||||
def sort_Y_firstly(arr, threashold):
|
||||
# sort using y1 first and then x1
|
||||
arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
|
||||
and arr[j + 1]["x0"] < arr[j]["x0"]:
|
||||
tmp = deepcopy(arr[j])
|
||||
arr[j] = deepcopy(arr[j + 1])
|
||||
arr[j + 1] = deepcopy(tmp)
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_X_firstly(arr, threashold, copy=True):
|
||||
# sort using y1 first and then x1
|
||||
arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
|
||||
and arr[j + 1]["top"] < arr[j]["top"]:
|
||||
tmp = deepcopy(arr[j]) if copy else arr[j]
|
||||
arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
|
||||
arr[j + 1] = deepcopy(tmp) if copy else tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_C_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
||||
arr = Recognizer.sort_X_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if "C" not in arr[j] or "C" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["C"] < arr[j]["C"] \
|
||||
or (
|
||||
arr[j + 1]["C"] == arr[j]["C"]
|
||||
and arr[j + 1]["top"] < arr[j]["top"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
|
||||
|
||||
@staticmethod
|
||||
def sort_R_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
||||
arr = Recognizer.sort_Y_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if "R" not in arr[j] or "R" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["R"] < arr[j]["R"] \
|
||||
or (
|
||||
arr[j + 1]["R"] == arr[j]["R"]
|
||||
and arr[j + 1]["x0"] < arr[j]["x0"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def overlapped_area(a, b, ratio=True):
|
||||
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
|
||||
if b["x0"] > x1 or b["x1"] < x0:
|
||||
return 0
|
||||
if b["bottom"] < tp or b["top"] > btm:
|
||||
return 0
|
||||
x0_ = max(b["x0"], x0)
|
||||
x1_ = min(b["x1"], x1)
|
||||
assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
tp_ = max(b["top"], tp)
|
||||
btm_ = min(b["bottom"], btm)
|
||||
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
||||
x0 != 0 and btm - tp != 0 else 0
|
||||
if ov > 0 and ratio:
|
||||
ov /= (x1 - x0) * (btm - tp)
|
||||
return ov
|
||||
|
||||
@staticmethod
|
||||
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
|
||||
def notOverlapped(a, b):
|
||||
return any([a["x1"] < b["x0"],
|
||||
a["x0"] > b["x1"],
|
||||
a["bottom"] < b["top"],
|
||||
a["top"] > b["bottom"]])
|
||||
|
||||
i = 0
|
||||
while i + 1 < len(layouts):
|
||||
j = i + 1
|
||||
while j < min(i + far, len(layouts)) \
|
||||
and (layouts[i].get("type", "") != layouts[j].get("type", "")
|
||||
or notOverlapped(layouts[i], layouts[j])):
|
||||
j += 1
|
||||
if j >= min(i + far, len(layouts)):
|
||||
i += 1
|
||||
continue
|
||||
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
|
||||
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if layouts[i].get("score") and layouts[j].get("score"):
|
||||
if layouts[i]["score"] > layouts[j]["score"]:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
continue
|
||||
|
||||
area_i, area_i_1 = 0, 0
|
||||
for b in boxes:
|
||||
if not notOverlapped(b, layouts[i]):
|
||||
area_i += Recognizer.overlapped_area(b, layouts[i], False)
|
||||
if not notOverlapped(b, layouts[j]):
|
||||
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
|
||||
|
||||
if area_i > area_i_1:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
|
||||
return layouts
|
||||
|
||||
def create_inputs(self, imgs, im_info):
|
||||
"""generate input for different model type
|
||||
Args:
|
||||
imgs (list(numpy)): list of images (np.ndarray)
|
||||
im_info (list(dict)): list of image info
|
||||
Returns:
|
||||
inputs (dict): input of model
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
im_shape = []
|
||||
scale_factor = []
|
||||
if len(imgs) == 1:
|
||||
inputs['image'] = np.array((imgs[0],)).astype('float32')
|
||||
inputs['im_shape'] = np.array(
|
||||
(im_info[0]['im_shape'],)).astype('float32')
|
||||
inputs['scale_factor'] = np.array(
|
||||
(im_info[0]['scale_factor'],)).astype('float32')
|
||||
return inputs
|
||||
|
||||
for e in im_info:
|
||||
im_shape.append(np.array((e['im_shape'],)).astype('float32'))
|
||||
scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
|
||||
|
||||
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
|
||||
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
|
||||
|
||||
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
|
||||
max_shape_h = max([e[0] for e in imgs_shape])
|
||||
max_shape_w = max([e[1] for e in imgs_shape])
|
||||
padding_imgs = []
|
||||
for img in imgs:
|
||||
im_c, im_h, im_w = img.shape[:]
|
||||
padding_im = np.zeros(
|
||||
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = img
|
||||
padding_imgs.append(padding_im)
|
||||
inputs['image'] = np.stack(padding_imgs, axis=0)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped(box, boxes_sorted_by_y, naive=False):
|
||||
if not boxes_sorted_by_y:
|
||||
return
|
||||
bxs = boxes_sorted_by_y
|
||||
s, e, ii = 0, len(bxs), 0
|
||||
while s < e and not naive:
|
||||
ii = (e + s) // 2
|
||||
pv = bxs[ii]
|
||||
if box["bottom"] < pv["top"]:
|
||||
e = ii
|
||||
continue
|
||||
if box["top"] > pv["bottom"]:
|
||||
s = ii + 1
|
||||
continue
|
||||
break
|
||||
while s < ii:
|
||||
if box["top"] > bxs[s]["bottom"]:
|
||||
s += 1
|
||||
break
|
||||
while e - 1 > ii:
|
||||
if box["bottom"] < bxs[e - 1]["top"]:
|
||||
e -= 1
|
||||
break
|
||||
|
||||
max_overlaped_i, max_overlaped = None, 0
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(bxs[i], box)
|
||||
if ov <= max_overlaped:
|
||||
continue
|
||||
max_overlaped_i = i
|
||||
max_overlaped = ov
|
||||
|
||||
return max_overlaped_i
|
||||
|
||||
@staticmethod
|
||||
def find_horizontally_tightest_fit(box, boxes):
|
||||
if not boxes:
|
||||
return
|
||||
min_dis, min_i = 1000000, None
|
||||
for i,b in enumerate(boxes):
|
||||
if box.get("layoutno", "0") != b.get("layoutno", "0"): continue
|
||||
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
||||
if dis < min_dis:
|
||||
min_i = i
|
||||
min_dis = dis
|
||||
return min_i
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
||||
if not boxes:
|
||||
return
|
||||
max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0
|
||||
s, e = 0, len(boxes)
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(box, boxes[i])
|
||||
_ov = Recognizer.overlapped_area(boxes[i], box)
|
||||
if (ov, _ov) < (max_overlapped, _max_overlapped):
|
||||
continue
|
||||
max_overlapped_i = i
|
||||
max_overlapped = ov
|
||||
_max_overlapped = _ov
|
||||
|
||||
return max_overlapped_i
|
||||
|
||||
def preprocess(self, image_list):
|
||||
inputs = []
|
||||
if "scale_factor" in self.input_names:
|
||||
preprocess_ops = []
|
||||
for op_info in [
|
||||
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
|
||||
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
|
||||
{'type': 'Permute'},
|
||||
{'stride': 32, 'type': 'PadStride'}
|
||||
]:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop('type')
|
||||
preprocess_ops.append(eval(op_type)(**new_op_info))
|
||||
|
||||
for im_path in image_list:
|
||||
im, im_info = preprocess(im_path, preprocess_ops)
|
||||
inputs.append({"image": np.array((im,)).astype('float32'),
|
||||
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
|
||||
else:
|
||||
hh, ww = self.input_shape
|
||||
for img in image_list:
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
|
||||
# Scale input pixel values to 0 to 1
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img[np.newaxis, :, :, :].astype(np.float32)
|
||||
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
|
||||
return inputs
|
||||
|
||||
def postprocess(self, boxes, inputs, thr):
|
||||
if "scale_factor" in self.input_names:
|
||||
bb = []
|
||||
for b in boxes:
|
||||
clsid, bbox, score = int(b[0]), b[2:], b[1]
|
||||
if score < thr:
|
||||
continue
|
||||
if clsid >= len(self.label_list):
|
||||
continue
|
||||
bb.append({
|
||||
"type": self.label_list[clsid].lower(),
|
||||
"bbox": [float(t) for t in bbox.tolist()],
|
||||
"score": float(score)
|
||||
})
|
||||
return bb
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# [x, y, w, h] to [x1, y1, x2, y2]
|
||||
y = np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2
|
||||
y[:, 2] = x[:, 0] + x[:, 2] / 2
|
||||
y[:, 3] = x[:, 1] + x[:, 3] / 2
|
||||
return y
|
||||
|
||||
def compute_iou(box, boxes):
|
||||
# Compute xmin, ymin, xmax, ymax for both boxes
|
||||
xmin = np.maximum(box[0], boxes[:, 0])
|
||||
ymin = np.maximum(box[1], boxes[:, 1])
|
||||
xmax = np.minimum(box[2], boxes[:, 2])
|
||||
ymax = np.minimum(box[3], boxes[:, 3])
|
||||
|
||||
# Compute intersection area
|
||||
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
|
||||
|
||||
# Compute union area
|
||||
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
||||
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
union_area = box_area + boxes_area - intersection_area
|
||||
|
||||
# Compute IoU
|
||||
iou = intersection_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
def iou_filter(boxes, scores, iou_threshold):
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
keep_boxes = []
|
||||
while sorted_indices.size > 0:
|
||||
# Pick the last box
|
||||
box_id = sorted_indices[0]
|
||||
keep_boxes.append(box_id)
|
||||
|
||||
# Compute IoU of the picked box with the rest
|
||||
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
|
||||
|
||||
# Remove boxes with IoU over the threshold
|
||||
keep_indices = np.where(ious < iou_threshold)[0]
|
||||
|
||||
# print(keep_indices.shape, sorted_indices.shape)
|
||||
sorted_indices = sorted_indices[keep_indices + 1]
|
||||
|
||||
return keep_boxes
|
||||
|
||||
boxes = np.squeeze(boxes).T
|
||||
# Filter out object confidence scores below threshold
|
||||
scores = np.max(boxes[:, 4:], axis=1)
|
||||
boxes = boxes[scores > thr, :]
|
||||
scores = scores[scores > thr]
|
||||
if len(boxes) == 0: return []
|
||||
|
||||
# Get the class with the highest confidence
|
||||
class_ids = np.argmax(boxes[:, 4:], axis=1)
|
||||
boxes = boxes[:, :4]
|
||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
|
||||
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
||||
boxes = xywh2xyxy(boxes)
|
||||
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
indices = []
|
||||
for class_id in unique_class_ids:
|
||||
class_indices = np.where(class_ids == class_id)[0]
|
||||
class_boxes = boxes[class_indices, :]
|
||||
class_scores = scores[class_indices]
|
||||
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
|
||||
indices.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return [{
|
||||
"type": self.label_list[class_ids[i]].lower(),
|
||||
"bbox": [float(t) for t in boxes[i].tolist()],
|
||||
"score": float(scores[i])
|
||||
} for i in indices]
|
||||
|
||||
def __call__(self, image_list, thr=0.7, batch_size=16):
|
||||
res = []
|
||||
imgs = []
|
||||
for i in range(len(image_list)):
|
||||
if not isinstance(image_list[i], np.ndarray):
|
||||
imgs.append(np.array(image_list[i]))
|
||||
else: imgs.append(image_list[i])
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
||||
for i in range(batch_loop_cnt):
|
||||
start_index = i * batch_size
|
||||
end_index = min((i + 1) * batch_size, len(imgs))
|
||||
batch_image_list = imgs[start_index:end_index]
|
||||
inputs = self.preprocess(batch_image_list)
|
||||
print("preprocess")
|
||||
for ins in inputs:
|
||||
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr)
|
||||
res.append(bb)
|
||||
|
||||
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,83 +1,83 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import PIL
|
||||
from PIL import ImageDraw
|
||||
|
||||
|
||||
def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
for idx, im in enumerate(image_list):
|
||||
im = draw_box(im, results[idx], labels, threshold=threshold)
|
||||
|
||||
out_path = os.path.join(output_dir, f"{idx}.jpg")
|
||||
im.save(out_path, quality=95)
|
||||
print("save result to: " + out_path)
|
||||
|
||||
|
||||
def draw_box(im, result, lables, threshold=0.5):
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
color_list = get_color_map_list(len(lables))
|
||||
clsid2color = {n.lower():color_list[i] for i,n in enumerate(lables)}
|
||||
result = [r for r in result if r["score"] >= threshold]
|
||||
|
||||
for dt in result:
|
||||
color = tuple(clsid2color[dt["type"]])
|
||||
xmin, ymin, xmax, ymax = dt["bbox"]
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=draw_thickness,
|
||||
fill=color)
|
||||
|
||||
# draw label
|
||||
text = "{} {:.4f}".format(dt["type"], dt["score"])
|
||||
tw, th = imagedraw_textsize_c(draw, text)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
|
||||
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
|
||||
return im
|
||||
|
||||
|
||||
def get_color_map_list(num_classes):
|
||||
"""
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
Returns:
|
||||
color_map (list): RGB color list
|
||||
"""
|
||||
color_map = num_classes * [0, 0, 0]
|
||||
for i in range(0, num_classes):
|
||||
j = 0
|
||||
lab = i
|
||||
while lab:
|
||||
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
|
||||
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
|
||||
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
|
||||
j += 1
|
||||
lab >>= 3
|
||||
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
|
||||
return color_map
|
||||
|
||||
|
||||
def imagedraw_textsize_c(draw, text):
|
||||
if int(PIL.__version__.split('.')[0]) < 10:
|
||||
tw, th = draw.textsize(text)
|
||||
else:
|
||||
left, top, right, bottom = draw.textbbox((0, 0), text)
|
||||
tw, th = right - left, bottom - top
|
||||
|
||||
return tw, th
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import PIL
|
||||
from PIL import ImageDraw
|
||||
|
||||
|
||||
def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
for idx, im in enumerate(image_list):
|
||||
im = draw_box(im, results[idx], labels, threshold=threshold)
|
||||
|
||||
out_path = os.path.join(output_dir, f"{idx}.jpg")
|
||||
im.save(out_path, quality=95)
|
||||
print("save result to: " + out_path)
|
||||
|
||||
|
||||
def draw_box(im, result, lables, threshold=0.5):
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
color_list = get_color_map_list(len(lables))
|
||||
clsid2color = {n.lower():color_list[i] for i,n in enumerate(lables)}
|
||||
result = [r for r in result if r["score"] >= threshold]
|
||||
|
||||
for dt in result:
|
||||
color = tuple(clsid2color[dt["type"]])
|
||||
xmin, ymin, xmax, ymax = dt["bbox"]
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=draw_thickness,
|
||||
fill=color)
|
||||
|
||||
# draw label
|
||||
text = "{} {:.4f}".format(dt["type"], dt["score"])
|
||||
tw, th = imagedraw_textsize_c(draw, text)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
|
||||
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
|
||||
return im
|
||||
|
||||
|
||||
def get_color_map_list(num_classes):
|
||||
"""
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
Returns:
|
||||
color_map (list): RGB color list
|
||||
"""
|
||||
color_map = num_classes * [0, 0, 0]
|
||||
for i in range(0, num_classes):
|
||||
j = 0
|
||||
lab = i
|
||||
while lab:
|
||||
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
|
||||
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
|
||||
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
|
||||
j += 1
|
||||
lab >>= 3
|
||||
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
|
||||
return color_map
|
||||
|
||||
|
||||
def imagedraw_textsize_c(draw, text):
|
||||
if int(PIL.__version__.split('.')[0]) < 10:
|
||||
tw, th = draw.textsize(text)
|
||||
else:
|
||||
left, top, right, bottom = draw.textbbox((0, 0), text)
|
||||
tw, th = right - left, bottom - top
|
||||
|
||||
return tw, th
|
||||
|
||||
@ -1,56 +1,56 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
'../../')))
|
||||
|
||||
from deepdoc.vision.seeit import draw_box
|
||||
from deepdoc.vision import OCR, init_in_out
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main(args):
|
||||
ocr = OCR()
|
||||
images, outputs = init_in_out(args)
|
||||
|
||||
for i, img in enumerate(images):
|
||||
bxs = ocr(np.array(img))
|
||||
bxs = [(line[0], line[1][0]) for line in bxs]
|
||||
bxs = [{
|
||||
"text": t,
|
||||
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
|
||||
"type": "ocr",
|
||||
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
|
||||
img = draw_box(images[i], bxs, ["ocr"], 1.)
|
||||
img.save(outputs[i], quality=95)
|
||||
with open(outputs[i] + ".txt", "w+") as f:
|
||||
f.write("\n".join([o["text"] for o in bxs]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--inputs',
|
||||
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
||||
required=True)
|
||||
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
|
||||
default="./ocr_outputs")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
'../../')))
|
||||
|
||||
from deepdoc.vision.seeit import draw_box
|
||||
from deepdoc.vision import OCR, init_in_out
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main(args):
|
||||
ocr = OCR()
|
||||
images, outputs = init_in_out(args)
|
||||
|
||||
for i, img in enumerate(images):
|
||||
bxs = ocr(np.array(img))
|
||||
bxs = [(line[0], line[1][0]) for line in bxs]
|
||||
bxs = [{
|
||||
"text": t,
|
||||
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
|
||||
"type": "ocr",
|
||||
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
|
||||
img = draw_box(images[i], bxs, ["ocr"], 1.)
|
||||
img.save(outputs[i], quality=95)
|
||||
with open(outputs[i] + ".txt", "w+") as f:
|
||||
f.write("\n".join([o["text"] for o in bxs]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--inputs',
|
||||
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
||||
required=True)
|
||||
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
|
||||
default="./ocr_outputs")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -1,187 +1,187 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os, sys
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
'../../')))
|
||||
|
||||
from deepdoc.vision.seeit import draw_box
|
||||
from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
import argparse
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main(args):
|
||||
images, outputs = init_in_out(args)
|
||||
if args.mode.lower() == "layout":
|
||||
labels = LayoutRecognizer.labels
|
||||
detr = Recognizer(
|
||||
labels,
|
||||
"layout",
|
||||
os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc/"))
|
||||
if args.mode.lower() == "tsr":
|
||||
labels = TableStructureRecognizer.labels
|
||||
detr = TableStructureRecognizer()
|
||||
ocr = OCR()
|
||||
|
||||
layouts = detr(images, float(args.threshold))
|
||||
for i, lyt in enumerate(layouts):
|
||||
if args.mode.lower() == "tsr":
|
||||
#lyt = [t for t in lyt if t["type"] == "table column"]
|
||||
html = get_table_html(images[i], lyt, ocr)
|
||||
with open(outputs[i] + ".html", "w+") as f:
|
||||
f.write(html)
|
||||
lyt = [{
|
||||
"type": t["label"],
|
||||
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
|
||||
"score": t["score"]
|
||||
} for t in lyt]
|
||||
img = draw_box(images[i], lyt, labels, float(args.threshold))
|
||||
img.save(outputs[i], quality=95)
|
||||
print("save result to: " + outputs[i])
|
||||
|
||||
|
||||
def get_table_html(img, tb_cpns, ocr):
|
||||
boxes = ocr(np.array(img))
|
||||
boxes = Recognizer.sort_Y_firstly(
|
||||
[{"x0": b[0][0], "x1": b[1][0],
|
||||
"top": b[0][1], "text": t[0],
|
||||
"bottom": b[-1][1],
|
||||
"layout_type": "table",
|
||||
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
||||
np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
|
||||
)
|
||||
|
||||
def gather(kwd, fzy=10, ption=0.6):
|
||||
nonlocal boxes
|
||||
eles = Recognizer.sort_Y_firstly(
|
||||
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
|
||||
eles = Recognizer.layouts_cleanup(boxes, eles, 5, ption)
|
||||
return Recognizer.sort_Y_firstly(eles, 0)
|
||||
|
||||
headers = gather(r".*header$")
|
||||
rows = gather(r".* (row|header)")
|
||||
spans = gather(r".*spanning")
|
||||
clmns = sorted([r for r in tb_cpns if re.match(
|
||||
r"table column$", r["label"])], key=lambda x: x["x0"])
|
||||
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
||||
|
||||
for b in boxes:
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
||||
if ii is not None:
|
||||
b["R"] = ii
|
||||
b["R_top"] = rows[ii]["top"]
|
||||
b["R_bott"] = rows[ii]["bottom"]
|
||||
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
|
||||
if ii is not None:
|
||||
b["H_top"] = headers[ii]["top"]
|
||||
b["H_bott"] = headers[ii]["bottom"]
|
||||
b["H_left"] = headers[ii]["x0"]
|
||||
b["H_right"] = headers[ii]["x1"]
|
||||
b["H"] = ii
|
||||
|
||||
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
||||
if ii is not None:
|
||||
b["C"] = ii
|
||||
b["C_left"] = clmns[ii]["x0"]
|
||||
b["C_right"] = clmns[ii]["x1"]
|
||||
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3)
|
||||
if ii is not None:
|
||||
b["H_top"] = spans[ii]["top"]
|
||||
b["H_bott"] = spans[ii]["bottom"]
|
||||
b["H_left"] = spans[ii]["x0"]
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
html = """
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
._table_1nkzy_11 {
|
||||
margin: auto;
|
||||
width: 70%%;
|
||||
padding: 10px;
|
||||
}
|
||||
._table_1nkzy_11 p {
|
||||
margin-bottom: 50px;
|
||||
border: 1px solid #e1e1e1;
|
||||
}
|
||||
|
||||
caption {
|
||||
color: #6ac1ca;
|
||||
font-size: 20px;
|
||||
height: 50px;
|
||||
line-height: 50px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
._table_1nkzy_11 table {
|
||||
width: 100%%;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
th {
|
||||
color: #fff;
|
||||
background-color: #6ac1ca;
|
||||
}
|
||||
|
||||
td:hover {
|
||||
background: #c1e8e8;
|
||||
}
|
||||
|
||||
tr:nth-child(even) {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
|
||||
._table_1nkzy_11 th,
|
||||
._table_1nkzy_11 td {
|
||||
text-align: center;
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
%s
|
||||
</body>
|
||||
</html>
|
||||
""" % TableStructureRecognizer.construct_table(boxes, html=True)
|
||||
return html
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--inputs',
|
||||
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
||||
required=True)
|
||||
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
|
||||
default="./layouts_outputs")
|
||||
parser.add_argument(
|
||||
'--threshold',
|
||||
help="A threshold to filter out detections. Default: 0.5",
|
||||
default=0.5)
|
||||
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
|
||||
default="layout")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os, sys
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
'../../')))
|
||||
|
||||
from deepdoc.vision.seeit import draw_box
|
||||
from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
import argparse
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main(args):
|
||||
images, outputs = init_in_out(args)
|
||||
if args.mode.lower() == "layout":
|
||||
labels = LayoutRecognizer.labels
|
||||
detr = Recognizer(
|
||||
labels,
|
||||
"layout",
|
||||
os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc/"))
|
||||
if args.mode.lower() == "tsr":
|
||||
labels = TableStructureRecognizer.labels
|
||||
detr = TableStructureRecognizer()
|
||||
ocr = OCR()
|
||||
|
||||
layouts = detr(images, float(args.threshold))
|
||||
for i, lyt in enumerate(layouts):
|
||||
if args.mode.lower() == "tsr":
|
||||
#lyt = [t for t in lyt if t["type"] == "table column"]
|
||||
html = get_table_html(images[i], lyt, ocr)
|
||||
with open(outputs[i] + ".html", "w+") as f:
|
||||
f.write(html)
|
||||
lyt = [{
|
||||
"type": t["label"],
|
||||
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
|
||||
"score": t["score"]
|
||||
} for t in lyt]
|
||||
img = draw_box(images[i], lyt, labels, float(args.threshold))
|
||||
img.save(outputs[i], quality=95)
|
||||
print("save result to: " + outputs[i])
|
||||
|
||||
|
||||
def get_table_html(img, tb_cpns, ocr):
|
||||
boxes = ocr(np.array(img))
|
||||
boxes = Recognizer.sort_Y_firstly(
|
||||
[{"x0": b[0][0], "x1": b[1][0],
|
||||
"top": b[0][1], "text": t[0],
|
||||
"bottom": b[-1][1],
|
||||
"layout_type": "table",
|
||||
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
||||
np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
|
||||
)
|
||||
|
||||
def gather(kwd, fzy=10, ption=0.6):
|
||||
nonlocal boxes
|
||||
eles = Recognizer.sort_Y_firstly(
|
||||
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
|
||||
eles = Recognizer.layouts_cleanup(boxes, eles, 5, ption)
|
||||
return Recognizer.sort_Y_firstly(eles, 0)
|
||||
|
||||
headers = gather(r".*header$")
|
||||
rows = gather(r".* (row|header)")
|
||||
spans = gather(r".*spanning")
|
||||
clmns = sorted([r for r in tb_cpns if re.match(
|
||||
r"table column$", r["label"])], key=lambda x: x["x0"])
|
||||
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
||||
|
||||
for b in boxes:
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
||||
if ii is not None:
|
||||
b["R"] = ii
|
||||
b["R_top"] = rows[ii]["top"]
|
||||
b["R_bott"] = rows[ii]["bottom"]
|
||||
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
|
||||
if ii is not None:
|
||||
b["H_top"] = headers[ii]["top"]
|
||||
b["H_bott"] = headers[ii]["bottom"]
|
||||
b["H_left"] = headers[ii]["x0"]
|
||||
b["H_right"] = headers[ii]["x1"]
|
||||
b["H"] = ii
|
||||
|
||||
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
||||
if ii is not None:
|
||||
b["C"] = ii
|
||||
b["C_left"] = clmns[ii]["x0"]
|
||||
b["C_right"] = clmns[ii]["x1"]
|
||||
|
||||
ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3)
|
||||
if ii is not None:
|
||||
b["H_top"] = spans[ii]["top"]
|
||||
b["H_bott"] = spans[ii]["bottom"]
|
||||
b["H_left"] = spans[ii]["x0"]
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
html = """
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
._table_1nkzy_11 {
|
||||
margin: auto;
|
||||
width: 70%%;
|
||||
padding: 10px;
|
||||
}
|
||||
._table_1nkzy_11 p {
|
||||
margin-bottom: 50px;
|
||||
border: 1px solid #e1e1e1;
|
||||
}
|
||||
|
||||
caption {
|
||||
color: #6ac1ca;
|
||||
font-size: 20px;
|
||||
height: 50px;
|
||||
line-height: 50px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
._table_1nkzy_11 table {
|
||||
width: 100%%;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
th {
|
||||
color: #fff;
|
||||
background-color: #6ac1ca;
|
||||
}
|
||||
|
||||
td:hover {
|
||||
background: #c1e8e8;
|
||||
}
|
||||
|
||||
tr:nth-child(even) {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
|
||||
._table_1nkzy_11 th,
|
||||
._table_1nkzy_11 td {
|
||||
text-align: center;
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
%s
|
||||
</body>
|
||||
</html>
|
||||
""" % TableStructureRecognizer.construct_table(boxes, html=True)
|
||||
return html
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--inputs',
|
||||
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
||||
required=True)
|
||||
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
|
||||
default="./layouts_outputs")
|
||||
parser.add_argument(
|
||||
'--threshold',
|
||||
help="A threshold to filter out detections. Default: 0.5",
|
||||
default=0.5)
|
||||
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
|
||||
default="layout")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user