File size: 11,805 Bytes
516a027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry responsible for built-in keras classes."""

import tensorflow as tf

from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
    default_8bit_quantize_registry,)
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
    default_8bit_quantizers,)


layers = keras.layers


class _PrunePreserveInfo(object):
  """PrunePreserveInfo."""

  def __init__(self, weight_attrs, quantize_config_attrs):
    """Initializes PrunePreserveInfo.

    Args:
      weight_attrs: list of sparsity preservable weight attributes of layer.
      quantize_config_attrs: list of quantization configuration class name.
    """
    self.weight_attrs = weight_attrs
    self.quantize_config_attrs = quantize_config_attrs


class PrunePreserveQuantizeRegistry():
  """PrunePreserveQuantizeRegistry responsible for built-in keras layers."""

  # The keys represent built-in keras layers; the first values represent the
  # the variables within the layers which hold the kernel weights, second
  # values represent the class name of quantization configuration for layers.
  # This decide the weights of layers with quantization configurations are
  # sparsity preservable.
  _LAYERS_CONFIG_MAP = {
      layers.Conv2D:
      _PrunePreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
      layers.Dense:
      _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),

      # DepthwiseConv2D is supported with 8bit qat, but not with prune,
      # thus for DepthwiseConv2D PQAT, weights sparsity preserve is disabled.
      layers.DepthwiseConv2D:
      _PrunePreserveInfo(['depthwise_kernel'], ['Default8BitQuantizeConfig']),

      # layers that supported with prune, but not yet with QAT
      # layers.Conv1D:
      # _PrunePreserveInfo(['kernel'], []),
      # layers.Conv2DTranspose:
      # _PrunePreserveInfo(['kernel'], []),
      # layers.Conv3D:
      # _PrunePreserveInfo(['kernel'], []),
      # layers.Conv3DTranspose:
      # _PrunePreserveInfo(['kernel'], []),
      # layers.LocallyConnected1D:
      # _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
      # layers.LocallyConnected2D:
      # _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),

      # SeparableConv need verify from 8bit qat
      # layers.SeparableConv1D:
      # _PrunePreserveInfo(['pointwise_kernel'], \
      #   ['Default8BitConvQuantizeConfig']),
      # layers.SeparableConv2D:
      # _PrunePreserveInfo(['pointwise_kernel'], \
      #   ['Default8BitConvQuantizeConfig']),

      # Embedding need verify from 8bit qat
      # layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
  }

  _DISABLE_PRUNE_PRESERVE = frozenset({
      layers.DepthwiseConv2D,
  })

  def __init__(self):

    self._config_quantizer_map = {
        'Default8BitQuantizeConfig':
        PrunePreserveDefault8BitWeightsQuantizer(),
        'Default8BitConvQuantizeConfig':
        PrunePreserveDefault8BitConvWeightsQuantizer(),
    }

  @classmethod
  def _no_trainable_weights(cls, layer):
    """Returns whether this layer has trainable weights.

    Args:
      layer: The layer to check for trainable weights.

    Returns:
      True/False whether the layer has trainable weights.
    """
    return not layer.trainable_weights

  @classmethod
  def _disable_prune_preserve(cls, layer):
    """Returns whether disable this layer for prune preserve.

    Args:
      layer: The layer to check for disable.

    Returns:
      True/False whether disable this layer for prune preserve.
    """

    return layer.__class__ in cls._DISABLE_PRUNE_PRESERVE

  @classmethod
  def supports(cls, layer):
    """Returns whether the registry supports this layer type.

    Args:
      layer: The layer to check for support.

    Returns:
      True/False whether the layer type is supported.
    """

    # layers without trainable weights are considered supported,
    # e.g., ReLU, Softmax, and AveragePooling2D.
    if cls._no_trainable_weights(layer):
      return True

    if layer.__class__ in cls._LAYERS_CONFIG_MAP:
      return True

    return False

  @classmethod
  def _weight_names(cls, layer):
    """Gets the weight names."""
    if cls._no_trainable_weights(layer):
      return []

    return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs

  @classmethod
  def get_sparsity_preservable_weights(cls, layer):
    """Gets sparsity preservable weights from keras layer.

    Args:
      layer: instance of keras layer

    Returns:
      List of sparsity preservable weights
    """
    return [getattr(layer, weight) for weight in cls._weight_names(layer)]

  @classmethod
  def get_suppport_quantize_config_names(cls, layer):
    """Gets class name of supported quantize config for layer.

    Args:
      layer: instance of keras layer

    Returns:
      List of supported quantize config class name.
    """

    # layers without trainable weights don't need quantize_config for pqat
    if cls._no_trainable_weights(layer):
      return []

    return cls._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs

  def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
    """Applies weights sparsity preservation.

    Args:
      layer: The layer to check for support.
      quantize_config: quantization config to check for support,
        apply sparsity preservation to pruned weights
    Raises:
      ValueError when layer is supported does not have quantization config.
    Returns:
      Returns quantize_config with addon sparsity preserve weight_quantizer.
    """
    if self.supports(layer):
      if (self._no_trainable_weights(layer) or
          self._disable_prune_preserve(layer)):
        return quantize_config
      if (quantize_config.__class__.__name__
          in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs):
        quantize_config.weight_quantizer = self._config_quantizer_map[
            quantize_config.__class__.__name__]
      else:
        raise ValueError('Configuration {} is not supported for Layer {}.'
                         .format(str(quantize_config.__class__.__name__),
                                 str(layer.__class__.__name__)))
    else:
      raise ValueError('Layer {} is not supported.'.format(
          str(layer.__class__.__name__)))

    return quantize_config


class Default8bitPrunePreserveQuantizeRegistry(PrunePreserveQuantizeRegistry):
  """Default 8 bit PrunePreserveQuantizeRegistry."""

  def get_quantize_config(self, layer):
    """Returns the quantization config with addon sparsity.

    Args:
      layer: input layer to return quantize config for.

    Returns:
      Returns the quantization config with sparsity preserve weight_quantizer.
    """
    quantize_config = (default_8bit_quantize_registry
                       .Default8BitQuantizeRegistry()
                       .get_quantize_config(layer))
    prune_aware_quantize_config = self.apply_sparsity_preserve_quantize_config(
        layer, quantize_config)

    return prune_aware_quantize_config


class PrunePreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
  """Quantize weights while preserve sparsity."""

  def __init__(self, num_bits, per_axis, symmetric, narrow_range):
    """Initializes PrunePreserveDefaultWeightsQuantizer.

    Args:
      num_bits: Number of bits for quantization
      per_axis: Whether to apply per_axis quantization. The last dimension is
        used as the axis.
      symmetric: If true, use symmetric quantization limits instead of training
        the minimum and maximum of each quantization range separately.
      narrow_range: In case of 8 bits, narrow_range nudges the quantized range
        to be [-127, 127] instead of [-128, 127]. This ensures symmetric range
        has 0 as the centre.
    """
    quantizers.LastValueQuantizer.__init__(self, num_bits, per_axis, symmetric,
                                           narrow_range)

  def _build_sparsity_mask(self, name, layer):
    weights = getattr(layer.layer, name)
    sparsity_mask = tf.math.divide_no_nan(weights, weights)

    return {'sparsity_mask': sparsity_mask}

  def build(self, tensor_shape, name, layer):
    """Constructs mask to preserve weights sparsity.

    Args:
      tensor_shape: Shape of weights which needs to be quantized.
      name: Name of weights in layer.
      layer: quantization wrapped keras layer.

    Returns:
      Dictionary of constructed sparsity mask and
      quantization params, the dictionary will be passed
      to __call__ function.
    """
    result = self._build_sparsity_mask(name, layer)
    result.update(
        super(PrunePreserveDefaultWeightsQuantizer,
              self).build(tensor_shape, name, layer))
    return result

  def __call__(self, inputs, training, weights, **kwargs):
    """Applies sparsity preserved quantization to the input tensor.

    Args:
      inputs: Input tensor (layer's weights) to be quantized.
      training: Whether the graph is currently training.
      weights: Dictionary of weights (params) the quantizer can use to
        quantize the tensor (layer's weights). This contains the weights
        created in the `build` function.
      **kwargs: Additional variables which may be passed to the quantizer.

    Returns:
      quantized tensor.
    """

    prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask'])

    return quant_ops.LastValueQuantize(
        prune_preserve_inputs,
        weights['min_var'],
        weights['max_var'],
        is_training=training,
        num_bits=self.num_bits,
        per_channel=self.per_axis,
        symmetric=self.symmetric,
        narrow_range=self.narrow_range,
    )


class PrunePreserveDefault8BitWeightsQuantizer(
    PrunePreserveDefaultWeightsQuantizer):
  """PrunePreserveWeightsQuantizer for default 8bit weights."""

  def __init__(self):
    super(PrunePreserveDefault8BitWeightsQuantizer,
          self).__init__(num_bits=8,
                         per_axis=False,
                         symmetric=True,
                         narrow_range=True)


class PrunePreserveDefault8BitConvWeightsQuantizer(
    PrunePreserveDefaultWeightsQuantizer,
    default_8bit_quantizers.Default8BitConvWeightsQuantizer,):
  """PrunePreserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights."""

  # pylint: disable=super-init-not-called
  def __init__(self):
    # Skip PrunePreserveDefaultWeightsQuantizer since they have the same super.
    default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)

  def build(self, tensor_shape, name, layer):
    result = PrunePreserveDefaultWeightsQuantizer._build_sparsity_mask(
        self, name, layer)
    result.update(
        default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
            self, tensor_shape, name, layer))
    return result