ing0 commited on
Commit
24f6793
·
1 Parent(s): f401ec4
diffrhythm/infer/infer.py CHANGED
@@ -8,6 +8,7 @@ from tqdm import tqdm
8
  import random
9
  import numpy as np
10
  import time
 
11
 
12
  from diffrhythm.infer.infer_utils import (
13
  get_reference_latent,
@@ -17,6 +18,7 @@ from diffrhythm.infer.infer_utils import (
17
  get_negative_style_prompt
18
  )
19
 
 
20
  def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
21
  downsampling_ratio = 2048
22
  io_channels = 2
@@ -72,6 +74,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
72
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
73
  return y_final
74
 
 
75
  def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time):
76
  # import pdb; pdb.set_trace()
77
  s_t = time.time()
@@ -100,7 +103,7 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
100
  output = rearrange(output, "b d n -> d (b n)")
101
  output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
102
  output_np = output_tensor.numpy().T.astype(np.float32)
103
- print(f"**** vae time : {time.tiem()-e_t} ****")
104
  print(output_np.mean(), output_np.min(), output_np.max(), output_np.std())
105
  return (44100, output_np)
106
 
 
8
  import random
9
  import numpy as np
10
  import time
11
+ import spaces
12
 
13
  from diffrhythm.infer.infer_utils import (
14
  get_reference_latent,
 
18
  get_negative_style_prompt
19
  )
20
 
21
+ @spaces.GPU
22
  def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
23
  downsampling_ratio = 2048
24
  io_channels = 2
 
74
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
75
  return y_final
76
 
77
+ @spaces.GPU
78
  def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time):
79
  # import pdb; pdb.set_trace()
80
  s_t = time.time()
 
103
  output = rearrange(output, "b d n -> d (b n)")
104
  output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
105
  output_np = output_tensor.numpy().T.astype(np.float32)
106
+ print(f"**** vae time : {time.time()-e_t} ****")
107
  print(output_np.mean(), output_np.min(), output_np.max(), output_np.std())
108
  return (44100, output_np)
109
 
diffrhythm/infer/infer_utils.py CHANGED
@@ -35,7 +35,9 @@ def prepare_model(device):
35
  # prepare vae
36
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
37
  vae = torch.jit.load(vae_ckpt_path).to(device)
38
-
 
 
39
  return cfm, tokenizer, muq, vae
40
 
41
 
 
35
  # prepare vae
36
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
37
  vae = torch.jit.load(vae_ckpt_path).to(device)
38
+ print("********* vae.parameters() ", next(vae.parameters()).dtype)
39
+ vae = vae.half()
40
+ print("********* vae half .parameters() ", next(vae.parameters()).dtype)
41
  return cfm, tokenizer, muq, vae
42
 
43