Source code for sparseml.keras.optim.mask_pruning

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import collections
import inspect
from typing import List, Union

import tensorflow

from sparseml.keras.optim.mask_pruning_creator import (
    PruningMaskCreator,
    load_mask_creator,
)
from sparseml.keras.utils import keras


__all__ = [
    "MaskedLayer",
    "PruningScheduler",
    "remove_pruning_masks",
]


[docs]class PruningScheduler(abc.ABC): """ Abstract pruning scheduler """ _REGISTRY = {} def __init_subclass__(cls): super().__init_subclass__() PruningScheduler._register_class(cls)
[docs] @abc.abstractmethod def should_prune(self, step: int) -> bool: """ Check if the given step is a right time for pruning :param step: training step :return: True if pruning should take place; False otherwise """ raise NotImplementedError("Not implemented")
[docs] @abc.abstractmethod def target_sparsity(self, step: int, **kwargs) -> float: """ Compute the target sparsity at the given step :param step: training step :param kwargs: optional keyword params that a specific scheduler might need :return: target sparsity """ raise NotImplementedError("Not implemented")
[docs] @abc.abstractmethod def get_config(self): raise NotImplementedError("Not implemented")
[docs] @classmethod def deserialize(cls, config): """ Deserialize a pruning scheduler from config returned by scheduler's get_config method :param config: a pruning scheduler's config :return: a pruning scheduler instance """ if "class_name" not in config: raise ValueError("The 'class_name' not found in config: {}".format(config)) class_name = config["class_name"] return keras.utils.deserialize_keras_object( config, module_objects=globals(), custom_objects={class_name: PruningScheduler._REGISTRY[class_name]}, )
@classmethod def _register_class(cls, target_cls): PruningScheduler._REGISTRY[target_cls.__name__] = target_cls
MaskedParamInfo = collections.namedtuple( "MaskedParamInfo", ["name", "param", "mask", "sparsity"] ) class MaskAndWeightUpdater: """ Core logic of updating masks and weights :param pruning_vars: a list of tuples where each element contains weight tensor, mask and sparsity :param pruning_scheduler: a pruning scheduler :param mask_creator: a mask creator :param global_step: a global step tensor """ def __init__( self, pruning_vars: List[MaskedParamInfo], pruning_scheduler: PruningScheduler, mask_creator: PruningMaskCreator, global_step: tensorflow.Tensor, ): self._pruning_vars = pruning_vars self._pruning_scheduler = pruning_scheduler self._mask_creator = mask_creator self._global_step = global_step self._update_ready = None def _is_pruning_step(self) -> bool: global_step_val = keras.backend.get_value(self._global_step) assert global_step_val >= 0 update_ready = self._pruning_scheduler.should_prune(global_step_val) return update_ready def _conditional_training_update(self): def _no_update_masks_and_weights(): return tensorflow.no_op("no_update") def _update_masks_and_weights(): assignments = [] global_step_val = keras.backend.get_value(self._global_step) for masked_param_info in self._pruning_vars: new_sparsity = self._pruning_scheduler.target_sparsity(global_step_val) new_mask = self._mask_creator.create_sparsity_mask( masked_param_info.param, new_sparsity ) assignments.append(masked_param_info.mask.assign(new_mask)) assignments.append(masked_param_info.sparsity.assign(new_sparsity)) masked_param = tensorflow.math.multiply( masked_param_info.param, masked_param_info.mask ) assignments.append(masked_param_info.param.assign(masked_param)) return tensorflow.group(assignments) update_ready = self._is_pruning_step() self._update_ready = update_ready return tensorflow.cond( tensorflow.cast(update_ready, tensorflow.bool), _update_masks_and_weights, _no_update_masks_and_weights, ) def apply_masks(self): """ Apply masks to the weights """ assignments = [] for masked_param_info in self._pruning_vars: masked_param = tensorflow.math.multiply( masked_param_info.param, masked_param_info.mask ) assignments.append(masked_param_info.param.assign(masked_param)) return tensorflow.group(assignments) def conditional_update(self, training=None): """ Conditionally update masks and weights :param training: if in training mode """ def _update(): with tensorflow.control_dependencies([self._conditional_training_update()]): return tensorflow.no_op("update") def _no_update(): return tensorflow.no_op("no_update") training = keras.backend.learning_phase() if training is None else training return tensorflow.cond( tensorflow.cast(training, tensorflow.bool), _update, _no_update ) _LAYER_PRUNABLE_PARAMS_MAP = { keras.layers.Conv1D: ["kernel"], keras.layers.Conv2D: ["kernel"], keras.layers.Conv2DTranspose: ["kernel"], keras.layers.Conv3D: ["kernel"], keras.layers.Conv3DTranspose: ["kernel"], keras.layers.Dense: ["kernel"], keras.layers.Embedding: ["embeddings"], keras.layers.LocallyConnected1D: ["kernel"], keras.layers.LocallyConnected2D: ["kernel"], keras.layers.SeparableConv1D: ["pointwise_kernel"], keras.layers.SeparableConv2D: ["pointwise_kernel"], } def _get_default_prunable_params(layer: keras.layers.Layer): if layer.__class__ in _LAYER_PRUNABLE_PARAMS_MAP: prunable_param_names = _LAYER_PRUNABLE_PARAMS_MAP[layer.__class__] return { "{}/{}".format(layer.name, param_name): getattr(layer, param_name) for param_name in prunable_param_names } else: expected_layers = [layer.__class__ for layer in _LAYER_PRUNABLE_PARAMS_MAP] raise ValueError( "Layer {} cannot be pruned. Expected layers: {}".format( layer, expected_layers ) )
[docs]class MaskedLayer(keras.layers.Wrapper): """ Masked layer is a layer wrapping around another layer with a mask; the mask however is shared if the enclosed layer is again of MaskedLayer type :param layer: either a MaskedLayer or a keras layer :param pruning_scheduler: a pruning scheduler :param mask_creator: a mask creator :param kwargs: optional params for keras layer constructor, e.g. layer name """ def __init__( self, layer: keras.layers.Layer, pruning_scheduler: PruningScheduler, mask_type: Union[str, List[int]] = "unstructured", **kwargs, ): if not isinstance(layer, MaskedLayer) and not isinstance( layer, keras.layers.Layer ): raise ValueError( "Invalid layer passed in, expected MaskedLayer or a keras Layer, " "but got {}".format(layer) ) super(MaskedLayer, self).__init__(layer, **kwargs) self._layer = layer self._pruning_scheduler = pruning_scheduler self._mask_type = mask_type self._mask_creator = None self._pruning_vars = [] self._global_step = None self._mask_updater = None
[docs] def build(self, input_shape): super(MaskedLayer, self).build(input_shape) self._mask_creator = load_mask_creator(self._mask_type) self._pruning_vars = self._reuse_or_create_pruning_vars() self._global_step = self.add_weight( "global_step", shape=[], initializer=keras.initializers.Constant(-1), dtype=tensorflow.int64, trainable=False, ) self._mask_updater = MaskAndWeightUpdater( self._pruning_vars, self._pruning_scheduler, self._mask_creator, self._global_step, )
def _reuse_or_create_pruning_vars( self, ) -> List[MaskedParamInfo]: if isinstance(self._layer, MaskedLayer): # All nested masked layers reused pruning vars created # for the "core", inner-most, Keras built-in layer return self._layer.pruning_vars assert isinstance(self._layer, keras.layers.Layer) prunable_params = _get_default_prunable_params(self._layer) pruning_vars = [] for name, param in prunable_params.items(): mask = self.add_weight( "mask", shape=param.shape, initializer=keras.initializers.get("ones"), dtype=param.dtype, trainable=False, ) sparsity = self.add_weight( "sparsity", shape=[], initializer=keras.initializers.get("zeros"), dtype=param.dtype, trainable=False, ) pruning_vars.append(MaskedParamInfo(name, param, mask, sparsity)) return pruning_vars
[docs] def call(self, inputs: tensorflow.Tensor, training=None): """ Forward function for calling layer instance as function """ training = keras.backend.learning_phase() if training is None else training def _apply_masks_to_weights(): with tensorflow.control_dependencies([self._mask_updater.apply_masks()]): return tensorflow.no_op("update") def _no_apply_masks_to_weights(): return tensorflow.no_op("no_update_masks") tensorflow.cond( tensorflow.cast(training, tensorflow.bool), _apply_masks_to_weights, _no_apply_masks_to_weights, ) args = inspect.getfullargspec(self._layer.call).args if "training" in args: return self._layer.call(inputs, training=training) else: return self._layer.call(inputs)
[docs] def get_config(self): """ Get layer config Serialization and deserialization should be done using keras.serialize/deserialize, which create and retrieve the "class_name" field automatically. The resulting config below therefore does not contain the field. """ config = super(MaskedLayer, self).get_config() if "layer" not in config: raise RuntimeError("Expected 'layer' field not found in config") config.update( { "pruning_scheduler": self._pruning_scheduler.get_config(), "mask_type": self._mask_type, } ) return config
[docs] @classmethod def from_config(cls, config): config = config.copy() layer = keras.layers.deserialize( config.pop("layer"), custom_objects={"MaskedLayer": MaskedLayer} ) if not isinstance(layer, MaskedLayer) and not isinstance( layer, keras.layers.Layer ): raise RuntimeError("Unexpected layer created from config") pruning_scheduler = PruningScheduler.deserialize( config.pop("pruning_scheduler") ) if not isinstance(pruning_scheduler, PruningScheduler): raise RuntimeError("Unexpected pruning scheduler type created from config") mask_type = config.pop("mask_type") masked_layer = MaskedLayer(layer, pruning_scheduler, mask_type, **config) return masked_layer
[docs] def compute_output_shape(self, input_shape): return self._layer.compute_output_shape(input_shape)
@property def global_step(self): return self._global_step @property def mask_updater(self): return self._mask_updater @property def masks(self): return [masked_param_info.mask for masked_param_info in self._pruning_vars] @property def pruning_vars(self): return self._pruning_vars @property def pruned_layer(self): if isinstance(self._layer, MaskedLayer): return self._layer.pruned_layer elif isinstance(self._layer, keras.layers.Layer): return self._layer else: raise RuntimeError("Unrecognized layer") @property def masked_layer(self): return self._layer
[docs]def remove_pruning_masks(model: keras.Model): """ Remove pruning masks from a model that was pruned using the MaskedLayer logic :param model: a model that was pruned using MaskedLayer :return: the original model with pruned weights """ def _get_pruned_layer(layer): # If the model is loaded through SavedFormat, the layer of type # MaskedLayer would belong to a special package, hence the # second check below based simply on class name is_masked_layer = isinstance( layer, MaskedLayer ) or layer.__class__.__name__.endswith("MaskedLayer") if is_masked_layer: return _get_pruned_layer(layer.layer) elif isinstance(layer, keras.layers.Layer): return layer else: raise ValueError("Unknown layer type") def _remove_pruning_masks(layer): is_masked_layer = isinstance( layer, MaskedLayer ) or layer.__class__.__name__.endswith("MaskedLayer") if is_masked_layer: return _get_pruned_layer(layer) return layer # TODO: while the resulting model could be exported to ONNX, its built status # is removed return keras.models.clone_model( model, input_tensors=None, clone_function=_remove_pruning_masks )