GiusFra commited on
Commit
ecec5b7
1 Parent(s): 161df88

Quantization script

Browse files
Files changed (2) hide show
  1. minimal_script.py +274 -0
  2. requirements.txt +10 -0
minimal_script.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
3
+ SPDX-License-Identifier: MIT
4
+ """
5
+
6
+ import argparse
7
+ import copy
8
+ from datetime import datetime
9
+ import json
10
+ import os
11
+ import time
12
+
13
+ from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
14
+ from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat
15
+ from brevitas.quant.scaled_int import Int8ActPerTensorFloat
16
+ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
17
+ from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d, LoRACompatibleQuantLinear
18
+ from diffusers import DiffusionPipeline
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.attention_processor import AttnProcessor
21
+ import pandas as pd
22
+ import torch
23
+ from torch import nn
24
+ from tqdm import tqdm
25
+ import brevitas.nn as qnn
26
+
27
+ from brevitas.graph.base import ModuleToModuleByClass
28
+ from brevitas.graph.calibrate import bias_correction_mode
29
+ from brevitas.graph.calibrate import calibration_mode
30
+ from brevitas.graph.equalize import activation_equalization_mode
31
+ from brevitas.graph.quantize import layerwise_quantize
32
+ from brevitas.inject.enum import StatsOp
33
+ from brevitas.nn.equalized_layer import EqualizedModule
34
+ from brevitas.utils.torch_utils import KwargsForwardHook
35
+
36
+ from brevitas_examples.common.parse_utils import add_bool_arg
37
+ from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
38
+ from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
39
+ import brevitas.config as config
40
+
41
+ TEST_SEED = 123456
42
+ torch.manual_seed(TEST_SEED)
43
+
44
+ class WeightQuant(ShiftedUint8WeightPerChannelFloat):
45
+ narrow_range = False
46
+ scaling_min_val = 1e-4
47
+ quantize_zero_point = True
48
+ scaling_impl_type = 'parameter_from_stats'
49
+ zero_point_impl = ParameterFromStatsFromParameterZeroPoint
50
+
51
+ class InputQuant(Int8ActPerTensorFloat):
52
+ scaling_stats_op = StatsOp.MAX
53
+
54
+ class OutputQuant(Fp8e4m3FNUZActPerTensorFloat):
55
+ scaling_stats_op = StatsOp.MAX
56
+
57
+ NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"]
58
+
59
+ def load_calib_prompts(calib_data_path, sep="\t"):
60
+ df = pd.read_csv(calib_data_path, sep=sep)
61
+ lst = df["caption"].tolist()
62
+ return lst
63
+
64
+ def run_val_inference(
65
+ pipe,
66
+ prompts,
67
+ guidance_scale,
68
+ total_steps,
69
+ test_latents=None):
70
+ with torch.no_grad():
71
+ for prompt in tqdm(prompts):
72
+ # We don't want to generate any image, so we return only the latent encoding pre VAE
73
+ pipe(
74
+ prompt,
75
+ negative_prompt=NEGATIVE_PROMPTS[0],
76
+ latents=test_latents,
77
+ output_type='latent',
78
+ guidance_scale=guidance_scale,
79
+ num_inference_steps=total_steps)
80
+
81
+
82
+ def main(args):
83
+
84
+ dtype = getattr(torch, args.dtype)
85
+
86
+ calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
87
+ latents = torch.load(args.path_to_latents).to(torch.float16)
88
+
89
+ # Create output dir. Move to tmp if None
90
+ ts = datetime.fromtimestamp(time.time())
91
+ str_ts = ts.strftime("%Y%m%d_%H%M%S")
92
+ output_dir = os.path.join(args.output_path, f'{str_ts}')
93
+ os.mkdir(output_dir)
94
+
95
+ # Dump args to json
96
+ with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
97
+ json.dump(vars(args), fp)
98
+
99
+ # Load model from float checkpoint
100
+ print(f"Loading model from {args.model}...")
101
+ pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype)
102
+ print(f"Model loaded from {args.model}.")
103
+
104
+ # Move model to target device
105
+ print(f"Moving model to {args.device}...")
106
+ pipe = pipe.to(args.device)
107
+
108
+ # Enable attention slicing
109
+ if args.attention_slicing:
110
+ pipe.enable_attention_slicing()
111
+
112
+ # Extract list of layers to avoid
113
+ blacklist = []
114
+ for name, _ in pipe.unet.named_modules():
115
+ if 'time_emb' in name:
116
+ blacklist.append(name.split('.')[-1])
117
+ print(f"Blacklisted layers: {blacklist}")
118
+
119
+ # Make sure there all LoRA layers are fused first, otherwise raise an error
120
+ for m in pipe.unet.modules():
121
+ if hasattr(m, 'lora_layer') and m.lora_layer is not None:
122
+ raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
123
+
124
+ pipe.set_progress_bar_config(disable=True)
125
+ with activation_equalization_mode(
126
+ pipe.unet,
127
+ alpha=args.act_eq_alpha,
128
+ layerwise=True,
129
+ blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None,
130
+ add_mul_node=True):
131
+ # Workaround to expose `in_features` attribute from the Hook Wrapper
132
+ for m in pipe.unet.modules():
133
+ if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
134
+ m.in_features = m.module.in_features
135
+ total_steps = args.calibration_steps
136
+ run_val_inference(
137
+ pipe,
138
+ calibration_prompts,
139
+ total_steps=total_steps,
140
+ test_latents=latents,
141
+ guidance_scale=args.guidance_scale)
142
+
143
+ # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper
144
+ for m in pipe.unet.modules():
145
+ if isinstance(m, EqualizedModule) and hasattr(m.layer, 'in_features'):
146
+ m.in_features = m.layer.in_features
147
+
148
+
149
+ quant_layer_kwargs = {
150
+ 'input_quant': InputQuant, 'weight_quant': WeightQuant, 'dtype': dtype, 'device': args.device, 'input_dtype': dtype, 'input_device': args.device}
151
+ quant_linear_kwargs = copy.deepcopy(quant_layer_kwargs)
152
+ if args.quantize_sdp:
153
+ output_quant = OutputQuant
154
+ rewriter = ModuleToModuleByClass(
155
+ Attention,
156
+ QuantAttention,
157
+ softmax_output_quant=output_quant,
158
+ query_dim=lambda module: module.to_q.in_features,
159
+ dim_head=lambda module: int(1 / (module.scale ** 2)),
160
+ processor=AttnProcessor(),
161
+ is_equalized=True)
162
+ config.IGNORE_MISSING_KEYS = True
163
+ pipe.unet = rewriter.apply(pipe.unet)
164
+ config.IGNORE_MISSING_KEYS = False
165
+ pipe.unet = pipe.unet.to(args.device)
166
+ pipe.unet = pipe.unet.to(dtype)
167
+ # quant_kwargs = layer_map[torch.nn.Linear][1]
168
+ what_to_quantize = ['to_q', 'to_k', 'to_v']
169
+ quant_linear_kwargs['output_quant'] = lambda module, name: output_quant if any(ending in name for ending in what_to_quantize) else None
170
+ quant_linear_kwargs['output_dtype'] = dtype
171
+ quant_linear_kwargs['output_device'] = args.device
172
+
173
+ layer_map = {
174
+ nn.Linear: (qnn.QuantLinear, quant_linear_kwargs),
175
+ nn.Conv2d: (qnn.QuantConv2d, quant_layer_kwargs),
176
+ 'diffusers.models.lora.LoRACompatibleLinear':
177
+ (LoRACompatibleQuantLinear, quant_layer_kwargs),
178
+ 'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_layer_kwargs)}
179
+
180
+ pipe.unet = layerwise_quantize(
181
+ model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist)
182
+ print("Model quantization applied.")
183
+
184
+ pipe.set_progress_bar_config(disable=True)
185
+
186
+ print("Applying activation calibration")
187
+ with torch.no_grad(), calibration_mode(pipe.unet):
188
+ run_val_inference(
189
+ pipe,
190
+ calibration_prompts,
191
+ total_steps=args.calibration_steps,
192
+ test_latents=latents,
193
+ guidance_scale=args.guidance_scale)
194
+
195
+ print("Applying bias correction")
196
+ with torch.no_grad(), bias_correction_mode(pipe.unet):
197
+ run_val_inference(
198
+ pipe,
199
+ calibration_prompts,
200
+ total_steps=args.calibration_steps,
201
+ test_latents=latents,
202
+ guidance_scale=args.guidance_scale)
203
+
204
+ if args.checkpoint_name is not None:
205
+ torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name))
206
+
207
+ if args.export_target:
208
+ export_quant_params(pipe, output_dir)
209
+
210
+
211
+ if __name__ == "__main__":
212
+ parser = argparse.ArgumentParser(description='Stable Diffusion quantization')
213
+ parser.add_argument(
214
+ '-m',
215
+ '--model',
216
+ type=str,
217
+ default=None,
218
+ help='Path or name of the model.')
219
+ parser.add_argument(
220
+ '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.')
221
+ parser.add_argument(
222
+ '--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt')
223
+ parser.add_argument(
224
+ '--checkpoint-name',
225
+ type=str,
226
+ default=None,
227
+ help=
228
+ 'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
229
+ )
230
+ parser.add_argument(
231
+ '--path-to-latents',
232
+ type=str,
233
+ default=None,
234
+ help=
235
+ 'Load pre-defined latents. If not provided, they are generated based on an internal seed.')
236
+
237
+ parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
238
+ parser.add_argument(
239
+ '--calibration-steps', type=float, default=8, help='Steps used during calibration')
240
+ add_bool_arg(
241
+ parser,
242
+ 'output-path',
243
+ str_true=True,
244
+ default='.',
245
+ help='Path where to generate output folder.')
246
+ parser.add_argument(
247
+ '--dtype',
248
+ default='float16',
249
+ choices=['float32', 'float16', 'bfloat16'],
250
+ help='Model Dtype, choices are float32, float16, bfloat16. Default: float16')
251
+ add_bool_arg(
252
+ parser,
253
+ 'attention-slicing',
254
+ default=False,
255
+ help='Enable attention slicing. Default: Disabled')
256
+ add_bool_arg(
257
+ parser,
258
+ 'export-target',
259
+ default=True,
260
+ help='Export flow.')
261
+ parser.add_argument(
262
+ '--act-eq-alpha',
263
+ type=float,
264
+ default=0.9,
265
+ help='Alpha for activation equalization. Default: 0.9')
266
+ add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled')
267
+ add_bool_arg(
268
+ parser,
269
+ 'exclude-blacklist-act-eq',
270
+ default=False,
271
+ help='Exclude unquantized layers from activation equalization. Default: Disabled')
272
+ args = parser.parse_args()
273
+ print("Args: " + str(vars(args)))
274
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ diffusers==0.21.2
3
+ open-clip-torch==2.7.0
4
+ opencv-python==4.8.1.78
5
+ pycocotools==2.0.7
6
+ scipy==1.9.1
7
+ torchmetrics[image]==1.2.0
8
+ tqdm
9
+ transformers==4.33.2
10
+ brevitas @ git+https://github.com/Xilinx/brevitas@dev