lucio commited on
Commit
ac40f21
1 Parent(s): ccf8b98

fix spanish asr

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -21,9 +21,7 @@ lang_classifier = EncoderClassifier.from_hparams(
21
  )
22
 
23
  def load_hf_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"):
24
- processor = Wav2Vec2Processor.from_pretrained(model_path)
25
- model = AutoModelForCTC.from_pretrained(model_path)
26
- return processor, model
27
 
28
  # download STT model
29
  model_info = {
@@ -52,10 +50,8 @@ def client(audio_data: np.array, sample_rate: int, default_lang: str):
52
 
53
  if text_lab == 'Spanish':
54
  text_lab = 'español'
55
- processor, model = STT_MODELS['español']
56
- inputs = processor(waveform)
57
- logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
58
- result = processor.decode(torch.argmax(logits, dim=-1).cpu().tolist())
59
 
60
  else:
61
  text_lab = default_lang
 
21
  )
22
 
23
  def load_hf_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"):
24
+ return pipeline("automatic-speech-recognition", model=model_path)
 
 
25
 
26
  # download STT model
27
  model_info = {
 
50
 
51
  if text_lab == 'Spanish':
52
  text_lab = 'español'
53
+ asr_pipeline = STT_MODELS['español']
54
+ result = asr_pipeline(waveform, chunk_length_s=5, stride_length_s=1)['text']
 
 
55
 
56
  else:
57
  text_lab = default_lang