583 lines
21 KiB
Python
583 lines
21 KiB
Python
|
|
"""Inference-only Deepseek-OCR model compatible with HuggingFace weights."""
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
from transformers import BatchFeature
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor import SamplingMetadata
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|
MultiModalKwargs, NestedTensors)
|
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
|
ImageSize, MultiModalDataItems)
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
BaseProcessingInfo, PromptReplacement,
|
|
PromptUpdate)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
|
|
MlpProjectorConfig,
|
|
VisionEncoderConfig)
|
|
from process.image_process import (
|
|
DeepseekOCRProcessor, count_tiles)
|
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|
# from vllm.utils import is_list_of
|
|
|
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|
init_vllm_registered_model, maybe_prefix,
|
|
merge_multimodal_embeddings)
|
|
|
|
from deepencoder.sam_vary_sdpa import build_sam_vit_b
|
|
from deepencoder.clip_sdpa import build_clip_l
|
|
from deepencoder.build_linear import MlpProjector
|
|
from addict import Dict
|
|
# import time
|
|
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, PRINT_NUM_VIS_TOKENS, PROMPT
|
|
# The image token id may be various
|
|
_IMAGE_TOKEN = "<image>"
|
|
|
|
|
|
class DeepseekOCRProcessingInfo(BaseProcessingInfo):
|
|
|
|
def get_hf_config(self):
|
|
return self.ctx.get_hf_config(DeepseekVLV2Config)
|
|
|
|
def get_hf_processor(self, **kwargs: object):
|
|
return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
return {"image": None}
|
|
|
|
def get_num_image_tokens(self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
cropping: bool = True) -> int:
|
|
hf_processor = self.get_hf_processor()
|
|
|
|
|
|
# image_size = hf_processor.image_size
|
|
# patch_size = hf_processor.patch_size
|
|
# downsample_ratio = hf_processor.downsample_ratio
|
|
|
|
image_size = IMAGE_SIZE
|
|
base_size = BASE_SIZE
|
|
patch_size = 16
|
|
downsample_ratio = 4
|
|
|
|
if CROP_MODE:
|
|
if image_width <= 640 and image_height <= 640:
|
|
crop_ratio = [1, 1]
|
|
else:
|
|
# images_crop_raw, crop_ratio = hf_processor.dynamic_preprocess(image)
|
|
|
|
# find the closest aspect ratio to the target
|
|
crop_ratio = count_tiles(image_width, image_height, image_size=IMAGE_SIZE)
|
|
|
|
# print('===========')
|
|
# print('crop_ratio ', crop_ratio)
|
|
# print('============')
|
|
|
|
num_width_tiles, num_height_tiles = crop_ratio
|
|
else:
|
|
num_width_tiles = num_height_tiles = 1
|
|
|
|
h = w = math.ceil((base_size // patch_size) / downsample_ratio)
|
|
|
|
h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)
|
|
|
|
global_views_tokens = h * (w + 1)
|
|
if num_width_tiles >1 or num_height_tiles>1:
|
|
local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
|
|
else:
|
|
local_views_tokens = 0
|
|
|
|
|
|
return global_views_tokens + local_views_tokens + 1
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
|
|
if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
|
|
return ImageSize(width=1024*2, height=1024*2)
|
|
return ImageSize(width=640*2, height=640*2)
|
|
|
|
|
|
class DeepseekOCRDummyInputsBuilder(
|
|
BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
processor = self.info.get_hf_processor()
|
|
image_token = processor.image_token
|
|
|
|
return image_token * num_images
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> MultiModalDataDict:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
max_image_size = self.info.get_image_size_with_most_features()
|
|
|
|
if '<image>' in PROMPT:
|
|
return {
|
|
"image":
|
|
DeepseekOCRProcessor().tokenize_with_images(images = self._get_dummy_images(width=max_image_size.width,
|
|
height=max_image_size.height,
|
|
num_images=num_images), bos=True, eos=True, cropping=CROP_MODE)
|
|
}
|
|
else:
|
|
return {
|
|
"image": []
|
|
}
|
|
|
|
|
|
|
|
|
|
class DeepseekOCRMultiModalProcessor(
|
|
BaseMultiModalProcessor[DeepseekOCRProcessingInfo]):
|
|
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
|
|
|
|
# print(mm_data)
|
|
if mm_data:
|
|
processed_outputs = self.info.ctx.call_hf_processor(
|
|
self.info.get_hf_processor(**mm_kwargs),
|
|
dict(prompt=prompt, **mm_data),
|
|
mm_kwargs,
|
|
)
|
|
|
|
else:
|
|
tokenizer = self.info.get_tokenizer()
|
|
processed_outputs = tokenizer(prompt,
|
|
add_special_tokens=True,
|
|
return_tensors="pt")
|
|
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(
|
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
|
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
|
# image_embeds=MultiModalFieldConfig.batched("image2"),
|
|
images_crop=MultiModalFieldConfig.batched("image"),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargs,
|
|
) -> Sequence[PromptUpdate]:
|
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
|
|
image_token_id = hf_processor.image_token_id
|
|
assert isinstance(image_token_id, int)
|
|
|
|
def get_replacement_deepseek_vl2(item_idx: int):
|
|
images = mm_items.get_items(
|
|
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
|
|
|
|
|
|
|
if isinstance(images, ImageEmbeddingItems):
|
|
num_image_tokens = images.get_feature_size(item_idx)
|
|
else:
|
|
|
|
|
|
width = images[0][-1][0][0]
|
|
height = images[0][-1][0][1]
|
|
|
|
num_image_tokens = self.info.get_num_image_tokens(
|
|
image_width=width,
|
|
image_height=height,
|
|
# flag = True,
|
|
cropping=CROP_MODE,
|
|
)
|
|
return [image_token_id] * num_image_tokens
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target=[image_token_id],
|
|
replacement=get_replacement_deepseek_vl2,
|
|
)
|
|
]
|
|
|
|
def _cached_apply_hf_processor(
|
|
self,
|
|
prompt: Union[str, list[int]],
|
|
mm_data_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> tuple[list[int], MultiModalKwargs, bool]:
|
|
# The processor logic is different for len(images) <= 2 vs > 2
|
|
# Since the processing cache assumes that the processor output is
|
|
# invariant of how many images are passed per prompt, we only
|
|
# perform caching for the most common case
|
|
if mm_data_items.get_count("image", strict=False) > 2:
|
|
# This code path corresponds to the cache being disabled
|
|
return self._apply_hf_processor_main(
|
|
prompt=prompt,
|
|
mm_items=mm_data_items,
|
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
enable_hf_prompt_update=True,
|
|
)
|
|
|
|
return super()._cached_apply_hf_processor(
|
|
prompt=prompt,
|
|
mm_data_items=mm_data_items,
|
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
)
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
DeepseekOCRMultiModalProcessor,
|
|
info=DeepseekOCRProcessingInfo,
|
|
dummy_inputs=DeepseekOCRDummyInputsBuilder)
|
|
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|
|
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
|
"language.": "language_model.",
|
|
})
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
|
|
# config.model_type ='deepseek_vl_v2'
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
|
|
self.vision_config = config.vision_config
|
|
self.projector_config = config.projector_config
|
|
self.text_config = config.text_config
|
|
|
|
model_config = vllm_config.model_config
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
|
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
|
|
|
|
self.sam_model = build_sam_vit_b()
|
|
self.vision_model = build_clip_l()
|
|
|
|
n_embed = 1280
|
|
self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
|
|
self.tile_tag = config.tile_tag
|
|
self.global_view_pos = config.global_view_pos
|
|
|
|
# self.sam_model = torch.compile(self.sam_model, mode="reduce-overhead")
|
|
# self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead")
|
|
# self.projector = torch.compile(self.projector, mode="max-autotune")
|
|
|
|
|
|
|
|
|
|
# special token for image token sequence format
|
|
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
|
|
if self.tile_tag == "2D":
|
|
# <|view_separator|>, <|\n|>
|
|
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
|
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
|
else:
|
|
raise ValueError(
|
|
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
|
|
)
|
|
|
|
if self.text_config.topk_method == "noaux_tc":
|
|
architectures = ["DeepseekV3ForCausalLM"]
|
|
elif not self.text_config.use_mla:
|
|
architectures = ["DeepseekForCausalLM"]
|
|
else:
|
|
architectures = ["DeepseekV2ForCausalLM"]
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=self.text_config,
|
|
prefix=maybe_prefix(prefix, "language"),
|
|
architectures=architectures,
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object):
|
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
|
images_crop = kwargs.pop("images_crop", None)
|
|
|
|
|
|
if pixel_values is None or torch.sum(pixel_values).item() == 0:
|
|
return None
|
|
|
|
if pixel_values is not None:
|
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of pixel values. "
|
|
f"Got type: {type(pixel_values)}")
|
|
|
|
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of image sizes. "
|
|
f"Got type: {type(images_spatial_crop)}")
|
|
|
|
if not isinstance(images_crop, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of image crop. "
|
|
f"Got type: {type(images_crop)}")
|
|
|
|
return [pixel_values, images_crop, images_spatial_crop]
|
|
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
|
|
|
|
def _pixel_values_to_embedding(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
images_crop: torch.Tensor,
|
|
images_spatial_crop: torch.Tensor,
|
|
) -> NestedTensors:
|
|
|
|
# Pixel_values (global view): [n_image, batch_size, 3, height, width]
|
|
# images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
|
|
# images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
|
|
# split the pixel and image_crop, all batch_size = 1
|
|
|
|
images_in_this_batch = []
|
|
|
|
|
|
# print(type(images_crop))
|
|
|
|
# print(pixel_values.shape)
|
|
|
|
|
|
with torch.no_grad():
|
|
for jdx in range(images_spatial_crop.size(0)):
|
|
# with torch.set_grad_enabled(False):
|
|
patches = images_crop[jdx][0].to(torch.bfloat16) # batch_size = 1
|
|
image_ori = pixel_values[jdx]
|
|
crop_shape = images_spatial_crop[jdx][0]
|
|
|
|
if torch.sum(patches).item() != 0: # if all values = 0, no crop
|
|
# P, C, H, W = patches.shape
|
|
# crop_flag = 1
|
|
local_features_1 = self.sam_model(patches)
|
|
#TODO del patches
|
|
# torch.compiler.cudagraph_mark_step_begin()
|
|
local_features_2 = self.vision_model(patches, local_features_1)
|
|
|
|
|
|
local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
|
local_features = self.projector(local_features)
|
|
|
|
|
|
global_features_1 = self.sam_model(image_ori)
|
|
global_features_2 = self.vision_model(image_ori, global_features_1)
|
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
|
global_features = self.projector(global_features)
|
|
|
|
if PRINT_NUM_VIS_TOKENS:
|
|
print('=====================')
|
|
print('BASE: ', global_features.shape)
|
|
print('PATCHES: ', local_features.shape)
|
|
print('=====================')
|
|
|
|
_, hw, n_dim = global_features.shape
|
|
h = w = int(hw ** 0.5)
|
|
|
|
_2, hw2, n_dim2 = local_features.shape
|
|
h2 = w2 = int(hw2 ** 0.5)
|
|
|
|
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
|
|
|
|
global_features = global_features.view(h, w, n_dim)
|
|
|
|
global_features = torch.cat(
|
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
|
|
)
|
|
|
|
global_features = global_features.view(-1, n_dim)
|
|
|
|
|
|
local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
|
|
local_features = torch.cat(
|
|
[local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
|
|
)
|
|
local_features = local_features.view(-1, n_dim2)
|
|
|
|
global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
|
|
|
|
else:
|
|
global_features_1 = self.sam_model(image_ori)
|
|
global_features_2 = self.vision_model(image_ori, global_features_1)
|
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
|
global_features = self.projector(global_features)
|
|
|
|
if PRINT_NUM_VIS_TOKENS:
|
|
print('=====================')
|
|
print('BASE: ', global_features.shape)
|
|
print('NO PATCHES')
|
|
print('=====================')
|
|
|
|
_, hw, n_dim = global_features.shape
|
|
h = w = int(hw ** 0.5)
|
|
|
|
global_features = global_features.view(h, w, n_dim)
|
|
|
|
global_features = torch.cat(
|
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
|
|
)
|
|
|
|
global_features = global_features.view(-1, n_dim)
|
|
|
|
global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
|
|
|
|
images_in_this_batch.append(global_local_features)
|
|
|
|
return images_in_this_batch
|
|
|
|
def _process_image_input(
|
|
self, image_input) -> torch.Tensor:
|
|
|
|
|
|
# image_input: [pixel_values, images_crop, images_spatial_crop]
|
|
|
|
pixel_values = image_input[0].to(torch.bfloat16)
|
|
# print(image_input[1][0].shape)
|
|
# print(type(image_input[1]))
|
|
# exit()
|
|
|
|
# images_crop = image_input[1].to(torch.bfloat16)
|
|
images_crop = image_input[1]
|
|
# images_crop = image_input[1]
|
|
images_spatial_crop = image_input[2].to(dtype=torch.long)
|
|
|
|
# local_start = time.time()
|
|
vision_features = self._pixel_values_to_embedding(
|
|
pixel_values=pixel_values, images_crop = images_crop, images_spatial_crop=images_spatial_crop)
|
|
|
|
# local_total_time = time.time() - local_start
|
|
|
|
# print('encoder_time: ', local_total_time)
|
|
# exit()
|
|
return vision_features
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.language_model
|
|
|
|
def get_multimodal_embeddings(
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return None
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
return vision_embeddings
|
|
|
|
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
) -> torch.Tensor:
|
|
|
|
|
|
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
|
|
|
|
if multimodal_embeddings is not None:
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids, inputs_embeds, multimodal_embeddings,
|
|
self.image_token_id)
|
|
# print(len(multimodal_embeddings))
|
|
# print(input_ids.shape)
|
|
# print(type(inputs_embeds))
|
|
# print(inputs_embeds.shape)
|
|
|
|
return inputs_embeds
|
|
|
|
def forward(self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object):
|
|
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
# condition is for v0 compatibility
|
|
elif inputs_embeds is None:
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
vision_embeddings)
|
|
input_ids = None
|
|
|
|
hidden_states = self.language_model(input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
|
processed_weights = []
|
|
|
|
for name, tensor in weights:
|
|
if 'sam_model' in name or 'vision_model' in name or 'projector' in name or 'image_newline' in name or 'view_seperator' in name:
|
|
new_name = name.replace('model.', '', 1)
|
|
else:
|
|
new_name = 'language.' + name
|
|
|
|
processed_weights.append((new_name, tensor))
|
|
|
|
loader = AutoWeightsLoader(self)
|
|
autoloaded_weights = loader.load_weights(processed_weights, mapper=self.hf_to_vllm_mapper)
|
|
|
|
|
|
|
|
|
|
|
|
return autoloaded_weights
|