nssharmaofficial commited on
Commit
58b663c
1 Parent(s): e7324f8

Fix generate caption function

Browse files
Files changed (1) hide show
  1. source/predict_sample.py +35 -24
source/predict_sample.py CHANGED
@@ -14,7 +14,10 @@ def generate_caption(image: torch.Tensor,
14
  image_decoder: Decoder,
15
  vocab: Vocab,
16
  device: torch.device) -> list[str]:
17
- """ Generate caption of a single image of size (1, 3, 224, 224)
 
 
 
18
 
19
  Returns:
20
  list[str]: caption for given image
@@ -25,49 +28,57 @@ def generate_caption(image: torch.Tensor,
25
  image = image.unsqueeze(0)
26
  # image: (1, 3, 224, 224)
27
 
28
- features = image_encoder.forward(image)
29
- # features: (1, IMAGE_EMB_DIM)
30
- features = features.to(device)
31
- features = features.unsqueeze(0)
32
- # features: (1, 1, IMAGE_EMB_DIM)
33
-
34
  hidden = image_decoder.hidden_state_0
35
  cell = image_decoder.cell_state_0
36
  # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
37
 
38
  sentence = []
39
 
40
- # start with '<sos>' as first word
41
- previous_word = vocab.index2word[vocab.SOS]
42
 
43
  MAX_LENGTH = 20
44
 
45
  for i in range(MAX_LENGTH):
46
 
47
- input_word_id = vocab.word_to_index(previous_word)
48
- input_word_tensor = torch.tensor([input_word_id]).unsqueeze(0)
49
- # input_word_tensor : (1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- input_word_tensor = input_word_tensor.to(device)
52
- lstm_input = emb_layer.forward(input_word_tensor)
53
- # lstm_input : (1, 1, WORD_EMB_DIM)
54
 
55
- next_word_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
56
- # next_word_pred : (1, 1, VOCAB_SIZE)
57
 
58
- next_word_pred = next_word_pred[0, 0, :]
59
- # next_word_pred : (VOCAB_SIZE)
 
60
 
61
- next_word_pred = torch.argmax(next_word_pred)
62
- next_word_pred = vocab.index_to_word(int(next_word_pred.item()))
 
 
 
 
63
 
64
  # stop if we predict '<eos>'
65
  if next_word_pred == vocab.index2word[vocab.EOS]:
66
  break
67
 
68
- sentence.append(next_word_pred)
69
- previous_word = next_word_pred
70
-
71
  return sentence
72
 
73
 
 
14
  image_decoder: Decoder,
15
  vocab: Vocab,
16
  device: torch.device) -> list[str]:
17
+ """
18
+ Generate caption of a single image of size (3, 224, 224).
19
+ Generating of caption starts with <sos>, and each next predicted word ID
20
+ is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
21
 
22
  Returns:
23
  list[str]: caption for given image
 
28
  image = image.unsqueeze(0)
29
  # image: (1, 3, 224, 224)
30
 
 
 
 
 
 
 
31
  hidden = image_decoder.hidden_state_0
32
  cell = image_decoder.cell_state_0
33
  # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
34
 
35
  sentence = []
36
 
37
+ # initialize LSTM input to SOS token = 1
38
+ input_words = [vocab.SOS]
39
 
40
  MAX_LENGTH = 20
41
 
42
  for i in range(MAX_LENGTH):
43
 
44
+ features = image_encoder.forward(image)
45
+ # features: (1, IMAGE_EMB_DIM)
46
+ features = features.to(device)
47
+ features = features.unsqueeze(0)
48
+ # features: (1, 1, IMAGE_EMB_DIM)
49
+
50
+ input_words_tensor = torch.tensor([input_words])
51
+ # input_word_tensor : (B=1, SEQ_LENGTH)
52
+ input_words_tensor = input_words_tensor.to(device)
53
+
54
+ lstm_input = emb_layer.forward(input_words_tensor)
55
+ # lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM)
56
+
57
+ lstm_input = lstm_input.permute(1, 0, 2)
58
+ # lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM)
59
+ SEQ_LENGTH = lstm_input.shape[0]
60
 
61
+ features = features.repeat(SEQ_LENGTH, 1, 1)
62
+ # features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
 
63
 
64
+ next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
65
+ # next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
66
 
67
+ next_id_pred = next_id_pred[-1, 0, :]
68
+ # next_id_pred : (VOCAB_SIZE)
69
+ next_id_pred = torch.argmax(next_id_pred)
70
 
71
+ # append it to input_words which will be again as input for LSTM
72
+ input_words.append(next_id_pred.item())
73
+
74
+ # id --> word
75
+ next_word_pred = vocab.index_to_word(int(next_id_pred.item()))
76
+ sentence.append(next_word_pred)
77
 
78
  # stop if we predict '<eos>'
79
  if next_word_pred == vocab.index2word[vocab.EOS]:
80
  break
81
 
 
 
 
82
  return sentence
83
 
84