LPhilp1943 commited on
Commit
7fa5660
1 Parent(s): 592ca27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -31
app.py CHANGED
@@ -2,16 +2,29 @@ import gradio as gr
2
  import os
3
  import torch
4
  import soundfile as sf
5
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, VitsModel, AutoTokenizer
 
 
6
  import librosa
7
  import string
8
 
9
  os.makedirs("output_audio", exist_ok=True)
10
 
 
11
  asr_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
12
  asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
13
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
14
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
 
 
 
 
 
 
 
 
 
15
 
16
  def resample_audio(input_audio_path, target_sr):
17
  waveform, sr = sf.read(input_audio_path)
@@ -19,48 +32,42 @@ def resample_audio(input_audio_path, target_sr):
19
  waveform = librosa.resample(waveform, orig_sr=sr, target_sr=target_sr)
20
  return waveform
21
 
22
- def speech_to_text(input_audio_or_text):
23
- if isinstance(input_audio_or_text, str):
24
- waveform = resample_audio(input_audio_or_text, 16000)
25
- input_values = asr_processor(waveform, sampling_rate=16000, return_tensors="pt").input_values
26
- with torch.no_grad():
27
- logits = asr_model(input_values).logits
28
- predicted_ids = torch.argmax(logits, dim=-1)
29
- transcription = asr_processor.batch_decode(predicted_ids)[0]
30
- else:
31
- transcription = input_audio_or_text
32
  return transcription.strip()
33
 
34
- def text_to_speech(text):
35
- # Ensure the text input is not empty to avoid padding errors in the transformer model
36
  if not text.strip():
37
  return "The text input is empty, please provide a valid string."
38
 
39
- text = text.lower().translate(str.maketrans('', '', string.punctuation))
40
- inputs = tts_tokenizer(text, return_tensors="pt")
41
- inputs['input_ids'] = inputs['input_ids'].long() # Ensure input_ids are of type Long
42
- with torch.no_grad():
43
- output = tts_model(**inputs).waveform
44
- waveform = output.numpy().squeeze()
45
- output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech.wav")
 
46
  sf.write(output_path, waveform, 22050)
47
- resampled_waveform = librosa.resample(waveform, orig_sr=22050, target_sr=16000)
48
- resampled_output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech_16khz.wav")
49
- sf.write(resampled_output_path, resampled_waveform, 16000)
50
- return resampled_output_path
51
 
52
  def speech_to_speech(input_audio, text_input=None):
53
- transcription = speech_to_text(input_audio) if text_input is None else text_input
54
- synthesized_speech_path = text_to_speech(transcription)
 
55
  return synthesized_speech_path
56
 
57
  iface = gr.Interface(
58
  fn=speech_to_speech,
59
- inputs=[gr.Audio(type="filepath", label="Input Audio"),
60
- gr.Textbox(label="Text Input", placeholder="Enter text to synthesize speech (optional)")],
61
  outputs=gr.Audio(label="Synthesized Speech"),
62
  title="Speech-to-Speech Application",
63
- description="This app converts speech to text and then back to speech, ensuring the output audio is resampled to 16kHz."
64
  )
65
 
66
  iface.launch(share=True)
 
2
  import os
3
  import torch
4
  import soundfile as sf
5
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
6
+ from TTS.tts.configs.xtts_config import XttsConfig
7
+ from TTS.tts.models.xtts import Xtts
8
  import librosa
9
  import string
10
 
11
  os.makedirs("output_audio", exist_ok=True)
12
 
13
+ # Initialize ASR model
14
  asr_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
15
  asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
16
+ asr_model.eval()
17
+
18
+ # Initialize TTS model
19
+ tts_config_path = "/path/to/xtts/config.json"
20
+ tts_checkpoint_dir = "/path/to/xtts/"
21
+ speaker_wav_path = "/path/to/target/speaker.wav" # Update with actual speaker wav path for cloning voice
22
+
23
+ tts_config = XttsConfig()
24
+ tts_config.load_json(tts_config_path)
25
+ tts_model = Xtts.init_from_config(tts_config)
26
+ tts_model.load_checkpoint(tts_config, checkpoint_dir=tts_checkpoint_dir, eval=True)
27
+ tts_model.cuda()
28
 
29
  def resample_audio(input_audio_path, target_sr):
30
  waveform, sr = sf.read(input_audio_path)
 
32
  waveform = librosa.resample(waveform, orig_sr=sr, target_sr=target_sr)
33
  return waveform
34
 
35
+ def speech_to_text(input_audio_path):
36
+ waveform = resample_audio(input_audio_path, 16000)
37
+ input_values = asr_processor(waveform, sampling_rate=16000, return_tensors="pt").input_values
38
+ with torch.no_grad():
39
+ logits = asr_model(input_values).logits
40
+ predicted_ids = torch.argmax(logits, dim=-1)
41
+ transcription = asr_processor.batch_decode(predicted_ids)[0]
 
 
 
42
  return transcription.strip()
43
 
44
+ def text_to_speech(text, output_path="output_audio/output.wav"):
 
45
  if not text.strip():
46
  return "The text input is empty, please provide a valid string."
47
 
48
+ outputs = tts_model.synthesize(
49
+ text,
50
+ tts_config,
51
+ speaker_wav=speaker_wav_path,
52
+ gpt_cond_len=3,
53
+ language="en"
54
+ )
55
+ waveform = outputs['waveform'].squeeze().cpu().numpy()
56
  sf.write(output_path, waveform, 22050)
57
+ return output_path
 
 
 
58
 
59
  def speech_to_speech(input_audio, text_input=None):
60
+ if text_input is None:
61
+ text_input = speech_to_text(input_audio)
62
+ synthesized_speech_path = text_to_speech(text_input)
63
  return synthesized_speech_path
64
 
65
  iface = gr.Interface(
66
  fn=speech_to_speech,
67
+ inputs=[gr.Audio(type="filepath", label="Input Audio"), gr.Textbox(label="Text Input", optional=True)],
 
68
  outputs=gr.Audio(label="Synthesized Speech"),
69
  title="Speech-to-Speech Application",
70
+ description="Converts speech to text and then back to speech, ensuring the output audio is of high quality."
71
  )
72
 
73
  iface.launch(share=True)