import torch import torch._utils import torch.nn as nn import torchvision.models as models from typing import Tuple from config import Config class Encoder(nn.Module): def __init__(self, image_emb_dim: int, device: torch.device): """ Image encoder to obtain features from images. Contains pretrained Resnet50 with last layer removed and a linear layer with the output dimension of (BATCH, image_emb_dim) """ super(Encoder, self).__init__() self.image_emb_dim = image_emb_dim self.device = device # pretrained Resnet50 model with freezed parameters resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) for param in resnet.parameters(): param.requires_grad_(False) # remove last layer modules = list(resnet.children())[:-1] self.resnet = nn.Sequential(*modules) # define a final classifier self.fc = nn.Linear(in_features=resnet.fc.in_features, out_features=self.image_emb_dim) def forward(self, images: torch.Tensor) -> torch.Tensor: """ Forward operation of encoder, passing images through resnet and then linear layer. Args: > images (torch.Tensor): (BATCH, 3, 224, 224) Returns: > features (torch.Tensor): (BATCH, IMAGE_EMB_DIM) """ features = self.resnet(images) # features: (BATCH, 2048, 1, 1) features = features.reshape(features.size(0), -1).to(self.device) # features: (BATCH, 2048) features = self.fc(features).to(self.device) # features: (BATCH, IMAGE_EMB_DIM) return features class Decoder(nn.Module): def __init__(self, image_emb_dim: int, word_emb_dim: int, hidden_dim: int, num_layers: int, vocab_size: int, device: torch.device): """ Decoder taking as input for the LSTM layer the concatenation of features obtained from the encoder and embedded captions obtained from the embedding layer. Hidden and cell states are randomly initialized. Final classifier is a linear layer with output dimension of the size of a vocabulary. """ super(Decoder, self).__init__() self.config = Config() self.image_emd_dim = image_emb_dim self.word_emb_dim = word_emb_dim self.hidden_dim = hidden_dim self.num_layer = num_layers self.vocab_size = vocab_size self.device = device self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim))) self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim))) self.lstm = nn.LSTM(input_size=self.image_emd_dim + self.word_emb_dim, hidden_size=self.hidden_dim, num_layers=self.num_layer, bidirectional=False) self.fc = nn.Sequential( nn.Linear(in_features=self.hidden_dim, out_featuers=self.vocab_size), nn.LogSoftmax(dim=2) ) def forward(self, embedded_captions: torch.Tensor, features: torch.Tensor, hidden: torch.Tensor, cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Forward operation of (word-by-word) decoder. The LSTM input (concatenation of embedded_captions and features) is passed through LSTM and then linear layer. Args: > embedded_captions(torch.Tensor): (SEQ_LENGTH = 1, BATCH, WORD_EMB_DIM) > features (torch.Tensor): (1, BATCH, IMAGE_EMB_DIM) > hidden (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM) > cell (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM) Returns: > output (torch.Tensor): (1, BATCH, VOCAB_SIZE) > (hidden, cell) (torch.Tensor, torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM) """ lstm_input = torch.cat((embedded_captions, features), dim=2) output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell)) # output : (SEQ_LENGTH, BATCH, HIDDEN_DIM) # hidden : (NUM_LAYER, BATCH, HIDDEN_DIM) output = output.to(self.device) output = self.fc(output) # output : (SEQ_LENGTH, BATCH, VOCAB_SIZE) return output, (hidden, cell)