nssharmaofficial commited on
Commit
0716538
1 Parent(s): f7300ff

Fix forward method

Browse files
Files changed (1) hide show
  1. source/predict_sample.py +1 -1
source/predict_sample.py CHANGED
@@ -61,7 +61,7 @@ def generate_caption(image: torch.Tensor,
61
  features = features.repeat(SEQ_LENGTH, 1, 1)
62
  # features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
63
 
64
- next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
65
  # next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
66
 
67
  next_id_pred = next_id_pred[-1, 0, :]
 
61
  features = features.repeat(SEQ_LENGTH, 1, 1)
62
  # features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
63
 
64
+ next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, hidden, cell)
65
  # next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
66
 
67
  next_id_pred = next_id_pred[-1, 0, :]