Spaces:
Sleeping
Sleeping
File size: 4,474 Bytes
cb7427c 47169db cb7427c 4169569 cb7427c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
import torch._utils
import torch.nn as nn
import torchvision.models as models
from typing import Tuple
from source.config import Config
class Encoder(nn.Module):
def __init__(self, image_emb_dim: int, device: torch.device):
""" Image encoder to obtain features from images. Contains pretrained Resnet50 with last layer removed
and a linear layer with the output dimension of (BATCH, image_emb_dim)
"""
super(Encoder, self).__init__()
self.image_emb_dim = image_emb_dim
self.device = device
# pretrained Resnet50 model with freezed parameters
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
for param in resnet.parameters():
param.requires_grad_(False)
# remove last layer
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
# define a final classifier
self.fc = nn.Linear(in_features=resnet.fc.in_features, out_features=self.image_emb_dim)
def forward(self, images: torch.Tensor) -> torch.Tensor:
""" Forward operation of encoder, passing images through resnet and then linear layer.
Args:
> images (torch.Tensor): (BATCH, 3, 224, 224)
Returns:
> features (torch.Tensor): (BATCH, IMAGE_EMB_DIM)
"""
features = self.resnet(images)
# features: (BATCH, 2048, 1, 1)
features = features.reshape(features.size(0), -1).to(self.device)
# features: (BATCH, 2048)
features = self.fc(features).to(self.device)
# features: (BATCH, IMAGE_EMB_DIM)
return features
class Decoder(nn.Module):
def __init__(self,
image_emb_dim: int,
word_emb_dim: int,
hidden_dim: int,
num_layers: int,
vocab_size: int,
device: torch.device):
"""
Decoder taking as input for the LSTM layer the concatenation of features obtained from the encoder
and embedded captions obtained from the embedding layer. Hidden and cell states are randomly initialized.
Final classifier is a linear layer with output dimension of the size of a vocabulary.
"""
super(Decoder, self).__init__()
self.config = Config()
self.image_emd_dim = image_emb_dim
self.word_emb_dim = word_emb_dim
self.hidden_dim = hidden_dim
self.num_layer = num_layers
self.vocab_size = vocab_size
self.device = device
self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))
self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))
self.lstm = nn.LSTM(input_size=self.image_emd_dim + self.word_emb_dim,
hidden_size=self.hidden_dim,
num_layers=self.num_layer,
bidirectional=False)
self.fc = nn.Sequential(
nn.Linear(in_features=self.hidden_dim, out_features=self.vocab_size),
nn.LogSoftmax(dim=2)
)
def forward(self,
embedded_captions: torch.Tensor,
features: torch.Tensor,
hidden: torch.Tensor,
cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward operation of (word-by-word) decoder.
The LSTM input (concatenation of embedded_captions and features) is passed through LSTM and then linear layer.
Args:
> embedded_captions(torch.Tensor): (SEQ_LENGTH = 1, BATCH, WORD_EMB_DIM)
> features (torch.Tensor): (1, BATCH, IMAGE_EMB_DIM)
> hidden (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
> cell (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
Returns:
> output (torch.Tensor): (1, BATCH, VOCAB_SIZE)
> (hidden, cell) (torch.Tensor, torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
"""
lstm_input = torch.cat((embedded_captions, features), dim=2)
output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
# output : (SEQ_LENGTH, BATCH, HIDDEN_DIM)
# hidden : (NUM_LAYER, BATCH, HIDDEN_DIM)
output = output.to(self.device)
output = self.fc(output)
# output : (SEQ_LENGTH, BATCH, VOCAB_SIZE)
return output, (hidden, cell)
|