Source code for sparseml.tensorflow_v1.optim.mask_pruning

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Code related to applying a mask onto a variable to impose kernel sparsity,
aka model pruning, on a TensorFlow graph.
"""

from collections import namedtuple
from typing import List, Tuple


try:
    import tensorflow.contrib.graph_editor as graph_editor

    tf_contrib_err = None
except Exception as err:
    graph_editor = None
    tf_contrib_err = err

from sparseml.tensorflow_v1.optim.mask_creator_pruning import PruningMaskCreator
from sparseml.tensorflow_v1.utils import (
    clean_tensor_name,
    get_ops_and_inputs_by_name_or_regex,
    get_tensor_var,
    is_prunable_op,
    tf_compat,
    tf_compat_div,
)


__all__ = [
    "PruningOpVars",
    "PruningScope",
    "create_op_pruning",
    "create_graph_ops_pruning",
    "create_ks_scheduled_constant_graph_ops",
    "get_or_create_graph_ops_pruning",
    "apply_op_vars_masks",
    "create_summaries_pruning",
    "create_ks_schedule_ops",
    "get_or_create_ks_schedule_ops",
    "get_or_create_ks_scheduled_graph_ops",
]


PruningOpVars = namedtuple(
    "PruningOpVars", ["op", "op_input", "update", "mask", "masked"]
)


[docs]class PruningScope(object): """ Convenience class for dealing with scope and names for kernel sparsity in the tf graph. """ NM_KS = "nm_ks" NM_KS_OPS = "nm_ks_ops" OPS = "ops" OPS_INPUT = "input_ops" OPS_UPDATE = "update_ops" OPS_SUMMARY = "summary_ops" OPS_SCHEDULE = "schedule_ops" OPS_SPARSITY = "sparsity_ops" OP_COND_UPDATE = "nm_conditional_update" OP_SPARSITY = "nm_sparsity" OP_UPDATE_READY = "nm_update_ready" OP_MASKED_VAR = "nm_masked_var" OP_MASK_ASSIGN = "nm_mask_assign" OP_PRUNE_VARS_ASSIGN = "nm_prune_vars_assign" OP_MASK_UPDATE_NO_OP = "nm_mask_update_no_op" OP_MASK_UPDATE = "nm_mask_update" OP_WEIGHT_UPDATE = "nm_weight_update" OP_SAVE = "nm_save" VAR_MASK = "nm_mask" VAR_THRESHOLD = "nm_threshold"
[docs] @staticmethod def general(ks_group: str, additional: str = None, trailing_slash: bool = False): """ Create a general kernel sparsity scope in the tf graph. Use cases are for generic ops like target sparsity, conditional updates, etc. :param ks_group: the group identifier the scope should be created under :param additional: any additional scope that should be added to the end :param trailing_slash: include a trailing forward slash if True, else False :return: the proper scope """ scope = PruningScope._format(PruningScope.NM_KS_OPS, ks_group) scope = PruningScope._format( scope, additional=additional, trailing_slash=trailing_slash ) return scope
[docs] @staticmethod def model( op_tens: tf_compat.Tensor, ks_group: str, additional: str = None, trailing_slash: bool = False, ) -> str: """ Create a model specific kernel sparsity scope in the tf graph. Use cases are for the specific mask, threshold, etc variables to induce sparsity along with the ops to update those vars. :param op_tens: the op tensor to create the scope for :param ks_group: the group identifier the scope should be created under :param additional: any additional scope that should be added to the end :param trailing_slash: include a trailing forward slash if True, else False :return: the proper scope """ op_name = clean_tensor_name(op_tens) scope = PruningScope._format( "{}_{}".format(op_name, PruningScope.NM_KS), ks_group ) scope = PruningScope._format( scope, additional=additional, trailing_slash=trailing_slash ) return scope
[docs] @staticmethod def collection_name(ks_group: str, name: str) -> str: """ Create a predictable name for a given variable / op in a group for lookup / storage in a collection :param ks_group: the group identifier the name belongs under :param name: the name of the op or variable to be stored or retrieved :return: the formatted name for use in a collection """ return "nm_ks_collection_{}_{}".format(ks_group, name)
@staticmethod def _format( current: str, additional: str = None, trailing_slash: bool = False ) -> str: scope = current if additional is not None: scope = "{}/{}".format(current, additional) if trailing_slash: scope += "/" return scope
def create_op_pruning_no_update( op: tf_compat.Operation, op_input: tf_compat.Tensor, ks_group: str, leave_enabled: bool = True, is_after_end_step: tf_compat.Tensor = None, ) -> PruningOpVars: """ Creates the necessary variables and operators to gradually apply sparsity to an operators variable without returning a PruningOpVars.update value. :param op: the operation to prune to the given sparsity :param op_input: the parameter within the op to create a mask for :param ks_group: the group identifier the scope should be created under mask_creator :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking :param is_after_end_step: only should be provided if leave_enabled is False; tensor that is true if the current global step is after end_epoch :return: a named tuple containing the assignment op, mask variable, threshold tensor, and masked tensor """ if tf_contrib_err: raise tf_contrib_err op_sgv = graph_editor.sgv(op) # create the necessary variables first with tf_compat.variable_scope( PruningScope.model(op, ks_group), reuse=tf_compat.AUTO_REUSE ): mask = tf_compat.get_variable( PruningScope.VAR_MASK, op_input.get_shape(), initializer=tf_compat.ones_initializer(), trainable=False, dtype=op_input.dtype, ) tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.VAR_MASK), mask ) # create the masked operation and assign as the new input to the op with tf_compat.name_scope(PruningScope.model(op, ks_group, trailing_slash=True)): masked = tf_compat.multiply(mask, op_input, PruningScope.OP_MASKED_VAR) op_inp_tens = ( masked if leave_enabled else tf_compat.cond(is_after_end_step, lambda: op_input, lambda: masked) ) op_swapped_inputs = [ inp if inp != op_input else op_inp_tens for inp in op_sgv.inputs ] graph_editor.swap_inputs(op, op_swapped_inputs) tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OP_MASKED_VAR), masked ) return PruningOpVars(op, op_input, None, mask, masked)
[docs]def create_op_pruning( op: tf_compat.Operation, op_input: tf_compat.Tensor, sparsity: tf_compat.Tensor, update_ready: tf_compat.Tensor, leave_enabled: bool, is_after_end_step: tf_compat.Tensor, ks_group: str, mask_creator: PruningMaskCreator, ) -> PruningOpVars: """ Creates the necessary variables and operators to gradually apply sparsity to an operators variable. Handles setting a mask on an operator to the given sparsity. Sets the mask based on pruning away the lowest absolute magnitude weights. :param op: the operation to prune to the given sparsity :param op_input: the variable of the parameter within op to prune :param sparsity: the target sparsity to use for assigning the masks :param update_ready: the tensor where if true will update the mask from sparsity, if false will not update the mask :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking :param is_after_end_step: tensor that is true if the current global step is after end_epoch :param ks_group: the group identifier the scope should be created under :param mask_creator: object to define sparisty mask creation :return: a named tuple containing the assignment op, mask variable, threshold tensor, and masked tensor """ initial_vars = create_op_pruning_no_update( op, op_input, ks_group, leave_enabled, is_after_end_step ) op = initial_vars.op op_var_tens = initial_vars.op_input mask = initial_vars.mask masked = initial_vars.masked def _update(): # create the update ops using the target sparsity tensor with tf_compat.name_scope( PruningScope.model( op, ks_group, additional=PruningScope.OPS_UPDATE, trailing_slash=True, ) ): new_mask = mask_creator.create_sparsity_mask(op_var_tens, sparsity) weight_var = get_tensor_var(op_var_tens) return tf_compat.group( tf_compat.assign(mask, new_mask, name=PruningScope.OP_MASK_ASSIGN), tf_compat.assign( weight_var, tf_compat.multiply(new_mask, op_var_tens), name=PruningScope.OP_WEIGHT_UPDATE, ), ) def _no_update(): with tf_compat.name_scope( PruningScope.model( op, ks_group, additional=PruningScope.OPS_UPDATE, trailing_slash=True, ) ): # return no op wrapped in group to match update type return tf_compat.group( tf_compat.constant( 0.0, dtype=op_var_tens.dtype, name=PruningScope.OP_MASK_UPDATE_NO_OP ) ) with tf_compat.name_scope( PruningScope.model( op, ks_group, additional=PruningScope.OPS_UPDATE, trailing_slash=True, ) ): mask_update = tf_compat.cond( update_ready, _update, _no_update, name=PruningScope.OP_MASK_UPDATE ) # add return state to collections tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OP_MASK_UPDATE), mask_update ) return PruningOpVars(op, op_var_tens, mask_update, mask, masked)
def create_constant_op_pruning( op: tf_compat.Operation, op_input: tf_compat.Tensor, is_start_step: tf_compat.Tensor, is_end_step: tf_compat.Tensor, ks_group: str, ) -> PruningOpVars: """ Creates PruningOpVars with constant mask for the given operation on start step, sets mask to be all 1s for the weight tensor where the operation input is non zero and 0 elsewhere. At the end_step we revert the mask to be all 1s and update the weight. :param op: the operation to prune to the given sparsity :param op_input: the input tensor to op to create a constant mask for :param is_start_step: True only if we are at the start step. :param is_end_step: True only if we are at the start end step. :param ks_group: the group identifier the scope should be created under :return: a named tuple containing the assignment op, mask variable, threshold tensor, and masked tensor """ initial_vars = create_op_pruning_no_update(op, op_input, ks_group) op = initial_vars.op op_var_tens = initial_vars.op_input mask = initial_vars.mask masked = initial_vars.masked is_start_or_end_step = tf_compat.logical_or(is_start_step, is_end_step) def _set_constant_mask(): # Assign mask tensor to be 1 for all nonzero values of op_var_tens otherwise 0 # On end step, revert mask to be all 1s with tf_compat.name_scope( PruningScope.model( op, ks_group, additional=PruningScope.OPS_UPDATE, trailing_slash=True, ) ): new_mask = tf_compat.cond( is_start_step, lambda: tf_compat.cast( tf_compat.not_equal(op_var_tens, 0.0), dtype=op_var_tens.dtype ), lambda: tf_compat.ones(op_var_tens.shape, dtype=op_var_tens.dtype), ) weight_var = get_tensor_var(op_var_tens) return tf_compat.group( tf_compat.assign(mask, new_mask, name=PruningScope.OP_MASK_ASSIGN), tf_compat.assign( weight_var, masked, name=PruningScope.OP_WEIGHT_UPDATE ), ) def _no_op(): with tf_compat.name_scope( PruningScope.model( op, ks_group, additional=PruningScope.OPS_UPDATE, trailing_slash=True, ) ): # return no op wrapped in group to match update type return tf_compat.group( tf_compat.constant( 0.0, dtype=op_var_tens.dtype, name=PruningScope.OP_MASK_UPDATE_NO_OP ) ) with tf_compat.name_scope( PruningScope.model( op, ks_group, additional=PruningScope.OPS_UPDATE, trailing_slash=True, ) ): mask_update = tf_compat.cond( is_start_or_end_step, _set_constant_mask, _no_op, name=PruningScope.OP_MASK_UPDATE, ) return PruningOpVars(op, op_var_tens, mask_update, mask, masked)
[docs]def create_graph_ops_pruning( graph: tf_compat.Graph, var_names: List[str], sparsity: tf_compat.Tensor, update_ready: tf_compat.Tensor, leave_enabled: bool, is_after_end_step: tf_compat.Tensor, ks_group: str, mask_creator: PruningMaskCreator, ) -> List[PruningOpVars]: """ Creates the necessary variables and operators to gradually apply sparsity to a given list of operators in a graph. Handles setting a mask on an operator to the given sparsity. Sets the mask based on pruning away the lowest absolute magnitude weights. :param graph: the tf graph to pull the operator out of for applying the pruning to :param var_names: the names or regex patterns of names of variables to prune in the graph to the given sparsity :param sparsity: the target sparsity to use for assigning the masks :param update_ready: the tensor where if true will update the mask from sparsity, if false will not update the mask :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking :param is_after_end_step: tensor that is true if the current global step is after end_epoch :param ks_group: the group identifier the scope should be created under :param mask_creator: optional object to define sparisty mask creation :return: a list of the created named tuples each containing the assignment op, mask variable, threshold tensor, and masked tensor """ pruning_op_vars = [] variable_masks = {} # cache of mask vars for input variables for op, op_input in get_ops_and_inputs_by_name_or_regex(var_names, graph): if op_input not in variable_masks: op_vars = create_op_pruning( op, op_input, sparsity, update_ready, leave_enabled, is_after_end_step, ks_group, mask_creator, ) pruning_op_vars.append(op_vars) variable_masks[op_input] = op_vars else: # Reuse masks if the input variable is shared and already computed _, _, mask_update, mask, masked = variable_masks[op_input] pruning_op_vars.append( PruningOpVars(op, op_input, mask_update, mask, masked) ) tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OPS), op ) tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OPS_INPUT), op_input ) return pruning_op_vars
[docs]def get_or_create_graph_ops_pruning( graph: tf_compat.Graph, var_names: List[str], sparsity: tf_compat.Tensor, update_ready: tf_compat.Tensor, leave_enabled: bool, is_after_end_step: tf_compat.Tensor, ks_group: str, mask_creator: PruningMaskCreator, ) -> List[PruningOpVars]: """ Creates or retrieves (if previously created) the necessary variables and operators to gradually apply sparsity to a given list of operators in a graph. Handles setting a mask on an operator to the given sparsity. Sets the mask based on pruning away the lowest absolute magnitude weights. :param graph: the tf graph to pull the operator out of for applying the pruning to :param var_names: the names or regex patterns of names of variables to prune in the graph to the given sparsity :param sparsity: the target sparsity to use for assigning the masks :param update_ready: the tensor where if true will update the mask from sparsity, if false will not update the mask :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking :param is_after_end_step: tensor that is true if the current global step is after end_epoch :param ks_group: the group identifier the scope should be created under :param mask_creator: optional object to define sparisty mask creation :return: a list of the created or retrieved named tuples each containing the assignment op, mask variable, threshold tensor, and masked tensor """ ops = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.OPS) ) ops_input = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.OPS_INPUT) ) mask_updates = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.OP_MASK_UPDATE) ) masks = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.VAR_MASK) ) maskeds = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.OP_MASKED_VAR) ) if ( len(ops) < 1 or len(ops_input) < 1 or len(mask_updates) < 1 or len(masks) < 1 or len(maskeds) < 1 ): # create new pruning ops pruning_op_vars = create_graph_ops_pruning( graph, var_names, sparsity, update_ready, leave_enabled, is_after_end_step, ks_group, mask_creator, ) else: # use collection pruning ops pruning_op_vars = [] for op, op_input, mask_update, mask, masked in zip( ops, ops_input, mask_updates, masks, maskeds ): pruning_op_vars.append( PruningOpVars(op, op_input, mask_update, mask, masked) ) return pruning_op_vars
[docs]def create_summaries_pruning(pruning_op_vars: List[PruningOpVars]): """ Create TensorBoard summary ops in the current graph for the given list of PruningOpVars. :param pruning_op_vars: the list of named tuples containing the masked input to the pruned op to record sparsity for in TensorBoard. :return: the created summaries for the pruned op vars """ summaries = [] for op_vars in pruning_op_vars: try: zero_fraction = tf_compat.zero_fraction except Exception: def zero_fraction(inp: tf_compat.Tensor): nonzero = tf_compat.cast( tf_compat.reduce_sum( tf_compat.cast(tf_compat.not_equal(inp, 0), tf_compat.int64) ), tf_compat.float32, ) size = tf_compat.size(inp, out_type=tf_compat.float32) return 1 - tf_compat_div(nonzero, size) if is_prunable_op(op_vars.op): sum_op = tf_compat.summary.scalar( "Modifier_Pruning/{}".format(clean_tensor_name(op_vars.op)), zero_fraction(op_vars.masked), ) summaries.append(sum_op) return summaries
[docs]def apply_op_vars_masks( pruning_op_vars: List[PruningOpVars], ks_group: str, sess: tf_compat.Session ): """ Apply the masks to the original ops input var so that it can be saved with the desired sparsity for later. :param pruning_op_vars: the list of named tuples containing the sparse mask and the op variable to apply the sparse mask to :param ks_group: the group to create the assign ops under :param sess: the session to use to run the assign """ for op_vars in pruning_op_vars: with tf_compat.name_scope( PruningScope.model(op_vars.op, ks_group, PruningScope.OP_SAVE) ): masked_var = tf_compat.multiply(op_vars.op_input, op_vars.mask) input_var = get_tensor_var(op_vars.op_input) assign = tf_compat.assign(input_var, masked_var) sess.run(assign)
[docs]def create_ks_schedule_ops( global_step: tf_compat.Variable, begin_step: int, end_step: int, update_step_freq: int, init_sparsity: float, final_sparsity: float, exponent: float, ks_group: str, ) -> Tuple[tf_compat.Tensor, tf_compat.Tensor]: """ Create a gradual schedule for model pruning (kernel sparsity). Creates a sparsity tensor that goes from init_sparsity til final_sparsity starting at begin_step and ending at end_step. Uses the global_step to map those. Additionally creates an update_ready tensor that is True if an update to the sparsity tensor should be run, False otherwise. :param global_step: the global optimizer step for the training graph :param begin_step: the global step to begin pruning at :param end_step: the global step to end pruning at :param update_step_freq: the number of global steps between each weight update :param init_sparsity: the starting value for sparsity of a weight tensor to be enforce :param final_sparsity: the end value for sparsity for a weight tensor to be enforce :param exponent: the exponent to use for interpolating between init_sparsity and final_sparsity higher values will lead to larger sparsity steps at the beginning vs the end ie: linear (1) vs cubic (3) :param ks_group: the group identifier the scope should be created under :return: a tuple containing the signal for update_ready and the target sparsity """ # create the scheduling ops first and the sparsity ops with tf_compat.name_scope( PruningScope.general( ks_group, additional=PruningScope.OPS_SCHEDULE, trailing_slash=True ) ): sched_before = tf_compat.less(global_step, begin_step) sched_start = tf_compat.equal(global_step, begin_step) sched_end = tf_compat.equal(global_step, end_step) sched_active = tf_compat.logical_and( tf_compat.greater(global_step, begin_step), tf_compat.less(global_step, end_step), ) sched_active_inclusive = tf_compat.logical_or( sched_active, tf_compat.logical_or(sched_start, sched_end) ) sched_update = tf_compat.cond( tf_compat.less_equal(update_step_freq, 0), lambda: tf_compat.constant(True), lambda: tf_compat.equal( tf_compat.mod((global_step - begin_step), update_step_freq), 0 ), ) sched_update_ready = tf_compat.logical_or( tf_compat.logical_or(sched_start, sched_end), sched_update ) percentage = tf_compat.minimum( 1.0, tf_compat.maximum( 0.0, tf_compat_div( tf_compat.cast(global_step - begin_step, tf_compat.float32), end_step - begin_step, ), ), ) exp_percentage = 1 - tf_compat.pow(1 - percentage, exponent) calc_sparsity = ( tf_compat.multiply(final_sparsity - init_sparsity, exp_percentage) + init_sparsity ) # create the update ready tensor and sparsity tensor with tf_compat.name_scope(PruningScope.general(ks_group, trailing_slash=True)): update_ready = tf_compat.logical_and( sched_active_inclusive, sched_update_ready, name=PruningScope.OP_UPDATE_READY, ) sparsity = tf_compat.case( [ (sched_before, lambda: tf_compat.constant(0.0)), (sched_start, lambda: tf_compat.constant(init_sparsity)), (sched_active, lambda: calc_sparsity), ], default=lambda: tf_compat.constant(final_sparsity), name=PruningScope.OP_SPARSITY, ) # add return state to collections tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY), update_ready, ) tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY), sparsity ) return update_ready, sparsity
[docs]def get_or_create_ks_schedule_ops( global_step: tf_compat.Tensor, begin_step: int, end_step: int, update_step_freq: int, init_sparsity: float, final_sparsity: float, exponent: float, ks_group: str, ) -> Tuple[tf_compat.Tensor, tf_compat.Tensor]: """ Creates or retrieves (if previously created) a gradual schedule for model pruning (kernel sparsity). Creates a sparsity tensor that goes from init_sparsity til final_sparsity starting at begin_step and ending at end_step. Uses the global_step to map those. Additionally creates an update_ready tensor that is True if an update to the sparsity tensor should be run, False otherwise. :param global_step: the global optimizer step for the training graph :param begin_step: the global step to begin pruning at :param end_step: the global step to end pruning at :param update_step_freq: the number of global steps between each weight update :param init_sparsity: the starting value for sparsity of a weight tensor to be enforce :param final_sparsity: the end value for sparsity for a weight tensor to be enforce :param exponent: the exponent to use for interpolating between init_sparsity and final_sparsity higher values will lead to larger sparsity steps at the beginning vs the end ie: linear (1) vs cubic (3) :param ks_group: the group identifier the scope should be created under :return: a tuple containing the signal for update_ready and the target sparsity """ update_ready = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY) ) sparsity = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY) ) update_ready = update_ready[0] if len(update_ready) > 0 else None sparsity = sparsity[0] if len(sparsity) > 0 else None if update_ready is None or sparsity is None: update_ready, sparsity = create_ks_schedule_ops( global_step, begin_step, end_step, update_step_freq, init_sparsity, final_sparsity, exponent, ks_group, ) # add return state to collections tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY), update_ready, ) tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY), sparsity ) return update_ready, sparsity
def get_scheduled_update_op( pruning_op_vars: List[PruningOpVars], ks_group: str, ): """ Creates model pruning (kernel sparsity) ops and vars in the graph to be applied over a specific schedule. Creates them for the ops in the graph such that they follow the given schedule. :param pruning_op_vars: List of tuples of operation tensors and masks. :param ks_group: the group identifier the scope should be created under :return: the update operation to run in a session """ update_op = tf_compat.get_collection( PruningScope.collection_name(ks_group, PruningScope.OP_COND_UPDATE) ) update_op = update_op[0] if len(update_op) > 0 else None if update_op is None: update_op = tf_compat.group(*[op_var.update for op_var in pruning_op_vars]) # add return state to collections tf_compat.add_to_collection( PruningScope.collection_name(ks_group, PruningScope.OP_COND_UPDATE), update_op, ) return update_op
[docs]def get_or_create_ks_scheduled_graph_ops( graph: tf_compat.Graph, global_step: tf_compat.Variable, var_names: List[str], begin_step: int, end_step: int, update_step_freq: int, init_sparsity: float, final_sparsity: float, exponent: float, leave_enabled: bool, ks_group: str, mask_creator: PruningMaskCreator, ) -> Tuple[tf_compat.Tensor, List[PruningOpVars], tf_compat.Tensor, tf_compat.Tensor]: """ Gets or creates model pruning (kernel sparsity) ops and vars in the graph to be applied over a specific schedule. Creates them for the var_names in the graph such that they follow a schedule from begin_step to end_step starting at init_sparsity and ending at final_sparsity. :param graph: the tf graph to pull the operator out of for applying the pruning to :param global_step: the global optimizer step for the training graph :param var_names: the names or regex patterns of names of variables to prune in the graph :param begin_step: the global step to begin pruning at :param end_step: the global step to end pruning at :param update_step_freq: the number of global steps between each weight update :param init_sparsity: the starting value for sparsity of a weight tensor to be enforce :param final_sparsity: the end value for sparsity for a weight tensor to be enforce :param exponent: the exponent to use for interpolating between init_sparsity and final_sparsity higher values will lead to larger sparsity steps at the beginning vs the end ie: linear (1) vs cubic (3) :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking :param ks_group: the group identifier the scope should be created under :param mask_creator: optional object to define sparisty mask creation :return: a tuple containing the update operation to run in a session, a list of the pruning ops and vars for each desired op in the graph, the tensor containing the update_ready signal for the pruning ops, the tensor containing the set sparsity for the pruning ops """ update_ready, sparsity = get_or_create_ks_schedule_ops( global_step, begin_step, end_step, update_step_freq, init_sparsity, final_sparsity, exponent, ks_group, ) is_after_end_step = tf_compat.greater(global_step, end_step) pruning_op_vars = get_or_create_graph_ops_pruning( graph, var_names, sparsity, update_ready, leave_enabled, is_after_end_step, ks_group, mask_creator, ) update_op = get_scheduled_update_op(pruning_op_vars, ks_group) return update_op, pruning_op_vars, update_ready, sparsity
[docs]def create_ks_scheduled_constant_graph_ops( graph: tf_compat.Graph, global_step: tf_compat.Variable, var_names: List[str], begin_step: int, end_step: int, ks_group: str, ) -> Tuple[tf_compat.Tensor, List[PruningOpVars]]: """ Creates constant model pruning ops. Does not modify the graph. :param graph: the tf graph to pull the operator out of for applying the pruning to :param global_step: the global optimizer step for the training graph :param var_names: a list of names or regex patterns to create constant ops for within the graph :param begin_step: the global step to begin pruning at :param end_step: the global step to end pruning at :param ks_group: the group identifier the scope should be created under :return: a tuple containing the update operation to run in a session, a list of the pruning ops and vars for each desired op in the graph """ pruning_op_vars = [] is_start_step = tf_compat.equal(global_step, begin_step) is_end_step = tf_compat.equal(global_step, end_step) for op, op_input in get_ops_and_inputs_by_name_or_regex(var_names, graph): op_vars = create_constant_op_pruning( op, op_input, is_start_step, is_end_step, ks_group ) pruning_op_vars.append(op_vars) update_op = get_scheduled_update_op(pruning_op_vars, ks_group) return update_op, pruning_op_vars