Porjaz commited on
Commit
d22911d
·
verified ·
1 Parent(s): a353d95

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +51 -52
custom_interface_app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from speechbrain.inference.interfaces import Pretrained
3
  import librosa
4
  import numpy as np
 
5
 
6
 
7
  class ASR(Pretrained):
@@ -85,69 +86,66 @@ class ASR(Pretrained):
85
  seq.append(token)
86
  output = []
87
  return seq
88
-
89
 
90
- def increase_volume(self, waveform, threshold_db=-25):
91
- # Measure loudness using RMS
92
- loudness_vector = librosa.feature.rms(y=waveform)
93
- average_loudness = np.mean(loudness_vector)
94
- average_loudness_db = librosa.amplitude_to_db(average_loudness)
95
 
96
- print(f"Average Loudness: {average_loudness_db} dB")
97
-
98
- # Check if loudness is below threshold and apply gain if needed
99
- if average_loudness_db < threshold_db:
100
- # Calculate gain needed
101
- gain_db = threshold_db - average_loudness_db
102
- gain = librosa.db_to_amplitude(gain_db) # Convert dB to amplitude factor
103
-
104
- # Apply gain to the audio signal
105
- waveform = waveform * gain
106
- loudness_vector = librosa.feature.rms(y=waveform)
107
- average_loudness = np.mean(loudness_vector)
108
- average_loudness_db = librosa.amplitude_to_db(average_loudness)
109
-
110
- print(f"Average Loudness: {average_loudness_db} dB")
111
- return waveform
112
-
113
-
114
- def classify_file_w2v2(self, waveform, device):
115
  # Get audio length in seconds
116
  sr = 16000
 
 
 
 
 
 
 
 
 
117
  audio_length = len(waveform) / sr
 
118
 
119
- if audio_length >= 30:
120
  print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
121
- # Detect non-silent segments
122
-
123
- non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
124
-
 
 
 
 
 
 
 
 
 
125
  segments = []
126
- current_segment = []
127
- current_length = 0
128
- max_duration = 30 * sr # Maximum segment duration in samples (20 seconds)
129
 
 
 
 
130
 
131
- for interval in non_silent_intervals:
132
- start, end = interval
133
- segment_part = waveform[start:end]
134
-
135
- # If adding the next part exceeds max duration, store the segment and start a new one
136
- if current_length + len(segment_part) > max_duration:
137
- segments.append(np.concatenate(current_segment))
138
- current_segment = []
139
- current_length = 0
140
-
141
- current_segment.append(segment_part)
142
- current_length += len(segment_part)
143
 
144
- # Append the last segment if it's not empty
145
- if current_segment:
146
- segments.append(np.concatenate(current_segment))
147
 
148
  # Process each segment
149
  outputs = []
150
  for i, segment in enumerate(segments):
 
 
 
 
151
  print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
152
 
153
  # import soundfile as sf
@@ -164,12 +162,13 @@ class ASR(Pretrained):
164
  # outputs.append(result)
165
  yield result
166
  else:
167
- waveform = torch.tensor(waveform).to(device)
 
 
 
168
  waveform = waveform.to(device)
169
- # Fake a batch:
170
- batch = waveform.unsqueeze(0)
171
  rel_length = torch.tensor([1.0]).to(device)
172
- outputs = " ".join(self.encode_batch_w2v2(device, batch, rel_length)[0])
173
  yield outputs
174
 
175
 
 
2
  from speechbrain.inference.interfaces import Pretrained
3
  import librosa
4
  import numpy as np
5
+ import torchaudio
6
 
7
 
8
  class ASR(Pretrained):
 
86
  seq.append(token)
87
  output = []
88
  return seq
 
89
 
 
 
 
 
 
90
 
91
+ def classify_file_w2v2(self, file, vad_model, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Get audio length in seconds
93
  sr = 16000
94
+ max_segment_length = 30
95
+
96
+ # waveform, sr = librosa.load(file, sr=sr)
97
+ waveform, file_sr = torchaudio.load(file)
98
+ # resample if not 16kHz
99
+ if file_sr != sr:
100
+ waveform = torchaudio.transforms.Resample(file_sr, sr)(waveform)
101
+
102
+ waveform = waveform.squeeze()
103
  audio_length = len(waveform) / sr
104
+ print(f"Audio length: {audio_length:.2f} seconds")
105
 
106
+ if audio_length >= max_segment_length:
107
  print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
108
+
109
+ # save waveform temporarily
110
+ torchaudio.save("temp.wav", waveform.unsqueeze(0), sr)
111
+ # get boundaries based on VAD
112
+ boundaries = vad_model.get_speech_segments("temp.wav",
113
+ large_chunk_size=30,
114
+ small_chunk_size=10,
115
+ apply_energy_VAD=True,
116
+ double_check=True)
117
+ # remove temp file
118
+ os.remove("temp.wav")
119
+
120
+ # Merge the segments to max max_segment_length
121
  segments = []
122
+ current_start = boundaries[0][0].item()
123
+ current_end = boundaries[0][1].item()
 
124
 
125
+ for i in range(1, len(boundaries)):
126
+ next_start = boundaries[i][0].item()
127
+ next_end = boundaries[i][1].item()
128
 
129
+ # Check if the current segment can merge with the next segment
130
+ if (current_end - current_start) + (next_end - next_start) <= max_segment_length:
131
+ # Extend the current segment
132
+ current_end = next_end
133
+ else:
134
+ # Add the current segment to the result and start a new one
135
+ segments.append([current_start, current_end])
136
+ current_start = next_start
137
+ current_end = next_end
 
 
 
138
 
139
+ # Add the last segment
140
+ segments.append([current_start, current_end])
 
141
 
142
  # Process each segment
143
  outputs = []
144
  for i, segment in enumerate(segments):
145
+ start, end = segment
146
+ start = int(start * sr)
147
+ end = int(end * sr)
148
+ segment = waveform[start:end]
149
  print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
150
 
151
  # import soundfile as sf
 
162
  # outputs.append(result)
163
  yield result
164
  else:
165
+ waveform, file_sr = torchaudio.load(file)
166
+ # resample if not 16kHz
167
+ if file_sr != sr:
168
+ waveform = torchaudio.transforms.Resample(file_sr, sr)(waveform)
169
  waveform = waveform.to(device)
 
 
170
  rel_length = torch.tensor([1.0]).to(device)
171
+ outputs = " ".join(self.encode_batch_w2v2(device, waveform, rel_length)[0])
172
  yield outputs
173
 
174