Source code for sparseml.pytorch.optim.mask_pruning

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

Code related to applying a mask onto a parameter to impose kernel sparsity,
aka model pruning

NOTE: this file is in the process of being phased out in favor of the
sparsification package. Once all references to mask utils in the optim
package are migrated, this file will be deleted

from functools import partial
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from sparseml.pytorch.optim.mask_creator_pruning import (
from sparseml.pytorch.optim.mask_pruning_scorer import create_pruning_param_scorer
from sparseml.pytorch.utils import mask_difference

__all__ = [

[docs]class ModuleParamPruningMask(object): """ Mask to apply kernel sparsity (model pruning) to a specific parameter in a layer :param layers: the layers containing the parameters to mask :param param_names: the names of the parameter to mask in each layer. If only one name is given, that name will be applied to all layers that this object masks. default is weight :param store_init: store the init weights in a separate variable that can be used and referenced later :param store_unmasked: store the unmasked weights in a separate variable that can be used and referenced later :param track_grad_mom: store the gradient updates to the parameter with a momentum variable must be in the range [0.0, 1.0), if set to 0.0 then will only keep most recent :param mask_creator: object to define sparisty mask creation, default is unstructured mask :param layer_names: the name of the layers the parameters to mask are located in :param global_sparsity: set True to enable global pruning. if True, when creating sparsity masks for a target sparsity sparsity masks will be created such that the average sparsity across all given layers is the target sparsity with the lowest global values masked. If False, each layer will be masked to the target sparsity ranking values within each individual tensor. Default is False :param score_type: the method used to score parameters for masking, i.e. 'magnitude', 'movement'. Default is 'magnitude' """ def __init__( self, layers: List[Module], param_names: Union[str, List[str]] = "weight", store_init: bool = False, store_unmasked: bool = False, track_grad_mom: float = -1.0, mask_creator: PruningMaskCreator = UnstructuredPruningMaskCreator(), layer_names: Optional[List[str]] = None, global_sparsity: bool = False, score_type: str = "magnitude", ): self._layers = layers self._param_names = ( param_names if isinstance(param_names, List) else [param_names] * len(self._layers) ) self._layer_names = layer_names self._store_init = store_init self._store_unmasked = store_unmasked self._track_grad_mom = track_grad_mom self._mask_creator = mask_creator self._global_sparsity = global_sparsity self._enabled = False self._forward_hooks = [None] * len(self._layers) self._undo_mask_hooks = [None] * len(self._layers) self._gradient_hooks = [None] * len(self._layers) self._params = [] # type: List[Parameter] self._param_orig_sizes = [] # type: List[int] for layer, param_name in zip(self._layers, self._param_names): try: layer_param = layer.__getattr__(param_name) self._params.append(layer_param) self._param_orig_sizes.append(layer_param.numel()) except Exception as err: raise RuntimeError( f"Error occurred while trying to get param {param_name} " f"in layer {layer}: {err}" ) # initialize masks to all ones self._param_masks = [torch.ones(param.shape) for param in self._params] self._params_init = [None] * len(self._layers) # type: List[Tensor] self._params_unmasked = [None] * len(self._layers) # type: List[Tensor] self._params_grad = [None] * len(self._layers) # type: List[Tensor] self._params_movement = [None] * len(self._layers) # type: List[Tensor] self._params_applied_thinning = [0.0] * len(self._layers) # type: List[float] # create scorer self._scorer = create_pruning_param_scorer(self._params, score_type) # movement pruning requires weight reintroduction self._allow_reintroduction = self._scorer.get_name() == "movement" self._store_unmasked |= self._allow_reintroduction # to support reintroduction self._setup_params_init() self._setup_params_unmasked() self._setup_params_grad() def __len__(self): return len(self._layers) def __del__(self): self._delete_hooks() @property def layers(self) -> List[Module]: """ :return: the layers containing the parameters to mask """ return self._layers @property def param_names(self) -> List[str]: """ :return: the names of the parameters to mask in the layers """ return self._param_names @property def layer_names(self) -> Optional[List[str]]: """ :return: the names of the layers the parameter to mask is located in """ return self._layer_names @property def names(self) -> List[str]: """ :return: the full names of the sparsity masks in the following format: <LAYER>.<PARAM>.sparsity_mask """ return [ f"{layer_name}.{param_name}.sparsity_mask" for layer_name, param_name in zip(self._layer_names, self._param_names) ] @property def store_init(self) -> bool: """ :return: store the init weights in a separate variable that can be used and referenced later """ return self._store_init @property def store_unmasked(self) -> bool: """ :return: store the unmasked weights in a separate variable that can be used and referenced later """ return self._store_unmasked @property def track_grad_mom(self) -> float: """ :return: store the gradient updates to the parameter with a momentum variable must be in the range [0.0, 1.0), if set to 0.0 then will only keep most recent """ return self._track_grad_mom @property def mask_creator(self) -> PruningMaskCreator: """ :return: SparsityMaskCreator object used to generate masks """ return self._mask_creator @property def global_sparsity(self) -> bool: """ :return: True if global pruning is enabled, False otherwise """ return self._global_sparsity @property def enabled(self) -> bool: """ :return: True if the parameter is currently being masked, False otherwise """ return self._enabled @property def allow_reintroduction(self) -> bool: """ :return: True if weight reintroduction is allowed """ return self._allow_reintroduction @enabled.setter def enabled(self, value: bool): """ :param value: True to mask the parameter, False otherwise """ if value and not self._enabled: self._create_hooks() self._params_grad = [None] * len(self._params) self._setup_params_grad() elif not value and self._enabled: self._delete_hooks() self._enabled = value @property def params_data(self) -> List[Tensor]: """ :return: the current tensors in each of the parameters """ return [ for param in self._params] @property def param_masks(self) -> List[Tensor]: """ :return: the current masks applied to each of the parameters """ return self._param_masks @property def params_init(self) -> List[Optional[Tensor]]: """ :return: the initial values of the parameters before being masked """ return self._params_init @property def params_unmasked(self) -> List[Optional[Tensor]]: """ :return: the unmasked values of the parameters (stores the last unmasked value before masking) """ params_unmasked = [] for idx in range(len(self._params)): if self._params_unmasked[idx] is None: params_unmasked.append(None) else: params_unmasked.append( self._params[idx].data + (self._param_masks == 0.0).type(self._params[idx].data.type()) * self._params_unmasked[idx] ) return params_unmasked @property def params_grad(self) -> List[Optional[Tensor]]: """ :return: the current gradient values for each parameter """ return self._params_grad @property def score_type(self) -> str: """ :return: the scoring method used to create masks (i.e. magnitude, movement) """ return self._scorer.get_name()
[docs] def set_param_data(self, value: Tensor, param_idx: int): """ :param value: the value to set as the current tensor for the parameter, if enabled the mask will be applied :param param_idx: index of the parameter in this object to set the data of """ if value is None: raise ValueError("param data cannot be set to None") if value.shape != self._params[param_idx].data.shape: raise ValueError( f"param_tensor shape of {value.shape} does not match parameter " f"shape of {self._param[param_idx].shape}" ) value = self._check_regen_value(value, param_idx) self._check_regen_param_vals(param_idx) self._params[param_idx].data.copy_(value) self._params_unmasked[param_idx] = None self._setup_params_unmasked(param_idx) if not self._allow_reintroduction: self.apply(param_idx)
[docs] def set_param_masks(self, masks: List[Tensor]): """ :param masks: the masks to set and apply as the current param tensors, if enabled mask is applied immediately """ mask_diffs = [] for idx, value in enumerate(masks): if value is None: raise ValueError("mask cannot be set to None") if value.shape != self._params[idx].shape: raise ValueError( "mask shape of {} does not match layer.param shape of {}".format( value.shape, self._params[idx].shape ) ) value = self._check_regen_value(value, idx) self._check_regen_param_vals(idx) mask_diff = mask_difference(self._param_masks[idx], value) self._param_masks[idx] = value mask_diffs.append(mask_diff) self._scorer.mask_update(masks, mask_diffs) if not self._allow_reintroduction: self.apply() return mask_diffs
[docs] def set_param_masks_from_weights(self) -> List[Tensor]: """ Convenience function to set the parameter masks such that the mask is 1 if a parameter value is non zero and 0 otherwise, unless otherwise defined by this object's mask_creator """ masks = self._mask_creator.create_sparsity_masks_from_tensor( [ for param in self._params] ) return self.set_param_masks(masks)
[docs] def set_param_masks_from_abs_threshold( self, threshold: Union[float, Tensor] ) -> List[Tensor]: """ Convenience function to set the parameter masks such that if abs(value) <= threshold the it a value is masked to 0 :param threshold: the threshold at which all values will be masked to 0 """ param_scores = self._scorer.score_parameters() masks = self._mask_creator.create_sparsity_masks_from_threshold( param_scores, threshold ) return self.set_param_masks(masks)
[docs] def set_param_masks_from_sparsity( self, sparsity: Union[float, List[float]] ) -> List[Tensor]: """ Convenience function to set the parameter masks such that each masks have an amount of masked values such that the percentage equals the sparsity amount given. Masks the absolute smallest values up until sparsity is reached. :param sparsity: the decimal sparsity to set the param mask to can also be a list where each element is a sparsity for a tensor in the same position in the tensor list. If global sparsity is enabled, all values of the sparsity list must be the same """ param_scores = self._scorer.score_parameters() if isinstance(sparsity, float): sparsity = [sparsity] * len(self._params) for idx, sparsity_val in enumerate(sparsity): applied_thinning = self._params_applied_thinning[idx] if applied_thinning > 0.0: # adjust sparsity for thinned (compressed) layer param # derived from: # remaining_num_els * (1 - adjusted_sparsity) = \ # orig_num_els * (1 - sparsity) # with applied_thinning = 1 - (remaining_num_els / orig_num_els) sparsity[idx] = (sparsity_val - applied_thinning) / ( 1 - applied_thinning ) masks = self._mask_creator.create_sparsity_masks( param_scores, sparsity, global_sparsity=self._global_sparsity ) self._scorer.update_last_applied_sparsity(sparsity) return self.set_param_masks(masks)
[docs] def apply(self, param_idx: Optional[int] = None): """ apply the current mask to the params tensor (zero out the desired values) :param param_idx: index of parameter to apply mask to. if not set, then masks will be applied to all parameters with available masks """ if not self._enabled: return indices = range(len(self._params)) if param_idx is None else [param_idx] for idx in indices: self._check_regen_param_vals(idx) with torch.no_grad(): if self._store_unmasked: self._params_unmasked[idx] = self._params[idx].data.mul( 1 - self._param_masks[idx] # inverted mask ) self._params[idx].data.mul_(self._param_masks[idx])
[docs] def reset(self): """ resets the current stored tensors such that they will be on the same device and have the initial data """ self._check_regen_param_vals() for idx, param in enumerate(self._params):[idx])
[docs] def pre_optim_step_update(self): """ updates scores and buffers that depend on gradients. Should be called before Optimizer.step() to grab the latest gradients """ self._scorer.pre_optim_step_update(self._param_masks)
[docs] def pruning_end(self, leave_enabled: bool): """ Performs any cleanup necessary for this pruning method. Disables weight reintroduction if enabled and applies masks :param leave_enabled: if False, all pruning hooks will be destroyed. Default is True """ if not leave_enabled: self.enabled = False self._allow_reintroduction = False self.apply() # ensure that weights are pruned to final level self._scorer.on_pruning_end()
[docs] def disable_reintroduction(self): """ if weight reintroduction is enabled (only during movement pruning), disables further weight reintroduction """ self._allow_reintroduction = False
def _check_regen_value(self, val: Tensor, param_idx: int) -> Tensor: if self._params[param_idx].data.device != val.device: val = ModuleParamPruningMask._detach_tens( torch.empty_like(self._params[param_idx].data).copy_(val) ) return val def _check_regen_param_vals(self, param_idx: int = None): indices = range(len(self._params)) if param_idx is None else [param_idx] for idx in indices: if self._params[idx].data.numel() < self._param_masks[idx].numel(): self._params_applied_thinning[idx] = 1 - ( self._params[idx].data.numel() / self._param_orig_sizes[idx] ) self._param_masks[idx] = ModuleParamPruningMask._detach_tens( torch.ones_like(self._params[idx].data) ) if self._params[idx].data.device != self._param_masks[idx].device: self._param_masks[idx] = ModuleParamPruningMask._detach_tens( torch.empty_like(self._params[idx].data).copy_( self._param_masks[idx] ) ) if ( self._params_init[idx] is not None and self._params[idx].data.device != self._params_init[idx].device ): self._params_init[idx] = ModuleParamPruningMask._detach_tens( torch.empty_like(self._params[idx].data).copy_( self._params_init[idx] ) ) if ( self._params_unmasked[idx] is not None and self._params[idx].data.device != self._params_unmasked[idx].device ): self._params_unmasked[idx] = ModuleParamPruningMask._detach_tens( torch.empty_like(self._params[idx].data).copy_( self._params_unmasked[idx] ) ) if ( self._params_grad[idx] is not None and self._params[idx].data.device != self._params_grad[idx].device ): self._params_grad[idx] = ModuleParamPruningMask._detach_tens( torch.empty_like(self._params[idx].data).copy_( self._params_grad[idx] ) ) self._scorer.check_regen_param_vals() def _create_hooks(self): for idx, (param, layer) in enumerate(zip(self._params, self._layers)): if self._forward_hooks[idx] is None: self._forward_hooks[idx] = layer.register_forward_pre_hook( partial(self._hook_mask_forward, idx) ) if self._allow_reintroduction and self._undo_mask_hooks[idx] is None: self._undo_mask_hooks[idx] = layer.register_forward_hook( partial(self._hook_undo_mask, idx) ) if self._gradient_hooks[idx] is None: self._gradient_hooks[idx] = param.register_hook( partial(self._hook_mask_gradient, idx) ) def _delete_hooks(self): if not hasattr(self, "_params"): return for idx in range(len(self._params)): if self._forward_hooks[idx] is not None: self._forward_hooks[idx].remove() self._forward_hooks[idx] = None if self._undo_mask_hooks[idx] is not None: self._undo_mask_hooks[idx].remove() self._undo_mask_hooks[idx] = None if self._gradient_hooks[idx] is not None: self._gradient_hooks[idx].remove() self._gradient_hooks[idx] = None def _hook_mask_forward( self, param_idx: int, mod: Module, inp: Union[Tensor, Tuple[Tensor]] ): self.apply(param_idx) def _hook_undo_mask(self, param_idx, module, inp, out): if self._allow_reintroduction: with torch.no_grad(): self._params[param_idx].data.add_(self._params_unmasked[param_idx]) def _hook_mask_gradient(self, param_idx, grad): if 0.0 <= self._track_grad_mom < 1.0: self._params_grad[param_idx].mul_(self._track_grad_mom).add_( (1.0 - self._track_grad_mom) * grad ) return ( grad.mul_(self._param_masks[param_idx]) if not self._allow_reintroduction else grad # do not mask gradient for movement pruning ) def _setup_params_init(self): for idx, param in enumerate(self._params): if self._store_init and self._params_init[idx] is None: self._params_init[idx] = ModuleParamPruningMask._detach_tens( ) elif not self._store_init and self._params_init[idx] is not None: self._params_init[idx] = None def _setup_params_unmasked(self, param_idx: int = None): indices = range(len(self._params)) if param_idx is None else [param_idx] for idx in indices: if self._store_unmasked and self._params_unmasked[idx] is None: self._params_unmasked[idx] = ModuleParamPruningMask._detach_tens( self._params[idx].data.clone() ) elif not self._store_unmasked and self._params_unmasked[idx] is not None: self._params_unmasked[idx] = None def _setup_params_grad(self): for idx, param in enumerate(self._params): if self._track_grad_mom >= 0.0 and self._params_grad[idx] is None: self._params_grad[idx] = ModuleParamPruningMask._detach_tens( ) elif self._track_grad_mom < 0.0 and self._params_grad[idx] is not None: self._params_grad[idx] = None @staticmethod def _detach_tens(tens) -> Tensor: return tens.detach().requires_grad_(False)