Porjaz commited on
Commit
b588470
1 Parent(s): 92d4c0e

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +12 -1
custom_interface.py CHANGED
@@ -104,7 +104,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
104
  (label encoder should be provided).
105
  """
106
  outputs = self.encode_batch(wavs, wav_lens)
107
- outputs = self.mods.output_mlp(outputs)
108
  out_prob = self.hparams.softmax(outputs)
109
  score, index = torch.max(out_prob, dim=-1)
110
  text_lab = self.hparams.label_encoder.decode_torch(index)
@@ -137,6 +137,17 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
137
  out_prob = self.hparams.softmax(outputs)
138
  score, index = torch.max(out_prob, dim=-1)
139
  text_lab = self.hparams.label_encoder.decode_torch(index)
 
 
 
 
 
 
 
 
 
 
 
140
  return out_prob, score, index, text_lab
141
 
142
  def forward(self, wavs, wav_lens=None, normalize=False):
 
104
  (label encoder should be provided).
105
  """
106
  outputs = self.encode_batch(wavs, wav_lens)
107
+ outputs = self.mods.label_lin(outputs)
108
  out_prob = self.hparams.softmax(outputs)
109
  score, index = torch.max(out_prob, dim=-1)
110
  text_lab = self.hparams.label_encoder.decode_torch(index)
 
137
  out_prob = self.hparams.softmax(outputs)
138
  score, index = torch.max(out_prob, dim=-1)
139
  text_lab = self.hparams.label_encoder.decode_torch(index)
140
+ if text_lab[0] == "1":
141
+ text_lab = "neutral"
142
+ elif text_lab[0] == "2":
143
+ text_lab = "sadness"
144
+ elif text_lab[0] == "3":
145
+ text_lab = "joy"
146
+ elif text_lab[0] == "4":
147
+ text_lab = "anger"
148
+ elif text_lab[0] == "5":
149
+ text_lab = "affection"
150
+
151
  return out_prob, score, index, text_lab
152
 
153
  def forward(self, wavs, wav_lens=None, normalize=False):