Plachta commited on
Commit
e74aea7
·
verified ·
1 Parent(s): bc452bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -20
app.py CHANGED
@@ -5,9 +5,9 @@ import librosa
5
  from modules.commons import build_model, load_checkpoint, recursive_munch
6
  import yaml
7
  from hf_utils import load_custom_model_from_hf
8
- import spaces
9
  import numpy as np
10
  from pydub import AudioSegment
 
11
 
12
  # Load model and configuration
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -33,8 +33,9 @@ model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
33
  # Load additional modules
34
  from modules.campplus.DTDNN import CAMPPlus
35
 
 
36
  campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
37
- campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path'], map_location='cpu'))
38
  campplus_model.eval()
39
  campplus_model.to(device)
40
 
@@ -50,6 +51,14 @@ hift_gen.load_state_dict(torch.load(hift_checkpoint_path, map_location='cpu'))
50
  hift_gen.eval()
51
  hift_gen.to(device)
52
 
 
 
 
 
 
 
 
 
53
  speech_tokenizer_type = config['model_params']['speech_tokenizer'].get('type', 'cosyvoice')
54
  if speech_tokenizer_type == 'cosyvoice':
55
  from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
@@ -69,6 +78,7 @@ elif speech_tokenizer_type == 'facodec':
69
  codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
70
  _ = [codec_encoder[key].eval() for key in codec_encoder]
71
  _ = [codec_encoder[key].to(device) for key in codec_encoder]
 
72
  # Generate mel spectrograms
73
  mel_fn_args = {
74
  "n_fft": config['preprocess_params']['spect_params']['n_fft'],
@@ -80,13 +90,24 @@ mel_fn_args = {
80
  "fmax": 8000,
81
  "center": False
82
  }
 
 
 
 
 
 
 
 
 
 
83
  from modules.audio import mel_spectrogram
84
 
85
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
 
86
 
87
  # f0 conditioned model
88
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
89
- "DiT_step_440000_seed_v2_uvit_facodec_small_wavenet_f0_pruned.pth",
90
  "config_dit_mel_seed_facodec_small_wavenet_f0.yml")
91
 
92
  config = yaml.safe_load(open(dit_config_path, 'r'))
@@ -114,8 +135,8 @@ def adjust_f0_semitones(f0_sequence, n_semitones):
114
  return f0_sequence * factor
115
 
116
  def crossfade(chunk1, chunk2, overlap):
117
- fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
118
- fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
119
  chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
120
  return chunk2
121
 
@@ -123,14 +144,14 @@ def crossfade(chunk1, chunk2, overlap):
123
  max_context_window = sr // hop_length * 30
124
  overlap_frame_len = 64
125
  overlap_wave_len = overlap_frame_len * hop_length
126
- max_wave_len_per_chunk = 24000 * 20
127
  bitrate = "320k"
128
 
129
- @spaces.GPU
130
  @torch.no_grad()
131
  @torch.inference_mode()
 
132
  def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, n_quantizers, f0_condition, auto_f0_adjust, pitch_shift):
133
  inference_module = model if not f0_condition else model_f0
 
134
  # Load audio
135
  source_audio = librosa.load(source, sr=sr)[0]
136
  ref_audio = librosa.load(target, sr=sr)[0]
@@ -150,6 +171,7 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
150
  elif speech_tokenizer_type == 'facodec':
151
  converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
152
  waves_input = converted_waves_24k.unsqueeze(1)
 
153
  wave_input_chunks = [
154
  waves_input[..., i:i + max_wave_len_per_chunk] for i in range(0, waves_input.size(-1), max_wave_len_per_chunk)
155
  ]
@@ -180,8 +202,8 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
180
  )
181
  S_ori = torch.cat([codes[1], codes[0]], dim=1)
182
 
183
- mel = to_mel(source_audio.to(device).float())
184
- mel2 = to_mel(ref_audio.to(device).float())
185
 
186
  target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
187
  target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
@@ -194,8 +216,8 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
194
  style2 = campplus_model(feat2.unsqueeze(0))
195
 
196
  if f0_condition:
197
- waves_16k = torchaudio.functional.resample(waves_24k, sr, 16000)
198
- converted_waves_16k = torchaudio.functional.resample(converted_waves_24k, sr, 16000)
199
  F0_ori = rmvpe.infer_from_audio(waves_16k[0], thred=0.03)
200
  F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.03)
201
 
@@ -244,7 +266,10 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
244
  mel2, style2, None, diffusion_steps,
