GiusFra commited on
Commit
7f81513
1 Parent(s): 72eb84b

Fix model loading

Browse files
Files changed (1) hide show
  1. quant_sdxl/quant_sdxl.py +4 -1
quant_sdxl/quant_sdxl.py CHANGED
@@ -102,7 +102,10 @@ def main(args):
102
 
103
  # Load model from float checkpoint
104
  print(f"Loading model from {args.model}...")
105
- pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype)
 
 
 
106
  print(f"Model loaded from {args.model}.")
107
 
108
  # Move model to target device
 
102
 
103
  # Load model from float checkpoint
104
  print(f"Loading model from {args.model}...")
105
+ variant = 'fp16' if dtype == torch.float16 else None
106
+ pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
107
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
108
+ pipe.vae.config.force_upcast=True
109
  print(f"Model loaded from {args.model}.")
110
 
111
  # Move model to target device