from torchvision import transforms import torch import torch.utils.data from PIL import Image from source.vocab import Vocab from source.model import Decoder, Encoder from source.config import Config def generate_caption(image: torch.Tensor, image_encoder: Encoder, emb_layer: torch.nn.Embedding, image_decoder: Decoder, vocab: Vocab, device: torch.device) -> list[str]: """ Generate caption of a single image of size (3, 224, 224). Generating of caption starts with , and each next predicted word ID is appended for the next LSTM input until the sentence reaches MAX_LENGTH or . Returns: list[str]: caption for given image """ image = image.to(device) # image: (3, 224, 224) image = image.unsqueeze(0) # image: (1, 3, 224, 224) hidden = image_decoder.hidden_state_0 cell = image_decoder.cell_state_0 # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM) sentence = [] # initialize LSTM input to SOS token = 1 input_words = [vocab.SOS] MAX_LENGTH = 20 for i in range(MAX_LENGTH): features = image_encoder.forward(image) # features: (1, IMAGE_EMB_DIM) features = features.to(device) features = features.unsqueeze(0) # features: (1, 1, IMAGE_EMB_DIM) input_words_tensor = torch.tensor([input_words]) # input_word_tensor : (B=1, SEQ_LENGTH) input_words_tensor = input_words_tensor.to(device) lstm_input = emb_layer.forward(input_words_tensor) # lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM) lstm_input = lstm_input.permute(1, 0, 2) # lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM) SEQ_LENGTH = lstm_input.shape[0] features = features.repeat(SEQ_LENGTH, 1, 1) # features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM) next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell) # next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE) next_id_pred = next_id_pred[-1, 0, :] # next_id_pred : (VOCAB_SIZE) next_id_pred = torch.argmax(next_id_pred) # append it to input_words which will be again as input for LSTM input_words.append(next_id_pred.item()) # id --> word next_word_pred = vocab.index_to_word(int(next_id_pred.item())) sentence.append(next_word_pred) # stop if we predict '' if next_word_pred == vocab.index2word[vocab.EOS]: break return sentence def main_caption(image): config = Config() vocab = Vocab() vocab.load_vocab(config.VOCAB_FILE) image = Image.fromarray(image.astype('uint8'), 'RGB') transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM, device=config.DEVICE) emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE, embedding_dim=config.WORD_EMB_DIM, padding_idx=vocab.PADDING_INDEX) image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM, word_emb_dim=config.WORD_EMB_DIM, hidden_dim=config.HIDDEN_DIM, num_layers=config.NUM_LAYER, vocab_size=config.VOCAB_SIZE, device=config.DEVICE) emb_layer.eval() image_encoder.eval() image_decoder.eval() emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE)) image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE)) image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE)) emb_layer = emb_layer.to(config.DEVICE) image_encoder = image_encoder.to(config.DEVICE) image_decoder = image_decoder.to(config.DEVICE) image = image.to(config.DEVICE) sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE) description = ' '.join(word for word in sentence) return description