LPhilp1943 commited on
Commit
e0a55da
1 Parent(s): 6275fb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -28
app.py CHANGED
@@ -21,7 +21,6 @@ def resample_audio(input_audio_path, target_sr):
21
 
22
  def speech_to_text(input_audio_or_text):
23
  if isinstance(input_audio_or_text, str):
24
- # If input is audio file path, convert speech to text
25
  waveform = resample_audio(input_audio_or_text, 16000)
26
  input_values = asr_processor(waveform, sampling_rate=16000, return_tensors="pt").input_values
27
  with torch.no_grad():
@@ -29,41 +28,28 @@ def speech_to_text(input_audio_or_text):
29
  predicted_ids = torch.argmax(logits, dim=-1)
30
  transcription = asr_processor.batch_decode(predicted_ids)[0]
31
  else:
32
- # If input is text, directly return it
33
  transcription = input_audio_or_text
34
  return transcription.strip()
35
 
36
  def text_to_speech(text):
37
- if isinstance(text, str):
38
- # If input is text, synthesize speech
39
- text = text.lower().translate(str.maketrans('', '', string.punctuation))
40
- inputs = tts_tokenizer(text, return_tensors="pt")
41
- with torch.no_grad():
42
- output = tts_model(**inputs).waveform
43
- waveform = output.numpy().squeeze()
44
- output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech.wav")
45
- sf.write(output_path, waveform, 22050) # Use a fixed sample rate for TTS output
46
- # Resample the TTS output to 16000 Hz for consistency with the ASR model's requirements
47
- resampled_waveform = librosa.resample(waveform, orig_sr=22050, target_sr=16000)
48
- resampled_output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech_16khz.wav")
49
- sf.write(resampled_output_path, resampled_waveform, 16000)
50
- return resampled_output_path
51
- else:
52
- # If input is already a path to synthesized speech, return it
53
- return text
54
 
55
  def speech_to_speech(input_audio, text_input=None):
56
- if text_input is None:
57
- # If no text input is provided, convert the input audio to text
58
- transcription = speech_to_text(input_audio)
59
- else:
60
- # If text input is provided, use it directly
61
- transcription = text_input
62
- # Synthesize text to speech and resample to 16kHz
63
  synthesized_speech_path = text_to_speech(transcription)
64
  return synthesized_speech_path
65
 
66
-
67
  iface = gr.Interface(
68
  fn=speech_to_speech,
69
  inputs=[gr.Audio(type="filepath", label="Input Audio"),
@@ -74,4 +60,3 @@ iface = gr.Interface(
74
  )
75
 
76
  iface.launch()
77
-
 
21
 
22
  def speech_to_text(input_audio_or_text):
23
  if isinstance(input_audio_or_text, str):
 
24
  waveform = resample_audio(input_audio_or_text, 16000)
25
  input_values = asr_processor(waveform, sampling_rate=16000, return_tensors="pt").input_values
26
  with torch.no_grad():
 
28
  predicted_ids = torch.argmax(logits, dim=-1)
29
  transcription = asr_processor.batch_decode(predicted_ids)[0]
30
  else:
 
31
  transcription = input_audio_or_text
32
  return transcription.strip()
33
 
34
  def text_to_speech(text):
35
+ text = text.lower().translate(str.maketrans('', '', string.punctuation))
36
+ inputs = tts_tokenizer(text, return_tensors="pt")
37
+ inputs.input_ids = inputs.input_ids.long() # Fix for the runtime error
38
+ with torch.no_grad():
39
+ output = tts_model(**inputs).waveform
40
+ waveform = output.numpy().squeeze()
41
+ output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech.wav")
42
+ sf.write(output_path, waveform, 22050)
43
+ resampled_waveform = librosa.resample(waveform, orig_sr=22050, target_sr=16000)
44
+ resampled_output_path = os.path.join("output_audio", f"{text[:10].replace(' ', '_')}_to_speech_16khz.wav")
45
+ sf.write(resampled_output_path, resampled_waveform, 16000)
46
+ return resampled_output_path
 
 
 
 
 
47
 
48
  def speech_to_speech(input_audio, text_input=None):
49
+ transcription = speech_to_text(input_audio) if text_input is None else text_input
 
 
 
 
 
 
50
  synthesized_speech_path = text_to_speech(transcription)
51
  return synthesized_speech_path
52
 
 
53
  iface = gr.Interface(
54
  fn=speech_to_speech,
55
  inputs=[gr.Audio(type="filepath", label="Input Audio"),
 
60
  )
61
 
62
  iface.launch()