245
  inference_cfg_rate=inference_cfg_rate)
246
  vc_target = vc_target[:, :, mel2.size(-1):]
247
- vc_wave = hift_gen.inference(vc_target, f0=None)
 
 
 
248
  if processed_frames == 0:
249
  if is_last_chunk:
250
  output_wave = vc_wave[0].cpu().numpy()
@@ -254,7 +279,7 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
254
  output_wave.tobytes(), frame_rate=sr,
255
  sample_width=output_wave.dtype.itemsize, channels=1
256
  ).export(format="mp3", bitrate=bitrate).read()
257
- yield mp3_bytes
258
  break
259
  output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
260
  generated_wave_chunks.append(output_wave)
@@ -265,7 +290,7 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
265
  output_wave.tobytes(), frame_rate=sr,
266
  sample_width=output_wave.dtype.itemsize, channels=1
267
  ).export(format="mp3", bitrate=bitrate).read()
268
- yield mp3_bytes
269
  elif is_last_chunk:
270
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
271
  generated_wave_chunks.append(output_wave)
@@ -275,7 +300,7 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
275
  output_wave.tobytes(), frame_rate=sr,
276
  sample_width=output_wave.dtype.itemsize, channels=1
277
  ).export(format="mp3", bitrate=bitrate).read()
278
- yield mp3_bytes
279
  break
280
  else:
281
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
@@ -287,7 +312,7 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
287
  output_wave.tobytes(), frame_rate=sr,
288
  sample_width=output_wave.dtype.itemsize, channels=1
289
  ).export(format="mp3", bitrate=bitrate).read()
290
- yield mp3_bytes
291
 
292
 
293
  if __name__ == "__main__":
@@ -308,10 +333,15 @@ if __name__ == "__main__":
308
  ]
309
 
310
  examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, 1, False, True, 0],
311
- ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
312
- "examples/reference/teio_0.wav", 100, 1.0, 0.7, 3, True, True, 0],]
 
 
 
 
313
 
314
- outputs = gr.Audio(label="Output Audio", streaming=True, format='mp3')
 
315
 
316
  gr.Interface(fn=voice_conversion,
317
  description=description,
@@ -320,4 +350,4 @@ if __name__ == "__main__":
320
  title="Seed Voice Conversion",
321
  examples=examples,
322
  cache_examples=False,
323
- ).launch()
 
5
  from modules.commons import build_model, load_checkpoint, recursive_munch
6
  import yaml
7
  from hf_utils import load_custom_model_from_hf
 
8
  import numpy as np
9
  from pydub import AudioSegment
10
+ import spaces
11
 
12
  # Load model and configuration
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
33
  # Load additional modules
34
  from modules.campplus.DTDNN import CAMPPlus
35
 
36
+ campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
37
  campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
38
+ campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
39
  campplus_model.eval()
40
  campplus_model.to(device)
41
 
 
51
  hift_gen.eval()
52
  hift_gen.to(device)
53
 
54
+ from modules.bigvgan import bigvgan
55
+
56
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
57
+
58
+ # remove weight norm in the model and set to eval mode
59
+ bigvgan_model.remove_weight_norm()
60
+ bigvgan_model = bigvgan_model.eval().to(device)
61
+
62
  speech_tokenizer_type = config['model_params']['speech_tokenizer'].get('type', 'cosyvoice')
63
  if speech_tokenizer_type == 'cosyvoice':
64
  from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
 
78
  codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
79
  _ = [codec_encoder[key].eval() for key in codec_encoder]
80
  _ = [codec_encoder[key].to(device) for key in codec_encoder]
81
+
82
  # Generate mel spectrograms
83
  mel_fn_args = {
84
  "n_fft": config['preprocess_params']['spect_params']['n_fft'],
 
90
  "fmax": 8000,
91
  "center": False
92
  }
93
+ mel_fn_args_f0 = {
94
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
95
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
96
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
97
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
98
+ "sampling_rate": sr,
99
+ "fmin": 0,
100
+ "fmax": None,
101
+ "center": False
102
+ }
103
  from modules.audio import mel_spectrogram
104
 
105
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
106
+ to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
107
 
108
  # f0 conditioned model
109
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
110
+ "DiT_seed_v2_uvit_facodec_small_wavenet_f0_bigvgan_pruned.pth",
111
  "config_dit_mel_seed_facodec_small_wavenet_f0.yml")
112
 
113
  config = yaml.safe_load(open(dit_config_path, 'r'))
 
135
  return f0_sequence * factor
136
 
137
  def crossfade(chunk1, chunk2, overlap):
