# -------------------------------------------------------- # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) # Github source: https://github.com/microsoft/unilm/tree/master/beit3 # Copyright (c) 2023 Microsoft # Licensed under The MIT License [see LICENSE for details] # --------------------------------------------------------' import torch import torch.nn as nn import torch.nn.functional as F from timm.models.registry import register_model import numpy as np import utils from modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config class TwoLayerMLP(nn.Module): def __init__( self, in_features, hidden_features, out_features, norm_layer, norm_input=True, ): super().__init__() self.norm1 = norm_layer(in_features) if norm_input else nn.Identity() self.dense1 = nn.Linear(in_features, hidden_features) self.norm2 = norm_layer(hidden_features) self.act = nn.GELU() self.dense2 = nn.Linear(hidden_features, out_features) def forward(self, x): x = self.norm1(x) x = self.dense1(x) x = self.norm2(x) x = self.act(x) return self.dense2(x) class Pooler(nn.Module): def __init__(self, input_features, output_features, norm_layer): super().__init__() self.norm = norm_layer(input_features) self.dense = nn.Linear(input_features, output_features) self.activation = nn.Tanh() def forward(self, x): cls_rep = x[:, 0, :] cls_rep = self.norm(cls_rep) pooled_output = self.dense(cls_rep) pooled_output = self.activation(pooled_output) return pooled_output class BEiT3ForVisualReasoning(BEiT3Wrapper): def __init__( self, args, num_classes, norm_layer=nn.LayerNorm, **kwargs ): super(BEiT3ForVisualReasoning, self).__init__(args=args) embed_dim = args.encoder_embed_dim self.head = TwoLayerMLP( in_features=embed_dim * 4, hidden_features=embed_dim * 2, out_features=num_classes, norm_layer=norm_layer, ) init_scale = 0.001 self.head.apply(self._init_weights) if isinstance(self.head.dense1, nn.Linear): self.head.dense1.weight.data.mul_(init_scale) self.head.dense1.bias.data.mul_(init_scale) if isinstance(self.head.dense2, nn.Linear): self.head.dense2.weight.data.mul_(init_scale) self.head.dense2.bias.data.mul_(init_scale) def forward(self, image_a, image_b, text_description, padding_mask, **kwargs): bsz, _ = text_description.size() vision_input = torch.cat((image_a, image_b), dim=0) language_input = torch.cat((text_description, text_description), dim=0) padding_mask = torch.cat((padding_mask, padding_mask), dim=0) outputs = self.beit3( textual_tokens=language_input, visual_tokens=vision_input, text_padding_position=padding_mask, ) x = outputs["encoder_out"] multiway_split_position = outputs["multiway_split_position"] vision_cls = x[:, 0, :] language_cls = x[:, multiway_split_position, :] cls_rep = torch.cat((vision_cls, language_cls), dim=-1) a, b = torch.split(cls_rep, split_size_or_sections=[bsz, bsz], dim=0) cls_rep = torch.cat((a, b), dim=-1) return self.head(cls_rep) class BEiT3ForImageClassification(BEiT3Wrapper): def __init__( self, args, num_classes, norm_layer=nn.LayerNorm, **kwargs ): super(BEiT3ForImageClassification, self).__init__(args=args) embed_dim = args.encoder_embed_dim self.fc_norm = norm_layer(embed_dim) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.fc_norm.apply(self._init_weights) self.head.apply(self._init_weights) init_scale = 0.001 if isinstance(self.head, nn.Linear): self.head.weight.data.mul_(init_scale) self.head.bias.data.mul_(init_scale) def forward(self, image, **kwargs): x = self.beit3(textual_tokens=None, visual_tokens=image)["encoder_out"] t = x[:, 1:, :] cls_x = self.fc_norm(t.mean(1)) return self.head(cls_x) class BEiT3ForCaptioning(BEiT3Wrapper): def __init__( self, args, **kwargs ): super(BEiT3ForCaptioning, self).__init__(args=args) embed_dim = args.encoder_embed_dim self.mlm_head = nn.Linear(embed_dim, args.vocab_size) self.mlm_head.apply(self._init_weights) def forward(self, image, text_ids, padding_mask, language_masked_pos, text_len=None, incremental_state=None, **kwargs): text_len = text_len if text_len is not None else text_ids.size(1) image_len = self.beit3.vision_embed.num_position_embeddings() max_len = text_len + image_len uni_mask = torch.zeros((max_len, max_len), dtype=torch.long, device=text_ids.device) i_start, i_end = 0, image_len t_start, t_end = image_len, max_len # triangle mask for caption to caption uni_mask[t_start:t_end, t_start:t_end] = torch.tril(torch.ones(text_len, text_len, dtype=torch.long, device=text_ids.device)) # full attention for caption to image uni_mask[t_start:t_end, i_start:i_end] = 1 # full attention for image to image uni_mask[i_start:i_end, i_start:i_end] = 1 uni_mask = 1-uni_mask if incremental_state is not None: for idx in range(self.get_num_layers()): if idx not in incremental_state: incremental_state[idx] = {} # for incremental decoding positions = None if image is None: uni_mask = uni_mask[-2:] padding_mask = None # start position (2 (fairseq starts at 2) + cur_position) is equal to text_len positions = torch.arange(text_len, text_ids.size(1) + text_len, device=text_ids.device).long().unsqueeze(0) outputs = self.beit3( textual_tokens=text_ids, visual_tokens=image, text_padding_position=padding_mask, attn_mask=uni_mask, incremental_state=incremental_state, positions=positions, ) if image is not None: text_feats = outputs["encoder_out"][:, image_len:] else: text_feats = outputs["encoder_out"] if language_masked_pos is not None: text_feats = text_feats[language_masked_pos.bool()] return self.mlm_head(text_feats), incremental_state class BEiT3ForVisualQuestionAnswering(BEiT3Wrapper): def __init__( self, args, num_classes, norm_layer=nn.LayerNorm, **kwargs ): super(BEiT3ForVisualQuestionAnswering, self).__init__(args=args) embed_dim = args.encoder_embed_dim self.pooler = Pooler( input_features=embed_dim, output_features=embed_dim, norm_layer=norm_layer, ) self.pooler.apply(self._init_weights) self.head = nn.Sequential( nn.Linear(embed_dim, embed_dim * 2), norm_layer(embed_dim * 2), nn.GELU(), nn.Linear(embed_dim * 2, num_classes), ) self.head.apply(self._init_weights) def forward(self, image, question, padding_mask, **kwargs): outputs = self.beit3( textual_tokens=question, visual_tokens=image, text_padding_position=padding_mask, ) x = outputs["encoder_out"] cls_rep = self.pooler(x) return self.head(cls_rep) class BEiT3ForRetrieval(BEiT3Wrapper): def __init__( self, args, **kwargs ): super(BEiT3ForRetrieval, self).__init__(args=args) embed_dim = args.encoder_embed_dim self.language_head = nn.Linear(embed_dim, embed_dim, bias=False) self.vision_head = nn.Linear(embed_dim, embed_dim, bias=False) self.language_head.apply(self._init_weights) self.vision_head.apply(self._init_weights) self.criterion = utils.ClipLoss( rank=utils.get_rank(), world_size=utils.get_world_size(), ) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def forward(self, image=None, text_description=None, padding_mask=None, only_infer=False, **kwargs): if image is not None: outputs = self.beit3( textual_tokens=None, visual_tokens=image, text_padding_position=None, ) x = outputs["encoder_out"] vision_cls = self.vision_head(x[:, 0, :]) vision_cls = F.normalize(vision_cls, dim=-1) else: vision_cls = None if text_description is not None: outputs = self.beit3( textual_tokens=text_description, visual_tokens=None, text_padding_position=padding_mask, ) x = outputs["encoder_out"] language_cls = self.language_head(x[:, 0, :]) language_cls = F.normalize(language_cls, dim=-1) else: language_cls = None if only_infer: return vision_cls, language_cls else: loss, logits_per_image, logits_per_text = self.criterion( vision_cls, language_cls, self.logit_scale.exp()) return loss, vision_cls, language_cls @register_model def beit3_base_patch16_224_imageclassification(pretrained=False, **kwargs): args = _get_base_config(**kwargs) args.normalize_output = False model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs) return model @register_model def beit3_large_patch16_224_imageclassification(pretrained=False, **kwargs): args = _get_large_config(**kwargs) args.normalize_output = False model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs) return model @register_model def beit3_base_patch16_224_nlvr2(pretrained=False, **kwargs): args = _get_base_config(**kwargs) model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs) return model @register_model def beit3_large_patch16_224_nlvr2(pretrained=False, **kwargs): args = _get_large_config(**kwargs) model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs) return model @register_model def beit3_base_patch16_384_vqav2(pretrained=False, **kwargs): args = _get_base_config(img_size=384, **kwargs) args.normalize_output = False model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) return model @register_model def beit3_base_patch16_480_vqav2(pretrained=False, **kwargs): args = _get_base_config(img_size=480, **kwargs) args.normalize_output = False model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) return model @register_model def beit3_large_patch16_384_vqav2(pretrained=False, **kwargs): args = _get_large_config(img_size=384, **kwargs) args.normalize_output = False model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) return model @register_model def beit3_large_patch16_480_vqav2(pretrained=False, **kwargs): args = _get_large_config(img_size=480, **kwargs) args.normalize_output = False model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) return model @register_model def beit3_large_patch16_768_vqav2(pretrained=False, **kwargs): args = _get_large_config(img_size=768, **kwargs) args.normalize_output = False model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs) return model @register_model def beit3_base_patch16_224_captioning(pretrained=False, **kwargs): args = _get_base_config(**kwargs) model = BEiT3ForCaptioning(args, **kwargs) return model @register_model def beit3_base_patch16_480_captioning(pretrained=False, **kwargs): args = _get_base_config(img_size=480, **kwargs) model = BEiT3ForCaptioning(args, **kwargs) return model @register_model def beit3_large_patch16_480_captioning(pretrained=False, **kwargs): args = _get_large_config(img_size=480, **kwargs) model = BEiT3ForCaptioning(args, **kwargs) return model @register_model def beit3_base_patch16_224_retrieval(pretrained=False, **kwargs): args = _get_base_config(**kwargs) model = BEiT3ForRetrieval(args, **kwargs) return model @register_model def beit3_base_patch16_384_retrieval(pretrained=False, **kwargs): args = _get_base_config(img_size=384, **kwargs) model = BEiT3ForRetrieval(args, **kwargs) return model @register_model def beit3_large_patch16_384_retrieval(pretrained=False, **kwargs): args = _get_large_config(img_size=384, **kwargs) model = BEiT3ForRetrieval(args, **kwargs) return model