Update custom_interface.py
Browse files- custom_interface.py +1 -1
custom_interface.py
CHANGED
@@ -133,7 +133,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
133 |
batch = waveform.unsqueeze(0)
|
134 |
rel_length = torch.tensor([1.0])
|
135 |
outputs = self.encode_batch(batch, rel_length)
|
136 |
-
outputs = self.mods.
|
137 |
out_prob = self.hparams.softmax(outputs)
|
138 |
score, index = torch.max(out_prob, dim=-1)
|
139 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
|
|
133 |
batch = waveform.unsqueeze(0)
|
134 |
rel_length = torch.tensor([1.0])
|
135 |
outputs = self.encode_batch(batch, rel_length)
|
136 |
+
outputs = self.mods.label_lin(outputs).squeeze(1)
|
137 |
out_prob = self.hparams.softmax(outputs)
|
138 |
score, index = torch.max(out_prob, dim=-1)
|
139 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|