Spaces:
Sleeping
Sleeping
from torchvision import transforms | |
import torch | |
import torch.utils.data | |
from PIL import Image | |
from source.vocab import Vocab | |
from source.model import Decoder, Encoder | |
from source.config import Config | |
def generate_caption(image: torch.Tensor, | |
image_encoder: Encoder, | |
emb_layer: torch.nn.Embedding, | |
image_decoder: Decoder, | |
vocab: Vocab, | |
device: torch.device) -> list[str]: | |
""" Generate caption of a single image of size (1, 3, 224, 224) | |
Returns: | |
list[str]: caption for given image | |
""" | |
image = image.to(device) | |
# image: (3, 224, 224) | |
image = image.unsqueeze(0) | |
# image: (1, 3, 224, 224) | |
features = image_encoder.forward(image) | |
# features: (1, IMAGE_EMB_DIM) | |
features = features.to(device) | |
features = features.unsqueeze(0) | |
# features: (1, 1, IMAGE_EMB_DIM) | |
hidden = image_decoder.hidden_state_0 | |
cell = image_decoder.cell_state_0 | |
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM) | |
sentence = [] | |
# start with '<sos>' as first word | |
previous_word = vocab.index2word[vocab.SOS] | |
MAX_LENGTH = 20 | |
for i in range(MAX_LENGTH): | |
input_word_id = vocab.word_to_index(previous_word) | |
input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0) | |
# input_word_tensor : (1, 1) | |
input_word_tensor = input_word_tensor.to(device) | |
lstm_input = emb_layer.forward(input_word_tensor) | |
# lstm_input : (1, 1, WORD_EMB_DIM) | |
next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell) | |
# next_word_pred : (1, 1, VOCAB_SIZE) | |
next_word_pred = next_word_pred[0, 0, :] | |
# next_word_pred : (VOCAB_SIZE) | |
next_word_pred = torch.argmax(next_word_pred) | |
next_word_pred = vocab.index_to_word(int(next_word_pred.item())) | |
# stop if we predict '<eos>' | |
if next_word_pred == vocab.index2word[vocab.EOS]: | |
break | |
sentence.append(next_word_pred) | |
previous_word = next_word_pred | |
return sentence | |
def main_caption(image): | |
config = Config() | |
vocab = Vocab() | |
vocab.load_vocab(config.VOCAB_FILE) | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
image = transform(image) | |
image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM, | |
device=config.DEVICE) | |
emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE, | |
embedding_dim=config.WORD_EMB_DIM, | |
padding_idx=vocab.PADDING_INDEX) | |
image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM, | |
word_emb_dim=config.WORD_EMB_DIM, | |
hidden_dim=config.HIDDEN_DIM, | |
num_layers=config.NUM_LAYER, | |
vocab_size=config.VOCAB_SIZE, | |
device=config.DEVICE) | |
emb_layer.eval() | |
image_encoder.eval() | |
image_decoder.eval() | |
emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE)) | |
image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE)) | |
image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE)) | |
emb_layer = emb_layer.to(config.DEVICE) | |
image_encoder = image_encoder.to(config.DEVICE) | |
image_decoder = image_decoder.to(config.DEVICE) | |
image = image.to(config.DEVICE) | |
sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE) | |
description = ' '.join(word for word in sentence) | |
return description | |