from __future__ import annotations from typing import Union from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from .configuration_m3d_lamed import LamedConfig from abc import ABC, abstractmethod from torch import Tensor import math from typing import Any, Dict, List import torch import torch.nn as nn from typing import Optional, Tuple, Type from monai.networks.blocks import PatchEmbed import numpy as np import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange from collections.abc import Sequence from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.nets import ViT class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) class TwoWayTransformer(nn.Module): def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, ) -> None: """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer embedding_dim (int): the channel dimension for the input embeddings num_heads (int): the number of heads for multihead attention. Must divide embedding_dim mlp_dim (int): the channel dimension internal to the MLP block activation (nn.Module): the activation to use in the MLP block """ super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() for i in range(depth): self.layers.append( TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), ) ) self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, ) -> Tuple[Tensor, Tensor]: """ Args: image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. image_pe (torch.Tensor): the positional encoding to add to the image. Must have the same shape as image_embedding. point_embedding (torch.Tensor): the embedding to add to the query points. Must have shape B x N_points x embedding_dim for any N_points. Returns: torch.Tensor: the processed point_embedding torch.Tensor: the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C bs, c, h, w, d = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) # Prepare queries queries = point_embedding keys = image_embedding # Apply transformer blocks and final layernorm for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Apply the final attention layer from the points to the image q = queries + point_embedding k = keys + image_pe attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm_final_attn(queries) return queries, keys class TwoWayAttentionBlock(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys class Attention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) # Attention _, _, _, c_per_head = q.shape attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens attn = attn / math.sqrt(c_per_head) attn = torch.softmax(attn, dim=-1) # Get output out = attn @ v out = self._recombine_heads(out) out = self.out_proj(out) return out # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 1, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size # self.patch_embed = PatchEmbed( # kernel_size=(patch_size, patch_size), # stride=(patch_size, patch_size), # in_chans=in_chans, # embed_dim=embed_dim, # ) self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, spatial_dims=3, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, img_size // patch_size, embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) print('patch embedded shape: ', x.shape) # embedded: [8, 768, 6, 6, 6] if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention2( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class Attention2(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) return x def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": 'identity'} class SpatialPoolingProjector(nn.Module): def __init__(self, image_size, patch_size, in_dim, out_dim, layer_type, layer_num, pooling_type='spatial', pooling_size=2): super().__init__() self.in_dim = in_dim self.pooling_size = pooling_size self.num_patches_pre = [img // pch for img, pch in zip(image_size, patch_size)] self.num_patches_post = [num // pooling_size for num in self.num_patches_pre] if layer_type == 'linear': depth = int(layer_num) modules = [nn.Linear(in_dim, out_dim)] for _ in range(1, depth): modules.append(nn.Linear(out_dim, out_dim)) self.projector = nn.Sequential(*modules) elif layer_type == 'mlp': depth = int(layer_num) modules = [nn.Linear(in_dim, out_dim)] for _ in range(1, depth): modules.append(nn.GELU()) modules.append(nn.Linear(out_dim, out_dim)) self.projector = nn.Sequential(*modules) else: print("Projector error!") self.pooling_type = pooling_type def forward(self, x): B = x.shape[0] # B*N*D if self.pooling_type == 'spatial': to_3d = Rearrange("b (p1 p2 p3) d -> b d p1 p2 p3", b=B, d=self.in_dim, p1=self.num_patches_pre[0], p2=self.num_patches_pre[1], p3=self.num_patches_pre[2]) x = to_3d(x) x = F.avg_pool3d(x, kernel_size=self.pooling_size, stride=self.pooling_size) to_seq = Rearrange("b d p1 p2 p3 -> b (p1 p2 p3) d", b=B, d=self.in_dim, p1=self.num_patches_post[0], p2=self.num_patches_post[1], p3=self.num_patches_post[2]) x = to_seq(x) elif self.pooling_type == 'sequence': x = x.permute(0, 2, 1) #b d n x = F.avg_pool1d(x, kernel_size=self.pooling_size**3, stride=self.pooling_size**3) x = x.permute(0, 2, 1) #b n d x = rearrange(x, "b n d -> (b n) d") x = self.projector(x) x = rearrange(x, "(b n) d -> b n d", b=B) return x @property def proj_out_num(self): num = 1 for n in self.num_patches_post: num *= n return num class Minigpt(nn.Module): def __init__(self, config=None): super(Minigpt, self).__init__() # c*4 is the input size, and c is the output size for the linear layer inc, ouc = config.mm_hidden_size, config.hidden_size self.linear = nn.Linear(inc * 4, ouc) def forward(self, x): # x is the input tensor with shape [b, num_tokens, c] b, num_tokens, c = x.shape # Check if num_tokens is divisible by 4 if num_tokens % 4 != 0: raise ValueError("num_tokens must be divisible by 4") # Reshape x to [b, num_tokens/4, c*4] x = x.view(b, num_tokens // 4, c * 4) # Apply the linear transformation x = self.linear(x) return x class Vanilla(nn.Module): def __init__(self, config=None): super(Vanilla, self).__init__() # c*4 is the input size, and c is the output size for the linear layer inc, ouc = config.mm_hidden_size, config.hidden_size self.linear = nn.Linear(inc * 4, ouc) def forward(self, x): b, num_tokens, c = x.shape # Check if num_tokens is divisible by 4 if num_tokens % 4 != 0: raise ValueError("num_tokens must be divisible by 4") # First, reshape to [b, num_tokens//4, 4, c] x = x.view(b, num_tokens // 4, 4, c) # Then, permute to interleave the tokens x = x.permute(0, 1, 3, 2).contiguous() # Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens x = x.view(b, num_tokens // 4, c * 4) # Apply the linear transformation x = self.linear(x) return x def build_mm_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, 'mm_projector_type') if projector_type == 'linear': return nn.Linear(config.mm_hidden_size, config.hidden_size) elif projector_type == 'spp': return SpatialPoolingProjector(image_size=config.image_size, patch_size=config.patch_size, in_dim=config.mm_hidden_size, out_dim=config.hidden_size, layer_type=config.proj_layer_type, layer_num=config.proj_layer_num, pooling_type=config.proj_pooling_type, pooling_size=config.proj_pooling_size) elif projector_type == 'identity': return IdentityMap() else: raise ValueError(f'Unknown projector type: {projector_type}') class myViT(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " ViT supports Torchscript but only works for Pytorch after 1.8. """ def __init__( self, in_channels: int, img_size: Sequence[int] | int, patch_size: Sequence[int] | int, hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", classification: bool = False, num_classes: int = 2, dropout_rate: float = 0.0, spatial_dims: int = 3, post_activation="Tanh", qkv_bias: bool = False, save_attn: bool = False, ) -> None: """ Args: in_channels (int): dimension of input channels. img_size (Union[Sequence[int], int]): dimension of input image. patch_size (Union[Sequence[int], int]): dimension of patch size. hidden_size (int, optional): dimension of hidden layer. Defaults to 768. mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072. num_layers (int, optional): number of transformer blocks. Defaults to 12. num_heads (int, optional): number of attention heads. Defaults to 12. pos_embed (str, optional): position embedding layer type. Defaults to "conv". classification (bool, optional): bool argument to determine if classification is used. Defaults to False. num_classes (int, optional): number of classes if classification is used. Defaults to 2. dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. spatial_dims (int, optional): number of spatial dimensions. Defaults to 3. post_activation (str, optional): add a final acivation function to the classification head when `classification` is True. Default to "Tanh" for `nn.Tanh()`. Set to other values to remove this function. qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') # for 3-channel with image size of (128,128,128), 24 layers and classification backbone >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) # for 3-channel with image size of (224,224), 12 layers and classification backbone >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2) """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") self.hidden_size = hidden_size self.classification = classification self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, img_size=img_size, patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, pos_embed=pos_embed, dropout_rate=dropout_rate, spatial_dims=spatial_dims, ) self.blocks = nn.ModuleList( [ TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) for i in range(num_layers) ] ) self.norm = nn.LayerNorm(hidden_size) if self.classification: self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) # if post_activation == "Tanh": # self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh()) # else: # self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore def forward(self, x): x = self.patch_embedding(x) if hasattr(self, "cls_token"): cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] for blk in self.blocks: x = blk(x) hidden_states_out.append(x) x = self.norm(x) # if hasattr(self, "classification_head"): # x = self.classification_head(x[:, 0]) return x, hidden_states_out class ViT3DTower(nn.Module): def __init__(self, config): super().__init__() self.config = config self.select_layer = config.vision_select_layer self.select_feature = config.vision_select_feature self.vision_tower = myViT( in_channels=self.config.image_channel, img_size=self.config.image_size, patch_size=self.config.patch_size, pos_embed="perceptron", spatial_dims=len(self.config.patch_size), classification=True, ) def forward(self, images): last_feature, hidden_states = self.vision_tower(images) if self.select_layer == -1: image_features = last_feature elif self.select_layer < -1: image_features = hidden_states[self.select_feature] else: raise ValueError(f'Unexpected select layer: {self.select_layer}') if self.select_feature == 'patch': image_features = image_features[:, 1:] elif self.select_feature == 'cls_patch': image_features = image_features else: raise ValueError(f'Unexpected select feature: {self.select_feature}') return image_features @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def hidden_size(self): return self.vision_tower.hidden_size def build_vision_tower(config, **kwargs): vision_tower = getattr(config, 'vision_tower', None) if 'vit3d' in vision_tower.lower(): return ViT3DTower(config, **kwargs) else: raise ValueError(f'Unknown vision tower: {vision_tower}') class LamedMetaModel: def __init__(self, config): super(LamedMetaModel, self).__init__(config) self.config = config if hasattr(config, "vision_tower"): self.vision_tower = build_vision_tower(config) self.mm_projector = build_mm_projector(config) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) return vision_tower def initialize_vision_modules(self, model_args): self.config.image_channel = model_args.image_channel self.config.image_size = model_args.image_size self.config.patch_size = model_args.patch_size self.config.vision_tower = model_args.vision_tower self.config.vision_select_layer = model_args.vision_select_layer self.config.vision_select_feature = model_args.vision_select_feature self.config.mm_projector_type = model_args.mm_projector_type self.config.proj_layer_type = model_args.proj_layer_type self.config.proj_layer_num = model_args.proj_layer_num self.config.proj_pooling_type = model_args.proj_pooling_type self.config.proj_pooling_size = model_args.proj_pooling_size # vision tower if self.get_vision_tower() is None: self.vision_tower = build_vision_tower(self.config) # If you have a more robust vision encoder, try freezing the vision tower by requires_grad_(False) if model_args.pretrain_vision_model is not None: vision_model_weights = torch.load(model_args.pretrain_vision_model, map_location='cpu') self.vision_tower.vision_tower.load_state_dict(vision_model_weights, strict=True) self.config.mm_hidden_size = self.vision_tower.hidden_size # mm_projector if getattr(self, 'mm_projector', None) is None: self.mm_projector = build_mm_projector(self.config) if model_args.pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=True) class LamedMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels else: image_features = self.encode_images(images) inputs_embeds = self.get_model().embed_tokens(input_ids) inputs_embeds = torch.cat((inputs_embeds[:, :1, :], image_features, inputs_embeds[:, (image_features.shape[1] + 1):, :]), dim=1) return None, position_ids, attention_mask, past_key_values, inputs_embeds, labels def initialize_vision_tokenizer(self, model_args, tokenizer): num_new_tokens = model_args.num_new_tokens self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # we add 4 new tokens # if new tokens need input, please train input_embeddings for p in self.get_input_embeddings().parameters(): p.requires_grad = True # if new tokens need predict, please train output_embeddings for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings = embed_tokens_weight elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") class LamedLlamaModel(LamedMetaModel, LlamaModel): config_class = LamedConfig def __init__(self, config: LlamaConfig): super(LamedLlamaModel, self).__init__(config) class LamedLlamaForCausalLM(LamedMetaForCausalLM, LlamaForCausalLM): config_class = LamedConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LamedLlamaModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, images: Optional[torch.FloatTensor] = None, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: input_ids_pre = input_ids if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) @torch.no_grad() def generate( self, images: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor, Any]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, _ ) = self.prepare_inputs_for_multimodal( inputs, position_ids, attention_mask, None, None, images, ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) if images is not None: inputs['images'] = images return inputs