mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-23 06:46:40 +08:00
Refa: treat MinerU as an OCR model (#11849)
### What problem does this PR solve? Treat MinerU as an OCR model. ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring
This commit is contained in:
@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from rag.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
|
||||
|
||||
|
||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||
@ -43,7 +43,13 @@ def factories():
|
||||
mdl_types[m.fid] = set([])
|
||||
mdl_types[m.fid].add(m.model_type)
|
||||
for f in fac:
|
||||
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
|
||||
f["model_types"] = list(
|
||||
mdl_types.get(
|
||||
f["name"],
|
||||
[LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS, LLMType.OCR],
|
||||
)
|
||||
)
|
||||
|
||||
return get_json_result(data=fac)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -251,6 +257,15 @@ async def add_llm():
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.OCR.value:
|
||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
|
||||
ok, reason = mdl.check_available()
|
||||
if not ok:
|
||||
raise RuntimeError(reason or "Model not available")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
@ -297,6 +312,7 @@ async def delete_factory():
|
||||
@login_required
|
||||
def my_llms():
|
||||
try:
|
||||
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
||||
include_details = request.args.get("include_details", "false").lower() == "true"
|
||||
|
||||
if include_details:
|
||||
@ -344,6 +360,7 @@ def list_app():
|
||||
weighted = []
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||
|
||||
@ -14,15 +14,16 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from langfuse import Langfuse
|
||||
from common import settings
|
||||
from common.constants import LLMType
|
||||
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType
|
||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
@ -104,6 +105,10 @@ class TenantLLMService(CommonService):
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.OCR:
|
||||
if not llm_name:
|
||||
raise LookupError("OCR model name is required")
|
||||
mdlnm = llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
@ -137,31 +142,31 @@ class TenantLLMService(CommonService):
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
elif llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return None
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return None
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return None
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
elif llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return None
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
||||
model_name=model_config["llm_name"], lang=lang,
|
||||
base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
elif llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return None
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
@ -169,6 +174,17 @@ class TenantLLMService(CommonService):
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
|
||||
elif llm_type == LLMType.OCR:
|
||||
if model_config["llm_factory"] not in OcrModel:
|
||||
return None
|
||||
return OcrModel[model_config["llm_factory"]](
|
||||
key=model_config["api_key"],
|
||||
model_name=model_config["llm_name"],
|
||||
base_url=model_config.get("api_base", ""),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@ -186,6 +202,7 @@ class TenantLLMService(CommonService):
|
||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||
LLMType.OCR.value: llm_name,
|
||||
}
|
||||
|
||||
mdlnm = llm_map.get(llm_type)
|
||||
@ -218,6 +235,61 @@ class TenantLLMService(CommonService):
|
||||
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
def _collect_mineru_env_config(cls) -> dict | None:
|
||||
cfg = MINERU_DEFAULT_CONFIG
|
||||
found = False
|
||||
for key in MINERU_ENV_KEYS:
|
||||
val = os.environ.get(key)
|
||||
if val:
|
||||
found = True
|
||||
cfg[key] = val
|
||||
return cfg if found else None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def ensure_mineru_from_env(cls, tenant_id: str) -> str | None:
|
||||
"""
|
||||
Ensure a MinerU OCR model exists for the tenant if env variables are present.
|
||||
Return the existing or newly created llm_name, or None if env not set.
|
||||
"""
|
||||
cfg = cls._collect_mineru_env_config()
|
||||
if not cfg:
|
||||
return None
|
||||
|
||||
saved_mineru_models = cls.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value)
|
||||
|
||||
def _parse_api_key(raw: str) -> dict:
|
||||
try:
|
||||
return json.loads(raw or "{}")
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
for item in saved_mineru_models:
|
||||
api_cfg = _parse_api_key(item.api_key)
|
||||
normalized = {k: api_cfg.get(k, MINERU_DEFAULT_CONFIG.get(k)) for k in MINERU_ENV_KEYS}
|
||||
if normalized == cfg:
|
||||
return item.llm_name
|
||||
|
||||
used_names = {item.llm_name for item in saved_mineru_models}
|
||||
idx = 1
|
||||
base_name = "mineru-from-env"
|
||||
candidate = f"{base_name}-{idx}"
|
||||
while candidate in used_names:
|
||||
idx += 1
|
||||
candidate = f"{base_name}-{idx}"
|
||||
|
||||
cls.save(
|
||||
tenant_id=tenant_id,
|
||||
llm_factory="MinerU",
|
||||
llm_name=candidate,
|
||||
model_type=LLMType.OCR.value,
|
||||
api_key=json.dumps(cfg),
|
||||
api_base="",
|
||||
max_tokens=0,
|
||||
)
|
||||
return candidate
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
|
||||
Reference in New Issue
Block a user