138
+ fade_out = np.linspace(1, 0, overlap)
139
+ fade_in = np.linspace(0, 1, overlap)
140
  chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
141
  return chunk2
142
 
 
144
  max_context_window = sr // hop_length * 30
145
  overlap_frame_len = 64
146
  overlap_wave_len = overlap_frame_len * hop_length
 
147
  bitrate = "320k"
148
 
 
149
  @torch.no_grad()
150
  @torch.inference_mode()
151
+ @spaces.GPU
152
  def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, n_quantizers, f0_condition, auto_f0_adjust, pitch_shift):
153
  inference_module = model if not f0_condition else model_f0
154
+ mel_fn = to_mel if not f0_condition else to_mel_f0
155
  # Load audio
156
  source_audio = librosa.load(source, sr=sr)[0]
157
  ref_audio = librosa.load(target, sr=sr)[0]
 
171
  elif speech_tokenizer_type == 'facodec':
172
  converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
173
  waves_input = converted_waves_24k.unsqueeze(1)
174
+ max_wave_len_per_chunk = 24000 * 20
175
  wave_input_chunks = [
176
  waves_input[..., i:i + max_wave_len_per_chunk] for i in range(0, waves_input.size(-1), max_wave_len_per_chunk)
177
  ]
 
202
  )
203
  S_ori = torch.cat([codes[1], codes[0]], dim=1)
204
 
205
+ mel = mel_fn(source_audio.to(device).float())
206
+ mel2 = mel_fn(ref_audio.to(device).float())
207
 
208
  target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
209
  target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
 
216
  style2 = campplus_model(feat2.unsqueeze(0))
217
 
218
  if f0_condition:
219
+ waves_16k = torchaudio.functional.resample(waves_24k, 24000, 16000)
220
+ converted_waves_16k = torchaudio.functional.resample(converted_waves_24k, 24000, 16000)
221
  F0_ori = rmvpe.infer_from_audio(waves_16k[0], thred=0.03)
222
  F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.03)
223
 
 
266
  mel2, style2, None, diffusion_steps,
267
  inference_cfg_rate=inference_cfg_rate)
268
  vc_target = vc_target[:, :, mel2.size(-1):]
269
+ if not f0_condition:
270
+ vc_wave = hift_gen.inference(vc_target, f0=None)
271
+ else:
272
+ vc_wave = bigvgan_model(vc_target)[0]
273
  if processed_frames == 0:
274
  if is_last_chunk:
275
  output_wave = vc_wave[0].cpu().numpy()
 
279
  output_wave.tobytes(), frame_rate=sr,
280
  sample_width=output_wave.dtype.itemsize, channels=1
281
  ).export(format="mp3", bitrate=bitrate).read()
282
+ yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
283
  break
284
  output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
285
  generated_wave_chunks.append(output_wave)
 
290
  output_wave.tobytes(), frame_rate=sr,
291
  sample_width=output_wave.dtype.itemsize, channels=1
292
  ).export(format="mp3", bitrate=bitrate).read()
293
+ yield mp3_bytes, None
294
  elif is_last_chunk:
295
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
296
  generated_wave_chunks.append(output_wave)
 
300
  output_wave.tobytes(), frame_rate=sr,
301
  sample_width=output_wave.dtype.itemsize, channels=1
302
  ).export(format="mp3", bitrate=bitrate).read()
303
+ yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
304
  break
305
  else:
306
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
 
312
  output_wave.tobytes(), frame_rate=sr,
313
  sample_width=output_wave.dtype.itemsize, channels=1
314
  ).export(format="mp3", bitrate=bitrate).read()
315
+ yield mp3_bytes, None
316
 
317
 
318
  if __name__ == "__main__":
 
333
  ]
334
 
335
  examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, 1, False, True, 0],
336
+ ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, 1, True, True, 0],
337
+ ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
338
+ "examples/reference/teio_0.wav", 100, 1.0, 0.7, 3, True, False, 0],
339
+ ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
340
+ "examples/reference/trump_0.wav", 50, 1.0, 0.7, 3, True, False, -12],
341
+ ]
342
 
343
+ outputs = [gr.Audio(label="Stream Output Audio", streaming=True, format='mp3'),
344
+ gr.Audio(label="Full Output Audio", streaming=False, format='wav')]
345
 
346
  gr.Interface(fn=voice_conversion,
347
  description=description,
 
350
  title="Seed Voice Conversion",
351
  examples=examples,
352
  cache_examples=False,
353
+ ).launch()