Refa: treat MinerU as an OCR model (#11849)

### What problem does this PR solve?

 Treat MinerU as an OCR model.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-12-09 18:54:14 +08:00
committed by GitHub
parent 30377319d8
commit a94b3b9df2
9 changed files with 283 additions and 43 deletions

View File

@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from rag.utils.base64_image import test_image from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
@manager.route("/factories", methods=["GET"]) # noqa: F821 @manager.route("/factories", methods=["GET"]) # noqa: F821
@ -43,7 +43,13 @@ def factories():
mdl_types[m.fid] = set([]) mdl_types[m.fid] = set([])
mdl_types[m.fid].add(m.model_type) mdl_types[m.fid].add(m.model_type)
for f in fac: for f in fac:
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS])) f["model_types"] = list(
mdl_types.get(
f["name"],
[LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS, LLMType.OCR],
)
)
return get_json_result(data=fac) return get_json_result(data=fac)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -251,6 +257,15 @@ async def add_llm():
pass pass
except RuntimeError as e: except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
elif llm["model_type"] == LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
ok, reason = mdl.check_available()
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
else: else:
# TODO: check other type of models # TODO: check other type of models
pass pass
@ -297,6 +312,7 @@ async def delete_factory():
@login_required @login_required
def my_llms(): def my_llms():
try: try:
TenantLLMService.ensure_mineru_from_env(current_user.id)
include_details = request.args.get("include_details", "false").lower() == "true" include_details = request.args.get("include_details", "false").lower() == "true"
if include_details: if include_details:
@ -344,6 +360,7 @@ def list_app():
weighted = [] weighted = []
model_type = request.args.get("model_type") model_type = request.args.get("model_type")
try: try:
TenantLLMService.ensure_mineru_from_env(current_user.id)
objs = TenantLLMService.query(tenant_id=current_user.id) objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value]) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value} status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}

View File

@ -14,15 +14,16 @@
# limitations under the License. # limitations under the License.
# #
import os import os
import json
import logging import logging
from langfuse import Langfuse from langfuse import Langfuse
from common import settings from common import settings
from common.constants import LLMType from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel
class LLMFactoriesService(CommonService): class LLMFactoriesService(CommonService):
@ -104,6 +105,10 @@ class TenantLLMService(CommonService):
mdlnm = tenant.rerank_id if not llm_name else llm_name mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS: elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name mdlnm = tenant.tts_id if not llm_name else llm_name
elif llm_type == LLMType.OCR:
if not llm_name:
raise LookupError("OCR model name is required")
mdlnm = llm_name
else: else:
assert False, "LLM type error" assert False, "LLM type error"
@ -137,31 +142,31 @@ class TenantLLMService(CommonService):
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"]) base_url=model_config["api_base"])
if llm_type == LLMType.RERANK: elif llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel: if model_config["llm_factory"] not in RerankModel:
return None return None
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"]) base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value: elif llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: if model_config["llm_factory"] not in CvModel:
return None return None
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"], **kwargs) base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value: elif llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel: if model_config["llm_factory"] not in ChatModel:
return None return None
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"], **kwargs) base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT: elif llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel: if model_config["llm_factory"] not in Seq2txtModel:
return None return None
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
model_name=model_config["llm_name"], lang=lang, model_name=model_config["llm_name"], lang=lang,
base_url=model_config["api_base"]) base_url=model_config["api_base"])
if llm_type == LLMType.TTS: elif llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel: if model_config["llm_factory"] not in TTSModel:
return None return None
return TTSModel[model_config["llm_factory"]]( return TTSModel[model_config["llm_factory"]](
@ -169,6 +174,17 @@ class TenantLLMService(CommonService):
model_config["llm_name"], model_config["llm_name"],
base_url=model_config["api_base"], base_url=model_config["api_base"],
) )
elif llm_type == LLMType.OCR:
if model_config["llm_factory"] not in OcrModel:
return None
return OcrModel[model_config["llm_factory"]](
key=model_config["api_key"],
model_name=model_config["llm_name"],
base_url=model_config.get("api_base", ""),
**kwargs,
)
return None return None
@classmethod @classmethod
@ -186,6 +202,7 @@ class TenantLLMService(CommonService):
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
LLMType.OCR.value: llm_name,
} }
mdlnm = llm_map.get(llm_type) mdlnm = llm_map.get(llm_type)
@ -218,6 +235,61 @@ class TenantLLMService(CommonService):
~(cls.model.llm_name == "text-embedding-3-large")).dicts() ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs) return list(objs)
@classmethod
def _collect_mineru_env_config(cls) -> dict | None:
cfg = MINERU_DEFAULT_CONFIG
found = False
for key in MINERU_ENV_KEYS:
val = os.environ.get(key)
if val:
found = True
cfg[key] = val
return cfg if found else None
@classmethod
@DB.connection_context()
def ensure_mineru_from_env(cls, tenant_id: str) -> str | None:
"""
Ensure a MinerU OCR model exists for the tenant if env variables are present.
Return the existing or newly created llm_name, or None if env not set.
"""
cfg = cls._collect_mineru_env_config()
if not cfg:
return None
saved_mineru_models = cls.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value)
def _parse_api_key(raw: str) -> dict:
try:
return json.loads(raw or "{}")
except Exception:
return {}
for item in saved_mineru_models:
api_cfg = _parse_api_key(item.api_key)
normalized = {k: api_cfg.get(k, MINERU_DEFAULT_CONFIG.get(k)) for k in MINERU_ENV_KEYS}
if normalized == cfg:
return item.llm_name
used_names = {item.llm_name for item in saved_mineru_models}
idx = 1
base_name = "mineru-from-env"
candidate = f"{base_name}-{idx}"
while candidate in used_names:
idx += 1
candidate = f"{base_name}-{idx}"
cls.save(
tenant_id=tenant_id,
llm_factory="MinerU",
llm_name=candidate,
model_type=LLMType.OCR.value,
api_key=json.dumps(cfg),
api_base="",
max_tokens=0,
)
return candidate
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def delete_by_tenant_id(cls, tenant_id): def delete_by_tenant_id(cls, tenant_id):

