ZeyuXie commited on
Commit
8366c6c
1 Parent(s): 49bb4d0

Update pico_model.py

Browse files
Files changed (1) hide show
  1. pico_model.py +2 -1
pico_model.py CHANGED
@@ -226,7 +226,8 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
226
  ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
227
  del_parameter_key = ["text_branch.embeddings.position_ids"]
228
  ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
229
- diffusion_ckpt = torch.load(diffusion_pt, map_location=self.device)
 
230
  del diffusion_ckpt["class_emb.weight"]
231
  ckpt.update(diffusion_ckpt)
232
  self.load_state_dict(ckpt)
 
226
  ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
227
  del_parameter_key = ["text_branch.embeddings.position_ids"]
228
  ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
229
+ #diffusion_ckpt = torch.load(diffusion_pt, map_location=self.device)
230
+ diffusion_ckpt = torch.load(diffusion_pt)
231
  del diffusion_ckpt["class_emb.weight"]
232
  ckpt.update(diffusion_ckpt)
233
  self.load_state_dict(ckpt)