Update custom_interface.py
Browse files- custom_interface.py +12 -1
custom_interface.py
CHANGED
@@ -104,7 +104,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
104 |
(label encoder should be provided).
|
105 |
"""
|
106 |
outputs = self.encode_batch(wavs, wav_lens)
|
107 |
-
outputs = self.mods.
|
108 |
out_prob = self.hparams.softmax(outputs)
|
109 |
score, index = torch.max(out_prob, dim=-1)
|
110 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
@@ -137,6 +137,17 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
137 |
out_prob = self.hparams.softmax(outputs)
|
138 |
score, index = torch.max(out_prob, dim=-1)
|
139 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
return out_prob, score, index, text_lab
|
141 |
|
142 |
def forward(self, wavs, wav_lens=None, normalize=False):
|
|
|
104 |
(label encoder should be provided).
|
105 |
"""
|
106 |
outputs = self.encode_batch(wavs, wav_lens)
|
107 |
+
outputs = self.mods.label_lin(outputs)
|
108 |
out_prob = self.hparams.softmax(outputs)
|
109 |
score, index = torch.max(out_prob, dim=-1)
|
110 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
|
|
137 |
out_prob = self.hparams.softmax(outputs)
|
138 |
score, index = torch.max(out_prob, dim=-1)
|
139 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
140 |
+
if text_lab[0] == "1":
|
141 |
+
text_lab = "neutral"
|
142 |
+
elif text_lab[0] == "2":
|
143 |
+
text_lab = "sadness"
|
144 |
+
elif text_lab[0] == "3":
|
145 |
+
text_lab = "joy"
|
146 |
+
elif text_lab[0] == "4":
|
147 |
+
text_lab = "anger"
|
148 |
+
elif text_lab[0] == "5":
|
149 |
+
text_lab = "affection"
|
150 |
+
|
151 |
return out_prob, score, index, text_lab
|
152 |
|
153 |
def forward(self, wavs, wav_lens=None, normalize=False):
|