ImageCaption / source /model.py
nssharmaofficial's picture
First init
cb7427c
raw
history blame
4.47 kB
import torch
import torch._utils
import torch.nn as nn
import torchvision.models as models
from typing import Tuple
from config import Config
class Encoder(nn.Module):
def __init__(self, image_emb_dim: int, device: torch.device):
""" Image encoder to obtain features from images. Contains pretrained Resnet50 with last layer removed
and a linear layer with the output dimension of (BATCH, image_emb_dim)
"""
super(Encoder, self).__init__()
self.image_emb_dim = image_emb_dim
self.device = device
# pretrained Resnet50 model with freezed parameters
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
for param in resnet.parameters():
param.requires_grad_(False)
# remove last layer
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
# define a final classifier
self.fc = nn.Linear(in_features=resnet.fc.in_features, out_features=self.image_emb_dim)
def forward(self, images: torch.Tensor) -> torch.Tensor:
""" Forward operation of encoder, passing images through resnet and then linear layer.
Args:
> images (torch.Tensor): (BATCH, 3, 224, 224)
Returns:
> features (torch.Tensor): (BATCH, IMAGE_EMB_DIM)
"""
features = self.resnet(images)
# features: (BATCH, 2048, 1, 1)
features = features.reshape(features.size(0), -1).to(self.device)
# features: (BATCH, 2048)
features = self.fc(features).to(self.device)
# features: (BATCH, IMAGE_EMB_DIM)
return features
class Decoder(nn.Module):
def __init__(self,
image_emb_dim: int,
word_emb_dim: int,
hidden_dim: int,
num_layers: int,
vocab_size: int,
device: torch.device):
"""
Decoder taking as input for the LSTM layer the concatenation of features obtained from the encoder
and embedded captions obtained from the embedding layer. Hidden and cell states are randomly initialized.
Final classifier is a linear layer with output dimension of the size of a vocabulary.
"""
super(Decoder, self).__init__()
self.config = Config()
self.image_emd_dim = image_emb_dim
self.word_emb_dim = word_emb_dim
self.hidden_dim = hidden_dim
self.num_layer = num_layers
self.vocab_size = vocab_size
self.device = device
self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))
self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))
self.lstm = nn.LSTM(input_size=self.image_emd_dim + self.word_emb_dim,
hidden_size=self.hidden_dim,
num_layers=self.num_layer,
bidirectional=False)
self.fc = nn.Sequential(
nn.Linear(in_features=self.hidden_dim, out_featuers=self.vocab_size),
nn.LogSoftmax(dim=2)
)
def forward(self,
embedded_captions: torch.Tensor,
features: torch.Tensor,
hidden: torch.Tensor,
cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward operation of (word-by-word) decoder.
The LSTM input (concatenation of embedded_captions and features) is passed through LSTM and then linear layer.
Args:
> embedded_captions(torch.Tensor): (SEQ_LENGTH = 1, BATCH, WORD_EMB_DIM)
> features (torch.Tensor): (1, BATCH, IMAGE_EMB_DIM)
> hidden (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
> cell (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
Returns:
> output (torch.Tensor): (1, BATCH, VOCAB_SIZE)
> (hidden, cell) (torch.Tensor, torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
"""
lstm_input = torch.cat((embedded_captions, features), dim=2)
output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
# output : (SEQ_LENGTH, BATCH, HIDDEN_DIM)
# hidden : (NUM_LAYER, BATCH, HIDDEN_DIM)
output = output.to(self.device)
output = self.fc(output)
# output : (SEQ_LENGTH, BATCH, VOCAB_SIZE)
return output, (hidden, cell)