View File

@ -73,6 +73,7 @@ class LLMType(StrEnum):
IMAGE2TEXT = 'image2text' IMAGE2TEXT = 'image2text'
RERANK = 'rerank' RERANK = 'rerank'
TTS = 'tts' TTS = 'tts'
OCR = 'ocr'
class TaskStatus(StrEnum): class TaskStatus(StrEnum):
@ -199,3 +200,13 @@ PAGERANK_FLD = "pagerank_fea"
SVR_QUEUE_NAME = "rag_flow_svr_queue" SVR_QUEUE_NAME = "rag_flow_svr_queue"
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
TAG_FLD = "tag_feas" TAG_FLD = "tag_feas"
MINERU_ENV_KEYS = ["MINERU_APISERVER", "MINERU_OUTPUT_DIR", "MINERU_BACKEND", "MINERU_SERVER_URL", "MINERU_DELETE_OUTPUT"]
MINERU_DEFAULT_CONFIG = {
"MINERU_APISERVER": "",
"MINERU_OUTPUT_DIR": "",
"MINERU_BACKEND": "pipeline",
"MINERU_SERVER_URL": "",
"MINERU_DELETE_OUTPUT": 1,
}

View File

@ -5489,6 +5489,14 @@
"model_type": "reranker" "model_type": "reranker"
} }
] ]
},
{
"name": "MinerU",
"logo": "",
"tags": "OCR",
"status": "1",
"rank": "900",
"llm": []
} }
] ]
} }

View File

