# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code related to applying a mask onto a variable to impose kernel sparsity,
aka model pruning, on a TensorFlow graph.
"""
from collections import namedtuple
from typing import List, Tuple
try:
import tensorflow.contrib.graph_editor as graph_editor
tf_contrib_err = None
except Exception as err:
graph_editor = None
tf_contrib_err = err
from sparseml.tensorflow_v1.optim.mask_creator_pruning import PruningMaskCreator
from sparseml.tensorflow_v1.utils import (
clean_tensor_name,
get_ops_and_inputs_by_name_or_regex,
get_tensor_var,
is_prunable_op,
tf_compat,
tf_compat_div,
)
__all__ = [
"PruningOpVars",
"PruningScope",
"create_op_pruning",
"create_graph_ops_pruning",
"create_ks_scheduled_constant_graph_ops",
"get_or_create_graph_ops_pruning",
"apply_op_vars_masks",
"create_summaries_pruning",
"create_ks_schedule_ops",
"get_or_create_ks_schedule_ops",
"get_or_create_ks_scheduled_graph_ops",
]
PruningOpVars = namedtuple(
"PruningOpVars", ["op", "op_input", "update", "mask", "masked"]
)
[docs]class PruningScope(object):
"""
Convenience class for dealing with scope and names for kernel sparsity
in the tf graph.
"""
NM_KS = "nm_ks"
NM_KS_OPS = "nm_ks_ops"
OPS = "ops"
OPS_INPUT = "input_ops"
OPS_UPDATE = "update_ops"
OPS_SUMMARY = "summary_ops"
OPS_SCHEDULE = "schedule_ops"
OPS_SPARSITY = "sparsity_ops"
OP_COND_UPDATE = "nm_conditional_update"
OP_SPARSITY = "nm_sparsity"
OP_UPDATE_READY = "nm_update_ready"
OP_MASKED_VAR = "nm_masked_var"
OP_MASK_ASSIGN = "nm_mask_assign"
OP_PRUNE_VARS_ASSIGN = "nm_prune_vars_assign"
OP_MASK_UPDATE_NO_OP = "nm_mask_update_no_op"
OP_MASK_UPDATE = "nm_mask_update"
OP_WEIGHT_UPDATE = "nm_weight_update"
OP_SAVE = "nm_save"
VAR_MASK = "nm_mask"
VAR_THRESHOLD = "nm_threshold"
[docs] @staticmethod
def general(ks_group: str, additional: str = None, trailing_slash: bool = False):
"""
Create a general kernel sparsity scope in the tf graph.
Use cases are for generic ops like target sparsity, conditional updates, etc.
:param ks_group: the group identifier the scope should be created under
:param additional: any additional scope that should be added to the end
:param trailing_slash: include a trailing forward slash if True, else False
:return: the proper scope
"""
scope = PruningScope._format(PruningScope.NM_KS_OPS, ks_group)
scope = PruningScope._format(
scope, additional=additional, trailing_slash=trailing_slash
)
return scope
[docs] @staticmethod
def model(
op_tens: tf_compat.Tensor,
ks_group: str,
additional: str = None,
trailing_slash: bool = False,
) -> str:
"""
Create a model specific kernel sparsity scope in the tf graph.
Use cases are for the specific mask, threshold, etc variables
to induce sparsity along with the ops to update those vars.
:param op_tens: the op tensor to create the scope for
:param ks_group: the group identifier the scope should be created under
:param additional: any additional scope that should be added to the end
:param trailing_slash: include a trailing forward slash if True, else False
:return: the proper scope
"""
op_name = clean_tensor_name(op_tens)
scope = PruningScope._format(
"{}_{}".format(op_name, PruningScope.NM_KS), ks_group
)
scope = PruningScope._format(
scope, additional=additional, trailing_slash=trailing_slash
)
return scope
[docs] @staticmethod
def collection_name(ks_group: str, name: str) -> str:
"""
Create a predictable name for a given variable / op in a group for lookup /
storage in a collection
:param ks_group: the group identifier the name belongs under
:param name: the name of the op or variable to be stored or retrieved
:return: the formatted name for use in a collection
"""
return "nm_ks_collection_{}_{}".format(ks_group, name)
@staticmethod
def _format(
current: str, additional: str = None, trailing_slash: bool = False
) -> str:
scope = current
if additional is not None:
scope = "{}/{}".format(current, additional)
if trailing_slash:
scope += "/"
return scope
def create_op_pruning_no_update(
op: tf_compat.Operation,
op_input: tf_compat.Tensor,
ks_group: str,
leave_enabled: bool = True,
is_after_end_step: tf_compat.Tensor = None,
) -> PruningOpVars:
"""
Creates the necessary variables and operators to gradually
apply sparsity to an operators variable without returning a
PruningOpVars.update value.
:param op: the operation to prune to the given sparsity
:param op_input: the parameter within the op to create a mask for
:param ks_group: the group identifier the scope should be created under
mask_creator
:param leave_enabled: True to continue masking the weights after end_epoch,
False to stop masking
:param is_after_end_step: only should be provided if leave_enabled is False;
tensor that is true if the current global step is after end_epoch
:return: a named tuple containing the assignment op, mask variable,
threshold tensor, and masked tensor
"""
if tf_contrib_err:
raise tf_contrib_err
op_sgv = graph_editor.sgv(op)
# create the necessary variables first
with tf_compat.variable_scope(
PruningScope.model(op, ks_group), reuse=tf_compat.AUTO_REUSE
):
mask = tf_compat.get_variable(
PruningScope.VAR_MASK,
op_input.get_shape(),
initializer=tf_compat.ones_initializer(),
trainable=False,
dtype=op_input.dtype,
)
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.VAR_MASK), mask
)
# create the masked operation and assign as the new input to the op
with tf_compat.name_scope(PruningScope.model(op, ks_group, trailing_slash=True)):
masked = tf_compat.multiply(mask, op_input, PruningScope.OP_MASKED_VAR)
op_inp_tens = (
masked
if leave_enabled
else tf_compat.cond(is_after_end_step, lambda: op_input, lambda: masked)
)
op_swapped_inputs = [
inp if inp != op_input else op_inp_tens for inp in op_sgv.inputs
]
graph_editor.swap_inputs(op, op_swapped_inputs)
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_MASKED_VAR), masked
)
return PruningOpVars(op, op_input, None, mask, masked)
[docs]def create_op_pruning(
op: tf_compat.Operation,
op_input: tf_compat.Tensor,
sparsity: tf_compat.Tensor,
update_ready: tf_compat.Tensor,
leave_enabled: bool,
is_after_end_step: tf_compat.Tensor,
ks_group: str,
mask_creator: PruningMaskCreator,
) -> PruningOpVars:
"""
Creates the necessary variables and operators to gradually
apply sparsity to an operators variable.
Handles setting a mask on an operator to the given sparsity.
Sets the mask based on pruning away the lowest absolute magnitude weights.
:param op: the operation to prune to the given sparsity
:param op_input: the variable of the parameter within op to prune
:param sparsity: the target sparsity to use for assigning the masks
:param update_ready: the tensor where if true will update the mask from sparsity,
if false will not update the mask
:param leave_enabled: True to continue masking the weights after end_epoch,
False to stop masking
:param is_after_end_step: tensor that is true if the current global step
is after end_epoch
:param ks_group: the group identifier the scope should be created under
:param mask_creator: object to define sparisty mask creation
:return: a named tuple containing the assignment op, mask variable,
threshold tensor, and masked tensor
"""
initial_vars = create_op_pruning_no_update(
op, op_input, ks_group, leave_enabled, is_after_end_step
)
op = initial_vars.op
op_var_tens = initial_vars.op_input
mask = initial_vars.mask
masked = initial_vars.masked
def _update():
# create the update ops using the target sparsity tensor
with tf_compat.name_scope(
PruningScope.model(
op,
ks_group,
additional=PruningScope.OPS_UPDATE,
trailing_slash=True,
)
):
new_mask = mask_creator.create_sparsity_mask(op_var_tens, sparsity)
weight_var = get_tensor_var(op_var_tens)
return tf_compat.group(
tf_compat.assign(mask, new_mask, name=PruningScope.OP_MASK_ASSIGN),
tf_compat.assign(
weight_var,
tf_compat.multiply(new_mask, op_var_tens),
name=PruningScope.OP_WEIGHT_UPDATE,
),
)
def _no_update():
with tf_compat.name_scope(
PruningScope.model(
op,
ks_group,
additional=PruningScope.OPS_UPDATE,
trailing_slash=True,
)
):
# return no op wrapped in group to match update type
return tf_compat.group(
tf_compat.constant(
0.0, dtype=op_var_tens.dtype, name=PruningScope.OP_MASK_UPDATE_NO_OP
)
)
with tf_compat.name_scope(
PruningScope.model(
op,
ks_group,
additional=PruningScope.OPS_UPDATE,
trailing_slash=True,
)
):
mask_update = tf_compat.cond(
update_ready, _update, _no_update, name=PruningScope.OP_MASK_UPDATE
)
# add return state to collections
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_MASK_UPDATE), mask_update
)
return PruningOpVars(op, op_var_tens, mask_update, mask, masked)
def create_constant_op_pruning(
op: tf_compat.Operation,
op_input: tf_compat.Tensor,
is_start_step: tf_compat.Tensor,
is_end_step: tf_compat.Tensor,
ks_group: str,
) -> PruningOpVars:
"""
Creates PruningOpVars with constant mask for the given operation
on start step, sets mask to be all 1s for the weight tensor where
the operation input is non zero and 0 elsewhere.
At the end_step we revert the mask to be all 1s and update the weight.
:param op: the operation to prune to the given sparsity
:param op_input: the input tensor to op to create a constant mask for
:param is_start_step: True only if we are at the start step.
:param is_end_step: True only if we are at the start end step.
:param ks_group: the group identifier the scope should be created under
:return: a named tuple containing the assignment op, mask variable,
threshold tensor, and masked tensor
"""
initial_vars = create_op_pruning_no_update(op, op_input, ks_group)
op = initial_vars.op
op_var_tens = initial_vars.op_input
mask = initial_vars.mask
masked = initial_vars.masked
is_start_or_end_step = tf_compat.logical_or(is_start_step, is_end_step)
def _set_constant_mask():
# Assign mask tensor to be 1 for all nonzero values of op_var_tens otherwise 0
# On end step, revert mask to be all 1s
with tf_compat.name_scope(
PruningScope.model(
op,
ks_group,
additional=PruningScope.OPS_UPDATE,
trailing_slash=True,
)
):
new_mask = tf_compat.cond(
is_start_step,
lambda: tf_compat.cast(
tf_compat.not_equal(op_var_tens, 0.0), dtype=op_var_tens.dtype
),
lambda: tf_compat.ones(op_var_tens.shape, dtype=op_var_tens.dtype),
)
weight_var = get_tensor_var(op_var_tens)
return tf_compat.group(
tf_compat.assign(mask, new_mask, name=PruningScope.OP_MASK_ASSIGN),
tf_compat.assign(
weight_var, masked, name=PruningScope.OP_WEIGHT_UPDATE
),
)
def _no_op():
with tf_compat.name_scope(
PruningScope.model(
op,
ks_group,
additional=PruningScope.OPS_UPDATE,
trailing_slash=True,
)
):
# return no op wrapped in group to match update type
return tf_compat.group(
tf_compat.constant(
0.0, dtype=op_var_tens.dtype, name=PruningScope.OP_MASK_UPDATE_NO_OP
)
)
with tf_compat.name_scope(
PruningScope.model(
op,
ks_group,
additional=PruningScope.OPS_UPDATE,
trailing_slash=True,
)
):
mask_update = tf_compat.cond(
is_start_or_end_step,
_set_constant_mask,
_no_op,
name=PruningScope.OP_MASK_UPDATE,
)
return PruningOpVars(op, op_var_tens, mask_update, mask, masked)
[docs]def create_graph_ops_pruning(
graph: tf_compat.Graph,
var_names: List[str],
sparsity: tf_compat.Tensor,
update_ready: tf_compat.Tensor,
leave_enabled: bool,
is_after_end_step: tf_compat.Tensor,
ks_group: str,
mask_creator: PruningMaskCreator,
) -> List[PruningOpVars]:
"""
Creates the necessary variables and operators to gradually
apply sparsity to a given list of operators in a graph.
Handles setting a mask on an operator to the given sparsity.
Sets the mask based on pruning away the lowest absolute magnitude weights.
:param graph: the tf graph to pull the operator out of for applying the pruning to
:param var_names: the names or regex patterns of names of variables to prune in the
graph to the given sparsity
:param sparsity: the target sparsity to use for assigning the masks
:param update_ready: the tensor where if true will update the mask from sparsity,
if false will not update the mask
:param leave_enabled: True to continue masking the weights after end_epoch,
False to stop masking
:param is_after_end_step: tensor that is true if the current global step
is after end_epoch
:param ks_group: the group identifier the scope should be created under
:param mask_creator: optional object to define sparisty mask creation
:return: a list of the created named tuples each containing the
assignment op, mask variable, threshold tensor, and masked tensor
"""
pruning_op_vars = []
variable_masks = {} # cache of mask vars for input variables
for op, op_input in get_ops_and_inputs_by_name_or_regex(var_names, graph):
if op_input not in variable_masks:
op_vars = create_op_pruning(
op,
op_input,
sparsity,
update_ready,
leave_enabled,
is_after_end_step,
ks_group,
mask_creator,
)
pruning_op_vars.append(op_vars)
variable_masks[op_input] = op_vars
else: # Reuse masks if the input variable is shared and already computed
_, _, mask_update, mask, masked = variable_masks[op_input]
pruning_op_vars.append(
PruningOpVars(op, op_input, mask_update, mask, masked)
)
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OPS), op
)
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OPS_INPUT), op_input
)
return pruning_op_vars
[docs]def get_or_create_graph_ops_pruning(
graph: tf_compat.Graph,
var_names: List[str],
sparsity: tf_compat.Tensor,
update_ready: tf_compat.Tensor,
leave_enabled: bool,
is_after_end_step: tf_compat.Tensor,
ks_group: str,
mask_creator: PruningMaskCreator,
) -> List[PruningOpVars]:
"""
Creates or retrieves (if previously created) the necessary variables
and operators to gradually apply sparsity to a given list of operators in a graph.
Handles setting a mask on an operator to the given sparsity.
Sets the mask based on pruning away the lowest absolute magnitude weights.
:param graph: the tf graph to pull the operator out of for applying the pruning to
:param var_names: the names or regex patterns of names of variables to prune in the
graph to the given sparsity
:param sparsity: the target sparsity to use for assigning the masks
:param update_ready: the tensor where if true will update the mask from sparsity,
if false will not update the mask
:param leave_enabled: True to continue masking the weights after end_epoch,
False to stop masking
:param is_after_end_step: tensor that is true if the current global step
is after end_epoch
:param ks_group: the group identifier the scope should be created under
:param mask_creator: optional object to define sparisty mask creation
:return: a list of the created or retrieved named tuples each containing the
assignment op, mask variable, threshold tensor, and masked tensor
"""
ops = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.OPS)
)
ops_input = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.OPS_INPUT)
)
mask_updates = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_MASK_UPDATE)
)
masks = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.VAR_MASK)
)
maskeds = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_MASKED_VAR)
)
if (
len(ops) < 1
or len(ops_input) < 1
or len(mask_updates) < 1
or len(masks) < 1
or len(maskeds) < 1
): # create new pruning ops
pruning_op_vars = create_graph_ops_pruning(
graph,
var_names,
sparsity,
update_ready,
leave_enabled,
is_after_end_step,
ks_group,
mask_creator,
)
else: # use collection pruning ops
pruning_op_vars = []
for op, op_input, mask_update, mask, masked in zip(
ops, ops_input, mask_updates, masks, maskeds
):
pruning_op_vars.append(
PruningOpVars(op, op_input, mask_update, mask, masked)
)
return pruning_op_vars
[docs]def create_summaries_pruning(pruning_op_vars: List[PruningOpVars]):
"""
Create TensorBoard summary ops in the current graph for the
given list of PruningOpVars.
:param pruning_op_vars: the list of named tuples containing the masked input to the
pruned op to record sparsity for in TensorBoard.
:return: the created summaries for the pruned op vars
"""
summaries = []
for op_vars in pruning_op_vars:
try:
zero_fraction = tf_compat.zero_fraction
except Exception:
def zero_fraction(inp: tf_compat.Tensor):
nonzero = tf_compat.cast(
tf_compat.reduce_sum(
tf_compat.cast(tf_compat.not_equal(inp, 0), tf_compat.int64)
),
tf_compat.float32,
)
size = tf_compat.size(inp, out_type=tf_compat.float32)
return 1 - tf_compat_div(nonzero, size)
if is_prunable_op(op_vars.op):
sum_op = tf_compat.summary.scalar(
"Modifier_Pruning/{}".format(clean_tensor_name(op_vars.op)),
zero_fraction(op_vars.masked),
)
summaries.append(sum_op)
return summaries
[docs]def apply_op_vars_masks(
pruning_op_vars: List[PruningOpVars], ks_group: str, sess: tf_compat.Session
):
"""
Apply the masks to the original ops input var so that it can be saved
with the desired sparsity for later.
:param pruning_op_vars: the list of named tuples containing the sparse mask
and the op variable to apply the sparse mask to
:param ks_group: the group to create the assign ops under
:param sess: the session to use to run the assign
"""
for op_vars in pruning_op_vars:
with tf_compat.name_scope(
PruningScope.model(op_vars.op, ks_group, PruningScope.OP_SAVE)
):
masked_var = tf_compat.multiply(op_vars.op_input, op_vars.mask)
input_var = get_tensor_var(op_vars.op_input)
assign = tf_compat.assign(input_var, masked_var)
sess.run(assign)
[docs]def create_ks_schedule_ops(
global_step: tf_compat.Variable,
begin_step: int,
end_step: int,
update_step_freq: int,
init_sparsity: float,
final_sparsity: float,
exponent: float,
ks_group: str,
) -> Tuple[tf_compat.Tensor, tf_compat.Tensor]:
"""
Create a gradual schedule for model pruning (kernel sparsity).
Creates a sparsity tensor that goes from init_sparsity til final_sparsity
starting at begin_step and ending at end_step.
Uses the global_step to map those.
Additionally creates an update_ready tensor that is True if an update
to the sparsity tensor should be run, False otherwise.
:param global_step: the global optimizer step for the training graph
:param begin_step: the global step to begin pruning at
:param end_step: the global step to end pruning at
:param update_step_freq: the number of global steps between each weight update
:param init_sparsity: the starting value for sparsity of a
weight tensor to be enforce
:param final_sparsity: the end value for sparsity for a weight tensor to be enforce
:param exponent: the exponent to use for interpolating between
init_sparsity and final_sparsity higher values will lead to larger sparsity
steps at the beginning vs the end ie: linear (1) vs cubic (3)
:param ks_group: the group identifier the scope should be created under
:return: a tuple containing the signal for update_ready and the target sparsity
"""
# create the scheduling ops first and the sparsity ops
with tf_compat.name_scope(
PruningScope.general(
ks_group, additional=PruningScope.OPS_SCHEDULE, trailing_slash=True
)
):
sched_before = tf_compat.less(global_step, begin_step)
sched_start = tf_compat.equal(global_step, begin_step)
sched_end = tf_compat.equal(global_step, end_step)
sched_active = tf_compat.logical_and(
tf_compat.greater(global_step, begin_step),
tf_compat.less(global_step, end_step),
)
sched_active_inclusive = tf_compat.logical_or(
sched_active, tf_compat.logical_or(sched_start, sched_end)
)
sched_update = tf_compat.cond(
tf_compat.less_equal(update_step_freq, 0),
lambda: tf_compat.constant(True),
lambda: tf_compat.equal(
tf_compat.mod((global_step - begin_step), update_step_freq), 0
),
)
sched_update_ready = tf_compat.logical_or(
tf_compat.logical_or(sched_start, sched_end), sched_update
)
percentage = tf_compat.minimum(
1.0,
tf_compat.maximum(
0.0,
tf_compat_div(
tf_compat.cast(global_step - begin_step, tf_compat.float32),
end_step - begin_step,
),
),
)
exp_percentage = 1 - tf_compat.pow(1 - percentage, exponent)
calc_sparsity = (
tf_compat.multiply(final_sparsity - init_sparsity, exp_percentage)
+ init_sparsity
)
# create the update ready tensor and sparsity tensor
with tf_compat.name_scope(PruningScope.general(ks_group, trailing_slash=True)):
update_ready = tf_compat.logical_and(
sched_active_inclusive,
sched_update_ready,
name=PruningScope.OP_UPDATE_READY,
)
sparsity = tf_compat.case(
[
(sched_before, lambda: tf_compat.constant(0.0)),
(sched_start, lambda: tf_compat.constant(init_sparsity)),
(sched_active, lambda: calc_sparsity),
],
default=lambda: tf_compat.constant(final_sparsity),
name=PruningScope.OP_SPARSITY,
)
# add return state to collections
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY),
update_ready,
)
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY), sparsity
)
return update_ready, sparsity
[docs]def get_or_create_ks_schedule_ops(
global_step: tf_compat.Tensor,
begin_step: int,
end_step: int,
update_step_freq: int,
init_sparsity: float,
final_sparsity: float,
exponent: float,
ks_group: str,
) -> Tuple[tf_compat.Tensor, tf_compat.Tensor]:
"""
Creates or retrieves (if previously created) a gradual schedule
for model pruning (kernel sparsity).
Creates a sparsity tensor that goes from init_sparsity til final_sparsity
starting at begin_step and ending at end_step.
Uses the global_step to map those.
Additionally creates an update_ready tensor that is True if an update
to the sparsity tensor should be run, False otherwise.
:param global_step: the global optimizer step for the training graph
:param begin_step: the global step to begin pruning at
:param end_step: the global step to end pruning at
:param update_step_freq: the number of global steps between each weight update
:param init_sparsity: the starting value for sparsity of a
weight tensor to be enforce
:param final_sparsity: the end value for sparsity for a weight tensor to be enforce
:param exponent: the exponent to use for interpolating between
init_sparsity and final_sparsity higher values will lead to larger sparsity
steps at the beginning vs the end ie: linear (1) vs cubic (3)
:param ks_group: the group identifier the scope should be created under
:return: a tuple containing the signal for update_ready and the target sparsity
"""
update_ready = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY)
)
sparsity = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY)
)
update_ready = update_ready[0] if len(update_ready) > 0 else None
sparsity = sparsity[0] if len(sparsity) > 0 else None
if update_ready is None or sparsity is None:
update_ready, sparsity = create_ks_schedule_ops(
global_step,
begin_step,
end_step,
update_step_freq,
init_sparsity,
final_sparsity,
exponent,
ks_group,
)
# add return state to collections
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY),
update_ready,
)
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY), sparsity
)
return update_ready, sparsity
def get_scheduled_update_op(
pruning_op_vars: List[PruningOpVars],
ks_group: str,
):
"""
Creates model pruning (kernel sparsity) ops and vars in the graph
to be applied over a specific schedule.
Creates them for the ops in the graph such that they follow the given schedule.
:param pruning_op_vars: List of tuples of operation tensors and masks.
:param ks_group: the group identifier the scope should be created under
:return: the update operation to run in a session
"""
update_op = tf_compat.get_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_COND_UPDATE)
)
update_op = update_op[0] if len(update_op) > 0 else None
if update_op is None:
update_op = tf_compat.group(*[op_var.update for op_var in pruning_op_vars])
# add return state to collections
tf_compat.add_to_collection(
PruningScope.collection_name(ks_group, PruningScope.OP_COND_UPDATE),
update_op,
)
return update_op
[docs]def get_or_create_ks_scheduled_graph_ops(
graph: tf_compat.Graph,
global_step: tf_compat.Variable,
var_names: List[str],
begin_step: int,
end_step: int,
update_step_freq: int,
init_sparsity: float,
final_sparsity: float,
exponent: float,
leave_enabled: bool,
ks_group: str,
mask_creator: PruningMaskCreator,
) -> Tuple[tf_compat.Tensor, List[PruningOpVars], tf_compat.Tensor, tf_compat.Tensor]:
"""
Gets or creates model pruning (kernel sparsity) ops and vars in the graph
to be applied over a specific schedule.
Creates them for the var_names in the graph such that they follow a schedule
from begin_step to end_step starting at init_sparsity and ending at final_sparsity.
:param graph: the tf graph to pull the operator out of for applying the pruning to
:param global_step: the global optimizer step for the training graph
:param var_names: the names or regex patterns of names of variables to prune in the
graph
:param begin_step: the global step to begin pruning at
:param end_step: the global step to end pruning at
:param update_step_freq: the number of global steps between each weight update
:param init_sparsity: the starting value for sparsity of a
weight tensor to be enforce
:param final_sparsity: the end value for sparsity for a weight tensor to be enforce
:param exponent: the exponent to use for interpolating between
init_sparsity and final_sparsity higher values will lead to larger sparsity
steps at the beginning vs the end ie: linear (1) vs cubic (3)
:param leave_enabled: True to continue masking the weights after end_epoch,
False to stop masking
:param ks_group: the group identifier the scope should be created under
:param mask_creator: optional object to define sparisty mask creation
:return: a tuple containing the update operation to run in a session,
a list of the pruning ops and vars for each desired op in the graph,
the tensor containing the update_ready signal for the pruning ops,
the tensor containing the set sparsity for the pruning ops
"""
update_ready, sparsity = get_or_create_ks_schedule_ops(
global_step,
begin_step,
end_step,
update_step_freq,
init_sparsity,
final_sparsity,
exponent,
ks_group,
)
is_after_end_step = tf_compat.greater(global_step, end_step)
pruning_op_vars = get_or_create_graph_ops_pruning(
graph,
var_names,
sparsity,
update_ready,
leave_enabled,
is_after_end_step,
ks_group,
mask_creator,
)
update_op = get_scheduled_update_op(pruning_op_vars, ks_group)
return update_op, pruning_op_vars, update_ready, sparsity
[docs]def create_ks_scheduled_constant_graph_ops(
graph: tf_compat.Graph,
global_step: tf_compat.Variable,
var_names: List[str],
begin_step: int,
end_step: int,
ks_group: str,
) -> Tuple[tf_compat.Tensor, List[PruningOpVars]]:
"""
Creates constant model pruning ops. Does not modify the graph.
:param graph: the tf graph to pull the operator out of for applying the pruning to
:param global_step: the global optimizer step for the training graph
:param var_names: a list of names or regex patterns to create constant ops
for within the graph
:param begin_step: the global step to begin pruning at
:param end_step: the global step to end pruning at
:param ks_group: the group identifier the scope should be created under
:return: a tuple containing the update operation to run in a session,
a list of the pruning ops and vars for each desired op in the graph
"""
pruning_op_vars = []
is_start_step = tf_compat.equal(global_step, begin_step)
is_end_step = tf_compat.equal(global_step, end_step)
for op, op_input in get_ops_and_inputs_by_name_or_regex(var_names, graph):
op_vars = create_constant_op_pruning(
op, op_input, is_start_step, is_end_step, ks_group
)
pruning_op_vars.append(op_vars)
update_op = get_scheduled_update_op(pruning_op_vars, ks_group)
return update_op, pruning_op_vars