from transformers import AutoProcessor, AutoModelForCTC from transformers import pipeline import soundfile as sf import gradio as gr import librosa import torch import sox import os device = torch.device("cuda" if torch.cuda.is_available() else "cpu") api_token = os.getenv("API_TOKEN") asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en") def process_audio_file(file): data, sr = librosa.load(file) if sr != 16000: data = librosa.resample(data, sr, 16000) input_values = processor(data, sampling_rate=16_000, return_tensors="pt").input_values #.to(device) return input_values def transcribe(file_mic, file_upload): warn_output = "" if (file_mic is not None) and (file_upload is not None): warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" file = file_mic elif (file_mic is None) and (file_upload is None): return "ERROR: You have to either use the microphone or upload an audio file" elif file_mic is not None: file = file_mic else: file = file_upload input_values = process_audio_file(file) logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode(predicted_ids[0], skip_special_tokens=True) return warn_output + transcription def convert(inputfile, outfile): sox_tfm = sox.Transformer() sox_tfm.set_output_format( file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16 ) sox_tfm.build(inputfile, outfile) def parse_transcription(wav_file): filename = wav_file.name.split('.')[0] convert(wav_file.name, filename + "16k.wav") speech, _ = sf.read(filename + "16k.wav") print(speech.shape) input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values logits = asr_model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True) translated = he_en_translator(trasncription) return translated output = gr.outputs.Textbox(label="TEXT") input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True) input_upload = gr.inputs.Audio(source="upload", type="file", optional=True) gr.Interface(parse_transcription, inputs=[input_mic], outputs=output, analytics_enabled=False, show_tips=False, theme='huggingface', layout='horizontal', title="Draw Me A Sheep in Hebrew", enable_queue=True).launch(inline=False)