Source code for sparseml.tensorflow_v1.optim.modifier_pruning

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Modifiers for inducing / enforcing kernel sparsity (model pruning)
on models while pruning.
"""

import hashlib
from typing import Any, Dict, List, Tuple, Union

from sparseml.sparsification import (
    ConstantPruningModifier as BaseConstantPruningModifier,
)
from sparseml.sparsification import GMPruningModifier as BaseGMPruningModifier
from sparseml.tensorflow_v1.optim.mask_creator_pruning import (
    PruningMaskCreator,
    load_mask_creator,
)
from sparseml.tensorflow_v1.optim.mask_pruning import (
    PruningOpVars,
    apply_op_vars_masks,
    create_ks_scheduled_constant_graph_ops,
    create_summaries_pruning,
    get_or_create_ks_scheduled_graph_ops,
)
from sparseml.tensorflow_v1.optim.modifier import (
    EXTRAS_KEY_SUMMARIES,
    ModifierProp,
    ScheduledModifier,
    ScheduledUpdateModifier,
    TensorFlowModifierYAML,
)
from sparseml.tensorflow_v1.utils import (
    clean_tensor_name,
    get_ops_and_inputs_by_name_or_regex,
    tf_compat,
)
from sparseml.utils import ALL_TOKEN


__all__ = ["ConstantPruningModifier", "GMPruningModifier"]


[docs]@TensorFlowModifierYAML() class ConstantPruningModifier(BaseConstantPruningModifier, ScheduledModifier): """ Holds the sparsity level and shape for a given param constant while training. Useful for transfer learning use cases. | Sample yaml: | !ConstantPruningModifier | params: __ALL__ | start_epoch: 0.0 | end_epoch: 10.0 | log_types: __ALL__ :param params: List of str names or regex patterns of names for the parameter variables to apply the pruning modifier to. Regex patterns must be specified with the prefix 're:'. Can also use the token __ALL__ to specify all prunable layers and weights :param start_epoch: The epoch to start the modifier at :param end_epoch: The epoch to end the modifier at :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ """ def __init__( self, params: Union[str, List[str]], start_epoch: float = -1, end_epoch: float = -1, log_types: Union[str, List[str]] = ALL_TOKEN, ): super(ConstantPruningModifier, self).__init__( params=params, log_types=log_types, start_epoch=start_epoch, end_epoch=end_epoch, end_comparator=None, ) self._prune_op_vars = None self._update_ready = None self._sparsity = None @ModifierProp(serializable=False) def ks_group(self) -> str: """ :return: a hashed representation of the settings that identify this instance """ props = self.props(only_serializable=True, format_str=False) props = ["{}={}".format(key, val) for key, val in props.items()] props.sort() props = "&".join(props) return "{}".format(hashlib.md5(bytes(props, encoding="utf8")).hexdigest()) @property def prune_op_vars(self) -> Union[None, List[PruningOpVars]]: """ :return: the created pruning op vars in the graph if create_ops has been called, else None """ return self._prune_op_vars @property def update_ready(self): """ :return: the created update_ready tensor for setting the pruning ops if create_ops has been called, else None """ return self._update_ready @property def sparsity(self) -> Union[None, tf_compat.Tensor]: """ :return: the created sparsity tensor for setting the pruning ops if create_ops has been called, else None """ return self._sparsity
[docs] def create_ops( self, steps_per_epoch: int, global_step: tf_compat.Tensor, graph: tf_compat.Graph, ) -> Tuple[List[Union[tf_compat.Tensor, tf_compat.Operation]], Dict[str, Any]]: """ Create the sparsity ops to modify the training graph according to the settings for the current instance. :param steps_per_epoch: the number of steps (batches) per training epoch :param global_step: the global step used while training :param graph: the graph to be modified :return: a tuple (list of ops, dict of named ops / tensors) to be run or used for modifying the training process. """ mod_ops, mod_extras = super().create_ops(graph, None, None) start_step, end_step = self.start_end_steps(steps_per_epoch, after_optim=True) params = ( self.params if self.params != ALL_TOKEN else [ clean_tensor_name(var.name) for _, var in # Have ALL_TOKEN match to all variable names for now get_ops_and_inputs_by_name_or_regex(["re:.*"], graph) ] ) with graph.as_default(): update_op, prune_op_vars = create_ks_scheduled_constant_graph_ops( graph, global_step, params, start_step, end_step, self.ks_group, ) if self.log_types == ALL_TOKEN or "tensorboard" in self.log_types: mod_extras[EXTRAS_KEY_SUMMARIES] = create_summaries_pruning( prune_op_vars ) mod_ops.append(update_op) self._prune_op_vars = prune_op_vars # self._update_ready = tf_compat.constant(False, name="nm_update_ready") return mod_ops, mod_extras
[docs] def initialize_session(self, sess: tf_compat.Session): """ Initialize the mask variables for pruning. :param sess: the session to use for initializing """ super().initialize_session(sess) masks = [op_vars.mask for op_vars in self._prune_op_vars] if masks: sess.run(tf_compat.variables_initializer(masks))
[docs] def complete_graph(self, graph: tf_compat.Graph, sess: tf_compat.Session): """ Complete modifying the graph. Resets the pruned op's variables using the created masks to zero out the pruned weights for saving. :param graph: the modified graph that should be completed and cleaned. if not supplied, then will use the default graph :param sess: the session to use for completing the modified graph. if not supplied, then will use the default session :return: the cleaned graph """ super().complete_graph(graph, sess) with graph.as_default(): apply_op_vars_masks(self.prune_op_vars, self.ks_group, sess)
[docs]@TensorFlowModifierYAML() class GMPruningModifier(BaseGMPruningModifier, ScheduledUpdateModifier): """ Gradually applies kernel sparsity to a given variable or variables from init_sparsity until final_sparsity is reached over a given amount of time and applied with an interpolated function for each step taken. Applies based on magnitude pruning without any structure to the pruning. | Sample yaml: | !GMPruningModifier | params: __ALL__ | init_sparsity: 0.05 | final_sparsity: 0.8 | start_epoch: 0.0 | end_epoch: 10.0 | update_frequency: 1.0 | inter_func: cubic | log_types: __ALL__ | mask_type: unstructured | leave_enabled: True :param params: List of str names or name regex patterns for the variables in the graph to apply the pruning modifier to. Regex patterns must be specified with the prefix 're:'. __ALL__ will match to all parameters. :param init_sparsity: The initial sparsity for the variable to start with at start_epoch :param final_sparsity: The final sparsity for the variable to end with at end_epoch :param start_epoch: The epoch to start the modifier at :param end_epoch: The epoch to end the modifier at :param update_frequency: The number of epochs or fraction of epochs to update at between start and end :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking. Should be set to False if exporting the result immediately after or doing some other prune :param inter_func: The type of interpolation function to use: [linear, cubic, inverse_cubic] :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', 'channel', 'filter']), List to define block shape of a parameter's in and out channels, or a SparsityMaskCreator object. default is 'unstructured' :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking. Should be set to False if exporting the result immediately after or doing some other prune """ def __init__( self, params: Union[str, List[str]], init_sparsity: float, final_sparsity: float, start_epoch: float, end_epoch: float, update_frequency: float, inter_func: str = "cubic", log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", leave_enabled: bool = True, ): super(GMPruningModifier, self).__init__( params=params, init_sparsity=init_sparsity, final_sparsity=final_sparsity, start_epoch=start_epoch, end_epoch=end_epoch, update_frequency=update_frequency, inter_func=inter_func, log_types=log_types, mask_type=mask_type, leave_enabled=leave_enabled, min_start=-1.0, min_end=0.0, end_comparator=1, min_frequency=-1.0, ) self._mask_creator = mask_type if not isinstance(mask_type, PruningMaskCreator): self._mask_creator = load_mask_creator(mask_type) self._prune_op_vars = None self._update_ready = None self._sparsity = None self._mask_initializer = None @ModifierProp(serializable=False) def ks_group(self) -> str: """ :return: a hashed representation of the settings that identify this instance """ props = self.props(only_serializable=True, format_str=False) props = ["{}={}".format(key, val) for key, val in props.items()] props.sort() props = "&".join(props) return "{}".format(hashlib.md5(bytes(props, encoding="utf8")).hexdigest()) @ModifierProp(serializable=False) def exponent(self) -> float: """ :return: the exponent to be used in for the sparsity schedule """ if self._inter_func == "linear": return 1.0 if self._inter_func == "cubic": return 3.0 if self._inter_func == "inverse_cubic": return 1 / 3.0 raise ValueError( "unrecognized value given for inter_func of {}".format(self._inter_func) ) @property def prune_op_vars(self) -> Union[None, List[PruningOpVars]]: """ :return: the created pruning op vars in the graph if create_ops has been called, else None """ return self._prune_op_vars @property def update_ready(self): """ :return: the created update_ready tensor for setting the pruning ops if create_ops has been called, else None """ return self._update_ready @property def sparsity(self) -> Union[None, tf_compat.Tensor]: """ :return: the created sparsity tensor for setting the pruning ops if create_ops has been called, else None """ return self._sparsity
[docs] def create_ops( self, steps_per_epoch: int, global_step: tf_compat.Tensor, graph: tf_compat.Graph, ) -> Tuple[List[Union[tf_compat.Tensor, tf_compat.Operation]], Dict[str, Any]]: """ Create the sparsity ops to modify the training graph according to the settings for the current instance. :param steps_per_epoch: the number of steps (batches) per training epoch :param global_step: the global step used while training :param graph: the graph to be modified :return: a tuple (list of ops, dict of named ops / tensors) to be run or used for modifying the training process. """ mod_ops, mod_extras = super().create_ops(graph, steps_per_epoch, global_step) start_step, end_step = self.start_end_steps(steps_per_epoch, after_optim=True) update_frequency_step = self.update_frequency_steps(steps_per_epoch) params = ( self._params if self._params != ALL_TOKEN else [ clean_tensor_name(var.name) for _, var in # Have ALL_TOKEN match to all variable names for now get_ops_and_inputs_by_name_or_regex(["re:.*"], graph) ] ) with graph.as_default(): ( update_op, prune_op_vars, update_ready, sparsity, ) = get_or_create_ks_scheduled_graph_ops( graph, global_step, params, start_step, end_step, update_frequency_step, self._init_sparsity, self._final_sparsity, self.exponent, self._leave_enabled, self.ks_group, self._mask_creator, ) if self.log_types == ALL_TOKEN or "tensorboard" in self.log_types: mod_extras[EXTRAS_KEY_SUMMARIES] = create_summaries_pruning( prune_op_vars ) mod_ops.append(update_op) self._prune_op_vars = prune_op_vars self._update_ready = update_ready self._sparsity = sparsity # Create and cache the mask initializers to be run # through initialize_session. When using the estimator, # the initialization is done as part of the init_fn of # the training scaffold object, at which the graph cannot # be changed (hence the creation and caching) masks = [op_vars.mask for op_vars in self._prune_op_vars] self._mask_initializer = ( tf_compat.variables_initializer(masks) if masks else None ) return mod_ops, mod_extras
[docs] def initialize_session(self, sess: tf_compat.Session): """ Initialize the mask variables for pruning. :param sess: the session to use for initializing """ super().initialize_session(sess) if self._mask_initializer: sess.run(self._mask_initializer)
[docs] def complete_graph(self, graph: tf_compat.Graph, sess: tf_compat.Session): """ Complete modifying the graph. Resets the pruned op's variables using the created masks to zero out the pruned weights for saving. :param graph: the modified graph that should be completed and cleaned. if not supplied, then will use the default graph :param sess: the session to use for completing the modified graph. if not supplied, then will use the default session :return: the cleaned graph """ super().complete_graph(graph, sess) with graph.as_default(): apply_op_vars_masks(self.prune_op_vars, self.ks_group, sess)