@ -39,7 +39,6 @@ from sklearn.metrics import silhouette_score
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch from common.misc_utils import pip_install_torch
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from rag.prompts.generator import vision_llm_describe_prompt from rag.prompts.generator import vision_llm_describe_prompt
from common import settings from common import settings
@ -1455,6 +1454,8 @@ class VisionParser(RAGFlowPdfParser):
if pdf_page_num < start_page or pdf_page_num >= end_page: if pdf_page_num < start_page or pdf_page_num >= end_page:
continue continue
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
text = picture_vision_llm_chunk( text = picture_vision_llm_chunk(
binary=img_binary, binary=img_binary,
vision_model=self.vision_model, vision_model=self.vision_model,

View File

@ -34,7 +34,6 @@ from rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, ext
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
from deepdoc.parser.pdf_parser import PlainParser, VisionParser from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from deepdoc.parser.mineru_parser import MinerUParser
from deepdoc.parser.docling_parser import DoclingParser from deepdoc.parser.docling_parser import DoclingParser
from deepdoc.parser.tcadp_parser import TCADPParser from deepdoc.parser.tcadp_parser import TCADPParser
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table, attach_media_context
@ -58,27 +57,42 @@ def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None ,**kwargs): def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None ,**kwargs):
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api)
parse_method = kwargs.get("parse_method", "raw") parse_method = kwargs.get("parse_method", "raw")
mineru_llm_name = kwargs.get("mineru_llm_name")
tenant_id = kwargs.get("tenant_id")
if not pdf_parser.check_installation(): pdf_parser = None
callback(-1, "MinerU not found.") if tenant_id:
return None, None, pdf_parser if not mineru_llm_name:
try:
from api.db.services.tenant_llm_service import TenantLLMService
env_name = TenantLLMService.ensure_mineru_from_env(tenant_id)
candidates = TenantLLMService.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value)
if candidates:
mineru_llm_name = candidates[0].llm_name
elif env_name:
mineru_llm_name = env_name
except Exception as e: # best-effort fallback
logging.warning(f"fallback to env mineru: {e}")
if mineru_llm_name:
try:
ocr_model = LLMBundle(tenant_id, LLMType.OCR, llm_name=mineru_llm_name, lang=lang)
pdf_parser = ocr_model.mdl
sections, tables = pdf_parser.parse_pdf( sections, tables = pdf_parser.parse_pdf(
filepath=filename, filepath=filename,
binary=binary, binary=binary,
callback=callback, callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""), parse_method=parse_method,
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
server_url=os.environ.get("MINERU_SERVER_URL", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
parse_method=parse_method
) )
return sections, tables, pdf_parser return sections, tables, pdf_parser
except Exception as e:
logging.error(f"Failed to parse pdf via LLMBundle MinerU ({mineru_llm_name}): {e}")
if callback:
callback(-1, "MinerU not found.")
return None, None, None
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None ,**kwargs): def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None ,**kwargs):
pdf_parser = DoclingParser() pdf_parser = DoclingParser()
@ -692,7 +706,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
return res return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") layout_recognizer_raw = parser_config.get("layout_recognize", "DeepDOC")
parser_model_name = None
layout_recognizer = layout_recognizer_raw
if isinstance(layout_recognizer_raw, str):
lowered = layout_recognizer_raw.lower()
if lowered.startswith("mineru@"):
parser_model_name = layout_recognizer_raw.split("@", 1)[1]
layout_recognizer = "MinerU"
if parser_config.get("analyze_hyperlink", False) and is_root: if parser_config.get("analyze_hyperlink", False) and is_root:
urls = extract_links_from_pdf(binary) urls = extract_links_from_pdf(binary)
@ -711,6 +733,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
lang = lang, lang = lang,
callback = callback, callback = callback,
layout_recognizer = layout_recognizer, layout_recognizer = layout_recognizer,
mineru_llm_name = parser_model_name,
**kwargs **kwargs
) )

View File

@ -31,7 +31,6 @@ from common import settings
from common.constants import LLMType from common.constants import LLMType
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from deepdoc.parser import ExcelParser from deepdoc.parser import ExcelParser
from deepdoc.parser.mineru_parser import MinerUParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
from deepdoc.parser.tcadp_parser import TCADPParser from deepdoc.parser.tcadp_parser import TCADPParser
from rag.app.naive import Docx from rag.app.naive import Docx
@ -235,25 +234,55 @@ class Parser(ProcessBase):
conf = self._param.setups["pdf"] conf = self._param.setups["pdf"]
self.set_output("output_format", conf["output_format"]) self.set_output("output_format", conf["output_format"])
if conf.get("parse_method").lower() == "deepdoc": raw_parse_method = conf.get("parse_method", "")
parser_model_name = None
parse_method = raw_parse_method
parse_method = parse_method or ""
if isinstance(raw_parse_method, str):
lowered = raw_parse_method.lower()
if lowered.startswith("mineru@"):
parser_model_name = raw_parse_method.split("@", 1)[1]
parse_method = "MinerU"
elif lowered.endswith("@mineru"):
parser_model_name = raw_parse_method.rsplit("@", 1)[0]
parse_method = "MinerU"
if parse_method.lower() == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback) bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method").lower() == "plain_text": elif parse_method.lower() == "plain_text":
lines, _ = PlainParser()(blob) lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines] bboxes = [{"text": t} for t, _ in lines]
elif conf.get("parse_method").lower() == "mineru": elif parse_method.lower() == "mineru":
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru") def resolve_mineru_llm_name():
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987") configured = parser_model_name or conf.get("mineru_llm_name")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api) if configured:
ok, reason = pdf_parser.check_installation() return configured
if not ok:
raise RuntimeError(f"MinerU not found or server not accessible: {reason}. Please install it via: pip install -U 'mineru[core]'.") tenant_id = self._canvas._tenant_id
if not tenant_id:
return None
from api.db.services.tenant_llm_service import TenantLLMService
env_name = TenantLLMService.ensure_mineru_from_env(tenant_id)
candidates = TenantLLMService.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value)
if candidates:
return candidates[0].llm_name
return env_name
parser_model_name = resolve_mineru_llm_name()
if not parser_model_name:
raise RuntimeError("MinerU model not configured. Please add MinerU in Model Providers or set MINERU_* env.")
tenant_id = self._canvas._tenant_id
ocr_model = LLMBundle(tenant_id, LLMType.OCR, llm_name=parser_model_name, lang=conf.get("lang", "Chinese"))
pdf_parser = ocr_model.mdl
lines, _ = pdf_parser.parse_pdf( lines, _ = pdf_parser.parse_pdf(
filepath=name, filepath=name,
binary=blob, binary=blob,
callback=self.callback, callback=self.callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""), parse_method=conf.get("mineru_parse_method", "raw"),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
) )
bboxes = [] bboxes = []
for t, poss in lines: for t, poss in lines:
@ -263,7 +292,7 @@ class Parser(ProcessBase):
"text": t, "text": t,
} }
bboxes.append(box) bboxes.append(box)
elif conf.get("parse_method").lower() == "tcadp parser": elif parse_method.lower() == "tcadp parser":
# ADP is a document parsing tool using Tencent Cloud API # ADP is a document parsing tool using Tencent Cloud API
table_result_type = conf.get("table_result_type", "1") table_result_type = conf.get("table_result_type", "1")
markdown_image_response_type = conf.get("markdown_image_response_type", "1") markdown_image_response_type = conf.get("markdown_image_response_type", "1")

