Mixtral_ether / prune_preserve_quantize_registry.py
jeduardogruiz's picture
Upload 22 files
516a027 verified
raw
history blame
11.8 kB
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry responsible for built-in keras classes."""
import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
default_8bit_quantize_registry,)
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
default_8bit_quantizers,)
layers = keras.layers
class _PrunePreserveInfo(object):
"""PrunePreserveInfo."""
def __init__(self, weight_attrs, quantize_config_attrs):
"""Initializes PrunePreserveInfo.
Args:
weight_attrs: list of sparsity preservable weight attributes of layer.
quantize_config_attrs: list of quantization configuration class name.
"""
self.weight_attrs = weight_attrs
self.quantize_config_attrs = quantize_config_attrs
class PrunePreserveQuantizeRegistry():
"""PrunePreserveQuantizeRegistry responsible for built-in keras layers."""
# The keys represent built-in keras layers; the first values represent the
# the variables within the layers which hold the kernel weights, second
# values represent the class name of quantization configuration for layers.
# This decide the weights of layers with quantization configurations are
# sparsity preservable.
_LAYERS_CONFIG_MAP = {
layers.Conv2D:
_PrunePreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
layers.Dense:
_PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
# DepthwiseConv2D is supported with 8bit qat, but not with prune,
# thus for DepthwiseConv2D PQAT, weights sparsity preserve is disabled.
layers.DepthwiseConv2D:
_PrunePreserveInfo(['depthwise_kernel'], ['Default8BitQuantizeConfig']),
# layers that supported with prune, but not yet with QAT
# layers.Conv1D:
# _PrunePreserveInfo(['kernel'], []),
# layers.Conv2DTranspose:
# _PrunePreserveInfo(['kernel'], []),
# layers.Conv3D:
# _PrunePreserveInfo(['kernel'], []),
# layers.Conv3DTranspose:
# _PrunePreserveInfo(['kernel'], []),
# layers.LocallyConnected1D:
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
# layers.LocallyConnected2D:
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
# SeparableConv need verify from 8bit qat
# layers.SeparableConv1D:
# _PrunePreserveInfo(['pointwise_kernel'], \
# ['Default8BitConvQuantizeConfig']),
# layers.SeparableConv2D:
# _PrunePreserveInfo(['pointwise_kernel'], \
# ['Default8BitConvQuantizeConfig']),
# Embedding need verify from 8bit qat
# layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
}
_DISABLE_PRUNE_PRESERVE = frozenset({
layers.DepthwiseConv2D,
})
def __init__(self):
self._config_quantizer_map = {
'Default8BitQuantizeConfig':
PrunePreserveDefault8BitWeightsQuantizer(),
'Default8BitConvQuantizeConfig':
PrunePreserveDefault8BitConvWeightsQuantizer(),
}
@classmethod
def _no_trainable_weights(cls, layer):
"""Returns whether this layer has trainable weights.
Args:
layer: The layer to check for trainable weights.
Returns:
True/False whether the layer has trainable weights.
"""
return not layer.trainable_weights
@classmethod
def _disable_prune_preserve(cls, layer):
"""Returns whether disable this layer for prune preserve.
Args:
layer: The layer to check for disable.
Returns:
True/False whether disable this layer for prune preserve.
"""
return layer.__class__ in cls._DISABLE_PRUNE_PRESERVE
@classmethod
def supports(cls, layer):
"""Returns whether the registry supports this layer type.
Args:
layer: The layer to check for support.
Returns:
True/False whether the layer type is supported.
"""
# layers without trainable weights are considered supported,
# e.g., ReLU, Softmax, and AveragePooling2D.
if cls._no_trainable_weights(layer):
return True
if layer.__class__ in cls._LAYERS_CONFIG_MAP:
return True
return False
@classmethod
def _weight_names(cls, layer):
"""Gets the weight names."""
if cls._no_trainable_weights(layer):
return []
return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs
@classmethod
def get_sparsity_preservable_weights(cls, layer):
"""Gets sparsity preservable weights from keras layer.
Args:
layer: instance of keras layer
Returns:
List of sparsity preservable weights
"""
return [getattr(layer, weight) for weight in cls._weight_names(layer)]
@classmethod
def get_suppport_quantize_config_names(cls, layer):
"""Gets class name of supported quantize config for layer.
Args:
layer: instance of keras layer
Returns:
List of supported quantize config class name.
"""
# layers without trainable weights don't need quantize_config for pqat
if cls._no_trainable_weights(layer):
return []
return cls._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs
def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
"""Applies weights sparsity preservation.
Args:
layer: The layer to check for support.
quantize_config: quantization config to check for support,
apply sparsity preservation to pruned weights
Raises:
ValueError when layer is supported does not have quantization config.
Returns:
Returns quantize_config with addon sparsity preserve weight_quantizer.
"""
if self.supports(layer):
if (self._no_trainable_weights(layer) or
self._disable_prune_preserve(layer)):
return quantize_config
if (quantize_config.__class__.__name__
in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs):
quantize_config.weight_quantizer = self._config_quantizer_map[
quantize_config.__class__.__name__]
else:
raise ValueError('Configuration {} is not supported for Layer {}.'
.format(str(quantize_config.__class__.__name__),
str(layer.__class__.__name__)))
else:
raise ValueError('Layer {} is not supported.'.format(
str(layer.__class__.__name__)))
return quantize_config
class Default8bitPrunePreserveQuantizeRegistry(PrunePreserveQuantizeRegistry):
"""Default 8 bit PrunePreserveQuantizeRegistry."""
def get_quantize_config(self, layer):
"""Returns the quantization config with addon sparsity.
Args:
layer: input layer to return quantize config for.
Returns:
Returns the quantization config with sparsity preserve weight_quantizer.
"""
quantize_config = (default_8bit_quantize_registry
.Default8BitQuantizeRegistry()
.get_quantize_config(layer))
prune_aware_quantize_config = self.apply_sparsity_preserve_quantize_config(
layer, quantize_config)
return prune_aware_quantize_config
class PrunePreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
"""Quantize weights while preserve sparsity."""
def __init__(self, num_bits, per_axis, symmetric, narrow_range):
"""Initializes PrunePreserveDefaultWeightsQuantizer.
Args:
num_bits: Number of bits for quantization
per_axis: Whether to apply per_axis quantization. The last dimension is
used as the axis.
symmetric: If true, use symmetric quantization limits instead of training
the minimum and maximum of each quantization range separately.
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
to be [-127, 127] instead of [-128, 127]. This ensures symmetric range
has 0 as the centre.
"""
quantizers.LastValueQuantizer.__init__(self, num_bits, per_axis, symmetric,
narrow_range)
def _build_sparsity_mask(self, name, layer):
weights = getattr(layer.layer, name)
sparsity_mask = tf.math.divide_no_nan(weights, weights)
return {'sparsity_mask': sparsity_mask}
def build(self, tensor_shape, name, layer):
"""Constructs mask to preserve weights sparsity.
Args:
tensor_shape: Shape of weights which needs to be quantized.
name: Name of weights in layer.
layer: quantization wrapped keras layer.
Returns:
Dictionary of constructed sparsity mask and
quantization params, the dictionary will be passed
to __call__ function.
"""
result = self._build_sparsity_mask(name, layer)
result.update(
super(PrunePreserveDefaultWeightsQuantizer,
self).build(tensor_shape, name, layer))
return result
def __call__(self, inputs, training, weights, **kwargs):
"""Applies sparsity preserved quantization to the input tensor.
Args:
inputs: Input tensor (layer's weights) to be quantized.
training: Whether the graph is currently training.
weights: Dictionary of weights (params) the quantizer can use to
quantize the tensor (layer's weights). This contains the weights
created in the `build` function.
**kwargs: Additional variables which may be passed to the quantizer.
Returns:
quantized tensor.
"""
prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask'])
return quant_ops.LastValueQuantize(
prune_preserve_inputs,
weights['min_var'],
weights['max_var'],
is_training=training,
num_bits=self.num_bits,
per_channel=self.per_axis,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
)
class PrunePreserveDefault8BitWeightsQuantizer(
PrunePreserveDefaultWeightsQuantizer):
"""PrunePreserveWeightsQuantizer for default 8bit weights."""
def __init__(self):
super(PrunePreserveDefault8BitWeightsQuantizer,
self).__init__(num_bits=8,
per_axis=False,
symmetric=True,
narrow_range=True)
class PrunePreserveDefault8BitConvWeightsQuantizer(
PrunePreserveDefaultWeightsQuantizer,
default_8bit_quantizers.Default8BitConvWeightsQuantizer,):
"""PrunePreserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights."""
# pylint: disable=super-init-not-called
def __init__(self):
# Skip PrunePreserveDefaultWeightsQuantizer since they have the same super.
default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
def build(self, tensor_shape, name, layer):
result = PrunePreserveDefaultWeightsQuantizer._build_sparsity_mask(
self, name, layer)
result.update(
default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
self, tensor_shape, name, layer))
return result