Porjaz commited on
Commit
66f4d36
1 Parent(s): d76033c

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +1 -1
custom_interface.py CHANGED
@@ -133,7 +133,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
133
  batch = waveform.unsqueeze(0)
134
  rel_length = torch.tensor([1.0])
135
  outputs = self.encode_batch(batch, rel_length)
136
- outputs = self.mods.output_mlp(outputs).squeeze(1)
137
  out_prob = self.hparams.softmax(outputs)
138
  score, index = torch.max(out_prob, dim=-1)
139
  text_lab = self.hparams.label_encoder.decode_torch(index)
 
133
  batch = waveform.unsqueeze(0)
134
  rel_length = torch.tensor([1.0])
135
  outputs = self.encode_batch(batch, rel_length)
136
+ outputs = self.mods.label_lin(outputs).squeeze(1)
137
  out_prob = self.hparams.softmax(outputs)
138
  score, index = torch.max(out_prob, dim=-1)
139
  text_lab = self.hparams.label_encoder.decode_torch(index)