Spaces:
Sleeping
Sleeping
nssharmaofficial
commited on
Commit
•
58b663c
1
Parent(s):
e7324f8
Fix generate caption function
Browse files- source/predict_sample.py +35 -24
source/predict_sample.py
CHANGED
@@ -14,7 +14,10 @@ def generate_caption(image: torch.Tensor,
|
|
14 |
image_decoder: Decoder,
|
15 |
vocab: Vocab,
|
16 |
device: torch.device) -> list[str]:
|
17 |
-
"""
|
|
|
|
|
|
|
18 |
|
19 |
Returns:
|
20 |
list[str]: caption for given image
|
@@ -25,49 +28,57 @@ def generate_caption(image: torch.Tensor,
|
|
25 |
image = image.unsqueeze(0)
|
26 |
# image: (1, 3, 224, 224)
|
27 |
|
28 |
-
features = image_encoder.forward(image)
|
29 |
-
# features: (1, IMAGE_EMB_DIM)
|
30 |
-
features = features.to(device)
|
31 |
-
features = features.unsqueeze(0)
|
32 |
-
# features: (1, 1, IMAGE_EMB_DIM)
|
33 |
-
|
34 |
hidden = image_decoder.hidden_state_0
|
35 |
cell = image_decoder.cell_state_0
|
36 |
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
|
37 |
|
38 |
sentence = []
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
|
43 |
MAX_LENGTH = 20
|
44 |
|
45 |
for i in range(MAX_LENGTH):
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
# lstm_input : (1, 1, WORD_EMB_DIM)
|
54 |
|
55 |
-
|
56 |
-
#
|
57 |
|
58 |
-
|
59 |
-
#
|
|
|
60 |
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
|
64 |
# stop if we predict '<eos>'
|
65 |
if next_word_pred == vocab.index2word[vocab.EOS]:
|
66 |
break
|
67 |
|
68 |
-
sentence.append(next_word_pred)
|
69 |
-
previous_word = next_word_pred
|
70 |
-
|
71 |
return sentence
|
72 |
|
73 |
|
|
|
14 |
image_decoder: Decoder,
|
15 |
vocab: Vocab,
|
16 |
device: torch.device) -> list[str]:
|
17 |
+
"""
|
18 |
+
Generate caption of a single image of size (3, 224, 224).
|
19 |
+
Generating of caption starts with <sos>, and each next predicted word ID
|
20 |
+
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
|
21 |
|
22 |
Returns:
|
23 |
list[str]: caption for given image
|
|
|
28 |
image = image.unsqueeze(0)
|
29 |
# image: (1, 3, 224, 224)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
hidden = image_decoder.hidden_state_0
|
32 |
cell = image_decoder.cell_state_0
|
33 |
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
|
34 |
|
35 |
sentence = []
|
36 |
|
37 |
+
# initialize LSTM input to SOS token = 1
|
38 |
+
input_words = [vocab.SOS]
|
39 |
|
40 |
MAX_LENGTH = 20
|
41 |
|
42 |
for i in range(MAX_LENGTH):
|
43 |
|
44 |
+
features = image_encoder.forward(image)
|
45 |
+
# features: (1, IMAGE_EMB_DIM)
|
46 |
+
features = features.to(device)
|
47 |
+
features = features.unsqueeze(0)
|
48 |
+
# features: (1, 1, IMAGE_EMB_DIM)
|
49 |
+
|
50 |
+
input_words_tensor = torch.tensor([input_words])
|
51 |
+
# input_word_tensor : (B=1, SEQ_LENGTH)
|
52 |
+
input_words_tensor = input_words_tensor.to(device)
|
53 |
+
|
54 |
+
lstm_input = emb_layer.forward(input_words_tensor)
|
55 |
+
# lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM)
|
56 |
+
|
57 |
+
lstm_input = lstm_input.permute(1, 0, 2)
|
58 |
+
# lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM)
|
59 |
+
SEQ_LENGTH = lstm_input.shape[0]
|
60 |
|
61 |
+
features = features.repeat(SEQ_LENGTH, 1, 1)
|
62 |
+
# features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
|
|
|
63 |
|
64 |
+
next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
|
65 |
+
# next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
|
66 |
|
67 |
+
next_id_pred = next_id_pred[-1, 0, :]
|
68 |
+
# next_id_pred : (VOCAB_SIZE)
|
69 |
+
next_id_pred = torch.argmax(next_id_pred)
|
70 |
|
71 |
+
# append it to input_words which will be again as input for LSTM
|
72 |
+
input_words.append(next_id_pred.item())
|
73 |
+
|
74 |
+
# id --> word
|
75 |
+
next_word_pred = vocab.index_to_word(int(next_id_pred.item()))
|
76 |
+
sentence.append(next_word_pred)
|
77 |
|
78 |
# stop if we predict '<eos>'
|
79 |
if next_word_pred == vocab.index2word[vocab.EOS]:
|
80 |
break
|
81 |
|
|
|
|
|
|
|
82 |
return sentence
|
83 |
|
84 |
|