patrickvonplaten's picture
Update app.py
96bd204
import os
os.system("pip install gradio==2.8.0b2")
import gradio as gr
import librosa
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15")
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15", use_fast=False)
model = SpeechEncoderDecoderModel.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15")
def process_audio_file(file):
data, sr = librosa.load(file)
if sr != 16000:
data = librosa.resample(data, sr, 16000)
print(data.shape)
input_values = feature_extractor(data, return_tensors="pt").input_values
return input_values
def transcribe(file_mic, file_upload, target_language):
target_code = target_language.split("(")[-1].split(")")[0]
forced_bos_token_id = MAPPING[target_code]
warn_output = ""
if (file_mic is not None) and (file_upload is not None):
warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
file = file_mic
elif (file_mic is None) and (file_upload is None):
return "ERROR: You have to either use the microphone or upload an audio file"
elif file_mic is not None:
file = file_mic
else:
file = file_upload
input_values = process_audio_file(file)
sequences = model.generate(input_values, forced_bos_token_id=forced_bos_token_id)
transcription = tokenizer.batch_decode(sequences, skip_special_tokens=True)
return warn_output + transcription[0]
target_language = [
"German (de)",
"Turkish (tr)",
"Persian (fa)",
"Swedish (sv)",
"Mongolian (mn)",
"Chinese (zh)",
"Welsh (cy)",
"Catalan (ca)",
"Slovenian (sl)",
"Estonian (et)",
"Indonesian (id)",
"Arabic (ar)",
"Tamil (ta)",
"Latvian (lv)",
"Japanese (ja)",
]
MAPPING = {
"de": 250003,
"tr": 250023,
"fa": 250029,
"sv": 250042,
"mn": 250037,
"zh": 250025,
"cy": 250007,
"ca": 250005,
"sl": 250052,
"et": 250006,
"id": 250032,
"ar": 250001,
"ta": 250044,
"lv": 250017,
"ja": 250012,
}
iface = gr.Interface(
fn=transcribe,
inputs=[
gr.inputs.Audio(source="microphone", type='filepath', optional=True),
gr.inputs.Audio(source="upload", type='filepath', optional=True),
gr.inputs.Dropdown(target_language),
],
outputs="text",
layout="horizontal",
theme="huggingface",
article = "<p style='text-align: center'><a href='https://huggingface.co./facebook/wav2vec2-xls-r-300m-en-to-15' target='_blank'>Click to learn more about XLS-R-300M-EN-15 </a> | <a href='https://arxiv.org/abs/2111.09296' target='_blank'> With ๐ŸŽ™๏ธ from Facebook XLS-R </a></p>",
title="XLS-R 300M EN-to-15 Speech Translation",
description="A simple interface to translate English Speech to 15 possible languages.",
enable_queue=True,
allow_flagging=False,
)
iface.launch()