nickfraser
commited on
Commit
·
8324b6e
1
Parent(s):
13b9094
Feat (script): Added calibration size argument.
Browse files- minimal_script.py +7 -1
minimal_script.py
CHANGED
@@ -84,6 +84,8 @@ def main(args):
|
|
84 |
dtype = getattr(torch, args.dtype)
|
85 |
|
86 |
calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
|
|
|
|
|
87 |
latents = torch.load(args.path_to_latents).to(torch.float16)
|
88 |
|
89 |
# Create output dir. Move to tmp if None
|
@@ -221,6 +223,11 @@ if __name__ == "__main__":
|
|
221 |
'-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.')
|
222 |
parser.add_argument(
|
223 |
'--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt')
|
|
|
|
|
|
|
|
|
|
|
224 |
parser.add_argument(
|
225 |
'--checkpoint-name',
|
226 |
type=str,
|
@@ -234,7 +241,6 @@ if __name__ == "__main__":
|
|
234 |
default=None,
|
235 |
help=
|
236 |
'Load pre-defined latents. If not provided, they are generated based on an internal seed.')
|
237 |
-
|
238 |
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
|
239 |
parser.add_argument(
|
240 |
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
|
|
|
84 |
dtype = getattr(torch, args.dtype)
|
85 |
|
86 |
calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
|
87 |
+
assert args.calibration_prompts <= len(calibration_prompts) , f"--calibration-prompts must be <= {len(calibration_prompts)}"
|
88 |
+
calibration_prompts = calibration_prompts[:args.calibration_prompts]
|
89 |
latents = torch.load(args.path_to_latents).to(torch.float16)
|
90 |
|
91 |
# Create output dir. Move to tmp if None
|
|
|
223 |
'-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.')
|
224 |
parser.add_argument(
|
225 |
'--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt')
|
226 |
+
parser.add_argument(
|
227 |
+
'--calibration-prompts',
|
228 |
+
type=int,
|
229 |
+
default=500,
|
230 |
+
help='Number of prompts to use for calibration. Default: %(default)s')
|
231 |
parser.add_argument(
|
232 |
'--checkpoint-name',
|
233 |
type=str,
|
|
|
241 |
default=None,
|
242 |
help=
|
243 |
'Load pre-defined latents. If not provided, they are generated based on an internal seed.')
|
|
|
244 |
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
|
245 |
parser.add_argument(
|
246 |
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
|