diff --git "a/modeling_xgenmm.py" "b/modeling_xgenmm.py" new file mode 100644--- /dev/null +++ "b/modeling_xgenmm.py" @@ -0,0 +1,2516 @@ +import ast +import math +from einops import rearrange, repeat +from einops_exts import rearrange_many +from einops import rearrange +from PIL import Image +import torch +from torch import einsum, nn + +import numpy + + +from typing import List, Optional, Tuple, Union +import torch.nn.functional as F +from transformers.modeling_outputs import CausalLMOutputWithPast +from dataclasses import dataclass +from transformers import CLIPVisionModel +from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel +from transformers import PretrainedConfig, logging, CONFIG_MAPPING +from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer + + +logger = logging.get_logger(__name__) + + +class XGenMMVisionEncoderConfig(PretrainedConfig): + model_type = "xgenmm_vision_encoder" + + def __init__( + self, + model_name: str = "google/siglip-so400m-patch14-384", + anyres_grids: list[int] = [ + [384, 768], + [768, 384], + [768, 768], + [1152, 384], + [384, 1152], + ], + **kwargs, + ): + self.model_name = model_name + self.anyres_grids = anyres_grids + super().__init__(**kwargs) + + +class XGenMMVisionTokenizerConfig(PretrainedConfig): + model_type = "xgenmm_vision_tokenizer" + + def __init__( + self, + vis_feature_dim: int = 1152, + lang_embedding_dim: int = 3072, + num_vis_tokens: int = 128, + image_aspect_ratio: str = "anyres", + **kwargs, + ): + self.vis_feature_dim = vis_feature_dim + self.lang_embedding_dim = lang_embedding_dim + self.num_vis_tokens = num_vis_tokens + self.image_aspect_ratio = image_aspect_ratio + super().__init__(**kwargs) + + +class XGenMMConfig(PretrainedConfig): + model_type = "xgenmm" + + def __init__( + self, + vision_encoder_config: dict = None, + vision_tokenizer_config: dict = None, + text_config: dict = None, + **kwargs, + ): + + if vision_encoder_config is None: + vision_encoder_config = { + "image_aspect_ratio": "pad", + "anyres_patch_sampling": False, + } + logger.info( + "vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values." + ) + + if vision_tokenizer_config is None: + vision_tokenizer_config = {} + logger.info( + "vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values." + ) + + if text_config is None: + text_config = { + "initial_tokenizer_len": 32012, + "pad_token_id": 32011, + "bos_token_id": 1, + "eos_token_id": 32000, + "vocab_size": 32064, + "hidden_size": 3072, + "intermediate_size": 8192, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 32, + "resid_pdrop": 0.0, + "embd_pdrop": 0.0, + "attention_dropout": 0.0, + "hidden_act": "silu", + "max_position_embeddings": 4096, + "original_max_position_embeddings": 4096, + "initializer_range": 0.02, + "rms_norm_eps": 1e-05, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "sliding_window": 2047, + "return_dict": True, + "output_hidden_states": False, + "output_attentions": False, + "torchscript": False, + "torch_dtype": "bfloat16", + "use_bfloat16": False, + "tf_legacy_loss": False, + "pruned_heads": {}, + "tie_word_embeddings": False, + "chunk_size_feed_forward": 0, + "is_encoder_decoder": False, + "is_decoder": False, + "cross_attention_hidden_size": None, + "add_cross_attention": False, + "tie_encoder_decoder": False, + "max_length": 20, + "min_length": 0, + "do_sample": False, + "early_stopping": False, + "num_beams": 1, + "num_beam_groups": 1, + "diversity_penalty": 0.0, + "temperature": 1.0, + "top_k": 50, + "top_p": 1.0, + "typical_p": 1.0, + "repetition_penalty": 1.0, + "length_penalty": 1.0, + "no_repeat_ngram_size": 0, + "encoder_no_repeat_ngram_size": 0, + "bad_words_ids": None, + "num_return_sequences": 1, + "output_scores": False, + "return_dict_in_generate": False, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "remove_invalid_values": False, + "exponential_decay_length_penalty": None, + "suppress_tokens": None, + "begin_suppress_tokens": None, + "finetuning_task": None, + "id2label": {0: "LABEL_0", 1: "LABEL_1"}, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "tokenizer_class": None, + "prefix": None, + "bos_token_id": 1, + "pad_token_id": 32000, + "eos_token_id": 32000, + "sep_token_id": None, + "decoder_start_token_id": None, + "task_specific_params": None, + "problem_type": None, + "model_type": "phi3", + } + logger.info( + "text_config is None. Initializing the text config with default values (`Phi3Config`)." + ) + + self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config) + + self.vision_tokenizer_config = XGenMMVisionTokenizerConfig( + **vision_tokenizer_config + ) + + text_model_type = ( + text_config["model_type"] if "model_type" in text_config else "phi3" + ) + self.text_config = CONFIG_MAPPING[text_model_type](**text_config) + + for key in ["initial_tokenizer_len", "pad_token_id"]: + if key not in self.text_config.to_dict(): + raise ValueError(f"The key `{key}` is missing in the text_config.") + + super().__init__(**kwargs) + + +def hasattr_recursive(obj, att): + """ + Check if obj has nested attribute + Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c') + """ + if att == "": + return True + i = att.find(".") + if i < 0: + return hasattr(obj, att) + else: + try: + return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) + except: + return False + + +def getattr_recursive(obj, att): + """ + Return nested attribute of obj + Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c + """ + if att == "": + return obj + i = att.find(".") + if i < 0: + return getattr(obj, att) + else: + return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) + + +def setattr_recursive(obj, att, val): + """ + Set nested attribute of obj + Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val + """ + if "." in att: + obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) + setattr(obj, att.split(".")[-1], val) + + +def check_embedding_fns(lang_model): + """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model""" + if not has_fn(lang_model, "get_input_embeddings"): + if hasattr_recursive(lang_model, "transformer.wte"): # MPT + lang_model.get_input_embeddings = lambda: lang_model.transformer.wte + elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT + lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens + else: + raise ValueError( + "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py." + ) + + if not has_fn(lang_model, "set_input_embeddings"): + if hasattr_recursive(lang_model, "transformer.wte"): # MPT + lang_model.set_input_embeddings = lambda x: setattr_recursive( + lang_model, "transformer.wte", x + ) + elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT + lang_model.set_input_embeddings = lambda x: setattr_recursive( + lang_model, "model.decoder.embed_tokens", x + ) + else: + raise ValueError( + "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py." + ) + + if not has_fn(lang_model, "get_output_embeddings"): + if hasattr_recursive(lang_model, "lm_head"): + lang_model.get_output_embeddings = lambda: lang_model.lm_head + else: + raise ValueError( + "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." + ) + + if not has_fn(lang_model, "set_output_embeddings"): + if hasattr_recursive(lang_model, "lm_head"): + lang_model.set_output_embeddings = lambda x: setattr_recursive( + lang_model, "lm_head", x + ) + else: + raise ValueError( + "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." + ) + + +def has_fn(model, fn_name): + """Check if model has a function fn_name""" + return callable(getattr(model, fn_name, None)) + + +def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"): + """ + Stack a list of tensors with padding on one side + Args: + list_of_tensors (list[torch.Tensor]): List of tensors to stack + padding_value (int, optional): Value to pad with. Defaults to 0. + padding_side (str, optional): Side to pad on. Defaults to "right". + Returns: + torch.Tensor: Stacked tensors + """ + max_tokens = max(tensor.size(0) for tensor in list_of_tensors) + padded_tensors = [] + for tensor in list_of_tensors: + num_tokens = tensor.size(0) + if len(tensor.size()) == 1: + padding = torch.full( + (max_tokens - num_tokens,), + padding_value, + dtype=tensor.dtype, + device=tensor.device, + ) + else: + padding = torch.full( + (max_tokens - num_tokens, tensor.size(1)), + padding_value, + dtype=tensor.dtype, + device=tensor.device, + ) + padded_tensor = ( + torch.cat((tensor, padding), dim=0) + if padding_side == "right" + else torch.cat((padding, tensor), dim=0) + ) + padded_tensors.append(padded_tensor) + return torch.stack(padded_tensors) + + +def unpad_image(tensor, original_size, keep_original_shape=False): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of the image (height, width). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + if keep_original_shape: + attention_mask = torch.ones( + (current_height, current_width), device=tensor.device + ) + attention_mask[:padding, :] = 0 + attention_mask[current_height - padding :, :] = 0 + return tensor, attention_mask + else: + unpadded_tensor = tensor[:, padding : current_height - padding, :] + return unpadded_tensor, None + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + if keep_original_shape: + attention_mask = torch.ones( + (current_height, current_width), device=tensor.device + ) + attention_mask[:, :padding] = 0 + attention_mask[:, current_width - padding :] = 0 + return tensor, attention_mask + else: + unpadded_tensor = tensor[:, :, padding : current_width - padding] + return unpadded_tensor, None + + +def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int( + original_height * scale + ) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """ + Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width, height = select_best_resolution(image_size, possible_resolutions) + return width // patch_size, height // patch_size + + +def process_anyres_image(image, processor, grid_pinpoints): + """ + Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ + # FIXME: determine grid_pinpoints from image sizes. + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + processor_size = processor.transforms[0].size + patches = divide_to_patches(image_padded, processor_size[0]) + + image_original_resize = image.resize((processor_size[0], processor_size[0])) + + image_patches = [image_original_resize] + patches + image_patches = [processor(image_patch) for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +class VisionTokenizer(nn.Module): + def __init__(self, dim_media, num_tokens_per_media): + super().__init__() + self.dim_media = dim_media + self.num_tokens_per_media = num_tokens_per_media + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents, vision_attn_masks=None): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat( + (x, latents), dim=-2 + ) # TODO: Change the shape of vision attention mask according to this. + if vision_attn_masks is not None: + vision_attn_masks = torch.cat( + ( + vision_attn_masks, + torch.ones( + (latents.shape[0], latents.shape[-2]), + dtype=latents.dtype, + device=latents.device, + ), + ), + dim=-1, + ) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + # Apply vision attention mask here. + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention + if vision_attn_masks is not None: + attn_bias = torch.zeros( + (q.size(0), 1, 1, q.size(-2), k.size(-2)), + dtype=q.dtype, + device=q.device, + ) + vision_attn_masks = repeat( + vision_attn_masks, "b n -> b 1 1 l n", l=q.size(-2) + ) + attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf")) + sim += attn_bias + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def MLP(dim, inner_dim=-1, out_dim=-1): + inner_dim = dim * 2 if inner_dim < 0 else inner_dim + out_dim = dim if out_dim < 0 else out_dim + + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, out_dim, bias=False), + ) + + +def get_emb(sin_inp): + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + + +class PositionalEncoding1D(nn.Module): + def __init__(self, channels): + """ + :param channels: The last dimension of the tensor you want to apply pos emb to. + """ + super(PositionalEncoding1D, self).__init__() + self.org_channels = channels + channels = int(numpy.ceil(channels / 2) * 2) + self.channels = channels + inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + self.register_buffer("inv_freq", inv_freq) + self.register_buffer("cached_penc", None, persistent=False) + + def forward(self, tensor): + """ + :param tensor: A 3d tensor of size (batch_size, x, ch) + :return: Positional Encoding Matrix of size (batch_size, x, ch) + """ + if len(tensor.shape) != 3: + raise RuntimeError("The input tensor has to be 3d!") + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + return self.cached_penc + + self.cached_penc = None + batch_size, x, orig_ch = tensor.shape + pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + emb_x = get_emb(sin_inp_x) + emb = torch.zeros((x, self.channels), device=tensor.device, dtype=tensor.dtype) + emb[:, : self.channels] = emb_x + + self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1) + return self.cached_penc + + +class MultiHeadSelfAttention(nn.Module): + def __init__(self, *, dim, inner_dim, heads=8): + super().__init__() + dim_head = inner_dim // heads + self.scale = dim_head**-0.5 + self.heads = heads + + self.norm = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, n, D) + """ + latents = self.norm(x) + + h = self.heads + + q = self.to_q(latents) + k = self.to_k(latents) + v = self.to_v(latents) + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)", h=h) + return out + + +class TokenLearnerAttentionModule(nn.Module): + def __init__(self, *, dim, num_target_tokens): + super().__init__() + + self.mlp = MLP(dim, inner_dim=num_target_tokens * 2, out_dim=num_target_tokens) + + self.norm = nn.LayerNorm(dim) + self.num_target_tokens = num_target_tokens + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n, D) + """ + inputs = self.norm(x) + + attn = self.mlp(inputs) + attn = attn.softmax(dim=-2) + + out = einsum("... n i, ... n d -> ... i d", attn, x) + + return out + + +class GroupedTokenTuringMachineUnit(nn.Module): + def __init__( + self, + *, + dim, + process_size=128, + memory_size_per_group=4, + num_layers=1, + num_heads=8, + ): + super().__init__() + + self.process_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.process_layers.append( + nn.ModuleList( + [ + MultiHeadSelfAttention( + dim=dim, inner_dim=dim, heads=num_heads + ), + FeedForward(dim=dim, mult=4), + ] + ) + ) + + self.read_layer = TokenLearnerAttentionModule(dim=dim, num_target_tokens=process_size) + self.write_layer = TokenLearnerAttentionModule(dim=dim, num_target_tokens=memory_size_per_group) + + def forward(self, memory_tokens, input_tokens): + """ + Args: + memory_tokens (torch.Tensor): + shape (b, n, group_memory_size, D) + input_tokens (torch.Tensor): + shape (b, n, D) + """ + b, n, g, D = memory_tokens.shape + + input_tokens = input_tokens.unsqueeze(2) # (b, n, 1, D) + all_tokens = torch.cat([memory_tokens, input_tokens], dim=2) + + latents = all_tokens.view(b*n, g+1, D) + + for attn, ff in self.process_layers: + latents = attn(latents) + latents + latents = ff(latents) + latents + + # mem_out_tokens = memory_tokens.view(b*n, g, D) + latents = latents.view(b, n, g+1, D) + mem_out_tokens = torch.cat([memory_tokens, latents], dim=2) + + mem_out_tokens = mem_out_tokens.view(b*n, -1, D) + mem_out_tokens = self.write_layer(mem_out_tokens) + mem_out_tokens = mem_out_tokens.view(b, n, g, D) + + return mem_out_tokens + + +class TokenTuringMachineUnit(nn.Module): + def __init__( + self, + *, + dim, + process_size=64, + memory_size=128, + output_size=32, + num_layers=1, + num_heads=8, + ): + super().__init__() + + self.process_layers = nn.ModuleList([]) + for _ in range(num_layers): + self.process_layers.append( + nn.ModuleList( + [ + MultiHeadSelfAttention( + dim=dim, inner_dim=dim, heads=num_heads + ), + FeedForward(dim=dim, mult=4), + ] + ) + ) + + self.read_layer = TokenLearnerAttentionModule(dim=dim, num_target_tokens=process_size) + self.write_layer = TokenLearnerAttentionModule(dim=dim, num_target_tokens=memory_size) + self.output_layer = TokenLearnerAttentionModule(dim=dim, num_target_tokens=output_size) + + def forward(self, memory_tokens, input_tokens): + """ + Args: + memory_tokens (torch.Tensor): + shape (b, memory_size, D) + input_tokens (torch.Tensor): + shape (b, n, D) + """ + all_tokens = torch.cat([memory_tokens, input_tokens], dim=1) + + latents = self.read_layer(all_tokens) + + for attn, ff in self.process_layers: + latents = attn(latents) + latents + latents = ff(latents) + latents + + mem_out_tokens = torch.cat([memory_tokens, latents], dim=1) + mem_out_tokens = self.write_layer(mem_out_tokens) + + output_tokens = self.output_layer(latents) + + return (mem_out_tokens, output_tokens) + + +class GroupedTokenTuringMachine4(nn.Module): + def __init__( + self, + *, + dim, + process_size=128, + memory_size_per_group=4, + output_size=128, + num_layers=4, + num_heads=8, + ): + super().__init__() + + self.ttm_unit = GroupedTokenTuringMachineUnit( + dim=dim, + process_size=process_size, + memory_size_per_group=memory_size_per_group, + num_layers=num_layers, + num_heads=num_heads) + + self.initial_memory = nn.Parameter(torch.randn(process_size, memory_size_per_group, dim)) + + self.pos_emb = PositionalEncoding1D(dim) + + self.final_output = TokenLearnerAttentionModule(dim=dim, num_target_tokens=output_size) + + def forward(self, x): + """ + Args: + x (torch.Tensor): + shape (b, T, n, D) + """ + b, T, n, D = x.shape + + memory_tokens = repeat(self.initial_memory, "n g d -> b n g d", b=b) + + mean_x = torch.mean(x, dim=-2, keepdim=False) + positional_embeddings = self.pos_emb(mean_x) # (b, T, d) + + for i in range(T): + step_tokens = x[:, i, :, :] + + pos = positional_embeddings[:, i, :] + pos = pos.unsqueeze(1) + step_tokens = step_tokens + pos + memory_tokens = self.ttm_unit(memory_tokens, step_tokens) + + output_tokens = memory_tokens.view(b, -1, D) + output_tokens = self.final_output(output_tokens) + + return output_tokens.unsqueeze(1) + + +class GroupedTokenTuringMachine(nn.Module): + def __init__( + self, + *, + dim, + process_size=128, + memory_size_per_group=4, + num_layers=4, + num_heads=8, + ): + super().__init__() + + self.ttm_unit = GroupedTokenTuringMachineUnit( + dim=dim, + process_size=process_size, + memory_size_per_group=memory_size_per_group, + num_layers=num_layers, + num_heads=num_heads) + + self.initial_memory = nn.Parameter(torch.randn(process_size, memory_size_per_group, dim)) + + self.pos_emb = PositionalEncoding1D(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): + shape (b, T, n, D) + """ + b, T, n, D = x.shape + + memory_tokens = repeat(self.initial_memory, "n g d -> b n g d", b=b) + + mean_x = torch.mean(x, dim=-2, keepdim=False) + positional_embeddings = self.pos_emb(mean_x) # (b, T, d) + + for i in range(T): + step_tokens = x[:, i, :, :] + + pos = positional_embeddings[:, i, :] + pos = pos.unsqueeze(1) + step_tokens = step_tokens + pos + memory_tokens = self.ttm_unit(memory_tokens, step_tokens) + + memory_tokens = torch.mean(memory_tokens, dim=-2, keepdim=False) + # memory_tokens = torch.amax(memory_tokens, dim=-2, keepdim=False) + + return memory_tokens.unsqueeze(1) + + +class TokenTuringMachine(nn.Module): + def __init__( + self, + *, + dim, + process_size=64, + memory_size=128, + output_size=32, + num_layers=2, + num_heads=8, + final_output_only=False, + memory_out_mode=False, + ): + super().__init__() + + self.ttm_unit = TokenTuringMachineUnit( + dim=dim, + process_size=process_size, + memory_size=memory_size, + output_size=output_size, + num_layers=num_layers, + num_heads=num_heads) + + self.initial_memory = nn.Parameter(torch.randn(memory_size, dim)) + + self.final_output_only = final_output_only + + self.memory_out_mode = memory_out_mode + if self.memory_out_mode: + self.pos_emb = PositionalEncoding1D(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): + shape (b, T, n, D) + """ + b, T, n, D = x.shape + + output_tokens_list = [] + + memory_tokens = repeat(self.initial_memory, "n d -> b n d", b=b) + + if self.memory_out_mode: + positional_embeddings = self.pos_emb(x[:, :, 0, :]) + + for i in range(T): + step_tokens = x[:, i, :, :] + + if self.memory_out_mode: + pos = positional_embeddings[:, i, :] + pos = pos.unsqueeze(1) + step_tokens = step_tokens + pos + + # print(step_tokens.shape) + memory_tokens, output_tokens = self.ttm_unit(memory_tokens, step_tokens) + # print(f'memory_tokens shape: {memory_tokens.shape}') + # print(f'output_tokens shape: {output_tokens.shape}') + output_tokens_list.append(output_tokens) + + if self.final_output_only: + # return output_tokens.unsqueeze(1) + return output_tokens.unsqueeze(1) + elif self.memory_out_mode: + return memory_tokens.unsqueeze(1) + else: + output_tokens = torch.stack(output_tokens_list, dim=1) + return output_tokens + + +def num_params(module, filter_to_trainable=False): + """Returns the number of parameters in the module, or optionally only the trainable parameters""" + if filter_to_trainable: + return sum(p.numel() for p in module.parameters() if p.requires_grad) + else: + return sum(p.numel() for p in module.parameters()) + + +class PerceiverResampler(VisionTokenizer): + def __init__( + self, + *, + dim, + dim_inner=None, + depth=6, + dim_head=96, + heads=16, + num_latents=128, + max_num_media=None, + max_num_frames=None, + ff_mult=4, + video_mode='gttm', + ): + """ + Perceiver module which takes in image features and outputs image tokens. + Args: + dim (int): dimension of the incoming image features + dim_inner (int, optional): final dimension to project the incoming image features to; + also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim. + depth (int, optional): number of layers. Defaults to 6. + dim_head (int, optional): dimension of each head. Defaults to 64. + heads (int, optional): number of heads. Defaults to 8. + num_latents (int, optional): number of latent tokens to use in the Perceiver; + also corresponds to number of tokens per sequence to output. Defaults to 64. + max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver + and keep positional embeddings for. If None, no positional embeddings are used. + max_num_frames (int, optional): maximum number of frames to input into the Perceiver + and keep positional embeddings for. If None, no positional embeddings are used. + ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4. + """ + if dim_inner is not None: + projection = nn.Linear(dim, dim_inner) + else: + projection = None + dim_inner = dim + super().__init__(dim_media=dim, num_tokens_per_media=num_latents) + self.projection = projection + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + + # positional embeddings + self.frame_embs = ( + nn.Parameter(torch.randn(max_num_frames, dim)) + if exists(max_num_frames) + else None + ) + self.media_time_embs = ( + nn.Parameter(torch.randn(max_num_media, 1, dim)) + if exists(max_num_media) + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + self.video_mode = video_mode + if self.video_mode=='gttm': + # self.ttm = TokenTuringMachine(dim=dim, memory_size=128, memory_out_mode=True) + self.temporal_encoder = GroupedTokenTuringMachine(dim=dim, process_size=128, memory_size_per_group=4) + # self.temporal_encoder = GroupedTokenTuringMachine4(dim=dim, process_size=128, memory_size_per_group=4, output_size=32) + + def forward(self, x, vision_attn_masks): + """ + Args: + x (torch.Tensor): image features + shape (b, T, F, v, D) + vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x) + shape (b, v) + Returns: + shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if exists(self.frame_embs): + frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) + x = x + frame_embs + x = rearrange( + x, "b T F v d -> b T (F v) d" + ) # flatten the frame and spatial dimensions + if exists(self.media_time_embs): + x = x + self.media_time_embs[:T] + + # blocks + latents = self.latents + latents = repeat(latents, "n d -> b T n d", b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents, vision_attn_masks) + latents + latents = ff(latents) + latents + + if self.video_mode is not None: + latents = self.temporal_encoder(latents) + + if exists(self.projection): + return self.projection(self.norm(latents)) + else: + return self.norm(latents) + + +class DecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, + then it will create `num_additional_embeddings` additional parameters that are always trained. If + `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + max_original_id: int, + num_additional_embeddings: int = 0, + _weight: torch.Tensor = None, + num_original_embeddings: int = None, + embedding_dim: int = None, + partially_freeze=True, + device=None, + dtype=None, + pad_token_id=None, + ) -> None: + """ + Args: + max_original_id (`int`): + The largest token id that should be embedded using the regular embedding (regular `weight`). + This is usually len(tokenizer) - 1 before additional tokens are added. + Note that this may not equal self.weight.shape[0] + num_additional_embeddings (`int`): + Number of additional tokens to initialize an Embedding matrix for (`additional_weight`). + _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor. + If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters. + num_original_embeddings (`int`): + self.weight.shape[0] + embedding_dim (`int`): + The size of each embedding vector + partially_freeze: (`bool`, *optional*, defaults to `True`): + If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. + padding_idx (`int`, *optional*): + The padding index (needs to be less than num_embeddings) + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, + `max_norm` or `norm_type`. We are not supporting these. + """ + # validate args + if pad_token_id is not None and pad_token_id > max_original_id: + raise ValueError( + f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}." + + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None." + ) + if _weight is not None: + assert (num_original_embeddings is None) or ( + _weight.shape[0] == num_original_embeddings + ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}" + assert (embedding_dim is None) or ( + _weight.shape[1] == embedding_dim + ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}" + num_original_embeddings = _weight.shape[0] + embedding_dim = _weight.shape[1] + else: + assert ( + num_original_embeddings is not None + ), "num_original_embeddings must be provided if _weight is not provided" + assert ( + embedding_dim is not None + ), "embedding_dim must be provided if _weight is not provided" + + super().__init__( + num_embeddings=num_original_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=pad_token_id, + _weight=_weight, + ) + self.max_original_id = max_original_id + self.padding_idx = pad_token_id + self.num_additional_embeddings = num_additional_embeddings + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.set_requires_grad( + require_regular_grad=not partially_freeze, require_additional_grad=True + ) + + def set_requires_grad(self, require_regular_grad, require_additional_grad): + """ + Helper function to separately set the requires_grad flag for the regular weight and the additional weight. + """ + self.weight.requires_grad_(require_regular_grad) + self.additional_embedding.requires_grad_(require_additional_grad) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd + embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but + then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - + i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are + usually relatively short it's probably not faster or if faster not by much - but might be a good idea to + measure. + + """ + if self.num_additional_embeddings == 0: + return F.embedding(input_ids, self.weight) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids > self.max_original_id) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding( + input_ids_additional_vocab - self.max_original_id - 1 + ) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.max_original_id + 1, + self.num_additional_embeddings, + self.embedding_dim, + (not self.weight.requires_grad), + ) + + +class DecoupledLinear(nn.Linear): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0, + then it will create `additional_out_features * in_features` additional parameters that are always trained. If + `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. + """ + + def __init__( + self, + max_original_id: int, + additional_out_features: int = 0, + _weight: torch.Tensor = None, + _bias: torch.Tensor = None, + in_features: int = None, + original_out_features: int = None, + bias: bool = True, + partially_freeze: bool = True, + device=None, + dtype=None, + ) -> None: + """ + Args: + max_original_id (`int`): The largest token id that should be extracted from the regular weight. + This is usually len(tokenizer) - 1 before additional tokens are added. + Note that this may not equal original_out_features - 1 + _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor. + If provided, this sets the `in_features` and `original_out_features` parameters. + _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor. + in_features: int. Input hidden size. + original_out_features: int. Original out_features of the language model's get_output_embeddings() function. + additional_out_features: int. Number of additional trainable dimensions. + bias: bool. Whether to include a bias term. + partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen. + """ + # argument validation + if _weight is not None: + assert (_weight.shape[0] == original_out_features) or ( + original_out_features is None + ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}" + assert (_weight.shape[1] == in_features) or ( + in_features is None + ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}" + in_features = _weight.shape[1] + original_out_features = _weight.shape[0] + else: + assert ( + in_features is not None + ), "in_features must be provided if _weight is not provided" + assert ( + original_out_features is not None + ), "original_out_features must be provided if _weight is not provided" + + if _bias is not None: + assert bias is True, "bias must be True if _bias is provided" + + # initialize original linear + super().__init__(in_features, original_out_features, bias, device, dtype) + + # set weight and bias manually + if _weight is not None: + self.weight = nn.Parameter(_weight) + if _bias is not None: + self.bias = nn.Parameter(_bias) + + self.in_features = in_features + self.original_out_features = original_out_features + self.max_original_id = max_original_id + + # initialize additional linear + self.additional_out_features = additional_out_features + self.has_bias = bias + if additional_out_features > 0: + self.additional_fc = nn.Linear( + in_features=in_features, + out_features=additional_out_features, + bias=self.has_bias, + device=device, + dtype=dtype, + ) + self.set_requires_grad( + require_regular_grad=not partially_freeze, require_additional_grad=True + ) + + def set_requires_grad(self, require_regular_grad, require_additional_grad): + """ + Helper function to separately set the requires_grad flag for the regular weight and the additional weight. + """ + self.weight.requires_grad_(require_regular_grad) + if self.has_bias: + self.bias.requires_grad_(require_regular_grad) + self.additional_fc.requires_grad_(require_additional_grad) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = F.linear(input, self.weight, self.bias) + output = output[..., : self.max_original_id + 1] + + if self.additional_out_features > 0: + additional_features = F.linear( + input, self.additional_fc.weight, self.additional_fc.bias + ) + output = torch.cat((output, additional_features), -1) + return output + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format( + self.in_features, + self.max_original_id + 1, + self.additional_out_features, + self.bias is not None, + (not self.weight.requires_grad or not self.bias.requires_grad), + ) + + +class VLM(nn.Module): + """ + Generic vision-language model (VLM) class. + A VLM consists of four components: + 1. A vision encoder that extracts features from pixels, e.g. CLIP + input: (B, T_img, F, C, H, W) + output: (B, T_img, F, v, d) + 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head + input: (B, T_img, F, v, d) + output: (B, T_img, n, d) + 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence + 4. A language model + """ + + def __init__( + self, + vision_encoder: nn.Module, + vision_tokenizer: nn.Module, + lang_model: nn.Module, + initial_tokenizer_len: int, + pad_token_id: int, + gradient_checkpointing: bool = False, + ): + """ + Args: + vision_encoder (nn.Module): e.g. CLIP + vision_tokenizer (nn.Module): e.g. PerceiverResampler + lang_model (nn.Module): e.g. MPT + initial_tokenizer_len (int): size of the original tokenizer vocab + pad_token_id (int): id of the pad token + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. + """ + super().__init__() + + # save dimension information + self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] + if hasattr(lang_model.config, "d_model"): + self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model + else: + self.lang_hidden_dim = lang_model.config.hidden_size + self.vis_embedding_dim = vision_tokenizer.dim_media + self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media + + # core components + self.vision_encoder = vision_encoder + self.vision_tokenizer = vision_tokenizer + self.lang_model = lang_model + + # lm embeddings + self.pad_token_id = pad_token_id + self.initial_tokenizer_len = initial_tokenizer_len + input_embeds = DecoupledEmbedding( + max_original_id=initial_tokenizer_len - 1, + num_additional_embeddings=len(self.special_tokens), + _weight=self.lang_model.get_input_embeddings().weight, + pad_token_id=self.pad_token_id, + ) + if hasattr(input_embeds, "additional_embedding"): + input_embeds.additional_embedding.weight.data.normal_( + mean=0.0, + std=( + self.lang_model.config.initializer_range + if hasattr(self.lang_model.config, "initializer_range") + else 0.02 + ), + ) + self.lang_model.set_input_embeddings(input_embeds) + + out_embeds = DecoupledLinear( + max_original_id=initial_tokenizer_len - 1, + additional_out_features=len(self.special_tokens), + _weight=self.lang_model.get_output_embeddings().weight, + _bias=( + self.lang_model.get_output_embeddings().bias + if hasattr(self.lang_model.get_output_embeddings(), "bias") + else None + ), + ) + if hasattr(out_embeds, "additional_fc"): + out_embeds.additional_fc.weight.data.normal_( + mean=0.0, + std=( + self.lang_model.config.initializer_range + if hasattr(self.lang_model.config, "initializer_range") + else 0.02 + ), + ) + self.lang_model.set_output_embeddings(out_embeds) + + # gradient checkpointing + self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing + + def forward( + self, + vision_x: Optional[torch.Tensor], + lang_x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[ + List[Union[torch.Tensor, Tuple[torch.Tensor]]] + ] = None, + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + **kwargs, + ): + """ + Args: + vision_x: Vision input + shape (B, T_img, F, C, H, W) with F=1 + only F = 1 is supported (single-frame videos) + if T_img > the number of media tokens in the corresponding input_ids (lang_x), + only the first number of media tokens in lang_x are used + lang_x: Language input ids, with media tokens denoting where + visual media should be inserted. + shape (B, T_txt) + attention_mask: Attention mask. Defaults to None. + labels: Labels. Defaults to None. + shape (B, T_txt) + past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None. + list of length = number of decoder layers in the LM + exact implementation depends on LM, see Hugging Face docs + past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None. + shape (B, T_txt) + past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None. + use_cache (Optional[bool], optional): Whether to use cache. Defaults to False. + If True, includes key_values, media_locations, and vision_tokens in the output. + """ + assert not (past_vision_tokens is None) ^ ( + past_media_locations is None + ), "past_vision_tokens and past_media_locations must both be None or both be not None" + + # convert pixels to vision tokens + if vision_x is not None: + vision_features = self._encode_vision_x(vision_x=vision_x) + vision_tokens = self.vision_tokenizer(vision_features) + else: + vision_tokens = None + + # fuse the vision and language tokens + new_inputs = self._prepare_inputs_for_forward( + vision_tokens=vision_tokens, + lang_x=lang_x, + attention_mask=attention_mask, + labels=labels, + past_key_values=past_key_values, + past_media_locations=past_media_locations, + padding_side="right", + past_vision_tokens=past_vision_tokens, + ) + output = self.lang_model( + **new_inputs, + use_cache=use_cache, + past_key_values=past_key_values, + **kwargs, + ) + + # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream + # or to add the past_vision_tokens and past_media_locations to the output + output = self._postprocess_outputs_from_forward( + output=output, + lang_x=lang_x, + vision_tokens=vision_tokens, + use_cache=use_cache, + past_vision_tokens=past_vision_tokens, + past_media_locations=past_media_locations, + ) + + # postforward hooks + self._post_forward_hook() + return output + + def _encode_vision_x_anyres(self, samples, device): + assert self.anyres_grids is not None + image_raw = samples[ + "image" + ] # list of patch list in of shape [1, N_patch, C, H, W] + image_sizes = samples["image_size"] + + # Image_raw can be a list of list of patches, when a `samples` has multiple images. + if isinstance(image_raw[0], list): + images = [x.squeeze(0) for sample_img in image_raw for x in sample_img] + image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes] + else: + # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}" + # concate list of patches into one big patch for any res encoding. + images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W] + image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W] + image = image.to(device) + + with torch.no_grad(): + if self.vision_encoder.__class__.__name__ == "TimmModel": + image_embeds = self.vision_encoder.trunk.forward_features(image) + elif self.vision_encoder.__class__.__name__ in [ + "CLIPVisionModel", + "SiglipVisionTransformer", + ]: + image_embeds = self.vision_encoder(image).last_hidden_state + else: + image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples + + if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance( + self.vision_encoder, SiglipVisionTransformer + ): + base_img_size = self.vision_encoder.config.image_size + else: + base_img_size = self.vision_encoder.image_size[0] + + if self.vision_encoder.__class__.__name__ == "TimmModel": + grid_size = self.vision_encoder.trunk.patch_embed.grid_size + elif self.vision_encoder.__class__.__name__ in [ + "CLIPVisionModel", + "SiglipVisionTransformer", + ]: + grid_size_base = ( + self.vision_encoder.config.image_size + // self.vision_encoder.config.patch_size + ) + grid_size = (grid_size_base, grid_size_base) + else: + grid_size = self.vision_encoder.grid_size + height, width = grid_size + + if not image_embeds.shape[1] == height * width: + assert ( + image_embeds.shape[1] == height * width + 1 + ) # For vision encoders that has [CLS] token. + image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch. + n_vis_token_per_patch = image_embeds.shape[1] + + # Split encoded patches and merge patch features + # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C] + split_sizes = [image.shape[0] for image in images] + image_embeds = torch.split(image_embeds, split_sizes, dim=0) + # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width]) + new_image_embeds = [] + patch_attn_masks = [] + max_n_img_token = -1 + for idx, patch_embeds in enumerate(image_embeds): + if patch_embeds.shape[0] > 1: + # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)] + base_patch_embeds = patch_embeds[ + 0 + ] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image). + patch_embeds = patch_embeds[1:] + + assert height * width == base_patch_embeds.shape[0] + + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[idx], self.anyres_grids, base_img_size + ) # Hardcoded grid_pinpoints. + patch_embeds = patch_embeds.view( + num_patch_height, num_patch_width, height, width, -1 + ) + + patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous() + patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3) + patch_embeds, patch_attn_mask = unpad_image( + patch_embeds, image_sizes[idx], self.anyres_patch_sampling + ) + if hasattr(self, "image_newline"): + patch_embeds = torch.cat( + ( + patch_embeds, + self.image_newline[:, None, None].expand( + *patch_embeds.shape[:-1], 1 + ), + ), + dim=-1, + ) + if self.anyres_patch_sampling: + patch_embeds = patch_embeds.view( + -1, num_patch_height, num_patch_width, height * width + ) + patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0) + assert patch_attn_mask is not None + patch_attn_mask = patch_attn_mask.view( + num_patch_height, num_patch_width, height * width + ) + patch_attn_mask = patch_attn_mask.flatten(0, 1) + patch_embeds = torch.cat( + (base_patch_embeds.unsqueeze(0), patch_embeds), dim=0 + ) + patch_attn_mask = torch.cat( + ( + torch.ones( + n_vis_token_per_patch, device=patch_embeds.device + ).unsqueeze(0), + patch_attn_mask, + ), + dim=0, + ) + else: + patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1) + patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0) + else: + patch_embeds = ( + patch_embeds[0].unsqueeze(0) + if self.anyres_patch_sampling + else patch_embeds[0] + ) + patch_attn_mask = ( + torch.ones( + n_vis_token_per_patch, device=patch_embeds.device + ).unsqueeze(0) + if self.anyres_patch_sampling + else None + ) + if hasattr(self, "image_newline"): + patch_embeds = torch.cat( + (patch_embeds, self.image_newline[None]), dim=0 + ) + if not self.anyres_patch_sampling: + max_n_img_token = max(patch_embeds.shape[0], max_n_img_token) + + new_image_embeds.append(patch_embeds) + patch_attn_masks.append(patch_attn_mask) + + if self.anyres_patch_sampling: + # Return individual patches for independent token downsampling. + return new_image_embeds, patch_attn_masks + + # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask. + image_embeds = [] + image_atts = [] + for image_embed in new_image_embeds: + n_img_token = image_embed.shape[0] + img_attn = torch.ones( + (max_n_img_token), dtype=torch.long, device=image_embed.device + ) + if n_img_token < max_n_img_token: + padded_embed = torch.zeros( + (max_n_img_token, image_embed.shape[-1]), + dtype=image_embed.dtype, + device=image_embed.device, + ) + padded_embed[:n_img_token, :] = image_embed + img_attn[n_img_token:] = 0 # Mask out the padded entries. + else: + padded_embed = image_embed + image_embeds.append(padded_embed) + image_atts.append(img_attn) + image_embeds = torch.stack( + image_embeds, dim=0 + ) # Shape [B, N_tok_longest, C_dim] + image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim] + # TODO: reshape image_embeds and image_atts to "b T F v d" + image_embeds = image_embeds[:, None, None, :, :] + # image_atts = image_atts[:, None, None, :, :] + + return image_embeds, image_atts + + def _encode_vision_x(self, vision_x: torch.Tensor): + """ + Compute media tokens from vision input by passing it through vision encoder and conditioning language model. + Args: + vision_x: Vision input + shape (B, T_img, F, C, H, W) + Images in the same chunk are collated along T_img, and frames are collated along F + Currently only F=1 is supported (single-frame videos) + + rearrange code based on https://github.com/dhansmair/flamingo-mini + """ + assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" + b, T, F = vision_x.shape[:3] + + vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") + with torch.no_grad(): + if self.vision_encoder.__class__.__name__ == "TimmModel": + vision_x = self.vision_encoder.trunk.forward_features(vision_x) + elif self.vision_encoder.__class__.__name__ in [ + "CLIPVisionModel", + "SiglipVisionTransformer", + ]: + vision_x = self.vision_encoder(vision_x).last_hidden_state + else: + vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples + vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) + return vision_x + + def _concat_vision_cache( + self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache + ): + """ + Helper function to include the past vision tokens and past media locations in the output. + """ + if use_cache: + if past_media_locations is not None and past_vision_tokens is not None: + if vision_tokens is not None: + updated_vision_tokens = torch.cat( + [ + past_vision_tokens, + vision_tokens, + ], + dim=1, + ) + else: + updated_vision_tokens = past_vision_tokens + updated_media_locations = torch.cat( + [ + past_media_locations, + lang_x == self.media_token_id, + ], + dim=1, + ) + else: + updated_vision_tokens = vision_tokens + updated_media_locations = lang_x == self.media_token_id + + else: + updated_vision_tokens = None + updated_media_locations = None + + return updated_vision_tokens, updated_media_locations + + def generate( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor = None, + past_key_values: Optional[ + List[Union[torch.Tensor, Tuple[torch.Tensor]]] + ] = None, + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + **kwargs, + ): + """ + Generate text conditioned on vision and language inputs. + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + see documentation for forward + lang_x (torch.Tensor): Language input + shape (B, T_txt) + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + **kwargs: see generate documentation in Hugging Face CausalLM models. + Returns: + torch.Tensor: lang_x with generated tokens appended to it + """ + num_beams = kwargs.pop("num_beams", 1) + + # convert pixels to vision tokens + if vision_x is not None: + vision_features = self._encode_vision_x(vision_x=vision_x) + vision_tokens = self.vision_tokenizer(vision_features) + else: + vision_tokens = None + + # fuse the vision and language tokens + # for xattn, vision_x and media_location are repeat_interleaved s.t. + # the total batch size is B * num_beams + new_inputs = self._prepare_inputs_for_forward( + vision_tokens=vision_tokens, + lang_x=lang_x, + attention_mask=attention_mask, + past_key_values=past_key_values, + past_media_locations=past_media_locations, + past_vision_tokens=past_vision_tokens, + padding_side="left", + num_beams=num_beams, + ) + output = self.lang_model.generate( + **new_inputs, + past_key_values=past_key_values, + num_beams=num_beams, + use_cache=True, + **kwargs, + ) + self._post_forward_hook() + return output + + @property + def num_trainable_params(self): + """Print the number of trainable parameters""" + return num_params(self, filter_to_trainable=True) + + def set_trainable(self): + """ + Freeze appropriate parameters in the model. + """ + raise NotImplementedError + + def group_params_by_weight_decay(self): + """ + Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay) + """ + params_with_wd, params_without_wd = [], [] + for n, p in self.named_parameters(): + if p.requires_grad: + if self._should_apply_weight_decay(n): + params_with_wd.append(p) + else: + params_without_wd.append(p) + return params_with_wd, params_without_wd + + def _should_apply_weight_decay(self, parameter_name): + """ + Return whether weight decay should be applied to a parameter. + """ + raise NotImplementedError + + @property + def special_tokens(self): + """ + Returns a dict mapping from the attribute name of a special token to its string format, + e.g. "media_token": "" + """ + assert ( + "media_token" in self._special_tokens + ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id" + return self._special_tokens + + @property + def special_token_ids(self): + """ + Returns a list of the special token ids + """ + return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens] + + def set_special_token_ids(self, string_to_ids): + """ + Args: + string_to_ids (dict): mapping from token string to id + """ + assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys())) + for att_name, token_str in self.special_tokens.items(): + token_id = string_to_ids[token_str] + setattr(self, f"{att_name}_id", token_id) + setattr(self.lang_model, f"{att_name}_id", token_id) + + def init_gradient_checkpointing(self): + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointWrapper, + CheckpointImpl, + apply_activation_checkpointing, + ) + from functools import partial + + non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing( + self, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) + and not isinstance(m, CheckpointWrapper), + ) + + +@dataclass +class VLMOutputWithPast(CausalLMOutputWithPast): + """ + VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes: + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + """ + + past_media_locations: Optional[torch.Tensor] = None + past_vision_tokens: Optional[torch.Tensor] = None + + +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class VLMWithLanguageStream(VLM): + """ + VLM that fuses modalities by inserting vision tokens directly into the language stream. + """ + + def __init__( + self, + vision_encoder: nn.Module, + vision_tokenizer: nn.Module, + lang_model: nn.Module, + initial_tokenizer_len: int, + pad_token_id: int, + decoder_layers_attr_name: str = None, + gradient_checkpointing: bool = False, + ): + super().__init__( + vision_encoder=vision_encoder, + vision_tokenizer=vision_tokenizer, + lang_model=lang_model, + initial_tokenizer_len=initial_tokenizer_len, + pad_token_id=pad_token_id, + gradient_checkpointing=gradient_checkpointing, + ) + self.decoder_layers_attr_name = decoder_layers_attr_name + if decoder_layers_attr_name is not None: + for block in getattr_recursive( + self.lang_model, self.decoder_layers_attr_name + ): + block._use_gradient_checkpointing = gradient_checkpointing + + def _prepare_inputs_for_forward( + self, + vision_tokens: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor = None, + past_key_values=None, + vision_attention_mask: Optional[torch.Tensor] = None, + past_media_locations: torch.Tensor = None, + past_vision_tokens: torch.Tensor = None, + padding_side: str = "left", + num_beams: int = 1, + ): + """ + Insert the vision tokens directly into the language stream/ + This requires us to modify the input_ids, attention_mask, and labels. + """ + if past_key_values is not None: + past_len = past_key_values[0][0].shape[2] + assert attention_mask.shape[1] == past_len + lang_x.shape[1], ( + "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. " + + "Check that you've expanded the attention mask to account for past image tokens." + ) + + if vision_tokens is None: + return { + "input_ids": lang_x, + "attention_mask": attention_mask, + "labels": labels, + } + + # get the language embeddings + lang_embeds = self.lang_model.get_input_embeddings()(lang_x) + + # build up the multimodal embeddings + B = lang_x.shape[0] + has_labels = labels is not None + multimodal_embeds = [] + multimodal_attention_mask = [] + multimodal_labels = [] if has_labels else None + for i in range(B): + # get index of tokens in lang_x[i] + image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0] + + if len(image_token_idxs) == 0: + multimodal_embeds.append(lang_embeds[i].clone()) + multimodal_attention_mask.append(attention_mask[i].clone()) + if has_labels: + multimodal_labels.append(labels[i].clone()) + continue + + # loop through the image_token_idxs and insert the vision tokens + new_embed = lang_embeds[i].clone() + new_attention_mask = ( + attention_mask[i].clone() if attention_mask is not None else None + ) + if has_labels: + new_label = labels[i].clone() + + for img_num, img_idx in enumerate(image_token_idxs): + # Get vision token attention mask for padded llava-style any resolution image tokens. + if self.image_aspect_ratio == "anyres": + num_vis_tokens = vision_tokens[i][img_num].shape[0] + if vision_attention_mask is not None: + vis_attention_mask = vision_attention_mask[i] + else: + vis_attention_mask = torch.ones( + num_vis_tokens, dtype=torch.long + ).to(attention_mask.device) + else: + # assert ( + # vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis + # ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \ + # vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})" + # By default, vision tokens are not padded. + num_vis_tokens = vision_tokens[i][img_num].shape[0] + vis_attention_mask = torch.ones( + num_vis_tokens, dtype=torch.long + ).to(attention_mask.device) + + new_embed = torch.cat( + ( + new_embed[:img_idx], + vision_tokens[i][img_num], + new_embed[img_idx + 1 :], + ), + dim=0, + ) + new_attention_mask = torch.cat( + ( + new_attention_mask[:img_idx], + vis_attention_mask, + new_attention_mask[img_idx + 1 :], + ), + dim=0, + ) + if has_labels: + new_label = torch.cat( + ( + new_label[:img_idx], + torch.ones(num_vis_tokens, dtype=torch.long).to( + labels.device + ) + * -100, + new_label[img_idx + 1 :], + ), + dim=0, + ) + multimodal_embeds.append(new_embed) + multimodal_attention_mask.append(new_attention_mask) + if has_labels: + multimodal_labels.append(new_label) + + # stack + multimodal_embeds = stack_with_padding( + multimodal_embeds, + padding_value=self.pad_token_id, + padding_side=padding_side, + ) + multimodal_attention_mask = stack_with_padding( + multimodal_attention_mask, + padding_value=0, + padding_side=padding_side, + ) + if has_labels: + multimodal_labels = stack_with_padding( + multimodal_labels, + padding_value=-100, + padding_side=padding_side, + ) + + return { + "inputs_embeds": multimodal_embeds, + "attention_mask": multimodal_attention_mask, + "labels": multimodal_labels, + } + + def _postprocess_outputs_from_forward( + self, + output: CausalLMOutputWithPast, + lang_x: torch.Tensor, + vision_tokens: torch.Tensor, + past_vision_tokens: torch.Tensor, + past_media_locations: torch.Tensor, + use_cache: bool = False, + ): + # Include the past vision tokens and past media locations in the output + updated_vision_tokens, updated_media_locations = self._concat_vision_cache( + lang_x=lang_x, + vision_tokens=vision_tokens, + past_vision_tokens=past_vision_tokens, + past_media_locations=past_media_locations, + use_cache=use_cache, + ) + + # return logits that are the same shape as the original input_ids + logits = output.logits + batch_logits = [] + B, T_txt = lang_x.shape + for i in range(B): + sequence_logits = [] + logits_j = 0 + for j in range(T_txt): + if lang_x[i, j] != self.media_token_id: + sequence_logits.append(logits[i, logits_j]) + logits_j += 1 + else: + # append the logit for the first image token, then skip over the rest + # note: the model actually learns to predict , not + sequence_logits.append(logits[i, logits_j]) + logits_j += self.num_tokens_per_vis + sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size) + batch_logits.append(sequence_logits) + + batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size) + # The final logits shape should be the same as the original input_ids shape + assert batch_logits.shape[:2] == (B, T_txt) + + # assemble the output + output = VLMOutputWithPast( + loss=output.loss, + logits=batch_logits, + past_key_values=output.past_key_values, + hidden_states=output.hidden_states, + attentions=output.attentions, + past_media_locations=updated_media_locations, + past_vision_tokens=updated_vision_tokens, + ) + + return output + + def _post_forward_hook(self): + pass + + @property + def num_params_per_module(self): + """Print the number of parameters per module in the model""" + return "\n".join( + [ + f"Vision encoder: {num_params(self.vision_encoder):,} parameters", + f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters", + f"Language model: {num_params(self.lang_model):,} parameters", + ] + ) + + @property + def num_trainable_params_per_module(self): + """Print the number of trainable parameters per module in the model""" + return "\n".join( + [ + f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters", + f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters", + f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters", + ] + ) + + +class XGenMMPerceiver(VLMWithLanguageStream): + def __init__( + self, + vision_encoder: nn.Module, + vision_tokenizer: nn.Module, + lang_model: nn.Module, + initial_tokenizer_len: int, + pad_token_id: int, + decoder_layers_attr_name: str = None, + gradient_checkpointing: bool = False, + image_aspect_ratio: str = "anyres", + anyres_patch_sampling: bool = True, + anyres_grids: list[int] = None, + ): + """ + Args: + vision_encoder (nn.Module): HF CLIPModel + lang_encoder (nn.Module): HF causal language model + vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder + initial_tokenizer_len (int): size of the tokenizer vocab + padding_token_id (int): id of the padding token. None if no padding token; then a padding token + will be inserted into self.special_tokens, which factory.py fills after creating new tokens + decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. + gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. + """ + self._special_tokens = { + "media_token": "", + "image_placeholder_token": "", + "end_of_trunk_token": "<|endofchunk|>", + } + lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] + super().__init__( + vision_encoder=vision_encoder, + vision_tokenizer=vision_tokenizer, + lang_model=lang_model, + initial_tokenizer_len=initial_tokenizer_len, + gradient_checkpointing=gradient_checkpointing, + decoder_layers_attr_name=decoder_layers_attr_name, + pad_token_id=pad_token_id, + ) + self.image_aspect_ratio = image_aspect_ratio + self.anyres_patch_sampling = anyres_patch_sampling + self.anyres_grids = anyres_grids + + def set_trainable(self): + """ + Unfreeze everything except the vision_encoder + """ + self.requires_grad_(True) + self.vision_encoder.requires_grad_(False) + + def _should_apply_weight_decay(self, parameter_name): + """ + Kosmos applies 0.01 weight deacy to everything + """ + return True + + def generate( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + image_size: Optional[Tuple] = None, + attention_mask: torch.Tensor = None, + past_key_values: Optional[ + List[Union[torch.Tensor, Tuple[torch.Tensor]]] + ] = None, + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + **kwargs, + ): + """ + Generate text conditioned on vision and language inputs. + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + see documentation for forward + lang_x (torch.Tensor): Language input + shape (B, T_txt) + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + **kwargs: see generate documentation in Hugging Face CausalLM models. + Returns: + torch.Tensor: lang_x with generated tokens appended to it + """ + num_beams = kwargs.pop("num_beams", 1) + + # convert pixels to vision tokens + vision_attention_mask = None + if vision_x is not None: + if self.image_aspect_ratio == "anyres": + input_dict = dict(image=vision_x, image_size=image_size) + vision_features, vision_attn_masks = self._encode_vision_x_anyres( + input_dict, lang_x.device + ) + else: + vision_features = self._encode_vision_x(vision_x=vision_x) + vision_attn_masks = None + # If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d] + # Same for attention masks: [b, Np, v] -> [b*Np, v] + if self.anyres_patch_sampling: + split_sizes = [feature.shape[0] for feature in vision_features] + # Nested splits for multi-image samples. + if isinstance(vision_x[0], list): + nt_images = [len(images) for images in vision_x] + split_split_sizes = [] + img_id = 0 + for nt in nt_images: + split_split_sizes.append(split_sizes[img_id : img_id + nt]) + img_id += nt + else: + nt_images = [1] * len(vision_x) + split_split_sizes = split_sizes + vision_features = torch.cat(vision_features, dim=0) + vision_features = vision_features[ + :, None, None, :, : + ] # Expand dimensions. + vision_attn_masks = torch.cat(vision_attn_masks, dim=0) + + vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks) + + # Post-processing: Split the batches into groups of patches and concatenate them together. + if self.anyres_patch_sampling: + assert isinstance(vision_x, list) + if isinstance(vision_x[0], list): + vision_token_groups = torch.split( + vision_tokens, + list(sum(nt_img) for nt_img in split_split_sizes), + dim=0, + ) + vision_tokens = [] + + for sample_id, patch_vis_tokens in enumerate(vision_token_groups): + patch_vis_token_groups = torch.split( + patch_vis_tokens, split_split_sizes[sample_id], dim=0 + ) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...] + flatten_vision_tokens = [] + for image_vis_token in patch_vis_token_groups: + image_vis_token = image_vis_token.flatten( + 0, 2 + ) # [Np, 1, v, d] -> [Np*v, d] + flatten_vision_tokens.append(image_vis_token) + vision_tokens_i = flatten_vision_tokens + vision_tokens.append(vision_tokens_i) + else: + vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0) + vision_tokens = [] + for patch_vis_tokens in vision_token_groups: + patch_vis_tokens = patch_vis_tokens.flatten( + 0, 2 + ) # [Np, 1, v, d] -> [Np*v, d] + vision_tokens.append( + patch_vis_tokens.unsqueeze(0) + ) # Add the nt dimension. + else: + vision_tokens = None + + # fuse the vision and language tokens + # for xattn, vision_x and media_location are repeat_interleaved s.t. + # the total batch size is B * num_beams + new_inputs = self._prepare_inputs_for_forward( + vision_tokens=vision_tokens, + lang_x=lang_x, + attention_mask=attention_mask, + vision_attention_mask=vision_attention_mask, + past_key_values=past_key_values, + past_media_locations=past_media_locations, + past_vision_tokens=past_vision_tokens, + padding_side="left", + num_beams=num_beams, + ) + if past_key_values is not None: + output = self.lang_model.generate( + **new_inputs, + past_key_values=past_key_values, + num_beams=num_beams, + use_cache=True, + **kwargs, + ) + else: + output = self.lang_model.generate( + **new_inputs, + num_beams=num_beams, + use_cache=True, + **kwargs, + ) + self._post_forward_hook() + return output + + +class XGenMMVisionEncoder(PreTrainedModel): + main_input_name = "pixel_values" + config_class = XGenMMVisionEncoderConfig + + def __init__(self, config: XGenMMVisionEncoderConfig): + super().__init__(config) + if config.model_name != "google/siglip-so400m-patch14-384": + raise ValueError( + f"Unsupported model {config.model_name}. New vision models will be added soon." + ) + self.model = AutoModel.from_pretrained(config.model_name) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}" + return self.model.encode_image(pixel_values) + + +# vision tokenizer +class XGenMMVisionTokenizer(PreTrainedModel): + config_class = XGenMMVisionTokenizerConfig + + def __init__(self, config: XGenMMVisionTokenizerConfig): + super().__init__(config) + self.model = PerceiverResampler( + dim=config.vis_feature_dim, + dim_inner=config.lang_embedding_dim, + # TODO: hardwiring for now... + num_latents=128, + ) + + def forward(self, vision_features: torch.Tensor, vision_attn_masks: torch.Tensor): + return self.model(vision_features, vision_attn_masks) + + +# XGenMM model +class XGenMMModelForConditionalGeneration(PreTrainedModel): + config_class = XGenMMConfig + + def __init__(self, config: XGenMMConfig): + super().__init__(config) + + # vision encoder initialization + vision_encoder = AutoModel.from_pretrained( + config.vision_encoder_config.model_name + ).vision_model + + # language model initialization + language_model = AutoModelForCausalLM.from_config(config.text_config) + check_embedding_fns(language_model) + # Update _tied_weights_keys using the base model used. + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [ + f"language_model.{k}" for k in language_model._tied_weights_keys + ] + + # vision tokenizer initialization + if ( + config.vision_tokenizer_config.lang_embedding_dim + != language_model.get_input_embeddings().weight.shape[1] + ): + overwrite = language_model.get_input_embeddings().weight.shape[1] + config.vision_tokenizer_config.lang_embedding_dim = overwrite + print( + f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}." + ) + + vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model + + self.vlm = XGenMMPerceiver( + vision_encoder=vision_encoder, + vision_tokenizer=vision_tokenizer, + lang_model=language_model, + initial_tokenizer_len=config.text_config.initial_tokenizer_len, + pad_token_id=config.text_config.pad_token_id, + image_aspect_ratio=config.vision_encoder_config.image_aspect_ratio, + anyres_patch_sampling=config.vision_encoder_config.anyres_patch_sampling, + anyres_grids=config.vision_encoder_config.anyres_grids, + ) + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def generate( + self, + pixel_values: torch.FloatTensor, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **generate_kwargs, + ) -> torch.LongTensor: + self.vlm = self.vlm.eval() + return self.vlm.generate( + vision_x=pixel_values, + lang_x=input_ids, + attention_mask=attention_mask, + **generate_kwargs, + ) + + def update_special_tokens(self, tokenizer): + tokenizer.add_special_tokens( + {"additional_special_tokens": list(self.vlm.special_tokens.values())} + ) + self.vlm.lang_model.config.vocab_size = len(tokenizer) + self.vlm.set_special_token_ids( + { + v: tokenizer.convert_tokens_to_ids(v) + for v in self.vlm.special_tokens.values() + } + ) + return tokenizer