LPhilp1943 commited on
Commit
5e6eee9
1 Parent(s): 8dc6b2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -38
app.py CHANGED
@@ -4,66 +4,44 @@ import torch
4
  import soundfile as sf
5
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, VitsModel, AutoTokenizer
6
 
7
- # Ensure the output directory exists
8
  os.makedirs("output_audio", exist_ok=True)
9
 
10
- # Load the models and processors
11
  asr_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
12
  asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
13
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
14
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
15
 
16
  def speech_to_text(input_audio):
17
- # Load and preprocess the audio
18
  waveform, sr = sf.read(input_audio)
19
  input_values = asr_processor(waveform, sampling_rate=sr, return_tensors="pt").input_values
20
-
21
- # Perform speech recognition
22
  with torch.no_grad():
23
  logits = asr_model(input_values).logits
24
  predicted_ids = torch.argmax(logits, dim=-1)
25
-
26
- # Decode the predicted IDs to text
27
  transcription = asr_processor.batch_decode(predicted_ids)[0]
28
- return transcription
29
 
30
- def text_to_speech(text):
31
- # Tokenize text and generate waveform
32
  inputs = tts_tokenizer(text, return_tensors="pt")
33
  with torch.no_grad():
34
  output = tts_model(**inputs).waveform
35
- waveform = output.numpy()
36
-
37
- # Define output path and save waveform as audio file
38
- output_path = "output_audio/text_to_speech.wav"
39
- sf.write(output_path, waveform.squeeze(), 22050)
40
-
41
  return output_path
42
 
43
- def speech_to_speech(input_audio, target_text):
44
- # Synthesize speech directly from target text without transcribing the input audio
45
- return text_to_speech(target_text)
46
 
47
  iface = gr.Interface(
48
- fn={
49
- "Speech to Text": speech_to_text,
50
- "Text to Speech": text_to_speech,
51
- "Speech to Speech": speech_to_speech
52
- },
53
  inputs=[
54
- gr.Audio(label="Speech to Text"),
55
- gr.Textbox(label="Text to Speech"),
56
- [gr.Audio(label="Speech to Speech Input"), gr.Textbox(label="Target Text for Speech to Speech")] # Corrected: Use a list for multiple inputs
57
- ],
58
- outputs=[
59
- gr.Textbox(label="Transcription"),
60
- gr.Audio(label="Synthesized Speech"),
61
- gr.Audio(label="Speech to Speech Output")
62
  ],
 
63
  title="Speech Processing Application",
64
- description="This app uses Facebook's Wav2Vec 2.0 for speech-to-text and VITS for text-to-speech.",
65
- layout="vertical"
66
- )
67
-
68
- if __name__ == "__main__":
69
- iface.launch()
 
4
  import soundfile as sf
5
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, VitsModel, AutoTokenizer
6
 
 
7
  os.makedirs("output_audio", exist_ok=True)
8
 
 
9
  asr_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
10
  asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
11
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
12
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
13
 
14
  def speech_to_text(input_audio):
 
15
  waveform, sr = sf.read(input_audio)
16
  input_values = asr_processor(waveform, sampling_rate=sr, return_tensors="pt").input_values
 
 
17
  with torch.no_grad():
18
  logits = asr_model(input_values).logits
19
  predicted_ids = torch.argmax(logits, dim=-1)
 
 
20
  transcription = asr_processor.batch_decode(predicted_ids)[0]
21
+ return transcription.strip()
22
 
23
+ def text_to_speech(text, sample_rate=22050):
24
+ text = text.lower().translate(str.maketrans('', '', string.punctuation))
25
  inputs = tts_tokenizer(text, return_tensors="pt")
26
  with torch.no_grad():
27
  output = tts_model(**inputs).waveform
28
+ waveform = output.numpy().squeeze()
29
+ output_path = f"output_audio/{text[:10].replace(' ', '_')}_to_speech.wav"
30
+ sf.write(output_path, waveform, sample_rate)
 
 
 
31
  return output_path
32
 
33
+ def speech_to_speech(input_audio, target_text, sample_rate=22050):
34
+ transcription = speech_to_text(input_audio)
35
+ return text_to_speech(target_text, sample_rate)
36
 
37
  iface = gr.Interface(
38
+ fn=speech_to_speech,
 
 
 
 
39
  inputs=[
40
+ gr.Audio(source="upload", type="file", label="Input Audio"),
41
+ gr.Textbox(label="Target Text"),
42
+ gr.Slider(minimum=16000, maximum=48000, step=1000, default=22050, label="Sample Rate")
 
 
 
 
 
43
  ],
44
+ outputs=gr.Audio(label="Synthesized Speech"),
45
  title="Speech Processing Application",
46
+ description="This app uses Facebook's Wav2Vec 2.0 for speech-to-text and VITS for text-to-speech."
47
+ ).launch()