Spaces:
Runtime error
Runtime error
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import soundfile as sf | |
import torch | |
import gradio as gr | |
# load model and processor | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h") | |
# define function to read in sound file | |
def map_to_array(file): | |
speech, _ = sf.read(file) | |
return speech | |
# tokenize | |
def inference(audio): | |
input_values = processor(map_to_array(audio.name), return_tensors="pt", padding="longest").input_values # Batch size 1 | |
# retrieve logits | |
logits = model(input_values).logits | |
# take argmax and decode | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids) | |
return transcription[0] | |
inputs = gr.inputs.Audio(label="Input Audio", type="file") | |
outputs = gr.outputs.Textbox(label="Output Text") | |
title = "Robust wav2vec 2.0" | |
description = "Gradio demo for Robust wav2vec 2.0. To use it, simply upload your audio, or click one of the examples to load them. Read more at the links below." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.01027'>Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training</a> | <a href='https://github.com/pytorch/fairseq'>Github Repo</a></p>" | |
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article).launch() |