# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility / helper functions
"""
import logging
import os
import random
import re
import warnings
from collections import OrderedDict, namedtuple
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import numpy
import torch
from torch import Tensor
from torch.nn import Linear, Module, Parameter
from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
try:
quant_err = None
from torch.nn.qat import Conv2d as QATConv2d
from torch.nn.qat import Linear as QATLinear
from torch.quantization import QuantWrapper
except Exception as _err:
quant_err = _err
QuantWrapper = None
QATLinear = None
QATConv2d = None
from sparseml.utils import create_dirs, save_numpy
try:
from torch.nn.qat import Conv3d as QATConv3d
except Exception as _err:
quant_conv3d_err = _err
QATConv3d = None
__all__ = [
"default_device",
"device_of",
"get_optim_learning_rate",
"get_optim_groups_learning_rates",
"set_optim_learning_rate",
"early_stop_data_loader",
"infinite_data_loader",
"tensors_batch_size",
"tensors_to_device",
"tensors_to_precision",
"tensors_module_forward",
"tensor_export",
"tensors_export",
"tensor_density",
"tensor_sparsity",
"tensor_list_sparsity",
"tensor_sample",
"mask_difference",
"get_layer",
"replace_layer",
"get_terminal_layers",
"get_conv_layers",
"get_linear_layers",
"get_prunable_layers",
"get_quantizable_layers",
"get_named_layers_and_params_by_regex",
"any_str_or_regex_matches_param_name",
"NamedLayerParam",
"get_layer_param",
"set_deterministic_seeds",
"torch_distributed_zero_first",
"thin_model_from_checkpoint",
"MEMORY_BOUNDED",
"memory_aware_threshold",
]
_LOGGER = logging.getLogger(__name__)
##############################
#
# pytorch device helpers
#
##############################
[docs]def default_device() -> str:
"""
:return: the device that should be defaulted to for the current setup.
if multiple gpus are available then will return a string with all of them,
else if single gpu available then will return cuda,
else returns cpu
"""
if not torch.cuda.is_available():
return "cpu"
if torch.cuda.device_count() < 2:
return "cuda"
device_ids = [str(i) for i in range(torch.cuda.device_count())]
return "cuda:{}".format(",".join(device_ids))
[docs]def device_of(inputs: Any):
if isinstance(inputs, Tensor):
return inputs.device
elif isinstance(inputs, Mapping):
for tens in inputs.values():
return device_of(tens)
elif isinstance(inputs, Iterable):
return device_of(inputs[0])
else:
raise RuntimeError("Unknown type of inputs to device_of function")
return default_device()
##############################
#
# pytorch optim helpers
#
##############################
[docs]def get_optim_learning_rate(optim: Optimizer) -> float:
"""
:param optim: The optimizer to get the learning rate for
:return: convenience function to get the first learning rate for any of
the param groups in the optimizer
"""
for param_group in optim.param_groups:
return param_group["lr"]
raise RuntimeError("cannot get learning_rate, no param_groups available")
[docs]def get_optim_groups_learning_rates(optim: Optimizer) -> List[float]:
"""
:param optim: The optimizer to get the learning rates for
:return: get a list of tuples corresponding to the learning rates for the
param groups in the optimizer
"""
return [group["lr"] for group in optim.param_groups]
[docs]def set_optim_learning_rate(
optim: Optimizer, value: float, groups: Optional[List[int]] = None
):
"""
:param optim: The optimizer to set the learning rate for
:param value: the learning rate to set for the optimizer,
will set all param groups in the optim to this value
"""
for (index, group) in enumerate(optim.param_groups):
if not groups or index in groups:
group["lr"] = value
##############################
#
# pytorch data loader helpers
#
##############################
[docs]def early_stop_data_loader(data_loader: DataLoader, early_stop_steps: int):
"""
An iterator that goes through the data_loader for yields and stops
after early_stop_steps instead of the full loader
:param data_loader: the data loader to continually repeat
:param early_stop_steps: if set, the number of steps to run and break out early
instead of running all of the steps in the data loader,
if < 1 then will run the full length
:return: an iterable for the never ending data loader
"""
counter = 0
for data in data_loader:
yield data
counter += 1
if 0 < early_stop_steps <= counter:
break
[docs]def infinite_data_loader(
data_loader: DataLoader, early_stop_steps: int = -1, cache: bool = False
):
"""
A never ending data loader that will keep repeating the one passed in.
Will additionally cache the data if requested.
:param data_loader: the data loader to continually repeat
:param early_stop_steps: if set, the number of steps to run and break out early
instead of running all of the steps in the data loader
:param cache: True to cache the results in memory and return those on
subsequent requests, False otherwise
:return: an iterable for the never ending data loader
"""
cached = None
while True:
if not cache or cached is None:
cached = []
for data in early_stop_data_loader(data_loader, early_stop_steps):
if cache:
cached.append(deepcopy(data))
yield data
else:
for data in cached:
yield data
##############################
#
# pytorch tensor helper functions
#
##############################
NamedLayerParam = namedtuple(
"NamedLayerParam", ["layer_name", "layer", "param_name", "param"]
)
[docs]def tensors_batch_size(tensors: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]):
"""
Default function for getting the batch size from a tensor or collection of tensors.
Returns the batch size (zeroth index for shape) of the first found tensor.
Supported use cases:
- single tensor
- Dictionary of single tensors
- Dictionary of iterable of tensors
- Dictionary of dictionary of tensors
- Iterable of single tensors
- Iterable of iterable of tensors
- Iterable of dictionary of tensors
:param tensors: the tensor or collection of tensors to get a batch size from,
taken from the first found tensor
:return: the batch size (0th element of shape) of the first contained
tensor in the data
"""
if isinstance(tensors, Tensor):
return tensors.shape[0]
if isinstance(tensors, Dict):
for key, tens in tensors.items():
batch_size = tensors_batch_size(tens)
if batch_size > -1:
return batch_size
if isinstance(tensors, Iterable):
for tens in tensors:
batch_size = tensors_batch_size(tens)
if batch_size > -1:
return batch_size
return -1
[docs]def tensors_to_device(
tensors: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]], device: str
) -> Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]:
"""
Default function for putting a tensor or collection of tensors to the proper device.
Returns the tensor references after being placed on the proper device.
Supported use cases:
- single tensor
- Dictionary of single tensors
- Dictionary of iterable of tensors
- Dictionary of dictionary of tensors
- Iterable of single tensors
- Iterable of iterable of tensors
- Iterable of dictionary of tensors
:param tensors: the tensors or collection of tensors to put onto a device
:param device: the string representing the device to put the tensors on,
ex: 'cpu', 'cuda', 'cuda:1'
:return: the tensors or collection of tensors after being placed on the device
"""
if isinstance(tensors, Tensor):
return tensors.to(device)
if isinstance(tensors, OrderedDict):
return OrderedDict(
[(key, tensors_to_device(tens, device)) for key, tens in tensors.items()]
)
if isinstance(tensors, Dict):
return {key: tensors_to_device(tens, device) for key, tens in tensors.items()}
if isinstance(tensors, tuple):
return tuple(tensors_to_device(tens, device) for tens in tensors)
if isinstance(tensors, Iterable):
return [tensors_to_device(tens, device) for tens in tensors]
raise ValueError(
"unrecognized type for tensors given of {}".format(tensors.__class__.__name__)
)
[docs]def tensors_to_precision(
tensors: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]], full_precision: bool
) -> Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]:
"""
:param tensors: the tensors to change the precision of
:param full_precision: True for full precision (float 32) and
False for half (float 16)
:return: the tensors converted to the desired precision
"""
if isinstance(tensors, Tensor):
return tensors.float() if full_precision else tensors.half()
if isinstance(tensors, Dict):
return {
key: tensors_to_precision(tens, full_precision)
for key, tens in tensors.items()
}
if isinstance(tensors, tuple):
return tuple(tensors_to_precision(tens, full_precision) for tens in tensors)
if isinstance(tensors, Iterable):
return [tensors_to_precision(tens, full_precision) for tens in tensors]
raise ValueError(
"unrecognized type for tensors given of {}".format(tensors.__class__.__name__)
)
[docs]def tensors_module_forward(
tensors: Union[Tensor, Iterable[Tensor], Mapping[Any, Tensor]],
module: Module,
check_feat_lab_inp: bool = True,
) -> Any:
"""
Default function for calling into a model with data for a forward execution.
Returns the model result.
Note, if an iterable the features to be passed into the model are considered
to be at index 0 and other indices are for labels.
Supported use cases: single tensor,
iterable with first tensor taken as the features to pass into the model
:param tensors: the data to be passed into the model, if an iterable the features
to be passed into the model are considered to be at index 0 and other indices
are for labels
:param module: the module to pass the data into
:param check_feat_lab_inp: True to check if the incoming tensors looks like
it's made up of features and labels ie a tuple or list with 2 items
(typical output from a data loader) and will call into the model with just
the first element assuming it's the features False to not check
:return: the result of calling into the model for a forward pass
"""
if (
(isinstance(tensors, tuple) or isinstance(tensors, List))
and len(tensors) == 2
and check_feat_lab_inp
):
# assume if this is a list or tuple of 2 items that it is made up of
# (features, labels) pass the features into a recursive call for the model
return tensors_module_forward(tensors[0], module, check_feat_lab_inp=False)
if isinstance(tensors, Tensor):
return module(tensors)
if isinstance(tensors, Mapping):
return module(**tensors)
if isinstance(tensors, Iterable):
return module(*tensors)
raise ValueError(
"unrecognized type for data given of {}".format(tensors.__class__.__name__)
)
[docs]def tensor_export(
tensor: Union[Tensor, Dict[str, Tensor], Iterable[Tensor]],
export_dir: str,
name: str,
npz: bool = True,
) -> str:
"""
:param tensor: tensor to export to a saved numpy array file
:param export_dir: the directory to export the file in
:param name: the name of the file, .npy will be appended to it
:param npz: True to export as an npz file, False otherwise
:return: the path of the numpy file the tensor was exported to
"""
if isinstance(tensor, Tensor):
tensor = tensor.detach().cpu().numpy()
elif isinstance(tensor, Dict):
tensor = OrderedDict(
(key, val.detach().cpu().numpy()) for key, val in tensor.items()
)
elif isinstance(tensor, Iterable):
tensor = [
val.detach().cpu().numpy() if isinstance(val, Tensor) else val
for val in tensor
]
else:
raise ValueError("Unrecognized type given for tensorr {}".format(tensor))
return save_numpy(tensor, export_dir, name, npz)
[docs]def tensors_export(
tensors: Union[Tensor, Iterable[Tensor]],
export_dir: str,
name_prefix: str,
counter: int = 0,
break_batch: bool = False,
) -> List[str]:
"""
:param tensors: the tensors to export to a saved numpy array file
:param export_dir: the directory to export the files in
:param name_prefix: the prefix name for the tensors to save as, will append
info about the position of the tensor in a list or dict in addition
to the .npy file format
:param counter: the current counter to save the tensor at
:param break_batch: treat the tensor as a batch and break apart into
multiple tensors
:return: the exported paths
"""
create_dirs(export_dir)
exported_paths = []
if break_batch:
_tensors_export_batch(tensors, export_dir, name_prefix, counter, exported_paths)
else:
_tensors_export_recursive(
tensors, export_dir, name_prefix, counter, exported_paths
)
return exported_paths
def _tensors_export_recursive(
tensors: Union[Tensor, Iterable[Tensor]],
export_dir: str,
name_prefix: str,
counter: int,
exported_paths: List[str],
):
if isinstance(tensors, Tensor):
exported_paths.append(
tensor_export(tensors, export_dir, "{}-{:04d}".format(name_prefix, counter))
)
return
if isinstance(tensors, Iterable):
for index, tens in enumerate(tensors):
_tensors_export_recursive(
tens,
export_dir,
name_prefix,
counter + index,
exported_paths,
)
return
raise ValueError(
"unrecognized type for tensors given of {}".format(tensors.__class__.__name__)
)
def _tensors_export_batch(
tensors: Union[Tensor, Iterable[Tensor]],
export_dir: str,
name_prefix: str,
counter: int,
exported_paths: List[str],
):
if isinstance(tensors, Tensor):
if len(tensors.shape) == 1:
exported_paths.append(
tensor_export(
tensors, export_dir, "{}-{:04d}".format(name_prefix, counter)
)
)
return
for index, tens in enumerate(tensors):
exported_paths.append(
tensor_export(
tens, export_dir, "{}-{:04d}".format(name_prefix, counter + index)
)
)
return
if isinstance(tensors, Iterable):
for index, tens in enumerate(zip(*tensors)):
exported_paths.append(
tensor_export(
tens, export_dir, "{}-{:04d}".format(name_prefix, counter + index)
)
)
return
raise ValueError(
"unrecognized type for tensors given of {}".format(tensors.__class__.__name__)
)
[docs]def tensor_sparsity(
tens: Tensor, dim: Union[None, int, List[int], Tuple[int, ...]] = None
) -> Tensor:
"""
:param tens: the tensor to calculate the sparsity for
:param dim: the dimension(s) to split the calculations over;
ex, can split over batch, channels, or combos
:return: the sparsity of the input tens, ie the fraction of numbers that are zero
"""
if dim is None:
zeros = (tens == 0).sum()
total = tens.numel()
return zeros.float() / float(total)
if isinstance(dim, int):
dim = [dim]
if max(dim) >= len(tens.shape):
raise ValueError(
"Unsupported dim given of {} in {} for tensor shape {}".format(
max(dim), dim, tens.shape
)
)
sum_dims = [ind for ind in range(len(tens.shape)) if ind not in dim]
zeros = (tens == 0).sum(dim=sum_dims) if sum_dims else tens == 0
total = numpy.prod(
[tens.shape[ind] for ind in range(len(tens.shape)) if ind not in dim]
)
permute_order = sorted(
((d, len(dim) - i - 1) for i, d in enumerate(dim)), reverse=True
)
permute = [d[1] for d in permute_order]
if permute != [i for i in range(len(permute))]:
# need to permute to get desired dimensions at the front
zeros = zeros.permute(*permute).contiguous()
return zeros.float() / float(total)
[docs]def tensor_density(tens: Tensor, dim: Union[None, int, Iterable[int]] = None) -> Tensor:
"""
:param tens: the tensor to calculate the density for
:param dim: the dimension(s) to split the calculations over; ex, can split over
batch, channels, or combos
:return: the density of the input tens, ie the fraction of numbers that are non zero
"""
density = (tensor_sparsity(tens, dim) - 1.0) * -1.0
return density
[docs]def tensor_sample(
tens: Tensor,
sample_size: int,
dim: Union[None, int, List[int], Tuple[int, ...]] = None,
) -> Tensor:
"""
:param tens: the tensor to grab samples from
:param sample_size: the number of samples to grab overall if dim is not supplied
or per each dim if it is
:param dim: the dimension(s) to split the samples over;
ex, can split over batch, channels, or combos
:return: the sampled tensor
"""
if sample_size < 1:
raise ValueError("improper sample size given of {}".format(sample_size))
if dim is None:
indices = tens.new_zeros((sample_size,)).long().random_(0, tens.numel())
samples = tens.view(-1)[indices]
return samples
if isinstance(dim, int):
dim = [dim]
if max(dim) >= len(tens.shape):
raise ValueError(
"Unsupported dim given of {} in {} for tensor shape {}".format(
max(dim), dim, tens.shape
)
)
if dim != [ind for ind in range(len(dim))]:
# put the desired dimension(s) at the front to sample from
tens = tens.permute(
*dim, *[ind for ind in range(len(tens.shape)) if ind not in dim]
)
dim = [ind for ind in range(len(dim))]
if not tens.is_contiguous():
tens = tens.contiguous()
num_indices = int(numpy.prod([tens.shape[ind] for ind in range(len(dim))]))
elem_per_ind = int(
numpy.prod([tens.shape[ind] for ind in range(len(dim), len(tens.shape))])
)
# create a new tensor with offsets set for each of our elements that we are indexing
indices = tens.new_tensor(
[ind * elem_per_ind for ind in range(num_indices)], dtype=torch.long
).unsqueeze(1)
# now broadcast it across to the total number of elements we should end with
indices = indices * tens.new_ones((num_indices, sample_size), dtype=torch.long)
# finally add in a random number within the available range per index
indices += tens.new_zeros((num_indices, sample_size), dtype=torch.long).random_(
0, elem_per_ind
)
# get our samples
samples = tens.view(-1)[indices.view(-1)]
# reshape for the proper dimension
samples = samples.view(*(tens.shape[ind] for ind in dim), sample_size)
return samples
[docs]def tensor_list_sparsity(tensors: List[Tensor]) -> float:
"""
:param tensors: the list of tensors to calculate the sparsity for
:return: the total sparsity of all tensors in the list
"""
zeros = 0
numel = 0
for tensor in tensors:
zeros += (tensor == 0).sum().item()
numel += tensor.numel()
return float(zeros) / float(numel)
[docs]def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor:
"""
:param old_mask: the old mask to compare against for calculating the difference
:param new_mask: the new mask to compare with for calculating the difference
:return: a tensor representing the change from the old_mask to the new_mask
specifically values returned as 1.0 are newly unmasked (0.0 => 1.0)
values returned as -1.0 are newly masked (1.0 => 0.0)
values returned as 0.0 had no change in (0.0 => 0.0 or 1.0 => 1.0)
"""
newly_masked = ((old_mask != new_mask) & (new_mask == 0.0)).type(old_mask.type())
newly_unmasked = ((old_mask != new_mask) & (new_mask == 1.0)).type(old_mask.type())
return -1.0 * newly_masked + newly_unmasked
##############################
#
# pytorch module helper functions
#
##############################
[docs]def get_layer(name: str, module: Module) -> Module:
"""
:param name: the name of the layer to grab from the module
:param module: the module containing the layer to grab
:return: the module representing the layer in the module
"""
layers = name.split(".")
layer = module
for name in layers:
layer = layer.__getattr__(name)
return layer
[docs]def replace_layer(
module: Module,
name: str,
replace: Module,
) -> Module:
"""
General function to replace a layer in a module with the given new one.
:param module: the module to replace the layer in
:param name: the name of the layer to replace the activation for
:param replace: the module to replace the layer with
:return: the original layer that was replaced
"""
parent = module
sections = name.split(".")
for sec in sections[:-1]:
parent = parent.__getattr__(sec)
cur = parent.__getattr__(sections[-1])
parent.__setattr__(sections[-1], replace)
return cur
[docs]def get_terminal_layers(module: Module) -> Dict[str, Module]:
"""
:param module: the module to grab all terminal layers for
:return: a list of all of the terminal layers in a model
(ie not containers; so convs, linears, activations, etc)
"""
terminal = {}
for mod_name, mod in module.named_modules():
# check if it is a root node (only has itself in named_modules)
child_count = 0
for _, __ in mod.named_modules():
child_count += 1
if child_count != 1:
continue
terminal[mod_name] = mod
return terminal
[docs]def get_conv_layers(module: Module) -> Dict[str, Module]:
"""
:param module: the module to grab all conv layers for
:return: a list of all the conv layers in the module
"""
return {
name: mod for name, mod in module.named_modules() if isinstance(mod, _ConvNd)
}
[docs]def get_linear_layers(module: Module) -> Dict[str, Module]:
"""
:param module: the module to grab all linear layers for
:return: a list of all linear layers in the module
"""
return {
name: mod for name, mod in module.named_modules() if isinstance(mod, Linear)
}
[docs]def get_prunable_layers(module: Module) -> List[Tuple[str, Module]]:
"""
:param module: the module to get the prunable layers from
:return: a list containing the names and modules of the prunable layers
(Linear, ConvNd)
"""
return [
(name, mod)
for (name, mod) in module.named_modules()
if (
isinstance(mod, Linear)
or isinstance(mod, _ConvNd)
or (QATLinear and isinstance(mod, QATLinear))
or (QATConv2d and isinstance(mod, QATConv2d))
or (QATConv3d and isinstance(mod, QATConv3d))
)
]
[docs]def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]:
"""
:param module: the module to get the quantizable layers from
:return: a list containing the names and modules of the quantizable layers
(Linear, Conv2d, Conv3d)
"""
if QATLinear is None:
raise ImportError(
"PyTorch version is not setup for Quantization. "
"Please install a QAT compatible version of PyTorch"
)
return [
(name, mod)
for (name, mod) in module.named_modules()
if (
isinstance(mod, Linear)
or isinstance(mod, Conv2d)
or (QATConv3d and isinstance(mod, Conv3d))
)
]
def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]:
"""
:param module: the module to get the quantized layers from
:return: a list containing the names and modules of the quantized layers
(Linear, Conv2d, Conv3d)
"""
if QATLinear is None:
raise ImportError(
"PyTorch version is not setup for Quantization. "
"Please install a QAT compatible version of PyTorch"
)
quantized_layers = []
for (name, mod) in module.named_modules():
if (
(QATLinear and isinstance(mod, QATLinear))
or (QATConv2d and isinstance(mod, QATConv2d))
or (QATConv3d and isinstance(mod, QATConv3d))
):
quantized_layers.append((name, mod))
elif isinstance(mod, Conv3d) and not QATConv3d:
warnings.warn(
"Pytorch version is not setup for Conv3D Quantization. "
"Quantization of Conv3D layers will be skipped",
UserWarning,
)
return quantized_layers
[docs]def get_layer_param(param: str, layer: str, module: Module) -> Parameter:
"""
:param param: the name of the param to grab from the layer
:param layer: the name of the layer to grab from the module
:param module: the module containing the layer and the param
:return: the param taken from the given layer in the module
"""
layer = get_layer(layer, module) # type: Module
param = layer.__getattr__(param) # type: Parameter
return param
[docs]def get_named_layers_and_params_by_regex(
module: Module,
param_names: List[str],
params_strict: bool = False,
) -> List[NamedLayerParam]:
"""
:param module: the module to get the matching layers and params from
:param param_names: a list of names or regex patterns to match with full parameter
paths. Regex patterns must be specified with the prefix 're:'
:param params_strict: if True, this function will raise an exception if there a
parameter is not found to match every name or regex in param_names
:return: a list of NamedLayerParam tuples whose full parameter names in the given
module match one of the given regex patterns or parameter names
"""
named_layers_and_params = []
found_param_names = []
for layer_name, layer in module.named_modules():
for param_name, param in layer.named_parameters():
if "." in param_name: # skip parameters of nested layers
continue
full_param_name = "{}.{}".format(layer_name, param_name)
if any_str_or_regex_matches_param_name(full_param_name, param_names):
named_layers_and_params.append(
NamedLayerParam(layer_name, layer, param_name, param)
)
found_param_names.append(full_param_name)
elif layer_name.endswith(".module"):
# unwrap layers wrapped with a QuantWrapper and check if they match
parent_layer_name = ".".join(layer_name.split(".")[:-1])
parent_layer = get_layer(parent_layer_name, module)
skip_wrapper_name = "{}.{}".format(parent_layer_name, param_name)
if (
QuantWrapper is not None
and isinstance(parent_layer, QuantWrapper)
and any_str_or_regex_matches_param_name(
skip_wrapper_name, param_names
)
):
named_layers_and_params.append(
NamedLayerParam(layer_name, layer, param_name, param)
)
found_param_names.append(skip_wrapper_name)
if params_strict:
validate_all_params_found(param_names, found_param_names)
return named_layers_and_params
[docs]def any_str_or_regex_matches_param_name(
param_name: str,
name_or_regex_patterns: List[str],
) -> bool:
"""
:param param_name: The name of a parameter
:param name_or_regex_patterns: List of full param names to match to the input or
regex patterns to match with that should be prefixed with 're:'
:return: True if any given str or regex pattern matches the given name
"""
for name_or_regex in name_or_regex_patterns:
if name_or_regex[:3] == "re:":
pattern = name_or_regex[3:]
if re.match(pattern, param_name):
return True
else:
if param_name == name_or_regex:
return True
return False
def validate_all_params_found(
name_or_regex_patterns: List[str],
found_param_names: List[str],
):
"""
:param name_or_regex_patterns: List of full param names or regex patterns of them
to check for matches in found_param_names names
:param found_param_names: List of NamedLayerParam objects to check for matches
:raise RuntimeError: If there is a name or regex pattern that does not have a
match in found_param_names
"""
for name_or_regex in name_or_regex_patterns:
if "re:" != name_or_regex[:3] and name_or_regex in found_param_names:
continue # name found in list of full parameter names
if "re:" == name_or_regex[:3] and any(
re.match(name_or_regex[3:], name) for name in found_param_names
):
continue # regex pattern matches at least one full parameter name
raise RuntimeError(
"All supplied parameter names or regex patterns not found."
"No match for {} in found parameters {}. \nSupplied {}".format(
name_or_regex, found_param_names, name_or_regex_patterns
)
)
[docs]def set_deterministic_seeds(seed: int = 0):
"""
Manually seeds the numpy, random, and torch packages.
Also sets torch.backends.cudnn.deterministic to True
:param seed: the manual seed to use. Default is 0
"""
numpy.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
[docs]@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each
local 0 ranked process to do something.
:param local_rank: the local rank of this process
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
[docs]def thin_model_from_checkpoint(model: Module, state_dict: Dict[str, Any]):
"""
Updates any Linear/Conv/BN layers in the given model to match their
respective shapes in the given state dict. Purpose of compatibility
when loading weight for a model from a checkpoint of the same architecture
but with potentially structured thinning applied. Note that this function
has no guarantees on accuracy, will only resize model parameters for
loading compatibility. All adjustments done in place
:param model: model to potentially adjust parameter shapes of
:param state_dict: state dict to infer parameter shapes from
"""
first_thinned = True
for param_name, checkpoint_tens in state_dict.items():
if not param_name.endswith(".weight"):
continue # only deal with weight params of modules
layer_name = param_name[:-7]
layer = get_layer(layer_name, model)
if not hasattr(layer, "weight") or (
layer.weight.shape == checkpoint_tens.shape
):
continue # skip if there is no update to shape
# quick check that target layer is some flavor of FC/Conv/BN
layer_type = layer.__class__.__name__
if not (
"Linear" not in layer_type
or "Conv" not in layer_type
or ("BatchNorm" not in layer_type)
):
continue
orig_shape = layer.weight.shape
target_shape = checkpoint_tens.shape
# update weight param + grad
if len(target_shape) > 1:
layer.weight.data = layer.weight.data[
: target_shape[0], : target_shape[1], ...
]
if layer.weight.grad is not None:
layer.weight.grad = layer.weight.grad[
: target_shape[0], : target_shape[1], ...
]
else:
layer.weight.data = layer.weight.data[: target_shape[0]]
if layer.weight.grad is not None:
layer.weight.grad = layer.weight.grad[: target_shape[0]]
# update bias param + grad
if hasattr(layer, "bias") and layer.bias is not None:
# target output channels should be the first dim of target shape
layer.bias.data = layer.bias.data[: target_shape[0]]
if layer.bias.grad is not None:
layer.bias.grad = layer.bias.grad[: target_shape[0]]
# update layer attributes
if "BatchNorm" in layer_type:
if hasattr(layer, "num_features"):
layer.num_features = layer.weight.size(0)
# BN running mean and var are not stored as Parameters
if hasattr(layer, "running_mean"):
layer.running_mean = torch.zeros_like(layer.running_mean)[
: target_shape[0]
]
if hasattr(layer, "running_var"):
layer.running_var = torch.zeros_like(layer.running_var)[
: target_shape[0]
]
if "Linear" in layer_type:
if hasattr(layer, "out_features"):
layer.out_features = layer.weight.shape[0]
if hasattr(layer, "in_features"):
layer.in_features = layer.weight.shape[1]
if "Conv" in layer_type:
if hasattr(layer, "out_channels"):
layer.out_channels = layer.weight.shape[0]
if hasattr(layer, "in_channels"):
layer.in_channels = layer.weight.shape[1]
if hasattr(layer, "groups") and layer.groups > 1:
layer.groups = layer.weight.shape[0] // layer.weight.shape[1]
if first_thinned:
_LOGGER.info(
"Thinning module layers for compatibility with given state dict:"
)
first_thinned = False
_LOGGER.info(
f"Thinned layer {layer_name} from shape {orig_shape} to "
f"{layer.weight.shape}"
)
##############################
#
# misc pytorch helper functions
#
##############################
MEMORY_BOUNDED = "MEMORY_BOUNDED"
[docs]def memory_aware_threshold(tensor: torch.Tensor, idx: int) -> Tensor:
"""
Finds a threshold at the lookup idx in the most efficient way with available
resources. Will be phased out when GPU-memory overhead of torch.sort reduces,
or when torch.kthvalue becomes faster than torch.sort.
:param tensor: A tensor to find a k-th smallest value in, where k=idx+1
:param idx: A lookup index
:return: k-th smallest value from the given tensor, where k=idx+1
"""
try:
if (
MEMORY_BOUNDED in os.environ
and os.environ[MEMORY_BOUNDED].lower() == "true"
):
return torch.kthvalue(tensor.view(-1), idx + 1)[0]
else:
return torch.sort(tensor.view(-1))[0][idx]
except RuntimeError:
_LOGGER.warning(
"Finding threshold from sparsity failed due to lack of memory, "
"will attempt to recover. Consider setting env variable "
f"{MEMORY_BOUNDED}=True in future runs."
)
torch.cuda.empty_cache()
os.environ[MEMORY_BOUNDED] = "True"
return torch.kthvalue(tensor.view(-1), idx + 1)[0]