cocktailpeanut commited on
Commit
2e246ff
·
1 Parent(s): dc6a3d5
Files changed (2) hide show
  1. app.py +2 -0
  2. diffrhythm/infer/infer.py +7 -1
app.py CHANGED
@@ -38,6 +38,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048):
38
  style_prompt = get_style_prompt(muq, ref_audio_path)
39
  negative_style_prompt = get_negative_style_prompt(device)
40
  latent_prompt = get_reference_latent(device, max_frames)
 
41
  generated_song = inference(cfm_model=cfm,
42
  vae_model=vae,
43
  cond=latent_prompt,
@@ -52,6 +53,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048):
52
  )
53
  torch.cuda.empty_cache()
54
  gc.collect()
 
55
 
56
  return generated_song
57
 
 
38
  style_prompt = get_style_prompt(muq, ref_audio_path)
39
  negative_style_prompt = get_negative_style_prompt(device)
40
  latent_prompt = get_reference_latent(device, max_frames)
41
+ print(">0")
42
  generated_song = inference(cfm_model=cfm,
43
  vae_model=vae,
44
  cond=latent_prompt,
 
53
  )
54
  torch.cuda.empty_cache()
55
  gc.collect()
56
+ print(">4")
57
 
58
  return generated_song
59
 
diffrhythm/infer/infer.py CHANGED
@@ -78,6 +78,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
78
  def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type):
79
 
80
  with torch.inference_mode():
 
81
  generated, _ = cfm_model.sample(
82
  cond=cond,
83
  text=text,
@@ -93,14 +94,19 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
93
  gc.collect()
94
 
95
 
 
96
  generated = generated.to(torch.float32)
 
97
  latent = generated.transpose(1, 2) # [b d t]
98
- output = decode_audio(latent, vae_model, chunked=False)
 
 
99
 
100
  del latent, generated
101
  torch.cuda.empty_cache()
102
  gc.collect()
103
 
 
104
 
105
  # Rearrange audio batch to a single sequence
106
  output = rearrange(output, "b d n -> d (b n)")
 
78
  def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type):
79
 
80
  with torch.inference_mode():
81
+ print(">1")
82
  generated, _ = cfm_model.sample(
83
  cond=cond,
84
  text=text,
 
94
  gc.collect()
95
 
96
 
97
+ print(">2")
98
  generated = generated.to(torch.float32)
99
+ print(">3")
100
  latent = generated.transpose(1, 2) # [b d t]
101
+ print(">4")
102
+ output = decode_audio(latent, vae_model, chunked=True)
103
+ print(">5")
104
 
105
  del latent, generated
106
  torch.cuda.empty_cache()
107
  gc.collect()
108
 
109
+ print(">6")
110
 
111
  # Rearrange audio batch to a single sequence
112
  output = rearrange(output, "b d n -> d (b n)")