Refa: treat MinerU as an OCR model (#11849)

### What problem does this PR solve?

 Treat MinerU as an OCR model.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-12-09 18:54:14 +08:00
committed by GitHub
parent 30377319d8
commit a94b3b9df2
9 changed files with 283 additions and 43 deletions

View File

@ -14,15 +14,16 @@
# limitations under the License.
#
import os
import json
import logging
from langfuse import Langfuse
from common import settings
from common.constants import LLMType
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.user_service import TenantService
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel
class LLMFactoriesService(CommonService):
@ -104,6 +105,10 @@ class TenantLLMService(CommonService):
mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name
elif llm_type == LLMType.OCR:
if not llm_name:
raise LookupError("OCR model name is required")
mdlnm = llm_name
else:
assert False, "LLM type error"
@ -137,31 +142,31 @@ class TenantLLMService(CommonService):
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
elif llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return None
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
elif llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return None
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
elif llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return None
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
elif llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return None
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
model_name=model_config["llm_name"], lang=lang,
base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
elif llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return None
return TTSModel[model_config["llm_factory"]](
@ -169,6 +174,17 @@ class TenantLLMService(CommonService):
model_config["llm_name"],
base_url=model_config["api_base"],
)
elif llm_type == LLMType.OCR:
if model_config["llm_factory"] not in OcrModel:
return None
return OcrModel[model_config["llm_factory"]](
key=model_config["api_key"],
model_name=model_config["llm_name"],
base_url=model_config.get("api_base", ""),
**kwargs,
)
return None
@classmethod
@ -186,6 +202,7 @@ class TenantLLMService(CommonService):
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
LLMType.OCR.value: llm_name,
}
mdlnm = llm_map.get(llm_type)
@ -218,6 +235,61 @@ class TenantLLMService(CommonService):
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@classmethod
def _collect_mineru_env_config(cls) -> dict | None:
cfg = MINERU_DEFAULT_CONFIG
found = False
for key in MINERU_ENV_KEYS:
val = os.environ.get(key)
if val:
found = True
cfg[key] = val
return cfg if found else None
@classmethod
@DB.connection_context()
def ensure_mineru_from_env(cls, tenant_id: str) -> str | None:
"""
Ensure a MinerU OCR model exists for the tenant if env variables are present.
Return the existing or newly created llm_name, or None if env not set.
"""
cfg = cls._collect_mineru_env_config()
if not cfg:
return None
saved_mineru_models = cls.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value)
def _parse_api_key(raw: str) -> dict:
try:
return json.loads(raw or "{}")
except Exception:
return {}
for item in saved_mineru_models:
api_cfg = _parse_api_key(item.api_key)
normalized = {k: api_cfg.get(k, MINERU_DEFAULT_CONFIG.get(k)) for k in MINERU_ENV_KEYS}
if normalized == cfg:
return item.llm_name
used_names = {item.llm_name for item in saved_mineru_models}
idx = 1
base_name = "mineru-from-env"
candidate = f"{base_name}-{idx}"
while candidate in used_names:
idx += 1
candidate = f"{base_name}-{idx}"
cls.save(
tenant_id=tenant_id,
llm_factory="MinerU",
llm_name=candidate,
model_type=LLMType.OCR.value,
api_key=json.dumps(cfg),
api_base="",
max_tokens=0,
)
return candidate
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):