LPhilp1943's picture
Update app.py
e0a55da verified
raw
history blame
No virus
2.73 kB
import gradio as gr
import os
import torch
import soundfile as sf
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, VitsModel, AutoTokenizer
import librosa
import string
os.makedirs("output_audio", exist_ok=True)
asr_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
def resample_audio(input_audio_path, target_sr):
waveform, sr = sf.read(input_audio_path)
if sr != target_sr:
waveform = librosa.resample(waveform, orig_sr=sr, target_sr=target_sr)
return waveform
def speech_to_text(input_audio_or_text):
if isinstance(input_audio_or_text, str):
waveform = resample_audio(input_audio_or_text, 16000)
input_values = asr_processor(waveform, sampling_rate=16000, return_tensors="pt").input_values
with torch.no_grad():
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.batch_decode(predicted_ids)[0]
else:
transcription = input_audio_or_text
return transcription.strip()
def text_to_speech(text):
text = text.lower().translate(str.maketrans('', '', string.punctuation))
inputs = tts_tokenizer(text, return_tensors="pt")
inputs.input_ids = inputs.input_ids.long() # Fix for the runtime error
with torch.no_grad():
output = tts_model(**inputs).waveform
waveform = output.numpy().squeeze()
output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech.wav")
sf.write(output_path, waveform, 22050)
resampled_waveform = librosa.resample(waveform, orig_sr=22050, target_sr=16000)
resampled_output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech_16khz.wav")
sf.write(resampled_output_path, resampled_waveform, 16000)
return resampled_output_path
def speech_to_speech(input_audio, text_input=None):
transcription = speech_to_text(input_audio) if text_input is None else text_input
synthesized_speech_path = text_to_speech(transcription)
return synthesized_speech_path
iface = gr.Interface(
fn=speech_to_speech,
inputs=[gr.Audio(type="filepath", label="Input Audio"),
gr.Textbox(label="Text Input", placeholder="Enter text to synthesize speech (optional)")],
outputs=gr.Audio(label="Synthesized Speech"),
title="Speech-to-Speech Application",
description="This app converts speech to text and then back to speech, ensuring the output audio is resampled to 16kHz."
)
iface.launch()