File size: 3,982 Bytes
cb7427c
 
 
 
 
fa5cb64
 
 
cb7427c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image

from source.vocab import Vocab
from source.model import Decoder, Encoder
from source.config import Config


def generate_caption(image: torch.Tensor,
                     image_encoder: Encoder,
                     emb_layer: torch.nn.Embedding,
                     image_decoder: Decoder,
                     vocab: Vocab,
                     device: torch.device) -> list[str]:
    """ Generate caption of a single image of size (1, 3, 224, 224)

    Returns:
        list[str]: caption for given image
    """

    image = image.to(device)
    # image: (3, 224, 224)
    image = image.unsqueeze(0)
    # image: (1, 3, 224, 224)

    features = image_encoder.forward(image)
    # features: (1, IMAGE_EMB_DIM)
    features = features.to(device)
    features = features.unsqueeze(0)
    # features: (1, 1, IMAGE_EMB_DIM)

    hidden = image_decoder.hidden_state_0
    cell = image_decoder.cell_state_0
    # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)

    sentence = []

    # start with '<sos>' as first word
    previous_word = vocab.index2word[vocab.SOS]

    MAX_LENGTH = 20

    for i in range(MAX_LENGTH):

        input_word_id = vocab.word_to_index(previous_word)
        input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0)
        # input_word_tensor : (1, 1)

        input_word_tensor = input_word_tensor.to(device)
        lstm_input = emb_layer.forward(input_word_tensor)
        # lstm_input : (1, 1, WORD_EMB_DIM)

        next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
        # next_word_pred : (1, 1, VOCAB_SIZE)

        next_word_pred = next_word_pred[0, 0, :]
        # next_word_pred : (VOCAB_SIZE)

        next_word_pred = torch.argmax(next_word_pred)
        next_word_pred = vocab.index_to_word(int(next_word_pred.item()))

        # stop if we predict '<eos>'
        if next_word_pred == vocab.index2word[vocab.EOS]:
            break

        sentence.append(next_word_pred)
        previous_word = next_word_pred

    return sentence


def main_caption(image):

    config = Config()

    vocab = Vocab()
    vocab.load_vocab(config.VOCAB_FILE)

    image = Image.fromarray(image.astype('uint8'), 'RGB')

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    image = transform(image)

    image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
                            device=config.DEVICE)
    emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
                                   embedding_dim=config.WORD_EMB_DIM,
                                   padding_idx=vocab.PADDING_INDEX)
    image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM,
                            word_emb_dim=config.WORD_EMB_DIM,
                            hidden_dim=config.HIDDEN_DIM,
                            num_layers=config.NUM_LAYER,
                            vocab_size=config.VOCAB_SIZE,
                            device=config.DEVICE)

    emb_layer.eval()
    image_encoder.eval()
    image_decoder.eval()

    emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
    image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
    image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))

    emb_layer = emb_layer.to(config.DEVICE)
    image_encoder = image_encoder.to(config.DEVICE)
    image_decoder = image_decoder.to(config.DEVICE)
    image = image.to(config.DEVICE)

    sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
    description = ' '.join(word for word in sentence)

    return description