Spaces:
Running
on
Zero
Running
on
Zero
Update pico_model.py
Browse files- pico_model.py +2 -1
pico_model.py
CHANGED
@@ -226,7 +226,8 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
|
226 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
227 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
228 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
229 |
-
diffusion_ckpt = torch.load(diffusion_pt, map_location=self.device)
|
|
|
230 |
del diffusion_ckpt["class_emb.weight"]
|
231 |
ckpt.update(diffusion_ckpt)
|
232 |
self.load_state_dict(ckpt)
|
|
|
226 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
227 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
228 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
229 |
+
#diffusion_ckpt = torch.load(diffusion_pt, map_location=self.device)
|
230 |
+
diffusion_ckpt = torch.load(diffusion_pt)
|
231 |
del diffusion_ckpt["class_emb.weight"]
|
232 |
ckpt.update(diffusion_ckpt)
|
233 |
self.load_state_dict(ckpt)
|