File size: 1,457 Bytes
aba08c5
 
 
 
676bbaa
 
aba08c5
 
 
676bbaa
aba08c5
 
 
 
676bbaa
aba08c5
 
 
676bbaa
aba08c5
 
 
 
 
 
 
676bbaa
 
 
 
d8b535d
 
 
676bbaa
aba08c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf
import torch
import gradio as gr


# load model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h")

# define function to read in sound file
def map_to_array(file):
    speech, _ = sf.read(file)
    return speech

# tokenize
def inference(audio):
   input_values = processor(map_to_array(audio.name), return_tensors="pt", padding="longest").input_values  # Batch size 1

   # retrieve logits
   logits = model(input_values).logits

   # take argmax and decode
   predicted_ids = torch.argmax(logits, dim=-1)
   transcription = processor.batch_decode(predicted_ids)
   return transcription[0]

inputs = gr.inputs.Audio(label="Input Audio", type="file")
outputs =  gr.outputs.Textbox(label="Output Text")

title = "Robust wav2vec 2.0"
description = "Gradio demo for Robust wav2vec 2.0. To use it, simply upload your audio, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.01027'>Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training</a> | <a href='https://github.com/pytorch/fairseq'>Github Repo</a></p>"


gr.Interface(inference, inputs, outputs, title=title, description=description, article=article).launch()