# coding=utf-8 # [Apache-2.0] Modified from https://github.com/OFA-Sys/OFA """ PyTorch TiO model.""" import math import random from typing import Optional, Tuple from dataclasses import dataclass import torch from torch import nn from torch.nn import functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers import PreTrainedModel # from ...activations import ACT2FN # from ...file_utils import ( # add_code_sample_docstrings, # add_end_docstrings, # add_start_docstrings, # add_start_docstrings_to_model_forward, # replace_return_docstrings, # ) # from ...file_utils import ModelOutput # from ...modeling_outputs import ( # BaseModelOutputWithPastAndCrossAttentions, # Seq2SeqLMOutput, # Seq2SeqModelOutput, # ) # from ...modeling_utils import PreTrainedModel # from ...utils import logging from transformers.utils import logging, ModelOutput from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput from .configuration_tio import TiOConfig from .resnet import ResNet from torch import Tensor from typing import Dict, List, Optional, Tuple logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "TiOConfig" _TOKENIZER_FOR_DOC = "TiOTokenizer" DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) try: from apex.normalization import FusedLayerNorm as _FusedLayerNorm has_fused_layernorm = True class FusedLayerNorm(_FusedLayerNorm): @torch.jit.unused def forward(self, x): if not x.is_cuda: return super().forward(x) else: with torch.cuda.device(x.device): return super().forward(x) except ImportError: has_fused_layernorm = False def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): r""" Layer normalization. If apex is available, use `FusedLayerNorm` instead. """ if torch.jit.is_scripting(): export = True if not export and torch.cuda.is_available() and has_fused_layernorm: return FusedLayerNorm(normalized_shape, eps, elementwise_affine) return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS): r""" Make relative position indices for the text. """ context_pos = torch.arange(max_position, dtype=torch.long)[:, None] memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] relative_pos = context_pos - memory_pos sign = torch.sign(relative_pos) mid = bucket_size // 2 abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos)) # import pdb; pdb.set_trace() log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid log_pos = log_pos.int() # import numpy as np # log_pos = np.ceil(np.log(abs_pos.cpu().numpy() / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid # log_pos = torch.tensor(log_pos.astype('int64')) # log_pos = torch.LongTensor(log_pos.astype('int64'), device=abs_pos.device) bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos * sign).long() return bucket_pos + bucket_size - 1 def make_image_bucket_position(bucket_size, num_relative_distance): r""" Make relative position indices for the image. """ coords_h = torch.arange(bucket_size) coords_w = torch.arange(bucket_size) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 relative_coords[:, :, 1] += bucket_size - 1 relative_coords[:, :, 0] *= 2 * bucket_size - 1 relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 return relative_position_index def new_arange(x, *size): r""" Return a Tensor of `size` filled with a range function on the device of x. If size is empty, using the size of the variable x. """ if len(size) == 0: size = x.size() return torch.arange(size[-1], device=x.device).expand(*size).contiguous() def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): r""" Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): r""" Make causal mask used for uni-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.ones(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): r""" Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False): r""" Embedding for tokens """ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) if padding_idx is not None: nn.init.constant_(m.weight[padding_idx], 0) if zero_init: nn.init.constant_(m.weight, 0) return m def Linear(in_features, out_features, bias=True): r""" Implementation of linear projection with xavier initialization """ m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: nn.init.constant_(m.bias, 0.0) return m class LayerDropModuleList(nn.ModuleList): r""" A LayerDrop implementation based on :class:`torch.nn.ModuleList`. Args: p (float): probability of dropping out each layer modules (iterable, optional): an iterable of modules to add """ def __init__(self, p, modules=None): super().__init__(modules) self.p = p def __iter__(self): dropout_probs = torch.empty(len(self)).uniform_() for i, m in enumerate(super().__iter__()): if not self.training or (dropout_probs[i] > self.p): yield m def drop_path(x, drop_prob: float = 0.0, training: bool = False): r""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Args: x (`nn.Modules`): input nn layers. drop_prob (`float`): drop path ratio. training (`bool`): whether is training or inference. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (1, x.shape[1], 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): r""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Args: drop_prob: drop path ratio. """ def __init__(self, drop_prob=None): super().__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return "p={}".format(self.drop_prob) class TiOAttention(nn.Module): r""" Multi-headed attention, with additional implementation for NormFormer. Args: embed_dim (`int`): embedding dimension. num_heads (`int`): the number of attention heads. dropout (`float32`): the ratio for dropout. is_decoder (`bool`): whether or not decoder attention. bias (`bool`): whether to add bias. scale_heads (`bool`): whether to learn scaling heads, only for Normformer. """ def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, scale_heads: bool = True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." scale_factor=2 self.scaling = float(self.head_dim * scale_factor) ** -0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.attn_dropout = nn.Dropout(p=dropout) self.c_attn = nn.Parameter(torch.ones((self.num_heads,)), requires_grad=True) if scale_heads else None def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): r""" Reshape tensors for multi-head attention. """ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, attn_bias: Optional[torch.Tensor] = None, ): r""" Args: hidden_states (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`)`: input states. key_value_states (`torch.FloatTensor` of shape (bsz, tgt_len, embed_dim), *optional*): key value states. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key value states for fast inference. attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, seq_len)`, *optional*): attention mask. output_attentions (`bool`, *optional*): whether to output attention weights of all layers. attn_bias (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`, *optional*): the attention bias for positional information. Returns: attn_output (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): attention outputs. attn_weights_reshaped (`torch.FloatTensor`, *optional*): attention weights of all layers. past_key_value (`torch.FloatTensor`, *optional*): cached key value states for fast inference. """ # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, embed_dim = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ) # Add attention bias for positional information if attn_bias is not None: attn_weights += attn_bias if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if output_attentions: attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = self.attn_dropout(attn_weights) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) if self.c_attn is not None: attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = torch.einsum("bthd,h->bthd", attn_output, self.c_attn) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class TiOEncoderLayer(nn.Module): r""" TiO encoder layer implementation. Args: config: configuration for TiO. drop_path_rate: the ratio for drop path. """ def __init__(self, config: TiOConfig, drop_path_rate=0.0): super().__init__() self.embed_dim = config.d_model self.self_attn = TiOAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, ) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None self.dropout = nn.Dropout(config.dropout) self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = nn.Dropout(config.activation_dropout) self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.ffn_layer_norm = LayerNorm(config.encoder_ffn_dim) if config.normformer else None self.final_layer_norm = LayerNorm(self.embed_dim) self.normalize_before = config.encoder_normalize_before self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() def residual_connection(self, x, residual): r""" Residual connection with drop path. """ return residual + self.drop_path(x) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False, attn_bias: Optional[torch.Tensor] = None, ): r""" Args: hidden_states (`torch.FloatTensor`): input to the layer of shape *(bsz, src_len, embed_dim)* attention_mask (`torch.FloatTensor`): attention mask of size *(bsz, 1, src_len, src_len)* where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): whether to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. attn_bias (`torch.FloatTensor`): bias for positional information. Returns: outputs (`tuple(torch.FloatTensor)`): output hidden states of size (bsz, src_len, embed_dim), optionally with attention weights. """ residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, attn_bias=attn_bias, ) if self.self_attn_mid_layer_norm: hidden_states = self.self_attn_mid_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.residual_connection(hidden_states, residual) if not self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states if self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout(hidden_states) if self.ffn_layer_norm: hidden_states = self.ffn_layer_norm(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.residual_connection(hidden_states, residual) if not self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class TiODecoderLayer(nn.Module): r""" TiO decoder layer implementation. Args: config: configuration for TiO. drop_path_rate: the ratio for drop path. """ def __init__(self, config: TiOConfig, drop_path_rate=0.0): super().__init__() self.embed_dim = config.d_model self.self_attn = TiOAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, ) self.dropout = nn.Dropout(p=config.dropout) self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = nn.Dropout(p=config.activation_dropout) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None self.cross_attn = TiOAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, ) self.cross_attn_layer_norm = LayerNorm(self.embed_dim) self.cross_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.ffn_layer_norm = LayerNorm(config.decoder_ffn_dim) if config.normformer else None self.final_layer_norm = LayerNorm(self.embed_dim) self.normalize_before = config.decoder_normalize_before self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() def residual_connection(self, x, residual): r""" Residual connection with drop path. """ return residual + self.drop_path(x) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, self_attn_bias: Optional[torch.Tensor] = None, cross_attn_bias: Optional[torch.Tensor] = None, ): r""" Args: hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): input to the layer. attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`): attention mask where padding elements are indicated by very large negative values. encoder_hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`): cross attention input to the layer. encoder_attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`): encoder attention mask where padding elements are indicated by very large negative values. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): whether to return the attentions tensors of all attention layers. use_cache (`bool`, *optional*): whether to use cache self_attn_bias (`torch.FloatTensor`): self attention bias for positional information. cross_attn_bias (`torch.FloatTensor`): cross attention bias for positional information. """ # Self attention with intermediate layernorm residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to position 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, attn_bias=self_attn_bias, ) if self.self_attn_mid_layer_norm: hidden_states = self.self_attn_mid_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.residual_connection(hidden_states, residual) if not self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) # Cross attention with intermediate layernorm cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states if self.normalize_before: hidden_states = self.cross_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, attn_bias=cross_attn_bias, ) if self.cross_attn_mid_layer_norm: hidden_states = self.cross_attn_mid_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.residual_connection(hidden_states, residual) if not self.normalize_before: hidden_states = self.cross_attn_layer_norm(hidden_states) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # FFN with intermediate layernorm residual = hidden_states if self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout(hidden_states) if self.ffn_layer_norm: hidden_states = self.ffn_layer_norm(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.residual_connection(hidden_states, residual) if not self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs class TiOPreTrainedModel(PreTrainedModel): r""" Base class TiO """ config_class = TiOConfig base_model_prefix = "model" supports_gradient_checkpointing = True def _init_weights(self, module): r""" Weight initialization which follows BERT. """ std = self.config.init_std if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): r""" Turn on the switch of gradient checkpointing. """ if isinstance(module, (TiODecoder, TiOEncoder)): module.gradient_checkpointing = value @dataclass class TiOEncoderOutput(ModelOutput): r""" Base class for TiO's outputs. Args: last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(bsz, seq_len, hidden)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(bsz, num_heads, seq_len, seq_len)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): postional embeddings of the inputs. """ last_hidden_state: torch.FloatTensor = None padding_mask: torch.Tensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None position_embedding: Optional[torch.FloatTensor] = None TiO_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`~TiOConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ TiO_GENERATION_EXAMPLE = r""" Image captioning example: ```python >>> from PIL import Image >>> from torchvision import transforms >>> from transformers import TiOTokenizer, TiOForConditionalGeneration >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] >>> resolution = 256 >>> patch_resize_transform = transforms.Compose([ lambda image: image.convert("RGB"), transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) >>> model = TiOForConditionalGeneration.from_pretrained(ckpt_dir) >>> tokenizer = TiOTokenizer.from_pretrained(ckpt_dir) >>> txt = " what is the description of the image?" >>> inputs = tokenizer([txt], max_length=1024, return_tensors="pt")["input_ids"] >>> img = Image.open(path_to_image) >>> patch_img = patch_resize_transform(img).unsqueeze(0) >>> gen = model.generate(inputs, patch_img=patch_img, num_beams=4) >>> print(tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False)) ``` """ TiO_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of input sequence tokens in the vocabular, and padding will be ignored by default; indices can be obtained using [`~TiOTokenizer`]. patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image, which are transformed by the default operations. patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the second (if it exists) image. patch_masks (`torch.BoolTensor`): the patches to be masked. token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. sample_patch_num (`int`): the number of patches to sample. decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. encoder_outputs (`TiOEncoderOutput`): encoder outputs with hidden states, positional embeddings, and padding masks. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of shape `(bsz, num_heads, src_len, head_size)`. use_cache (`bool`): whether to use cache for faster inference. output_attentions (`bool`): whether to output attention weights. output_hidden_states (`bool`): whether to output hidden states. return_dict (`bool`): unused. Keep it for generation only. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ class TiOEncoder(TiOPreTrainedModel): r""" TiO encoder consisting of layers of [`TiOEncoderLayer`]. Args: config: TiOConfig embed_tokens (`nn.Embedding`, *optional*): output embedding """ def __init__(self, config: TiOConfig, embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) self.dropout = nn.Dropout(config.dropout) self.encoder_layerdrop = config.encoder_layerdrop embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.num_attention_heads = config.encoder_attention_heads if getattr(config, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if embed_tokens is not None: self.embed_tokens = embed_tokens else: self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) if config.add_type_embedding: self.type_embedding = Embedding(2, embed_dim, padding_idx=None) else: self.type_embedding = None if config.resnet_type == "resnet18": self.embed_images = ResNet([2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) elif config.resnet_type == "resnet34": self.embed_images = ResNet([3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) elif config.resnet_type == "resnet50": self.embed_images = ResNet([3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) elif config.resnet_type == "resnet101": self.embed_images = ResNet([3, 4, 23], drop_path_rate=config.resnet_drop_path_rate) elif config.resnet_type == "resnet152": self.embed_images = ResNet([3, 8, 36], drop_path_rate=config.resnet_drop_path_rate) else: raise NotImplementedError self.image_proj = Linear(1024, embed_dim) if config.resnet_model_path: resnet_state_dict = torch.load(config.resnet_model_path) self.embed_images.load_state_dict(resnet_state_dict) if config.patch_layernorm_embedding: self.patch_layernorm_embedding = LayerNorm(embed_dim) else: self.patch_layernorm_embedding = None self.embed_positions = Embedding(self.max_source_positions + 2, embed_dim) self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, embed_dim) self.pos_ln = LayerNorm(embed_dim) self.image_pos_ln = LayerNorm(embed_dim) self.pos_scaling = float(embed_dim / self.num_attention_heads * config.attn_scale_factor) ** -0.5 self.pos_q_linear = nn.Linear(embed_dim, embed_dim) self.pos_k_linear = nn.Linear(embed_dim, embed_dim) if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) dpr = [x.item() for x in torch.linspace(0, config.encoder_drop_path_rate, config.encoder_layers)] self.layers.extend( [TiOEncoderLayer(config, drop_path_rate=dpr[i]) for i in range(config.encoder_layers)] ) self.num_layers = len(self.layers) if config.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.token_bucket_size = config.token_bucket_size token_num_rel_dis = 2 * config.token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(config.token_bucket_size) self.token_rel_pos_table_list = nn.ModuleList( [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(config.encoder_layers)] ) self.image_bucket_size = config.image_bucket_size image_num_rel_dis = (2 * config.image_bucket_size - 1) * (2 * config.image_bucket_size - 1) + 3 image_rp_bucket = make_image_bucket_position(config.image_bucket_size, image_num_rel_dis) self.image_rel_pos_table_list = nn.ModuleList( [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(config.encoder_layers)] ) if config.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None self.register_buffer("token_rp_bucket", token_rp_bucket) self.register_buffer("image_rp_bucket", image_rp_bucket) self.entangle_position_embedding = config.entangle_position_embedding self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): r""" Get the embedding weight. """ return self.embed_tokens def set_input_embeddings(self, value): r""" Set the weight of embedding with the given tensor. """ self.embed_tokens = value def get_rel_pos_bias(self, x, idx): r""" Get the relative positional bias of the text, for attention. """ seq_len = x.size(1) rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) values = values.permute([0, 3, 1, 2]) return values.contiguous() def get_image_rel_pos_bias(self, image_position_ids, idx): r""" Get the relative positional bias of the image, for attention. """ bsz, seq_len = image_position_ids.shape rp_bucket_size = self.image_rp_bucket.size(1) rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( bsz, rp_bucket_size, rp_bucket_size ).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size) ).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len)) values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) values = values.permute(0, 3, 1, 2) return values def get_patch_images_info(self, patch_images, sample_patch_num, device): r""" Get the basic information of the resized image. Args: patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image. sample_patch_num (`int`): the number of patches to sample. If it is equal to -1, no sampling will be performed. device: GPU device. Returns: image_embed (`torch.FloatTensor` of shape `(bsz, h * w, hidden)`): the output of the visual encoder. image_num_patches (`int`, equal to `h * w`): the number of patches. image_padding_mask (`torch.BooleanTensor` of shape `(bsz, h*w)`): image padding mask. image_position_ids (`torch.LongTensor` of shape `(bsz, h*w)`): image position ids. image_pos_embed (`torch.FloatTensor` of shape (bsz, h*w, hidden)): the positional embedding. """ image_embed = self.embed_images(patch_images) h, w = image_embed.shape[-2:] image_num_patches = h * w image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool() image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \ torch.arange(h).unsqueeze(1) * self.image_bucket_size + 1 image_position_idx = image_position_idx.view(-1).to(device) image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches) image_embed = image_embed.flatten(2).transpose(1, 2) if sample_patch_num is not None: patch_orders = [ random.sample(range(image_num_patches), k=sample_patch_num) for _ in range(patch_images.size(0)) ] patch_orders = torch.LongTensor(patch_orders, device=device) image_embed = image_embed.gather( 1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2)) ) image_num_patches = sample_patch_num image_padding_mask = image_padding_mask.gather(1, patch_orders) image_position_ids = image_position_ids.gather(1, patch_orders) image_pos_embed = self.embed_image_positions(image_position_ids) return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed def forward_embedding( self, input_ids, image_embed: Optional[torch.Tensor] = None, image_embed_2: Optional[torch.Tensor] = None, token_embedding: Optional[torch.Tensor] = None, pos_embed: Optional[torch.Tensor] = None, image_pos_embed: Optional[torch.Tensor] = None, image_pos_embed_2: Optional[torch.Tensor] = None ): r""" Generate embeddings of both the image and the text. Actually since TiO unifies both unimodal and multimodal data, image inputs are optional. Args: input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the tokens in the vocabulary. image_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): image embeddings. image_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): image embeddings of the second image (if it exists). token_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`, *optional*): input token embeddings to replace the embeddings of input ids. image_pos_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): positional embeddings of the image. image_pos_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): positional embeddings of the second image. Returns: x (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): embeddings of the input. embed (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): embeddings without adding positional and type embeddings. """ # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(input_ids) x = embed = self.embed_scale * token_embedding if self.entangle_position_embedding and pos_embed is not None: x += pos_embed if self.type_embedding is not None: x += self.type_embedding(input_ids.new_zeros(x.size()[:2])) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout(x) # embed raw images if image_embed is not None: image_embed = self.image_proj(image_embed) image_x = image_embed = self.embed_scale * image_embed if self.entangle_position_embedding and image_pos_embed is not None: image_x += image_pos_embed if self.type_embedding is not None: image_x += self.type_embedding(input_ids.new_ones(image_x.size()[:2])) if self.patch_layernorm_embedding is not None: image_x = self.patch_layernorm_embedding(image_x) image_x = self.dropout(image_x) x = torch.cat([image_x, x], dim=1) embed = torch.cat([image_embed, embed], dim=1) if image_embed_2 is not None: assert self.type_embedding is not None image_embed_2 = self.image_proj(image_embed_2) image_x_2 = image_embed_2 = self.embed_scale * image_embed_2 if self.entangle_position_embedding and image_pos_embed_2 is not None: image_x_2 += image_pos_embed_2 if self.type_embedding is not None: image_x_2 += self.type_embedding(input_ids.new_full(image_x_2.size()[:2], fill_value=2)) if self.patch_layernorm_embedding is not None: image_x_2 = self.patch_layernorm_embedding(image_x_2) image_x_2 = self.dropout(image_x_2) if self.quant_noise is not None: image_x_2 = self.quant_noise(image_x_2) x = torch.cat([image_x_2, x], dim=1) embed = torch.cat([image_embed_2, embed], dim=1) return x, embed def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if "last_hidden_state" not in encoder_out: new_encoder_out = None else: new_encoder_out = encoder_out["last_hidden_state"].index_select(0, new_order) if "padding_mask" not in encoder_out: new_encoder_padding_mask = None else: new_encoder_padding_mask = encoder_out["padding_mask"].index_select(0, new_order) if "position_embedding" not in encoder_out: new_position_embeddings = None else: new_position_embeddings = encoder_out["position_embedding"].index_select(0, new_order) if "hidden_states" not in encoder_out: new_encoer_states = None else: encoder_states = encoder_out["hidden_states"] new_encoer_states = () if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): new_encoer_states += (state.index_select(0, new_order),) if "attentions" not in encoder_out: attentions = None else: attentions = encoder_out["attentions"] return TiOEncoderOutput( last_hidden_state=new_encoder_out, padding_mask=new_encoder_padding_mask, hidden_states=new_encoer_states, attentions=attentions, position_embedding=new_position_embeddings ) def forward( self, input_ids=None, patch_images: Optional[torch.Tensor] = None, patch_images_2: Optional[torch.Tensor] = None, patch_masks: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, token_embeddings: Optional[torch.Tensor] = None, sample_patch_num: Optional[int] = None, ): r""" Args: input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of input sequence tokens in the vocabular, and padding will be ignored by default; indices can be obtained using [`~TiOTokenizer`]. patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image, which are transformed by the default operations. patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the second (if it exists) image. patch_masks (`torch.BoolTensor`): the patches to be masked. output_attentions (`bool`): whether to return all attention weights, output_hidden_states (`bool`): whether to return all hidden states. token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. sample_patch_num (`int`): the number of patches to sample. Returns: [`TiOEncoderOutput`]: last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): the states of the last layer. padding_mask (`torch.BoolTensor` of shape `(bsz, seq_len)`): the padding mask of the source context. hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): the states of all layers including the embeddings. attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`): the attention weights of all layers. position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): positional embeddings of the input image and tokens. """ image_embed = None image_embed_2 = None image_pos_embed = None image_pos_embed_2 = None if patch_images is not None: image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device) # image_padding_mask[~patch_masks] = True # comment the line to temporarily fix the bug of mismatch if patch_images_2 is not None: image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device) image_padding_mask_2[~patch_masks] = True encoder_padding_mask = input_ids.eq(self.padding_idx) if patch_images is not None: encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1) if patch_images_2 is not None: encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1) has_pads = encoder_padding_mask.any() pos_embed = self.embed_positions(new_arange(input_ids)) x, encoder_embedding = self.forward_embedding( input_ids, image_embed, image_embed_2, token_embeddings, pos_embed, image_pos_embed, image_pos_embed_2 ) # account for padding while computing the representation if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) pos_embed = self.pos_ln(pos_embed) if patch_images is not None: image_pos_embed = self.image_pos_ln(image_pos_embed) pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) if patch_images_2 is not None: image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) pos_q = self.pos_q_linear(pos_embed).view( x.size(0), x.size(1), self.num_attention_heads, -1 ).transpose(1, 2) * self.pos_scaling pos_k = self.pos_k_linear(pos_embed).view( x.size(0), x.size(1), self.num_attention_heads, -1 ).transpose(1, 2) abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) # expand attention_mask if has_pads: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(~encoder_padding_mask, dtype=x.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # encoder layers for idx, layer in enumerate(self.layers): if output_hidden_states: encoder_states += (x,) self_attn_bias = abs_pos_bias.clone() self_attn_bias[:, :, -input_ids.size(1):, -input_ids.size(1):] += self.get_rel_pos_bias(input_ids, idx) if patch_images_2 is not None: self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ self.get_image_rel_pos_bias(image_position_ids_2, idx) self_attn_bias[:, :, image_num_patches_2:image_num_patches_2 + image_num_patches, image_num_patches_2:image_num_patches_2 + image_num_patches] += \ self.get_image_rel_pos_bias(image_position_ids, idx) elif patch_images is not None: self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ self.get_image_rel_pos_bias(image_position_ids, idx) self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) hidden_outputs = layer(x, attention_mask if has_pads else None, attn_bias=self_attn_bias, output_attentions=output_attentions) x = hidden_outputs[0] if output_attentions: attention = hidden_outputs[1] all_attentions = all_attentions + (attention,) if output_hidden_states: encoder_states += (x,) if self.layer_norm is not None: x = self.layer_norm(x) return TiOEncoderOutput( last_hidden_state=x, padding_mask=encoder_padding_mask, hidden_states=encoder_states, attentions=all_attentions, position_embedding=pos_embed, ) class TiODecoder(TiOPreTrainedModel): r""" TiO decoder consisting of layers of [`TiODecoderLayer`] Args: config: TiOConfig embed_tokens (`nn.Embedding`, *optional*): output embedding """ def __init__(self, config: TiOConfig, embed_tokens: Optional[nn.Embedding] = None, output_projection=None): super().__init__(config) self.dropout = nn.Dropout(config.dropout) self.decoder_layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self._future_mask = torch.empty(0) self.share_input_output_embed = config.share_decoder_input_output_embed self.num_attention_heads = config.decoder_attention_heads if embed_tokens is not None: self.embed_tokens = embed_tokens else: self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_dim = config.d_model self.output_embed_dim = config.d_model self.layers = nn.ModuleList([TiODecoderLayer(config) for _ in range(config.decoder_layers)]) if config.layernorm_embedding: self.layernorm_embedding = LayerNorm(self.embed_dim) else: self.layernorm_embedding = None self.window_size = config.code_image_size // 8 self.embed_positions = Embedding(self.max_target_positions + 2, self.embed_dim) self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, self.embed_dim) self.pos_ln = LayerNorm(self.embed_dim) self.image_pos_ln = LayerNorm(self.embed_dim) self.pos_scaling = float(self.embed_dim / self.num_attention_heads * config.attn_scale_factor) ** -0.5 self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) if config.code_layernorm_embedding: self.code_layernorm_embedding = LayerNorm(self.embed_dim) else: self.code_layernorm_embedding = None if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) dpr = [x.item() for x in torch.linspace(0, config.decoder_drop_path_rate, config.decoder_layers)] self.layers.extend([TiODecoderLayer(config, drop_path_rate=dpr[i]) for i in range(config.decoder_layers)]) self.num_layers = len(self.layers) if config.decoder_normalize_before: self.layer_norm = LayerNorm(self.embed_dim) else: self.layer_norm = None self.adaptive_softmax = None self.output_projection = output_projection if self.output_projection is None: self.build_output_projection(config) self.token_bucket_size = config.token_bucket_size token_num_rel_dis = 2 * config.token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(config.token_bucket_size) self.token_rel_pos_table_list = nn.ModuleList( [ Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(config.decoder_layers) ] ) self.image_bucket_size = config.image_bucket_size image_num_rel_dis = (2 * config.image_bucket_size - 1) * (2 * config.image_bucket_size - 1) + 3 image_rp_bucket = make_image_bucket_position(config.image_bucket_size, image_num_rel_dis) image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)]) image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 768)]) self.image_rel_pos_table_list = nn.ModuleList( [ Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(config.decoder_layers) ] ) self.register_buffer("token_rp_bucket", token_rp_bucket) self.register_buffer("image_rp_bucket", image_rp_bucket) self.register_buffer("image_position_idx", image_position_idx) self.entangle_position_embedding = config.entangle_position_embedding self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def build_output_projection(self, config): if self.share_input_output_embed: self.output_projection = nn.Linear( self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = self.embed_tokens.weight else: self.output_projection = nn.Linear( self.output_embed_dim, config.vocab_size, bias=False ) nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5) def get_rel_pos_bias(self, x, idx): r""" Get the relative positional bias of the text, for attention. """ seq_len = x.size(1) rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) values = values.permute([2, 0, 1]) return values.contiguous() def get_image_rel_pos_bias(self, x, idx): r""" Get the relative positional bias of the image, for attention. """ seq_len = x.size(1) image_position_idx = self.image_position_idx[:seq_len] rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx] values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) values = values.permute(2, 0, 1) return values def get_pos_info(self, tgt_pos_embed, src_pos_embed=None, use_image=False): r""" Get the positional information. Args: tgt_pos_embed (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): the target-side positional embeddings. src_pos_embed (`torch.FloatTensor` of shape `(bsz, src_len, embed_dim)`, *optional*): the source-side positional embeddings. use_image (`bool`): whether to use image. Returns: abs_pos_bias (`torch.FloatTensor` of shape `(bsz, src_len, tgt_len, src_len)`): absolute positional bias for attention. """ batch_size = tgt_pos_embed.size(0) tgt_len = tgt_pos_embed.size(1) tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) if src_pos_embed is not None: src_len = src_pos_embed.size(1) pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( batch_size, tgt_len, self.num_attention_heads, -1 ).transpose(1, 2) * self.pos_scaling pos_k = self.cross_pos_k_linear(src_pos_embed).view( batch_size, src_len, self.num_attention_heads, -1 ).transpose(1, 2) else: src_len = tgt_pos_embed.size(1) pos_q = self.self_pos_q_linear(tgt_pos_embed).view( batch_size, tgt_len, self.num_attention_heads, -1 ).transpose(1, 2) * self.pos_scaling pos_k = self.self_pos_k_linear(tgt_pos_embed).view( batch_size, src_len, self.num_attention_heads, -1 ).transpose(1, 2) abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) return abs_pos_bias def get_input_embeddings(self): r""" Get the input embeddings """ return self.embed_tokens def set_input_embeddings(self, value): r""" Set the weights of the embeddings with the given tensor. """ self.embed_tokens = value def _prepare_decoder_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length): r""" Create causal mask for unidirectional decoding. [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] """ combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, dtype, past_key_values_length=past_key_values_length ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_len=input_shape[-1]) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask.to(expanded_attn_mask.device) ) return combined_attention_mask def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return self.max_target_positions def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample) def get_normalized_probs_scriptable( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "target" in sample target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: return F.log_softmax(logits, dim=-1) else: return F.softmax(logits, dim=-1) def reorder_incremental_state_scripting( self, # incremental_state: Dict[str, Dict[str, Optional[Tensor]]], past_key_values: Optional[torch.Tensor], new_order: Tensor, ): """Main entry point for reordering the incremental state. Due to limitations in TorchScript, we call this function in :class:`fairseq.sequence_generator.SequenceGenerator` instead of calling :func:`reorder_incremental_state` directly. """ input_buffer = past_key_values new_past_key_values = [] if input_buffer is not None: for input_buffer_k in input_buffer: new_input_buffer_k = [] for input in input_buffer_k: if input is None: input = None else: input = input.index_select(0, new_order) new_input_buffer_k.append(input) new_past_key_values.append(new_input_buffer_k) return new_past_key_values def forward( self, input_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, encoder_hidden_states: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None, code_masks: Optional[torch.Tensor] = None, src_pos_embed: torch.Tensor = None, past_key_values: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, ): r""" Args: input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): mask to avoid attention on padding tokens. encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden state of the encoder. encoder_attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): the padding mask of the source side. code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. src_pos_embed (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the positional embeddings of the source side. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of shape `(bsz, num_heads, src_len, head_size)`. use_cache (`bool`): whether to use cache for faster inference. output_attentions (`bool`): whether to output attention weights. output_hidden_states (`bool`): whether to output hidden states. Returns: BaseModelOutputWithPastAndCrossAttentions or a plain tuple: last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden states. past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference. hidden_states (`tuple(torch.FloatTensor)`): hidden states of all layers. attentions (`tuple(torch.FloatTensor)): self attention weights of all layers. cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if past_key_values is not None and len(past_key_values)>0: size = past_key_values[0][0].size() bsz, tgt_len = size[0], size[-2] + 1 token_position_idx = torch.arange(tgt_len, device=input_ids.device).expand([bsz, tgt_len]).contiguous() else: bsz, tgt_len = input_ids.shape token_position_idx = new_arange(input_ids) tgt_pos_embed = self.embed_positions(token_position_idx) if code_masks is not None and torch.any(code_masks): image_position_idx = self.image_position_idx[:input_ids.size(1)].unsqueeze(0).expand(bsz, tgt_len) tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks] # self attn position bias self_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=False) if code_masks is not None and torch.any(code_masks): self_image_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=True) self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] # cross attn position bias cross_abs_pos_bias = self.get_pos_info(tgt_pos_embed, src_pos_embed=src_pos_embed) if code_masks is not None and torch.any(code_masks): cross_image_abs_pos_bias = self.get_pos_info(tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks] cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:]) all_prev_output_tokens = input_ids.clone() if past_key_values is not None and len(past_key_values)>0: input_ids = input_ids[:, -1:] cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :] tgt_pos_embed = tgt_pos_embed[:, -1:, :] # embed tokens and positions x = self.embed_scale * self.embed_tokens(input_ids) if self.entangle_position_embedding and not self.disable_entangle: x += tgt_pos_embed if self.layernorm_embedding is not None: if code_masks is None or not code_masks.any() or not self.code_layernorm_embedding: x = self.layernorm_embedding(x) elif code_masks is not None and code_masks.all(): x = self.code_layernorm_embedding(x) else: x[~code_masks] = self.layernorm_embedding(x[~code_masks]) x[code_masks] = self.code_layernorm_embedding(x[code_masks]) hidden_states = self.dropout(x) # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None and len(past_key_values)>0 else 0 shape, dtype = input_ids.shape, hidden_states.dtype attention_mask = self._prepare_decoder_attention_mask(attention_mask, shape, dtype, past_key_values_length) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # decoder layers for idx, layer in enumerate(self.layers): # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None and len(past_key_values)>0 else None self_attn_bias = self_abs_pos_bias.clone() if code_masks is None or not code_masks.any(): self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) elif code_masks is not None and code_masks.all(): self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) else: self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:]) if past_key_value is not None and len(past_key_values)>0 : self_attn_bias = self_attn_bias[:, -1:, :] layer_outputs = layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, self_attn_bias=self_attn_bias, cross_attn_bias=cross_abs_pos_bias, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if self.layer_norm is not None: hidden_states = self.layer_norm(hidden_states) if self.output_projection is not None: hidden_states = self.output_projection(hidden_states) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) # @add_start_docstrings( # "The bare TiO Model outputting raw hidden-states without any specific head on top.", # TiO_START_DOCSTRING, # ) class TiOModel(TiOPreTrainedModel): r""" The TiO model built with an encoder and a decoder only, without any classification head. Args: config (TiOConfig): TiO configuration. """ def __init__(self, config: TiOConfig, **kwargs): super().__init__(config) self.disable_entangle = getattr(kwargs,'disable_entangle',False) self.padding_idx, vocab_size = config.pad_token_id, config.vocab_size shared = nn.Embedding(vocab_size, config.d_model, self.padding_idx) self.encoder = TiOEncoder(config, shared) self.decoder = TiODecoder(config, shared) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): r""" Retrieve input embeddings. """ return self.encoder.get_input_embeddings() def set_input_embeddings(self, value): r""" Set values for input embeddings """ shared = value self.encoder.embed_tokens = shared self.decoder.embed_tokens = shared def get_encoder(self): r""" Retrieve the encoder """ return self.encoder def get_decoder(self): r""" Retrieve the decoder """ return self.decoder # @add_start_docstrings_to_model_forward(TiO_INPUTS_DOCSTRING) # @add_code_sample_docstrings( # processor_class=_TOKENIZER_FOR_DOC, # checkpoint=_CHECKPOINT_FOR_DOC, # output_type=Seq2SeqModelOutput, # config_class=_CONFIG_FOR_DOC, # ) def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder.max_positions() def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample) def get_normalized_probs_scriptable( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" if hasattr(self, "decoder"): return self.decoder.get_normalized_probs(net_output, log_probs, sample) elif torch.is_tensor(net_output): # syntactic sugar for simple models which don't have a decoder # (e.g., the classification tutorial) logits = net_output.float() if log_probs: return F.log_softmax(logits, dim=-1) else: return F.softmax(logits, dim=-1) raise NotImplementedError def forward( self, input_ids=None, patch_images=None, patch_images_2=None, patch_masks=None, token_embeddings=None, sample_patch_num=None, decoder_input_ids=None, code_masks=None, attention_mask=None, encoder_outputs=None, past_key_values=None, use_cache=False, output_attentions=False, output_hidden_states=False, return_dict=False, **args ): r""" Args: input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of input sequence tokens in the vocabular, and padding will be ignored by default; indices can be obtained using [`~TiOTokenizer`]. patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image, which are transformed by the default operations. patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the second (if it exists) image. patch_masks (`torch.BoolTensor`): the patches to be masked. token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. sample_patch_num (`int`): the number of patches to sample. decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. encoder_outputs (`TiOEncoderOutput`): encoder outputs with hidden states, positional embeddings, and padding masks. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of shape `(bsz, num_heads, src_len, head_size)`. use_cache (`bool`): whether to use cache for faster inference. output_attentions (`bool`): whether to output attention weights. output_hidden_states (`bool`): whether to output hidden states. return_dict (`bool`): unused. Keep it for generation only. Returns: Seq2SeqLMOutput: logits (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last decoder hidden states. past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference. decoder_hidden_states (`tuple(torch.FloatTensor)`): the decoder hidden states of all layers. decoder_attentions (`tuple(torch.FloatTensor)): the decoder self attention weights of all layers. cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers. encoder_last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): the encoder last hidden state. encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): the encoder states of all layers including the embeddings. encoder_attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`): the encoder attention weights of all layers. """ # 适配新的tokenizer输入 patch_images = args.get("pixel_values", patch_images) if "labels_attention_mask" in args: attention_mask = args.get("labels_attention_mask")[:-1] decoder_input_ids = args.get("labels")[:-1] output_attentions = output_attentions if output_attentions else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, patch_images=patch_images, patch_images_2=patch_images_2, patch_masks=patch_masks, output_attentions=output_attentions, output_hidden_states=output_hidden_states, token_embeddings=token_embeddings, sample_patch_num=sample_patch_num, ) # if decoder_input_ids.eq(self.config.pad_token_id).any(): # attention_mask = decoder_input_ids.eq(self.padding_idx) encoder_hidden_states = encoder_outputs.last_hidden_state if past_key_values is not None and len(past_key_values)>0: encoder_attention_mask = _expand_mask( ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids[:, -1:].shape[-1] ) else: encoder_attention_mask = _expand_mask( ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids.shape[-1] ) src_pos_embed = encoder_outputs.position_embedding decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, code_masks=code_masks, src_pos_embed=src_pos_embed, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) logits = decoder_outputs.last_hidden_state loss = None return Seq2SeqLMOutput( loss=loss, logits=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) def prepare_inputs_for_generation( self, decoder_input_ids=None, past=None, attention_mask=None, code_masks=None, use_cache=False, encoder_outputs=None, **kwargs ): # if attention_mask is None: attention_mask = decoder_input_ids.new_ones(decoder_input_ids.shape) # cut decoder_input_ids if past is used # if past is not None: # decoder_input_ids = decoder_input_ids[:, -1:] return { "input_ids": None, "patch_images": None, "patch_images_2": None, "patch_masks": None, "token_embeddings": None, "sample_patch_num": None, "attention_mask": attention_mask, "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "code_masks": code_masks, "use_cache": use_cache, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ): # 1. get encoder encoder = self.get_encoder() # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "attention_mask"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } if encoder_kwargs.get("patch_masks") is None: encoder_kwargs["patch_masks"] = torch.ones((len(inputs_tensor), 1), dtype=torch.bool, device=inputs_tensor.device) # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.main_input_name encoder_kwargs[model_input_name] = inputs_tensor model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) model_kwargs["attention_mask"] = None return model_kwargs @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () for layer_past in past: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past @staticmethod def _expand_inputs_for_generation( input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[ModelOutput] = None, **model_kwargs, ): expanded_return_idx = ( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) if is_encoder_decoder: if encoder_outputs is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) ) encoder_outputs["position_embedding"] = encoder_outputs.position_embedding.index_select( 0, expanded_return_idx.to(encoder_outputs.position_embedding.device) ) encoder_outputs["padding_mask"] = encoder_outputs.padding_mask.index_select( 0, expanded_return_idx.to(encoder_outputs.padding_mask.device) ) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs from .utils_tio import Utils from .gradio_app import get_gradio_demo TiOModel.utils = Utils TiOModel.get_gradio_demo = get_gradio_demo