Spaces:
Sleeping
Sleeping
File size: 3,982 Bytes
cb7427c fa5cb64 cb7427c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.vocab import Vocab
from source.model import Decoder, Encoder
from source.config import Config
def generate_caption(image: torch.Tensor,
image_encoder: Encoder,
emb_layer: torch.nn.Embedding,
image_decoder: Decoder,
vocab: Vocab,
device: torch.device) -> list[str]:
""" Generate caption of a single image of size (1, 3, 224, 224)
Returns:
list[str]: caption for given image
"""
image = image.to(device)
# image: (3, 224, 224)
image = image.unsqueeze(0)
# image: (1, 3, 224, 224)
features = image_encoder.forward(image)
# features: (1, IMAGE_EMB_DIM)
features = features.to(device)
features = features.unsqueeze(0)
# features: (1, 1, IMAGE_EMB_DIM)
hidden = image_decoder.hidden_state_0
cell = image_decoder.cell_state_0
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
sentence = []
# start with '<sos>' as first word
previous_word = vocab.index2word[vocab.SOS]
MAX_LENGTH = 20
for i in range(MAX_LENGTH):
input_word_id = vocab.word_to_index(previous_word)
input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0)
# input_word_tensor : (1, 1)
input_word_tensor = input_word_tensor.to(device)
lstm_input = emb_layer.forward(input_word_tensor)
# lstm_input : (1, 1, WORD_EMB_DIM)
next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
# next_word_pred : (1, 1, VOCAB_SIZE)
next_word_pred = next_word_pred[0, 0, :]
# next_word_pred : (VOCAB_SIZE)
next_word_pred = torch.argmax(next_word_pred)
next_word_pred = vocab.index_to_word(int(next_word_pred.item()))
# stop if we predict '<eos>'
if next_word_pred == vocab.index2word[vocab.EOS]:
break
sentence.append(next_word_pred)
previous_word = next_word_pred
return sentence
def main_caption(image):
config = Config()
vocab = Vocab()
vocab.load_vocab(config.VOCAB_FILE)
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
device=config.DEVICE)
emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
embedding_dim=config.WORD_EMB_DIM,
padding_idx=vocab.PADDING_INDEX)
image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM,
word_emb_dim=config.WORD_EMB_DIM,
hidden_dim=config.HIDDEN_DIM,
num_layers=config.NUM_LAYER,
vocab_size=config.VOCAB_SIZE,
device=config.DEVICE)
emb_layer.eval()
image_encoder.eval()
image_decoder.eval()
emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))
emb_layer = emb_layer.to(config.DEVICE)
image_encoder = image_encoder.to(config.DEVICE)
image_decoder = image_decoder.to(config.DEVICE)
image = image.to(config.DEVICE)
sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
description = ' '.join(word for word in sentence)
return description
|