nickfraser commited on
Commit
8324b6e
·
1 Parent(s): 13b9094

Feat (script): Added calibration size argument.

Browse files
Files changed (1) hide show
  1. minimal_script.py +7 -1
minimal_script.py CHANGED
@@ -84,6 +84,8 @@ def main(args):
84
  dtype = getattr(torch, args.dtype)
85
 
86
  calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
 
 
87
  latents = torch.load(args.path_to_latents).to(torch.float16)
88
 
89
  # Create output dir. Move to tmp if None
@@ -221,6 +223,11 @@ if __name__ == "__main__":
221
  '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.')
222
  parser.add_argument(
223
  '--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt')
 
 
 
 
 
224
  parser.add_argument(
225
  '--checkpoint-name',
226
  type=str,
@@ -234,7 +241,6 @@ if __name__ == "__main__":
234
  default=None,
235
  help=
236
  'Load pre-defined latents. If not provided, they are generated based on an internal seed.')
237
-
238
  parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
239
  parser.add_argument(
240
  '--calibration-steps', type=float, default=8, help='Steps used during calibration')
 
84
  dtype = getattr(torch, args.dtype)
85
 
86
  calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
87
+ assert args.calibration_prompts <= len(calibration_prompts) , f"--calibration-prompts must be <= {len(calibration_prompts)}"
88
+ calibration_prompts = calibration_prompts[:args.calibration_prompts]
89
  latents = torch.load(args.path_to_latents).to(torch.float16)
90
 
91
  # Create output dir. Move to tmp if None
 
223
  '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.')
224
  parser.add_argument(
225
  '--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt')
226
+ parser.add_argument(
227
+ '--calibration-prompts',
228
+ type=int,
229
+ default=500,
230
+ help='Number of prompts to use for calibration. Default: %(default)s')
231
  parser.add_argument(
232
  '--checkpoint-name',
233
  type=str,
 
241
  default=None,
242
  help=
243
  'Load pre-defined latents. If not provided, they are generated based on an internal seed.')
 
244
  parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
245
  parser.add_argument(
246
  '--calibration-steps', type=float, default=8, help='Steps used during calibration')