from torchvision import transforms import torch import torch.utils.data from PIL import Image from vocab import Vocab from model import Decoder, Encoder from config import Config def generate_caption(image: torch.Tensor, image_encoder: Encoder, emb_layer: torch.nn.Embedding, image_decoder: Decoder, vocab: Vocab, device: torch.device) -> list[str]: """ Generate caption of a single image of size (1, 3, 224, 224) Returns: list[str]: caption for given image """ image = image.to(device) # image: (3, 224, 224) image = image.unsqueeze(0) # image: (1, 3, 224, 224) features = image_encoder.forward(image) # features: (1, IMAGE_EMB_DIM) features = features.to(device) features = features.unsqueeze(0) # features: (1, 1, IMAGE_EMB_DIM) hidden = image_decoder.hidden_state_0 cell = image_decoder.cell_state_0 # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM) sentence = [] # start with '' as first word previous_word = vocab.index2word[vocab.SOS] MAX_LENGTH = 20 for i in range(MAX_LENGTH): input_word_id = vocab.word_to_index(previous_word) input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0) # input_word_tensor : (1, 1) input_word_tensor = input_word_tensor.to(device) lstm_input = emb_layer.forward(input_word_tensor) # lstm_input : (1, 1, WORD_EMB_DIM) next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell) # next_word_pred : (1, 1, VOCAB_SIZE) next_word_pred = next_word_pred[0, 0, :] # next_word_pred : (VOCAB_SIZE) next_word_pred = torch.argmax(next_word_pred) next_word_pred = vocab.index_to_word(int(next_word_pred.item())) # stop if we predict '' if next_word_pred == vocab.index2word[vocab.EOS]: break sentence.append(next_word_pred) previous_word = next_word_pred return sentence def main_caption(image): config = Config() vocab = Vocab() vocab.load_vocab(config.VOCAB_FILE) image = Image.fromarray(image.astype('uint8'), 'RGB') transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM, device=config.DEVICE) emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE, embedding_dim=config.WORD_EMB_DIM, padding_idx=vocab.PADDING_INDEX) image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM, word_emb_dim=config.WORD_EMB_DIM, hidden_dim=config.HIDDEN_DIM, num_layers=config.NUM_LAYER, vocab_size=config.VOCAB_SIZE, device=config.DEVICE) emb_layer.eval() image_encoder.eval() image_decoder.eval() emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE)) image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE)) image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE)) emb_layer = emb_layer.to(config.DEVICE) image_encoder = image_encoder.to(config.DEVICE) image_decoder = image_decoder.to(config.DEVICE) image = image.to(config.DEVICE) sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE) description = ' '.join(word for word in sentence) return description