ing0 commited on
Commit
ee08889
·
1 Parent(s): ee20368
Files changed (1) hide show
  1. diffrhythm/infer/infer_utils.py +1 -1
diffrhythm/infer/infer_utils.py CHANGED
@@ -35,7 +35,7 @@ def prepare_model(device):
35
  # prepare vae
36
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
37
  print(f"****************** {device} ******************")
38
- vae = torch.jit.load(vae_ckpt_path).to(device)
39
 
40
  return cfm, tokenizer, muq, vae
41
 
 
35
  # prepare vae
36
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
37
  print(f"****************** {device} ******************")
38
+ vae = torch.jit.load(vae_ckpt_path, map_location=device)
39
 
40
  return cfm, tokenizer, muq, vae
41