ing0 commited on
Commit
1a3ee96
·
1 Parent(s): 99ed263
Files changed (1) hide show
  1. diffrhythm/infer/infer_utils.py +2 -2
diffrhythm/infer/infer_utils.py CHANGED
@@ -34,10 +34,10 @@ def prepare_model(device):
34
 
35
  # prepare vae
36
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
37
- vae = torch.jit.load(vae_ckpt_path).to(device)
38
  print("********* vae.parameters() ", next(vae.parameters()).dtype)
39
  vae = vae.half()
40
- print("********* vae half .parameters() ", next(vae.parameters()).dtype)
41
  return cfm, tokenizer, muq, vae
42
 
43
 
 
34
 
35
  # prepare vae
36
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
37
+ vae = torch.jit.load(vae_ckpt_path, map_location='cpu').to(device)
38
  print("********* vae.parameters() ", next(vae.parameters()).dtype)
39
  vae = vae.half()
40
+ print("********* vae half parameters() ", next(vae.parameters()).dtype)
41
  return cfm, tokenizer, muq, vae
42
 
43