File size: 4,467 Bytes
cb7427c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch._utils
import torch.nn as nn
import torchvision.models as models
from typing import Tuple
from config import Config


class Encoder(nn.Module):
    def __init__(self, image_emb_dim: int, device: torch.device):
        """ Image encoder to obtain features from images. Contains pretrained Resnet50 with last layer removed
            and a linear layer with the output dimension of (BATCH, image_emb_dim)
        """

        super(Encoder, self).__init__()
        self.image_emb_dim = image_emb_dim
        self.device = device

        # pretrained Resnet50 model with freezed parameters
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        for param in resnet.parameters():
            param.requires_grad_(False)

        # remove last layer
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        # define a final classifier
        self.fc = nn.Linear(in_features=resnet.fc.in_features, out_features=self.image_emb_dim)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """ Forward operation of encoder, passing images through resnet and then linear layer.

        Args:
            > images (torch.Tensor): (BATCH, 3, 224, 224)

        Returns:
            > features (torch.Tensor): (BATCH, IMAGE_EMB_DIM)
        """

        features = self.resnet(images)
        # features: (BATCH, 2048, 1, 1)

        features = features.reshape(features.size(0), -1).to(self.device)
        # features: (BATCH, 2048)

        features = self.fc(features).to(self.device)
        # features: (BATCH, IMAGE_EMB_DIM)

        return features


class Decoder(nn.Module):
    def __init__(self,
                 image_emb_dim: int,
                 word_emb_dim: int,
                 hidden_dim: int,
                 num_layers: int,
                 vocab_size: int,
                 device: torch.device):
        """
        Decoder taking as input for the LSTM layer the concatenation of features obtained from the encoder
        and embedded captions obtained from the embedding layer. Hidden and cell states are randomly initialized.
        Final classifier is a linear layer with output dimension of the size of a vocabulary.
        """

        super(Decoder, self).__init__()

        self.config = Config()

        self.image_emd_dim = image_emb_dim
        self.word_emb_dim = word_emb_dim
        self.hidden_dim = hidden_dim
        self.num_layer = num_layers
        self.vocab_size = vocab_size
        self.device = device

        self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))
        self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))

        self.lstm = nn.LSTM(input_size=self.image_emd_dim + self.word_emb_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=self.num_layer,
                            bidirectional=False)

        self.fc = nn.Sequential(
            nn.Linear(in_features=self.hidden_dim, out_featuers=self.vocab_size),
            nn.LogSoftmax(dim=2)
        )

    def forward(self,
                embedded_captions: torch.Tensor,
                features: torch.Tensor,
                hidden: torch.Tensor,
                cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward operation of (word-by-word) decoder.
        The LSTM input (concatenation of embedded_captions and features) is passed through LSTM and then linear layer.

        Args:

            > embedded_captions(torch.Tensor): (SEQ_LENGTH = 1, BATCH, WORD_EMB_DIM)
            > features (torch.Tensor): (1, BATCH, IMAGE_EMB_DIM)
            > hidden (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
            > cell (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)

        Returns:

            > output (torch.Tensor): (1, BATCH, VOCAB_SIZE)
            > (hidden, cell) (torch.Tensor, torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
        """

        lstm_input = torch.cat((embedded_captions, features), dim=2)

        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        # output : (SEQ_LENGTH, BATCH, HIDDEN_DIM)
        # hidden : (NUM_LAYER, BATCH, HIDDEN_DIM)

        output = output.to(self.device)

        output = self.fc(output)
        # output : (SEQ_LENGTH, BATCH, VOCAB_SIZE)

        return output, (hidden, cell)