|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch VideoLLaMA3 model.""" |
|
|
|
import importlib.util |
|
import os.path as osp |
|
import re |
|
from abc import ABC, abstractmethod |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.checkpoint |
|
|
|
from transformers import AutoModel, Qwen2ForCausalLM, Qwen2Model |
|
from transformers.generation.utils import GenerateOutput |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
try: |
|
from .configuration_videollama3 import Videollama3Qwen2Config |
|
except ModuleNotFoundError: |
|
spec = importlib.util.spec_from_file_location( |
|
"configuration_videollama3", |
|
osp.join(osp.dirname(__file__), "configuration_videollama3.py"), |
|
) |
|
configuration_videollama3 = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(configuration_videollama3) |
|
Videollama3Qwen2Config = getattr( |
|
configuration_videollama3, |
|
"Videollama3Qwen2Config", |
|
) |
|
|
|
|
|
def build_mlp(depth, hidden_size, output_hidden_size): |
|
modules = [nn.Linear(hidden_size, output_hidden_size)] |
|
for _ in range(1, depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(output_hidden_size, output_hidden_size)) |
|
return nn.Sequential(*modules) |
|
|
|
|
|
def build_vision_projector(config, delay_load=False, **kwargs): |
|
|
|
projector_type = getattr(config, 'mm_projector_type', 'linear') |
|
if projector_type == "linear": |
|
|
|
return nn.Linear(config.mm_hidden_size, config.hidden_size) |
|
elif projector_type.startswith("mlp"): |
|
return MlpGeluProjector(config, projector_type) |
|
else: |
|
raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
|
|
|
class MlpGeluProjector(nn.Module): |
|
|
|
def __init__(self, config, projector_type): |
|
super().__init__() |
|
|
|
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) |
|
mlp_depth = int(mlp_gelu_match.group(1)) |
|
|
|
self.readout = build_mlp(mlp_depth, config.vision_encoder_config.hidden_size, config.hidden_size) |
|
|
|
def forward(self, x): |
|
x = self.readout(x) |
|
return x |
|
|
|
|
|
class Videollama3MetaModel: |
|
|
|
def __init__(self, config): |
|
super(Videollama3MetaModel, self).__init__(config) |
|
if config.vision_encoder is not None: |
|
self.vision_encoder = AutoModel.from_pretrained( |
|
config.vision_encoder, |
|
attn_implementation=self.config._attn_implementation, |
|
torch_dtype=self.dtype, |
|
) |
|
self.config.vision_encoder_config = self.vision_encoder.config |
|
self.config.vision_encoder = None |
|
elif config.vision_encoder_config is not None: |
|
self.vision_encoder = AutoModel.from_config( |
|
self.config.vision_encoder_config, |
|
attn_implementation=self.config._attn_implementation, |
|
torch_dtype=self.dtype, |
|
) |
|
else: |
|
raise ValueError("Vision encoder is not provided in config") |
|
self.mm_projector = build_vision_projector(config) |
|
|
|
def get_vision_encoder(self): |
|
return self.vision_encoder |
|
|
|
def get_mm_projector(self): |
|
return self.mm_projector |
|
|
|
|
|
class Videollama3Qwen2Model(Videollama3MetaModel, Qwen2Model): |
|
|
|
config_class = Videollama3Qwen2Config |
|
|
|
def __init__(self, config: Videollama3Qwen2Config): |
|
super(Videollama3Qwen2Model, self).__init__(config) |
|
|
|
|
|
class Videollama3MetaForCausalLM(ABC): |
|
|
|
@abstractmethod |
|
def get_model(self): |
|
pass |
|
|
|
def get_vision_encoder(self): |
|
return self.get_model().get_vision_encoder() |
|
|
|
def get_mm_projector(self): |
|
return self.get_model().get_mm_projector() |
|
|
|
def encode_images( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
grid_sizes: torch.LongTensor, |
|
merge_sizes: torch.LongTensor, |
|
) -> torch.FloatTensor: |
|
mm_features = self.get_model().get_vision_encoder()( |
|
pixel_values=pixel_values, |
|
grid_sizes=grid_sizes, |
|
merge_sizes=merge_sizes, |
|
) |
|
mm_features = self.get_model().mm_projector(mm_features) |
|
return mm_features |
|
|
|
def _get_valid_visual_tokens( |
|
self, |
|
mm_features: torch.FloatTensor, |
|
batched_num_patches: torch.LongTensor, |
|
modals: List[str], |
|
): |
|
valid_masks = [] |
|
for num_patches, modal in zip(batched_num_patches, modals): |
|
valid_mask = torch.full((num_patches, ), modal != "text", dtype=torch.bool, device=mm_features.device) |
|
valid_masks.append(valid_mask) |
|
mm_features = mm_features[torch.cat(valid_masks)] |
|
return mm_features |
|
|
|
def _maybe_truncate_visual_tokens( |
|
self, |
|
mm_features: torch.FloatTensor, |
|
compression_mask: torch.BoolTensor, |
|
batched_num_patches: torch.LongTensor, |
|
modals: List[str], |
|
input_ids: torch.LongTensor, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
): |
|
if position_ids is None or mm_features.shape[0] == input_ids.eq(self.config.image_token_index).sum(): |
|
return mm_features, compression_mask |
|
|
|
truncation_mask = [] |
|
for num_patches, modal in zip(batched_num_patches, modals): |
|
if modal == "text": |
|
truncation_mask.append(torch.ones((0,), dtype=torch.bool, device=input_ids.device)) |
|
else: |
|
truncation_mask.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device)) |
|
|
|
seq_end_indices = torch.nonzero(position_ids == 0)[:, 0] |
|
seq_end_indices = seq_end_indices[seq_end_indices > 0].tolist()+ [len(input_ids)] |
|
seq_start_indices = [0] + seq_end_indices[:-1] |
|
num_visual_tokens = [ |
|
input_ids[start:end].eq(self.config.image_token_index).sum() |
|
for start, end in zip(seq_start_indices, seq_end_indices) |
|
] |
|
|
|
for n, mask in zip(num_visual_tokens, truncation_mask): |
|
if len(mask) > 0: |
|
mask[n:] = False |
|
truncation_mask = torch.cat(truncation_mask) |
|
|
|
return mm_features[truncation_mask], compression_mask[truncation_mask] |
|
|
|
def _get_compression_mask( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
batched_num_patches: torch.LongTensor, |
|
grid_sizes: torch.LongTensor, |
|
merge_sizes: torch.LongTensor, |
|
modals: List[str], |
|
threshold: float = 0.1, |
|
min_tokens: int = 1, |
|
) -> torch.BoolTensor: |
|
batched_images = pixel_values.split(grid_sizes.prod(dim=1).tolist(), dim=0) |
|
compression_masks = [] |
|
|
|
for images, num_patches, grid_size, merge_size, modal in zip( |
|
batched_images, batched_num_patches, grid_sizes, merge_sizes, modals |
|
): |
|
t, h, w = grid_size |
|
if modal == "image" or (modal == "video" and t == 1): |
|
compression_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=images.device)) |
|
|
|
elif modal == "video": |
|
|
|
images = images.view(t, (h // merge_size) * (w // merge_size), -1) |
|
|
|
pixel_diff = images[1:] - images[:-1] |
|
pixel_diff = torch.abs(pixel_diff).mean(dim=-1) * 255 |
|
pixel_diff = torch.cat([torch.full_like(pixel_diff[0:1], threshold + 1), pixel_diff], dim=0) |
|
mask = pixel_diff > threshold |
|
padding_ids = torch.nonzero(mask.sum(dim=1) < min_tokens)[:, 0] |
|
|
|
mask[padding_ids, :min_tokens] = 1 |
|
compression_masks.append(mask.flatten()) |
|
|
|
else: |
|
|
|
compression_masks.append(torch.ones((0,), dtype=torch.bool, device=images.device)) |
|
|
|
return torch.cat(compression_masks) |
|
|
|
def _compress_visual_tokens( |
|
self, |
|
compression_mask: torch.BoolTensor, |
|
mm_features: torch.FloatTensor, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
): |
|
mm_features = mm_features[compression_mask] |
|
image_selected = (input_ids == self.config.image_token_index) |
|
|
|
text_masks = torch.logical_not(image_selected) |
|
text_masks[image_selected] = compression_mask |
|
input_ids = input_ids[text_masks] |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask[text_masks] |
|
if labels is not None: |
|
labels = labels[text_masks] |
|
if position_ids is not None: |
|
|
|
position_ids = position_ids[text_masks] |
|
pos_start = [0] + torch.nonzero(position_ids == 0)[:, 0].tolist() |
|
pos_end = pos_start[1:] + [len(input_ids)] |
|
position_ids = torch.cat([torch.arange(end - start, device=input_ids.device) for start, end in zip(pos_start, pos_end)]) |
|
|
|
return mm_features, input_ids, attention_mask, position_ids, labels |
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
grid_sizes: Optional[torch.LongTensor] = None, |
|
merge_sizes: Optional[torch.LongTensor] = None, |
|
modals: Optional[List[str]] = None, |
|
): |
|
vision_encoder = self.get_vision_encoder() |
|
|
|
if vision_encoder is None or pixel_values is None or input_ids.shape[1] == 1: |
|
return input_ids, attention_mask, position_ids, past_key_values, None, labels |
|
|
|
|
|
B, N = input_ids.shape |
|
input_ids = input_ids.view(B * N) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.view(B * N) |
|
if position_ids is not None: |
|
position_ids = position_ids.view(B * N) |
|
if labels is not None: |
|
labels = labels.view(B * N) |
|
|
|
|
|
batched_num_patches = grid_sizes.prod(dim=1).div(merge_sizes ** 2).long() |
|
mm_features = self.encode_images(pixel_values, grid_sizes, merge_sizes) |
|
mm_features = self._get_valid_visual_tokens(mm_features, batched_num_patches, modals) |
|
|
|
compression_mask = self._get_compression_mask( |
|
pixel_values, batched_num_patches, grid_sizes, merge_sizes, modals |
|
) |
|
mm_features, compression_mask = self._maybe_truncate_visual_tokens( |
|
mm_features, compression_mask, batched_num_patches, modals, input_ids, position_ids |
|
) |
|
|
|
|
|
if self.config.use_token_compression: |
|
assert B == 1, "Token compression is only supported for batch_size=1" |
|
mm_features, input_ids, attention_mask, labels, position_ids = self._compress_visual_tokens( |
|
compression_mask, mm_features, input_ids, attention_mask, labels, position_ids |
|
) |
|
|
|
|
|
inputs_embeds = self.get_model().embed_tokens(input_ids).clone() |
|
|
|
|
|
image_selected = (input_ids == self.config.image_token_index) |
|
inputs_embeds[image_selected] = inputs_embeds[image_selected] * 0.0 + mm_features |
|
|
|
|
|
C = inputs_embeds.shape[-1] |
|
inputs_embeds = inputs_embeds.reshape(B, -1, C) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.view(B, -1) |
|
if labels is not None: |
|
labels = labels.view(B, -1) |
|
if position_ids is not None: |
|
position_ids = position_ids.view(B, -1) |
|
|
|
return None, attention_mask, position_ids, past_key_values, inputs_embeds, labels |
|
|
|
|
|
class Videollama3Qwen2ForCausalLM(Qwen2ForCausalLM, Videollama3MetaForCausalLM): |
|
|
|
config_class = Videollama3Qwen2Config |
|
|
|
def __init__(self, config, **kwargs): |
|
super(Qwen2ForCausalLM, self).__init__(config) |
|
self.model = Videollama3Qwen2Model(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_model(self): |
|
return self.model |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
num_logits_to_keep: int = 0, |
|
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
grid_sizes: Optional[torch.LongTensor] = None, |
|
merge_sizes: Optional[torch.LongTensor] = None, |
|
modals: Optional[List[str]] = None, |
|
**loss_kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
if inputs_embeds is None: |
|
( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
inputs_embeds, |
|
labels, |
|
) = self.prepare_inputs_labels_for_multimodal( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
labels=labels, |
|
pixel_values=pixel_values, |
|
grid_sizes=grid_sizes, |
|
merge_sizes=merge_sizes, |
|
modals=modals, |
|
) |
|
|
|
return super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
num_logits_to_keep=num_logits_to_keep, |
|
**loss_kwargs, |
|
) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
grid_sizes: Optional[torch.LongTensor] = None, |
|
merge_sizes: Optional[torch.LongTensor] = None, |
|
modals: Optional[List[str]] = None, |
|
**kwargs, |
|
) -> Union[GenerateOutput, torch.LongTensor]: |
|
input_ids = kwargs.pop("input_ids", None) |
|
attention_mask = kwargs.pop("attention_mask", None) |
|
position_ids = kwargs.pop("position_ids", None) |
|
past_key_values = kwargs.pop("past_key_values", None) |
|
|
|
if "inputs_embeds" in kwargs: |
|
raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
|
if pixel_values is not None: |
|
( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
inputs_embeds, |
|
labels, |
|
) = self.prepare_inputs_labels_for_multimodal( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
labels=None, |
|
pixel_values=pixel_values, |
|
grid_sizes=grid_sizes, |
|
merge_sizes=merge_sizes, |
|
modals=modals, |
|
) |
|
else: |
|
inputs_embeds = self.get_model().embed_tokens(input_ids) |
|
|
|
return super().generate( |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs |
|
) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
|
images = kwargs.pop("images", None) |
|
_inputs = super().prepare_inputs_for_generation( |
|
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
|
) |
|
if images is not None: |
|
_inputs['images'] = images |
|
return _inputs |
|
|