ImageCaption / source /predict_sample.py
nssharmaofficial's picture
Fix imports
fa5cb64 verified
raw
history blame
3.98 kB
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.vocab import Vocab
from source.model import Decoder, Encoder
from source.config import Config
def generate_caption(image: torch.Tensor,
image_encoder: Encoder,
emb_layer: torch.nn.Embedding,
image_decoder: Decoder,
vocab: Vocab,
device: torch.device) -> list[str]:
""" Generate caption of a single image of size (1, 3, 224, 224)
Returns:
list[str]: caption for given image
"""
image = image.to(device)
# image: (3, 224, 224)
image = image.unsqueeze(0)
# image: (1, 3, 224, 224)
features = image_encoder.forward(image)
# features: (1, IMAGE_EMB_DIM)
features = features.to(device)
features = features.unsqueeze(0)
# features: (1, 1, IMAGE_EMB_DIM)
hidden = image_decoder.hidden_state_0
cell = image_decoder.cell_state_0
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
sentence = []
# start with '<sos>' as first word
previous_word = vocab.index2word[vocab.SOS]
MAX_LENGTH = 20
for i in range(MAX_LENGTH):
input_word_id = vocab.word_to_index(previous_word)
input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0)
# input_word_tensor : (1, 1)
input_word_tensor = input_word_tensor.to(device)
lstm_input = emb_layer.forward(input_word_tensor)
# lstm_input : (1, 1, WORD_EMB_DIM)
next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
# next_word_pred : (1, 1, VOCAB_SIZE)
next_word_pred = next_word_pred[0, 0, :]
# next_word_pred : (VOCAB_SIZE)
next_word_pred = torch.argmax(next_word_pred)
next_word_pred = vocab.index_to_word(int(next_word_pred.item()))
# stop if we predict '<eos>'
if next_word_pred == vocab.index2word[vocab.EOS]:
break
sentence.append(next_word_pred)
previous_word = next_word_pred
return sentence
def main_caption(image):
config = Config()
vocab = Vocab()
vocab.load_vocab(config.VOCAB_FILE)
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
device=config.DEVICE)
emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
embedding_dim=config.WORD_EMB_DIM,
padding_idx=vocab.PADDING_INDEX)
image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM,
word_emb_dim=config.WORD_EMB_DIM,
hidden_dim=config.HIDDEN_DIM,
num_layers=config.NUM_LAYER,
vocab_size=config.VOCAB_SIZE,
device=config.DEVICE)
emb_layer.eval()
image_encoder.eval()
image_decoder.eval()
emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))
emb_layer = emb_layer.to(config.DEVICE)
image_encoder = image_encoder.to(config.DEVICE)
image_decoder = image_decoder.to(config.DEVICE)
image = image.to(config.DEVICE)
sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
description = ' '.join(word for word in sentence)
return description