# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code related to applying a mask onto a parameter to impose kernel sparsity,
aka model pruning
NOTE: this file is in the process of being phased out in favor of the
sparsification package. Once all references to mask utils in the optim
package are migrated, this file will be deleted
"""
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module, Parameter
from sparseml.pytorch.optim.mask_creator_pruning import (
PruningMaskCreator,
UnstructuredPruningMaskCreator,
)
from sparseml.pytorch.optim.mask_pruning_scorer import create_pruning_param_scorer
from sparseml.pytorch.utils import mask_difference
__all__ = [
"ModuleParamPruningMask",
]
[docs]class ModuleParamPruningMask(object):
"""
Mask to apply kernel sparsity (model pruning) to a specific parameter in a layer
:param layers: the layers containing the parameters to mask
:param param_names: the names of the parameter to mask in each layer. If only
one name is given, that name will be applied to all layers that this object
masks. default is weight
:param store_init: store the init weights in a separate variable that can be
used and referenced later
:param store_unmasked: store the unmasked weights in a separate variable that
can be used and referenced later
:param track_grad_mom: store the gradient updates to the parameter with a
momentum variable must be in the range [0.0, 1.0), if set to 0.0 then will
only keep most recent
:param mask_creator: object to define sparisty mask creation,
default is unstructured mask
:param layer_names: the name of the layers the parameters to mask are located in
:param global_sparsity: set True to enable global pruning. if True, when creating
sparsity masks for a target sparsity sparsity masks will be created such that
the average sparsity across all given layers is the target sparsity with the
lowest global values masked. If False, each layer will be masked to the target
sparsity ranking values within each individual tensor. Default is False
:param score_type: the method used to score parameters for masking, i.e.
'magnitude', 'movement'. Default is 'magnitude'
"""
def __init__(
self,
layers: List[Module],
param_names: Union[str, List[str]] = "weight",
store_init: bool = False,
store_unmasked: bool = False,
track_grad_mom: float = -1.0,
mask_creator: PruningMaskCreator = UnstructuredPruningMaskCreator(),
layer_names: Optional[List[str]] = None,
global_sparsity: bool = False,
score_type: str = "magnitude",
):
self._layers = layers
self._param_names = (
param_names
if isinstance(param_names, List)
else [param_names] * len(self._layers)
)
self._layer_names = layer_names
self._store_init = store_init
self._store_unmasked = store_unmasked
self._track_grad_mom = track_grad_mom
self._mask_creator = mask_creator
self._global_sparsity = global_sparsity
self._enabled = False
self._forward_hooks = [None] * len(self._layers)
self._undo_mask_hooks = [None] * len(self._layers)
self._gradient_hooks = [None] * len(self._layers)
self._params = [] # type: List[Parameter]
self._param_orig_sizes = [] # type: List[int]
for layer, param_name in zip(self._layers, self._param_names):
try:
layer_param = layer.__getattr__(param_name)
self._params.append(layer_param)
self._param_orig_sizes.append(layer_param.numel())
except Exception as err:
raise RuntimeError(
f"Error occurred while trying to get param {param_name} "
f"in layer {layer}: {err}"
)
# initialize masks to all ones
self._param_masks = [torch.ones(param.shape) for param in self._params]
self._params_init = [None] * len(self._layers) # type: List[Tensor]
self._params_unmasked = [None] * len(self._layers) # type: List[Tensor]
self._params_grad = [None] * len(self._layers) # type: List[Tensor]
self._params_movement = [None] * len(self._layers) # type: List[Tensor]
self._params_applied_thinning = [0.0] * len(self._layers) # type: List[float]
# create scorer
self._scorer = create_pruning_param_scorer(self._params, score_type)
# movement pruning requires weight reintroduction
self._allow_reintroduction = self._scorer.get_name() == "movement"
self._store_unmasked |= self._allow_reintroduction # to support reintroduction
self._setup_params_init()
self._setup_params_unmasked()
self._setup_params_grad()
def __len__(self):
return len(self._layers)
def __del__(self):
self._delete_hooks()
@property
def layers(self) -> List[Module]:
"""
:return: the layers containing the parameters to mask
"""
return self._layers
@property
def param_names(self) -> List[str]:
"""
:return: the names of the parameters to mask in the layers
"""
return self._param_names
@property
def layer_names(self) -> Optional[List[str]]:
"""
:return: the names of the layers the parameter to mask is located in
"""
return self._layer_names
@property
def names(self) -> List[str]:
"""
:return: the full names of the sparsity masks in the following format:
<LAYER>.<PARAM>.sparsity_mask
"""
return [
f"{layer_name}.{param_name}.sparsity_mask"
for layer_name, param_name in zip(self._layer_names, self._param_names)
]
@property
def store_init(self) -> bool:
"""
:return: store the init weights in a separate variable that can be used and
referenced later
"""
return self._store_init
@property
def store_unmasked(self) -> bool:
"""
:return: store the unmasked weights in a separate variable that can be used and
referenced later
"""
return self._store_unmasked
@property
def track_grad_mom(self) -> float:
"""
:return: store the gradient updates to the parameter with a momentum variable
must be in the range [0.0, 1.0), if set to 0.0 then will only
keep most recent
"""
return self._track_grad_mom
@property
def mask_creator(self) -> PruningMaskCreator:
"""
:return: SparsityMaskCreator object used to generate masks
"""
return self._mask_creator
@property
def global_sparsity(self) -> bool:
"""
:return: True if global pruning is enabled, False otherwise
"""
return self._global_sparsity
@property
def enabled(self) -> bool:
"""
:return: True if the parameter is currently being masked, False otherwise
"""
return self._enabled
@property
def allow_reintroduction(self) -> bool:
"""
:return: True if weight reintroduction is allowed
"""
return self._allow_reintroduction
@enabled.setter
def enabled(self, value: bool):
"""
:param value: True to mask the parameter, False otherwise
"""
if value and not self._enabled:
self._create_hooks()
self._params_grad = [None] * len(self._params)
self._setup_params_grad()
elif not value and self._enabled:
self._delete_hooks()
self._enabled = value
@property
def params_data(self) -> List[Tensor]:
"""
:return: the current tensors in each of the parameters
"""
return [param.data for param in self._params]
@property
def param_masks(self) -> List[Tensor]:
"""
:return: the current masks applied to each of the parameters
"""
return self._param_masks
@property
def params_init(self) -> List[Optional[Tensor]]:
"""
:return: the initial values of the parameters before being masked
"""
return self._params_init
@property
def params_unmasked(self) -> List[Optional[Tensor]]:
"""
:return: the unmasked values of the parameters
(stores the last unmasked value before masking)
"""
params_unmasked = []
for idx in range(len(self._params)):
if self._params_unmasked[idx] is None:
params_unmasked.append(None)
else:
params_unmasked.append(
self._params[idx].data
+ (self._param_masks == 0.0).type(self._params[idx].data.type())
* self._params_unmasked[idx]
)
return params_unmasked
@property
def params_grad(self) -> List[Optional[Tensor]]:
"""
:return: the current gradient values for each parameter
"""
return self._params_grad
@property
def score_type(self) -> str:
"""
:return: the scoring method used to create masks (i.e. magnitude, movement)
"""
return self._scorer.get_name()
[docs] def set_param_data(self, value: Tensor, param_idx: int):
"""
:param value: the value to set as the current tensor for the parameter,
if enabled the mask will be applied
:param param_idx: index of the parameter in this object to set the data of
"""
if value is None:
raise ValueError("param data cannot be set to None")
if value.shape != self._params[param_idx].data.shape:
raise ValueError(
f"param_tensor shape of {value.shape} does not match parameter "
f"shape of {self._param[param_idx].shape}"
)
value = self._check_regen_value(value, param_idx)
self._check_regen_param_vals(param_idx)
self._params[param_idx].data.copy_(value)
self._params_unmasked[param_idx] = None
self._setup_params_unmasked(param_idx)
if not self._allow_reintroduction:
self.apply(param_idx)
[docs] def set_param_masks(self, masks: List[Tensor]):
"""
:param masks: the masks to set and apply as the current param tensors,
if enabled mask is applied immediately
"""
mask_diffs = []
for idx, value in enumerate(masks):
if value is None:
raise ValueError("mask cannot be set to None")
if value.shape != self._params[idx].shape:
raise ValueError(
"mask shape of {} does not match layer.param shape of {}".format(
value.shape, self._params[idx].shape
)
)
value = self._check_regen_value(value, idx)
self._check_regen_param_vals(idx)
mask_diff = mask_difference(self._param_masks[idx], value)
self._param_masks[idx] = value
mask_diffs.append(mask_diff)
self._scorer.mask_update(masks, mask_diffs)
if not self._allow_reintroduction:
self.apply()
return mask_diffs
[docs] def set_param_masks_from_weights(self) -> List[Tensor]:
"""
Convenience function to set the parameter masks such that the
mask is 1 if a parameter value is non zero and 0 otherwise,
unless otherwise defined by this object's mask_creator
"""
masks = self._mask_creator.create_sparsity_masks_from_tensor(
[param.data for param in self._params]
)
return self.set_param_masks(masks)
[docs] def set_param_masks_from_abs_threshold(
self, threshold: Union[float, Tensor]
) -> List[Tensor]:
"""
Convenience function to set the parameter masks such that if
abs(value) <= threshold the it a value is masked to 0
:param threshold: the threshold at which all values will be masked to 0
"""
param_scores = self._scorer.score_parameters()
masks = self._mask_creator.create_sparsity_masks_from_threshold(
param_scores, threshold
)
return self.set_param_masks(masks)
[docs] def set_param_masks_from_sparsity(
self, sparsity: Union[float, List[float]]
) -> List[Tensor]:
"""
Convenience function to set the parameter masks such that each masks have an
amount of masked values such that the percentage equals the sparsity amount
given. Masks the absolute smallest values up until sparsity is reached.
:param sparsity: the decimal sparsity to set the param mask to can also be a
list where each element is a sparsity for a tensor in the same position in
the tensor list. If global sparsity is enabled, all values of the sparsity
list must be the same
"""
param_scores = self._scorer.score_parameters()
if isinstance(sparsity, float):
sparsity = [sparsity] * len(self._params)
for idx, sparsity_val in enumerate(sparsity):
applied_thinning = self._params_applied_thinning[idx]
if applied_thinning > 0.0:
# adjust sparsity for thinned (compressed) layer param
# derived from:
# remaining_num_els * (1 - adjusted_sparsity) = \
# orig_num_els * (1 - sparsity)
# with applied_thinning = 1 - (remaining_num_els / orig_num_els)
sparsity[idx] = (sparsity_val - applied_thinning) / (
1 - applied_thinning
)
masks = self._mask_creator.create_sparsity_masks(
param_scores, sparsity, global_sparsity=self._global_sparsity
)
self._scorer.update_last_applied_sparsity(sparsity)
return self.set_param_masks(masks)
[docs] def apply(self, param_idx: Optional[int] = None):
"""
apply the current mask to the params tensor (zero out the desired values)
:param param_idx: index of parameter to apply mask to. if not set, then masks
will be applied to all parameters with available masks
"""
if not self._enabled:
return
indices = range(len(self._params)) if param_idx is None else [param_idx]
for idx in indices:
self._check_regen_param_vals(idx)
with torch.no_grad():
if self._store_unmasked:
self._params_unmasked[idx] = self._params[idx].data.mul(
1 - self._param_masks[idx] # inverted mask
)
self._params[idx].data.mul_(self._param_masks[idx])
[docs] def reset(self):
"""
resets the current stored tensors such that they will be on the same device
and have the initial data
"""
self._check_regen_param_vals()
for idx, param in enumerate(self._params):
param.data.copy_(self._params_init[idx])
[docs] def pre_optim_step_update(self):
"""
updates scores and buffers that depend on gradients. Should be called
before Optimizer.step() to grab the latest gradients
"""
self._scorer.pre_optim_step_update(self._param_masks)
[docs] def pruning_end(self, leave_enabled: bool):
"""
Performs any cleanup necessary for this pruning method.
Disables weight reintroduction if enabled and applies masks
:param leave_enabled: if False, all pruning hooks will be destroyed. Default
is True
"""
if not leave_enabled:
self.enabled = False
self._allow_reintroduction = False
self.apply() # ensure that weights are pruned to final level
self._scorer.on_pruning_end()
[docs] def disable_reintroduction(self):
"""
if weight reintroduction is enabled (only during movement pruning),
disables further weight reintroduction
"""
self._allow_reintroduction = False
def _check_regen_value(self, val: Tensor, param_idx: int) -> Tensor:
if self._params[param_idx].data.device != val.device:
val = ModuleParamPruningMask._detach_tens(
torch.empty_like(self._params[param_idx].data).copy_(val)
)
return val
def _check_regen_param_vals(self, param_idx: int = None):
indices = range(len(self._params)) if param_idx is None else [param_idx]
for idx in indices:
if self._params[idx].data.numel() < self._param_masks[idx].numel():
self._params_applied_thinning[idx] = 1 - (
self._params[idx].data.numel() / self._param_orig_sizes[idx]
)
self._param_masks[idx] = ModuleParamPruningMask._detach_tens(
torch.ones_like(self._params[idx].data)
)
if self._params[idx].data.device != self._param_masks[idx].device:
self._param_masks[idx] = ModuleParamPruningMask._detach_tens(
torch.empty_like(self._params[idx].data).copy_(
self._param_masks[idx]
)
)
if (
self._params_init[idx] is not None
and self._params[idx].data.device != self._params_init[idx].device
):
self._params_init[idx] = ModuleParamPruningMask._detach_tens(
torch.empty_like(self._params[idx].data).copy_(
self._params_init[idx]
)
)
if (
self._params_unmasked[idx] is not None
and self._params[idx].data.device != self._params_unmasked[idx].device
):
self._params_unmasked[idx] = ModuleParamPruningMask._detach_tens(
torch.empty_like(self._params[idx].data).copy_(
self._params_unmasked[idx]
)
)
if (
self._params_grad[idx] is not None
and self._params[idx].data.device != self._params_grad[idx].device
):
self._params_grad[idx] = ModuleParamPruningMask._detach_tens(
torch.empty_like(self._params[idx].data).copy_(
self._params_grad[idx]
)
)
self._scorer.check_regen_param_vals()
def _create_hooks(self):
for idx, (param, layer) in enumerate(zip(self._params, self._layers)):
if self._forward_hooks[idx] is None:
self._forward_hooks[idx] = layer.register_forward_pre_hook(
partial(self._hook_mask_forward, idx)
)
if self._allow_reintroduction and self._undo_mask_hooks[idx] is None:
self._undo_mask_hooks[idx] = layer.register_forward_hook(
partial(self._hook_undo_mask, idx)
)
if self._gradient_hooks[idx] is None:
self._gradient_hooks[idx] = param.register_hook(
partial(self._hook_mask_gradient, idx)
)
def _delete_hooks(self):
if not hasattr(self, "_params"):
return
for idx in range(len(self._params)):
if self._forward_hooks[idx] is not None:
self._forward_hooks[idx].remove()
self._forward_hooks[idx] = None
if self._undo_mask_hooks[idx] is not None:
self._undo_mask_hooks[idx].remove()
self._undo_mask_hooks[idx] = None
if self._gradient_hooks[idx] is not None:
self._gradient_hooks[idx].remove()
self._gradient_hooks[idx] = None
def _hook_mask_forward(
self, param_idx: int, mod: Module, inp: Union[Tensor, Tuple[Tensor]]
):
self.apply(param_idx)
def _hook_undo_mask(self, param_idx, module, inp, out):
if self._allow_reintroduction:
with torch.no_grad():
self._params[param_idx].data.add_(self._params_unmasked[param_idx])
def _hook_mask_gradient(self, param_idx, grad):
if 0.0 <= self._track_grad_mom < 1.0:
self._params_grad[param_idx].mul_(self._track_grad_mom).add_(
(1.0 - self._track_grad_mom) * grad
)
return (
grad.mul_(self._param_masks[param_idx])
if not self._allow_reintroduction
else grad # do not mask gradient for movement pruning
)
def _setup_params_init(self):
for idx, param in enumerate(self._params):
if self._store_init and self._params_init[idx] is None:
self._params_init[idx] = ModuleParamPruningMask._detach_tens(
param.data.clone()
)
elif not self._store_init and self._params_init[idx] is not None:
self._params_init[idx] = None
def _setup_params_unmasked(self, param_idx: int = None):
indices = range(len(self._params)) if param_idx is None else [param_idx]
for idx in indices:
if self._store_unmasked and self._params_unmasked[idx] is None:
self._params_unmasked[idx] = ModuleParamPruningMask._detach_tens(
self._params[idx].data.clone()
)
elif not self._store_unmasked and self._params_unmasked[idx] is not None:
self._params_unmasked[idx] = None
def _setup_params_grad(self):
for idx, param in enumerate(self._params):
if self._track_grad_mom >= 0.0 and self._params_grad[idx] is None:
self._params_grad[idx] = ModuleParamPruningMask._detach_tens(
param.data.new_zeros(param.data.shape)
)
elif self._track_grad_mom < 0.0 and self._params_grad[idx] is not None:
self._params_grad[idx] = None
@staticmethod
def _detach_tens(tens) -> Tensor:
return tens.detach().requires_grad_(False)