ImageCaption / source /predict_sample.py
nssharmaofficial's picture
Update code and weights
9a90e40
raw
history blame
4.36 kB
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.vocab import Vocab
from source.model import Decoder, Encoder
from source.config import Config
def generate_caption(image: torch.Tensor,
image_encoder: Encoder,
emb_layer: torch.nn.Embedding,
image_decoder: Decoder,
vocab: Vocab,
device: torch.device) -> list[str]:
"""
Generate caption of a single image of size (3, 224, 224).
Generating of caption starts with <sos>, and each next predicted word ID
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
Returns:
list[str]: caption for given image
"""
image = image.to(device)
# image: (3, 224, 224)
image = image.unsqueeze(0)
# image: (1, 3, 224, 224)
hidden = image_decoder.hidden_state_0
cell = image_decoder.cell_state_0
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
sentence = []
# initialize LSTM input to SOS token = 1
input_words = [vocab.SOS]
MAX_LENGTH = 20
for i in range(MAX_LENGTH):
features = image_encoder.forward(image)
# features: (1, IMAGE_EMB_DIM)
features = features.to(device)
features = features.unsqueeze(0)
# features: (1, 1, IMAGE_EMB_DIM)
input_words_tensor = torch.tensor([input_words])
# input_word_tensor : (B=1, SEQ_LENGTH)
input_words_tensor = input_words_tensor.to(device)
lstm_input = emb_layer.forward(input_words_tensor)
# lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM)
lstm_input = lstm_input.permute(1, 0, 2)
# lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM)
SEQ_LENGTH = lstm_input.shape[0]
features = features.repeat(SEQ_LENGTH, 1, 1)
# features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
# next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
next_id_pred = next_id_pred[-1, 0, :]
# next_id_pred : (VOCAB_SIZE)
next_id_pred = torch.argmax(next_id_pred)
# append it to input_words which will be again as input for LSTM
input_words.append(next_id_pred.item())
# id --> word
next_word_pred = vocab.index_to_word(int(next_id_pred.item()))
if next_word_pred == vocab.index2word[vocab.EOS]:
break
sentence.append(next_word_pred)
return sentence
def main_caption(image):
config = Config()
vocab = Vocab()
vocab.load_vocab(config.VOCAB_FILE)
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
device=config.DEVICE)
emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
embedding_dim=config.WORD_EMB_DIM,
padding_idx=vocab.PADDING_INDEX)
image_decoder = Decoder(word_emb_dim=config.WORD_EMB_DIM,
hidden_dim=config.HIDDEN_DIM,
num_layers=config.NUM_LAYER,
vocab_size=config.VOCAB_SIZE,
device=config.DEVICE)
emb_layer.eval()
image_encoder.eval()
image_decoder.eval()
emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))
emb_layer = emb_layer.to(config.DEVICE)
image_encoder = image_encoder.to(config.DEVICE)
image_decoder = image_decoder.to(config.DEVICE)
image = image.to(config.DEVICE)
sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
description = ' '.join(word for word in sentence)
return description