Amir Zait
fixed bugs
d8ec8f4
raw
history blame
2.91 kB
from transformers import AutoProcessor, AutoModelForCTC
from transformers import pipeline
import soundfile as sf
import gradio as gr
import librosa
import torch
import sox
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
api_token = os.getenv("API_TOKEN")
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
def process_audio_file(file):
data, sr = librosa.load(file)
if sr != 16000:
data = librosa.resample(data, sr, 16000)
input_values = processor(data, sampling_rate=16_000, return_tensors="pt").input_values #.to(device)
return input_values
def transcribe(file_mic, file_upload):
warn_output = ""
if (file_mic is not None) and (file_upload is not None):
warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
file = file_mic
elif (file_mic is None) and (file_upload is None):
return "ERROR: You have to either use the microphone or upload an audio file"
elif file_mic is not None:
file = file_mic
else:
file = file_upload
input_values = process_audio_file(file)
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
return warn_output + transcription
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
def parse_transcription(wav_file):
filename = wav_file.name.split('.')[0]
convert(wav_file.name, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
print(speech.shape)
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
translated = he_en_translator(trasncription)
return translated
output = gr.outputs.Textbox(label="TEXT")
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
input_upload = gr.inputs.Audio(source="upload", type="file", optional=True)
gr.Interface(parse_transcription, inputs=[input_mic], outputs=output,
analytics_enabled=False,
show_tips=False,
theme='huggingface',
layout='horizontal',
title="Draw Me A Sheep in Hebrew",
enable_queue=True).launch(inline=False)