ImageCaption / source /predict_sample.py
nssharmaofficial's picture
First init
cb7427c
raw
history blame
3.96 kB
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from vocab import Vocab
from model import Decoder, Encoder
from config import Config
def generate_caption(image: torch.Tensor,
image_encoder: Encoder,
emb_layer: torch.nn.Embedding,
image_decoder: Decoder,
vocab: Vocab,
device: torch.device) -> list[str]:
""" Generate caption of a single image of size (1, 3, 224, 224)
Returns:
list[str]: caption for given image
"""
image = image.to(device)
# image: (3, 224, 224)
image = image.unsqueeze(0)
# image: (1, 3, 224, 224)
features = image_encoder.forward(image)
# features: (1, IMAGE_EMB_DIM)
features = features.to(device)
features = features.unsqueeze(0)
# features: (1, 1, IMAGE_EMB_DIM)
hidden = image_decoder.hidden_state_0
cell = image_decoder.cell_state_0
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
sentence = []
# start with '<sos>' as first word
previous_word = vocab.index2word[vocab.SOS]
MAX_LENGTH = 20
for i in range(MAX_LENGTH):
input_word_id = vocab.word_to_index(previous_word)
input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0)
# input_word_tensor : (1, 1)
input_word_tensor = input_word_tensor.to(device)
lstm_input = emb_layer.forward(input_word_tensor)
# lstm_input : (1, 1, WORD_EMB_DIM)
next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
# next_word_pred : (1, 1, VOCAB_SIZE)
next_word_pred = next_word_pred[0, 0, :]
# next_word_pred : (VOCAB_SIZE)
next_word_pred = torch.argmax(next_word_pred)
next_word_pred = vocab.index_to_word(int(next_word_pred.item()))
# stop if we predict '<eos>'
if next_word_pred == vocab.index2word[vocab.EOS]:
break
sentence.append(next_word_pred)
previous_word = next_word_pred
return sentence
def main_caption(image):
config = Config()
vocab = Vocab()
vocab.load_vocab(config.VOCAB_FILE)
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
device=config.DEVICE)
emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
embedding_dim=config.WORD_EMB_DIM,
padding_idx=vocab.PADDING_INDEX)
image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM,
word_emb_dim=config.WORD_EMB_DIM,
hidden_dim=config.HIDDEN_DIM,
num_layers=config.NUM_LAYER,
vocab_size=config.VOCAB_SIZE,
device=config.DEVICE)
emb_layer.eval()
image_encoder.eval()
image_decoder.eval()
emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))
emb_layer = emb_layer.to(config.DEVICE)
image_encoder = image_encoder.to(config.DEVICE)
image_decoder = image_decoder.to(config.DEVICE)
image = image.to(config.DEVICE)
sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
description = ' '.join(word for word in sentence)
return description