View File

@ -121,6 +121,7 @@ EmbeddingModel = globals().get("EmbeddingModel", {})
RerankModel = globals().get("RerankModel", {}) RerankModel = globals().get("RerankModel", {})
Seq2txtModel = globals().get("Seq2txtModel", {}) Seq2txtModel = globals().get("Seq2txtModel", {})
TTSModel = globals().get("TTSModel", {}) TTSModel = globals().get("TTSModel", {})
OcrModel = globals().get("OcrModel", {})
MODULE_MAPPING = { MODULE_MAPPING = {
@ -130,6 +131,7 @@ MODULE_MAPPING = {
"rerank_model": RerankModel, "rerank_model": RerankModel,
"sequence2txt_model": Seq2txtModel, "sequence2txt_model": Seq2txtModel,
"tts_model": TTSModel, "tts_model": TTSModel,
"ocr_model": OcrModel,
} }
package_name = __name__ package_name = __name__
@ -171,4 +173,5 @@ __all__ = [
"RerankModel", "RerankModel",
"Seq2txtModel", "Seq2txtModel",
"TTSModel", "TTSModel",
"OcrModel",
] ]

76
rag/llm/ocr_model.py Normal file
View File

@ -0,0 +1,76 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
from typing import Any, Optional, Tuple
from deepdoc.parser.mineru_parser import MinerUParser
class Base:
def __init__(self, key: str, model_name: str, **kwargs):
self.model_name = model_name
def parse_pdf(self, filepath: str, binary=None, **kwargs) -> Tuple[Any, Any]:
raise NotImplementedError("Please implement parse_pdf!")
class MinerUOcrModel(Base, MinerUParser):
_FACTORY_NAME = "MinerU"
def __init__(self, key: str, model_name: str, **kwargs):
Base.__init__(self, key, model_name, **kwargs)
cfg = {}
if key:
try:
cfg = json.loads(key)
except Exception:
cfg = {}
self.mineru_api = cfg.get("MINERU_APISERVER", os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987"))
self.mineru_output_dir = cfg.get("MINERU_OUTPUT_DIR", os.environ.get("MINERU_OUTPUT_DIR", ""))
self.mineru_backend = cfg.get("MINERU_BACKEND", os.environ.get("MINERU_BACKEND", "pipeline"))
self.mineru_server_url = cfg.get("MINERU_SERVER_URL", os.environ.get("MINERU_SERVER_URL", ""))
self.mineru_delete_output = bool(int(cfg.get("MINERU_DELETE_OUTPUT", os.environ.get("MINERU_DELETE_OUTPUT", 1))))
self.mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
logging.info(f"Parsered MinerU config: {cfg}")
MinerUParser.__init__(self, mineru_path=self.mineru_executable, mineru_api=self.mineru_api, mineru_server_url=self.mineru_server_url)
def check_available(self, backend: Optional[str] = None, server_url: Optional[str] = None) -> Tuple[bool, str]:
backend = backend or self.mineru_backend
server_url = server_url or self.mineru_server_url
return self.check_installation(backend=backend, server_url=server_url)
def parse_pdf(self, filepath: str, binary=None, callback=None, parse_method: str = "raw", **kwargs):
ok, reason = self.check_available()
if not ok:
raise RuntimeError(f"MinerU not found or server not accessible: {reason}. Please install it via: pip install -U 'mineru[core]'.")
sections, tables = MinerUParser.parse_pdf(
self,
filepath=filepath,
binary=binary,
callback=callback,
output_dir=self.mineru_output_dir,
backend=self.mineru_backend,
server_url=self.mineru_server_url,
delete_output=self.mineru_delete_output,
parse_method=parse_method,
)
return sections, tables