# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes for defining sparsity masks based on model parameters.
"""
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Tuple, Union
import numpy
import tensorflow
__all__ = [
"PruningMaskCreator",
"UnstructuredPruningMaskCreator",
"GroupedPruningMaskCreator",
"DimensionPruningMaskCreator",
"BlockPruningMaskCreator",
"load_mask_creator",
]
[docs]class PruningMaskCreator(ABC):
"""
Base abstract class for a sparsity mask creator.
Subclasses should define all methods for creating masks and their initializers
"""
[docs] @abstractmethod
def get_mask_initializer(
self,
tensor: tensorflow.Tensor,
) -> Callable[[], tensorflow.Tensor]:
"""
:param tensor: A tensor of a model layer's weights
:return: Tensor initializer function for this sparsity mask
"""
raise NotImplementedError()
[docs] @abstractmethod
def create_sparsity_mask(
self,
tensor: tensorflow.Tensor,
sparsity: tensorflow.Tensor,
) -> tensorflow.Tensor:
"""
:param tensor: A tensor of a model layer's weights
:param sparsity: the target sparsity to use for assigning the masks
:return: A sparsity mask close to the set sparsity based on the values of
the input tensor
"""
raise NotImplementedError()
[docs]class UnstructuredPruningMaskCreator(PruningMaskCreator):
"""
Class for creating unstructured sparsity masks.
Masks will be created using unstructured sparsity by pruning weights ranked
by their magnitude.
"""
[docs] def get_mask_initializer(
self,
tensor: tensorflow.Tensor,
) -> Callable[[], tensorflow.Tensor]:
"""
:param tensor: A tensor of a model layer's weights
:return: Initializer for tensor where an element is 1.0 for nonzero weights
and zero for all other weights
:raise: ValueError If the dtype is not numeric or boolean
"""
def non_zero_mask_initializer(
shape: tensorflow.TensorShape,
dtype: tensorflow.DType = tensorflow.float32,
partition_info: Any = None, # unsued variable for compatability
) -> tensorflow.Tensor:
dtype = tensorflow.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == tensorflow.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
return tensorflow.cast(tensorflow.not_equal(tensor, 0.0), dtype=dtype)
return non_zero_mask_initializer
[docs] def create_sparsity_mask(
self,
tensor: tensorflow.Tensor,
sparsity: tensorflow.Tensor,
) -> tensorflow.Tensor:
"""
:param tensor: A tensor of a model layer's weights
:param sparsity: the target sparsity to use for assigning the masks
:return: A sparsity mask close to the set sparsity based on the values of
the input tensor
"""
abs_var = tensorflow.abs(tensor) # Magnitudes of weights
sparse_threshold_index = tensorflow.cast(
tensorflow.round(
tensorflow.cast(tensorflow.size(abs_var), tensorflow.float32) * sparsity
),
tensorflow.int32,
)
sparse_threshold_index = tensorflow.minimum(
tensorflow.maximum(sparse_threshold_index, 0),
tensorflow.size(tensor) - 1,
)
try:
argsort = tensorflow.argsort
except Exception:
try:
argsort = tensorflow.contrib.framework.argsort
except Exception:
raise RuntimeError(
"cannot find argsort function in tensorflow_v1, "
"currently unsupported"
)
# produce tensor where each element is the index in sorted order of abs_var
abs_var_flat = tensorflow.reshape(abs_var, [-1])
element_ranks_flat = tensorflow.scatter_nd(
tensorflow.expand_dims(argsort(abs_var_flat), 1),
tensorflow.range(abs_var_flat.get_shape()[0]),
abs_var_flat.get_shape(),
)
element_ranks = tensorflow.reshape(element_ranks_flat, abs_var.get_shape())
return tensorflow.cast(
tensorflow.math.greater_equal(element_ranks, sparse_threshold_index),
tensorflow.float32,
)
def __str__(self):
return "unstructured"
def __repr__(self):
return str(self)
[docs]class GroupedPruningMaskCreator(UnstructuredPruningMaskCreator):
"""
Abstract class for a sparsity mask creator that structures masks according to
grouping functions. Subclasses should implement group_tensor and
_map_mask_to_tensor
"""
_GROUPING_OPS = {
"mean": tensorflow.reduce_mean,
"max": tensorflow.reduce_max,
"min": tensorflow.reduce_min,
}
[docs] @staticmethod
def get_grouping_op(grouping_op_name: str) -> tensorflow.Operation:
"""
:param grouping_op_name: name of grouping operation to get tf operation for
:return: tf operation for grouping_op_name if available, raises error otherwise
"""
if grouping_op_name not in GroupedPruningMaskCreator._GROUPING_OPS:
raise ValueError("Invalid grouping op {}, valid grouping ops: {}").format(
grouping_op_name, GroupedPruningMaskCreator._GROUPING_OPS
)
return GroupedPruningMaskCreator._GROUPING_OPS[grouping_op_name]
[docs] @abstractmethod
def group_tensor(self, tensor: tensorflow.Tensor) -> tensorflow.Tensor:
"""
:param tensor: The tensor to reduce in groups
:return: The grouped tensor
"""
raise NotImplementedError()
@abstractmethod
def _map_mask_to_tensor(
self,
grouped_mask: tensorflow.Tensor,
original_tensor_shape: tensorflow.TensorShape,
) -> tensorflow.Tensor:
"""
:param grouped_mask: A binary mask the size of a tensor from group_tensor
:param original_tensor_shape: Shape of the original tensor grouped_mask
derives from
:return: The values from grouped_mask mapped to a tensor of size
original_tensor_shape
"""
raise NotImplementedError()
[docs] def get_mask_initializer(
self,
tensor: tensorflow.Tensor,
) -> Callable[[], tensorflow.Tensor]:
"""
:param tensor: A tensor of a model layer's weights
:return: Tensor initializer function for this sparsity mask
"""
def grouped_non_zero_mask_initializer(
shape: tensorflow.TensorShape,
dtype: tensorflow.DType = tensorflow.float32,
partition_info: Any = None, # unsued variable for compatability
) -> tensorflow.Tensor:
dtype = tensorflow.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == tensorflow.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
grouped_tensor = self.group_tensor(tensor)
grouped_mask = tensorflow.not_equal(grouped_tensor, 0.0)
mask = self._map_mask_to_tensor(grouped_mask, tensor.shape)
return tensorflow.cast(mask, dtype=dtype)
return grouped_non_zero_mask_initializer
[docs] def create_sparsity_mask(
self,
tensor: tensorflow.Tensor,
sparsity: tensorflow.Tensor,
) -> tensorflow.Tensor:
"""
:param tensor: A tensor of a model layer's weights
:param sparsity: the target sparsity to use for assigning the masks
:return: A sparsity mask close to the set sparsity based on the values of
the input tensor
"""
grouped_tensor = self.group_tensor(tensor)
grouped_mask = super().create_sparsity_mask(grouped_tensor, sparsity)
return self._map_mask_to_tensor(grouped_mask, tensor.shape)
[docs]class DimensionPruningMaskCreator(GroupedPruningMaskCreator):
"""
Structured sparsity mask creator that groups sparsity blocks by the given
dimension(s)
:param dim: The index or list of indices of dimensions to group the mask by or
the type of dims to prune (['channel', 'filter'])
"""
_VALID_DIM_NAMES = ["channel", "filter"]
def __init__(
self,
dim: Union[str, int, List[int]],
grouping_op_name: str = "mean",
):
if isinstance(dim, int):
dim = [dim]
self._dim = dim # List[int]
self._grouping_op = GroupedPruningMaskCreator.get_grouping_op(grouping_op_name)
self._dim_name = None
if isinstance(dim, str):
if dim in DimensionPruningMaskCreator._VALID_DIM_NAMES:
self._dim_name = dim
else:
raise ValueError(
"Invalid Dimension name: {}, valid names: {}".format(
dim, DimensionPruningMaskCreator._VALID_DIM_NAMES
)
)
def _set_dim_by_name_for_tensor(self, tensor: tensorflow.Tensor):
n_dims = len(tensor.shape)
if n_dims <= 2:
if self._dim_name == "channel":
self._dim = [0]
else:
raise ValueError(
f"filter pruning unsupported for tensors with fewer than "
f"3 dimensions. Received Tensor with shape {tensor.shape}"
)
elif self._dim_name == "channel":
# in channel should be the second to last dimension
self._dim = [n_dims - 2]
elif self._dim_name == "filter":
# Non-kernel dimensions should be the last two in a conv (in / out channels)
self._dim = [n_dims - 2, n_dims - 1]
else:
raise ValueError(
"Invalid dimension prune type: {}, valid types: {}".format(
self._dim_name, DimensionPruningMaskCreator._VALID_DIM_NAMES
)
)
[docs] def group_tensor(self, tensor: tensorflow.Tensor) -> tensorflow.Tensor:
"""
:param tensor: The tensor to transform
:return: The absolute mean values of the tensor grouped by the
dimension(s) in self._dim
"""
if self._dim_name is not None:
self._set_dim_by_name_for_tensor(tensor)
n_dims = len(tensor.shape)
reduced_axis = [idx for idx in range(n_dims) if idx not in self._dim]
return self._grouping_op(
tensorflow.abs(tensor),
axis=reduced_axis,
keepdims=True,
)
def _map_mask_to_tensor(
self,
grouped_mask: tensorflow.Tensor,
original_tensor_shape: tensorflow.TensorShape,
) -> tensorflow.Tensor:
"""
:param grouped_mask: A binary mask the size of a tensor from group_tensor
:param original_tensor_shape: Shape of the original tensor grouped_mask
derives from
:return: The values from grouped_mask mapped to a tensor of size
original_tensor_shape
"""
# using tile instead of broadcast_to for compatibility with older tf versions
# equivalent to: tensorflow.broadcast_to(grouped_mask, original_tensor_shape)
tile_vals = [
dim if idx not in self._dim else 1
for (idx, dim) in enumerate(original_tensor_shape)
]
return tensorflow.tile(grouped_mask, tile_vals)
def __str__(self):
if self._dim_name is not None:
return self._dim_name
return "{}:{}".format(self.__class__.__name__, self._dim)
def __repr__(self):
return str(self)
[docs]class BlockPruningMaskCreator(GroupedPruningMaskCreator):
"""
Structured sparsity mask creator that groups the input tensor into blocks of
shape block_shape.
block_shape must divide the shape of any input tensor evenly and must have exactly
2 elements for the shape of in and out channels in the blocks.
:param block_shape: The shape of blocks to strucure blocks of in and out channels
in the mask by. -1 represents blocking along the entire dimension.
"""
def __init__(
self,
block_shape: List[int],
grouping_op_name: str = "mean",
):
if len(block_shape) != 2:
raise ValueError(
(
"Invalid block_shape: {}"
" , block_shape must have length == 2 for in and out channels"
).format(block_shape)
)
self._block_shape = block_shape
self._grouping_op = GroupedPruningMaskCreator.get_grouping_op(grouping_op_name)
[docs] def group_tensor(self, tensor: tensorflow.Tensor) -> tensorflow.Tensor:
"""
:param tensor: The tensor to transform
:return: The absolute mean values of the tensor grouped by blocks of
shape self._block_shape
"""
blocked_tens_shape, _ = self._get_blocked_tens_shape_and_validate(tensor.shape)
# reorder so that in and out channel dimensions come before kernel
n_dims = len(tensor.shape)
if n_dims >= 3:
tens_trans_dims = [n_dims - 2, n_dims - 1, *range(n_dims - 2)]
tensor = tensorflow.transpose(tensor, tens_trans_dims)
blocked_tens = tensorflow.reshape(tensor, blocked_tens_shape)
reduced_blocks = self._grouping_op(
tensorflow.abs(blocked_tens), 1, keepdims=True
)
return reduced_blocks
def _map_mask_to_tensor(
self,
grouped_mask: tensorflow.Tensor,
original_tensor_shape: tensorflow.TensorShape,
) -> tensorflow.Tensor:
"""
:param grouped_mask: A binary mask the size of a tensor from group_tensor
:param original_tensor_shape: Shape of the original tensor grouped_mask
derives from
:return: The values from grouped_mask mapped to a tensor of size
original_tensor_shape
"""
(
blocked_tens_shape,
original_tensor_shape,
) = self._get_blocked_tens_shape_and_validate(original_tensor_shape)
block_values_shape = [blocked_tens_shape[0], blocked_tens_shape[2]]
# expand so every element has a corresponding value in the original tensor
block_mask = tensorflow.reshape(grouped_mask, block_values_shape)
block_mask = tensorflow.expand_dims(block_mask, 1)
# Recover reduced dimension of block_mask, using tile instead of broadcast_to
# for compatibility with older versions of tf
block_mask_shape = [dim.value for dim in block_mask.shape]
tile_shape = [
int(block_dim / mask_dim)
for (block_dim, mask_dim) in zip(blocked_tens_shape, block_mask_shape)
]
# equivalent to: tensorflow.broadcast_to(block_mask, blocked_tens_shape)
tensor_mask_blocked = tensorflow.tile(block_mask, tile_shape)
mask = tensorflow.reshape(tensor_mask_blocked, original_tensor_shape)
# Undo channel / kernel transpose if applicable
n_dims = len(original_tensor_shape)
if n_dims >= 3:
tens_trans_dims = [*range(2, n_dims), 0, 1]
mask = tensorflow.transpose(mask, tens_trans_dims)
return mask
def _get_blocked_tens_shape_and_validate(
self,
tens_shape: tensorflow.TensorShape,
) -> Tuple[List[int], tensorflow.TensorShape]:
"""
:param tens_shape: The shape of the tensor to group in blocks
:return: shape of tens when blocked by block_shape and the original
tensor shape with any transposes applied to it
:raise: ValueError if we are unable to block tens by shape block_shape
"""
block_shape = self._block_shape
n_dims = len(tens_shape)
if len(tens_shape) >= 3: # conv should have block shape like [1, ..., 1, X, Y]
block_shape = [*[1] * (n_dims - 2), *block_shape]
tens_shape = [dim.value for dim in tens_shape]
for idx, shape in enumerate(block_shape):
if shape == -1:
block_shape[idx] = int(tens_shape[idx])
# Validate
if n_dims < 2:
raise ValueError(
"Invalid tensor shape {}."
" BlockSparsityMaskCreator can only create masks from tensors with 2 or"
" more dimensions, tensor has {}.".format(tens_shape, n_dims)
)
for tens_dim, block_dim in zip(tens_shape, block_shape):
if tens_dim % block_dim != 0:
raise ValueError(
f"Invalid block_shape {block_shape} for parameter shape "
f"{tens_shape}. Elements of block_shape must divide parameter "
f"shape evenly"
)
# If this is a series of conv filters, reorder so in and out channels are first
if n_dims >= 3:
transpose_idx = [n_dims - 2, n_dims - 1, *range(n_dims - 2)]
block_shape = [block_shape[idx] for idx in transpose_idx]
tens_shape = [tens_shape[idx] for idx in transpose_idx]
# Compute blocked tensor shape
if len(block_shape) > 1 and block_shape[1] > 1:
blocked_tens_shape = [
tens_shape[0] * tens_shape[1] // (block_shape[0] * block_shape[1]),
block_shape[0] * block_shape[1],
-1,
]
else:
blocked_tens_shape = [tens_shape[0] // block_shape[0], block_shape[0], -1]
tens_size = numpy.prod(tens_shape)
num_block_elements = blocked_tens_shape[0] * blocked_tens_shape[1]
blocked_tens_shape[2] = tens_size // num_block_elements
return blocked_tens_shape, tens_shape
def __str__(self):
return str(self._block_shape)
def __repr__(self):
return str(self)
mask_creator_name_to_constructor_lambda = {
"unstructured": lambda: UnstructuredPruningMaskCreator(),
"channel": lambda: DimensionPruningMaskCreator("channel"),
"filter": lambda: DimensionPruningMaskCreator("filter"),
}
[docs]def load_mask_creator(obj: Union[str, Iterable[int]]) -> PruningMaskCreator:
"""
:param obj: Formatted string or iterable of block_shape specifying
SparsityMaskCreator object to return
:return: SparsityMaskCreator object created from obj
"""
if isinstance(obj, str) and obj in mask_creator_name_to_constructor_lambda:
constructor_lambda = mask_creator_name_to_constructor_lambda[obj]
return constructor_lambda()
# Checking for a BlockSparsityMaskCreator string
if ("[" in obj and "]" in obj) or ("(" in obj and ")" in obj):
stripped_str = obj.strip("[|]|(|)")
block_shape = [int(s) for s in stripped_str.split(",")]
return BlockPruningMaskCreator(block_shape)
if isinstance(obj, list) or isinstance(obj, tuple):
return BlockPruningMaskCreator(obj)
raise ValueError(
"Invalid mask type string: {}, could not map to an object".format(obj)
)