Source code for sparseml.keras.optim.mask_pruning_creator

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Classes for defining sparsity masks based on model parameters.
"""

from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Tuple, Union

import numpy
import tensorflow


__all__ = [
    "PruningMaskCreator",
    "UnstructuredPruningMaskCreator",
    "GroupedPruningMaskCreator",
    "DimensionPruningMaskCreator",
    "BlockPruningMaskCreator",
    "load_mask_creator",
]


[docs]class PruningMaskCreator(ABC): """ Base abstract class for a sparsity mask creator. Subclasses should define all methods for creating masks and their initializers """
[docs] @abstractmethod def get_mask_initializer( self, tensor: tensorflow.Tensor, ) -> Callable[[], tensorflow.Tensor]: """ :param tensor: A tensor of a model layer's weights :return: Tensor initializer function for this sparsity mask """ raise NotImplementedError()
[docs] @abstractmethod def create_sparsity_mask( self, tensor: tensorflow.Tensor, sparsity: tensorflow.Tensor, ) -> tensorflow.Tensor: """ :param tensor: A tensor of a model layer's weights :param sparsity: the target sparsity to use for assigning the masks :return: A sparsity mask close to the set sparsity based on the values of the input tensor """ raise NotImplementedError()
[docs]class UnstructuredPruningMaskCreator(PruningMaskCreator): """ Class for creating unstructured sparsity masks. Masks will be created using unstructured sparsity by pruning weights ranked by their magnitude. """
[docs] def get_mask_initializer( self, tensor: tensorflow.Tensor, ) -> Callable[[], tensorflow.Tensor]: """ :param tensor: A tensor of a model layer's weights :return: Initializer for tensor where an element is 1.0 for nonzero weights and zero for all other weights :raise: ValueError If the dtype is not numeric or boolean """ def non_zero_mask_initializer( shape: tensorflow.TensorShape, dtype: tensorflow.DType = tensorflow.float32, partition_info: Any = None, # unsued variable for compatability ) -> tensorflow.Tensor: dtype = tensorflow.as_dtype(dtype) if not dtype.is_numpy_compatible or dtype == tensorflow.string: raise ValueError("Expected numeric or boolean dtype, got %s." % dtype) return tensorflow.cast(tensorflow.not_equal(tensor, 0.0), dtype=dtype) return non_zero_mask_initializer
[docs] def create_sparsity_mask( self, tensor: tensorflow.Tensor, sparsity: tensorflow.Tensor, ) -> tensorflow.Tensor: """ :param tensor: A tensor of a model layer's weights :param sparsity: the target sparsity to use for assigning the masks :return: A sparsity mask close to the set sparsity based on the values of the input tensor """ abs_var = tensorflow.abs(tensor) # Magnitudes of weights sparse_threshold_index = tensorflow.cast( tensorflow.round( tensorflow.cast(tensorflow.size(abs_var), tensorflow.float32) * sparsity ), tensorflow.int32, ) sparse_threshold_index = tensorflow.minimum( tensorflow.maximum(sparse_threshold_index, 0), tensorflow.size(tensor) - 1, ) try: argsort = tensorflow.argsort except Exception: try: argsort = tensorflow.contrib.framework.argsort except Exception: raise RuntimeError( "cannot find argsort function in tensorflow_v1, " "currently unsupported" ) # produce tensor where each element is the index in sorted order of abs_var abs_var_flat = tensorflow.reshape(abs_var, [-1]) element_ranks_flat = tensorflow.scatter_nd( tensorflow.expand_dims(argsort(abs_var_flat), 1), tensorflow.range(abs_var_flat.get_shape()[0]), abs_var_flat.get_shape(), ) element_ranks = tensorflow.reshape(element_ranks_flat, abs_var.get_shape()) return tensorflow.cast( tensorflow.math.greater_equal(element_ranks, sparse_threshold_index), tensorflow.float32, )
def __str__(self): return "unstructured" def __repr__(self): return str(self)
[docs]class GroupedPruningMaskCreator(UnstructuredPruningMaskCreator): """ Abstract class for a sparsity mask creator that structures masks according to grouping functions. Subclasses should implement group_tensor and _map_mask_to_tensor """ _GROUPING_OPS = { "mean": tensorflow.reduce_mean, "max": tensorflow.reduce_max, "min": tensorflow.reduce_min, }
[docs] @staticmethod def get_grouping_op(grouping_op_name: str) -> tensorflow.Operation: """ :param grouping_op_name: name of grouping operation to get tf operation for :return: tf operation for grouping_op_name if available, raises error otherwise """ if grouping_op_name not in GroupedPruningMaskCreator._GROUPING_OPS: raise ValueError("Invalid grouping op {}, valid grouping ops: {}").format( grouping_op_name, GroupedPruningMaskCreator._GROUPING_OPS ) return GroupedPruningMaskCreator._GROUPING_OPS[grouping_op_name]
[docs] @abstractmethod def group_tensor(self, tensor: tensorflow.Tensor) -> tensorflow.Tensor: """ :param tensor: The tensor to reduce in groups :return: The grouped tensor """ raise NotImplementedError()
@abstractmethod def _map_mask_to_tensor( self, grouped_mask: tensorflow.Tensor, original_tensor_shape: tensorflow.TensorShape, ) -> tensorflow.Tensor: """ :param grouped_mask: A binary mask the size of a tensor from group_tensor :param original_tensor_shape: Shape of the original tensor grouped_mask derives from :return: The values from grouped_mask mapped to a tensor of size original_tensor_shape """ raise NotImplementedError()
[docs] def get_mask_initializer( self, tensor: tensorflow.Tensor, ) -> Callable[[], tensorflow.Tensor]: """ :param tensor: A tensor of a model layer's weights :return: Tensor initializer function for this sparsity mask """ def grouped_non_zero_mask_initializer( shape: tensorflow.TensorShape, dtype: tensorflow.DType = tensorflow.float32, partition_info: Any = None, # unsued variable for compatability ) -> tensorflow.Tensor: dtype = tensorflow.as_dtype(dtype) if not dtype.is_numpy_compatible or dtype == tensorflow.string: raise ValueError("Expected numeric or boolean dtype, got %s." % dtype) grouped_tensor = self.group_tensor(tensor) grouped_mask = tensorflow.not_equal(grouped_tensor, 0.0) mask = self._map_mask_to_tensor(grouped_mask, tensor.shape) return tensorflow.cast(mask, dtype=dtype) return grouped_non_zero_mask_initializer
[docs] def create_sparsity_mask( self, tensor: tensorflow.Tensor, sparsity: tensorflow.Tensor, ) -> tensorflow.Tensor: """ :param tensor: A tensor of a model layer's weights :param sparsity: the target sparsity to use for assigning the masks :return: A sparsity mask close to the set sparsity based on the values of the input tensor """ grouped_tensor = self.group_tensor(tensor) grouped_mask = super().create_sparsity_mask(grouped_tensor, sparsity) return self._map_mask_to_tensor(grouped_mask, tensor.shape)
[docs]class DimensionPruningMaskCreator(GroupedPruningMaskCreator): """ Structured sparsity mask creator that groups sparsity blocks by the given dimension(s) :param dim: The index or list of indices of dimensions to group the mask by or the type of dims to prune (['channel', 'filter']) """ _VALID_DIM_NAMES = ["channel", "filter"] def __init__( self, dim: Union[str, int, List[int]], grouping_op_name: str = "mean", ): if isinstance(dim, int): dim = [dim] self._dim = dim # List[int] self._grouping_op = GroupedPruningMaskCreator.get_grouping_op(grouping_op_name) self._dim_name = None if isinstance(dim, str): if dim in DimensionPruningMaskCreator._VALID_DIM_NAMES: self._dim_name = dim else: raise ValueError( "Invalid Dimension name: {}, valid names: {}".format( dim, DimensionPruningMaskCreator._VALID_DIM_NAMES ) ) def _set_dim_by_name_for_tensor(self, tensor: tensorflow.Tensor): n_dims = len(tensor.shape) if n_dims <= 2: if self._dim_name == "channel": self._dim = [0] else: raise ValueError( f"filter pruning unsupported for tensors with fewer than " f"3 dimensions. Received Tensor with shape {tensor.shape}" ) elif self._dim_name == "channel": # in channel should be the second to last dimension self._dim = [n_dims - 2] elif self._dim_name == "filter": # Non-kernel dimensions should be the last two in a conv (in / out channels) self._dim = [n_dims - 2, n_dims - 1] else: raise ValueError( "Invalid dimension prune type: {}, valid types: {}".format( self._dim_name, DimensionPruningMaskCreator._VALID_DIM_NAMES ) )
[docs] def group_tensor(self, tensor: tensorflow.Tensor) -> tensorflow.Tensor: """ :param tensor: The tensor to transform :return: The absolute mean values of the tensor grouped by the dimension(s) in self._dim """ if self._dim_name is not None: self._set_dim_by_name_for_tensor(tensor) n_dims = len(tensor.shape) reduced_axis = [idx for idx in range(n_dims) if idx not in self._dim] return self._grouping_op( tensorflow.abs(tensor), axis=reduced_axis, keepdims=True, )
def _map_mask_to_tensor( self, grouped_mask: tensorflow.Tensor, original_tensor_shape: tensorflow.TensorShape, ) -> tensorflow.Tensor: """ :param grouped_mask: A binary mask the size of a tensor from group_tensor :param original_tensor_shape: Shape of the original tensor grouped_mask derives from :return: The values from grouped_mask mapped to a tensor of size original_tensor_shape """ # using tile instead of broadcast_to for compatibility with older tf versions # equivalent to: tensorflow.broadcast_to(grouped_mask, original_tensor_shape) tile_vals = [ dim if idx not in self._dim else 1 for (idx, dim) in enumerate(original_tensor_shape) ] return tensorflow.tile(grouped_mask, tile_vals) def __str__(self): if self._dim_name is not None: return self._dim_name return "{}:{}".format(self.__class__.__name__, self._dim) def __repr__(self): return str(self)
[docs]class BlockPruningMaskCreator(GroupedPruningMaskCreator): """ Structured sparsity mask creator that groups the input tensor into blocks of shape block_shape. block_shape must divide the shape of any input tensor evenly and must have exactly 2 elements for the shape of in and out channels in the blocks. :param block_shape: The shape of blocks to strucure blocks of in and out channels in the mask by. -1 represents blocking along the entire dimension. """ def __init__( self, block_shape: List[int], grouping_op_name: str = "mean", ): if len(block_shape) != 2: raise ValueError( ( "Invalid block_shape: {}" " , block_shape must have length == 2 for in and out channels" ).format(block_shape) ) self._block_shape = block_shape self._grouping_op = GroupedPruningMaskCreator.get_grouping_op(grouping_op_name)
[docs] def group_tensor(self, tensor: tensorflow.Tensor) -> tensorflow.Tensor: """ :param tensor: The tensor to transform :return: The absolute mean values of the tensor grouped by blocks of shape self._block_shape """ blocked_tens_shape, _ = self._get_blocked_tens_shape_and_validate(tensor.shape) # reorder so that in and out channel dimensions come before kernel n_dims = len(tensor.shape) if n_dims >= 3: tens_trans_dims = [n_dims - 2, n_dims - 1, *range(n_dims - 2)] tensor = tensorflow.transpose(tensor, tens_trans_dims) blocked_tens = tensorflow.reshape(tensor, blocked_tens_shape) reduced_blocks = self._grouping_op( tensorflow.abs(blocked_tens), 1, keepdims=True ) return reduced_blocks
def _map_mask_to_tensor( self, grouped_mask: tensorflow.Tensor, original_tensor_shape: tensorflow.TensorShape, ) -> tensorflow.Tensor: """ :param grouped_mask: A binary mask the size of a tensor from group_tensor :param original_tensor_shape: Shape of the original tensor grouped_mask derives from :return: The values from grouped_mask mapped to a tensor of size original_tensor_shape """ ( blocked_tens_shape, original_tensor_shape, ) = self._get_blocked_tens_shape_and_validate(original_tensor_shape) block_values_shape = [blocked_tens_shape[0], blocked_tens_shape[2]] # expand so every element has a corresponding value in the original tensor block_mask = tensorflow.reshape(grouped_mask, block_values_shape) block_mask = tensorflow.expand_dims(block_mask, 1) # Recover reduced dimension of block_mask, using tile instead of broadcast_to # for compatibility with older versions of tf block_mask_shape = [dim.value for dim in block_mask.shape] tile_shape = [ int(block_dim / mask_dim) for (block_dim, mask_dim) in zip(blocked_tens_shape, block_mask_shape) ] # equivalent to: tensorflow.broadcast_to(block_mask, blocked_tens_shape) tensor_mask_blocked = tensorflow.tile(block_mask, tile_shape) mask = tensorflow.reshape(tensor_mask_blocked, original_tensor_shape) # Undo channel / kernel transpose if applicable n_dims = len(original_tensor_shape) if n_dims >= 3: tens_trans_dims = [*range(2, n_dims), 0, 1] mask = tensorflow.transpose(mask, tens_trans_dims) return mask def _get_blocked_tens_shape_and_validate( self, tens_shape: tensorflow.TensorShape, ) -> Tuple[List[int], tensorflow.TensorShape]: """ :param tens_shape: The shape of the tensor to group in blocks :return: shape of tens when blocked by block_shape and the original tensor shape with any transposes applied to it :raise: ValueError if we are unable to block tens by shape block_shape """ block_shape = self._block_shape n_dims = len(tens_shape) if len(tens_shape) >= 3: # conv should have block shape like [1, ..., 1, X, Y] block_shape = [*[1] * (n_dims - 2), *block_shape] tens_shape = [dim.value for dim in tens_shape] for idx, shape in enumerate(block_shape): if shape == -1: block_shape[idx] = int(tens_shape[idx]) # Validate if n_dims < 2: raise ValueError( "Invalid tensor shape {}." " BlockSparsityMaskCreator can only create masks from tensors with 2 or" " more dimensions, tensor has {}.".format(tens_shape, n_dims) ) for tens_dim, block_dim in zip(tens_shape, block_shape): if tens_dim % block_dim != 0: raise ValueError( f"Invalid block_shape {block_shape} for parameter shape " f"{tens_shape}. Elements of block_shape must divide parameter " f"shape evenly" ) # If this is a series of conv filters, reorder so in and out channels are first if n_dims >= 3: transpose_idx = [n_dims - 2, n_dims - 1, *range(n_dims - 2)] block_shape = [block_shape[idx] for idx in transpose_idx] tens_shape = [tens_shape[idx] for idx in transpose_idx] # Compute blocked tensor shape if len(block_shape) > 1 and block_shape[1] > 1: blocked_tens_shape = [ tens_shape[0] * tens_shape[1] // (block_shape[0] * block_shape[1]), block_shape[0] * block_shape[1], -1, ] else: blocked_tens_shape = [tens_shape[0] // block_shape[0], block_shape[0], -1] tens_size = numpy.prod(tens_shape) num_block_elements = blocked_tens_shape[0] * blocked_tens_shape[1] blocked_tens_shape[2] = tens_size // num_block_elements return blocked_tens_shape, tens_shape def __str__(self): return str(self._block_shape) def __repr__(self): return str(self)
mask_creator_name_to_constructor_lambda = { "unstructured": lambda: UnstructuredPruningMaskCreator(), "channel": lambda: DimensionPruningMaskCreator("channel"), "filter": lambda: DimensionPruningMaskCreator("filter"), }
[docs]def load_mask_creator(obj: Union[str, Iterable[int]]) -> PruningMaskCreator: """ :param obj: Formatted string or iterable of block_shape specifying SparsityMaskCreator object to return :return: SparsityMaskCreator object created from obj """ if isinstance(obj, str) and obj in mask_creator_name_to_constructor_lambda: constructor_lambda = mask_creator_name_to_constructor_lambda[obj] return constructor_lambda() # Checking for a BlockSparsityMaskCreator string if ("[" in obj and "]" in obj) or ("(" in obj and ")" in obj): stripped_str = obj.strip("[|]|(|)") block_shape = [int(s) for s in stripped_str.split(",")] return BlockPruningMaskCreator(block_shape) if isinstance(obj, list) or isinstance(obj, tuple): return BlockPruningMaskCreator(obj) raise ValueError( "Invalid mask type string: {}, could not map to an object".format(obj) )