Source code for sparseml.pytorch.utils.exporter

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Export PyTorch models to the local device
"""
import collections
import logging
import os
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy
import onnx
import torch
from onnx import numpy_helper
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from sparseml.pytorch.utils.helpers import (
    tensors_export,
    tensors_module_forward,
    tensors_to_device,
)
from sparseml.pytorch.utils.model import (
    is_parallel_model,
    save_model,
    script_model,
    trace_model,
)
from sparseml.utils import clean_path, create_parent_dirs


__all__ = [
    "ModuleExporter",
    "export_onnx",
]


DEFAULT_ONNX_OPSET = 9 if torch.__version__ < "1.3" else 11
_LOGGER = logging.getLogger(__name__)


[docs]class ModuleExporter(object): """ An exporter for exporting PyTorch modules into ONNX format as well as numpy arrays for the input and output tensors. :param module: the module to export :param output_dir: the directory to export the module and extras to """ def __init__( self, module: Module, output_dir: str, ): if is_parallel_model(module): module = module.module self._module = deepcopy(module).to("cpu").eval() self._output_dir = clean_path(output_dir)
[docs] def export_to_zoo( self, dataloader: DataLoader, original_dataloader: Optional[DataLoader] = None, shuffle: bool = False, max_samples: int = 20, data_split_cb: Optional[Callable[[Any], Tuple[Any, Any]]] = None, label_mapping_cb: Optional[Callable[[Any], Any]] = None, trace_script: bool = False, fail_on_torchscript_failure: bool = True, export_entire_model: bool = False, ): """ Creates and exports all related content of module including sample data, onnx, pytorch and torchscript. :param dataloader: DataLoader used to generate sample data :param original_dataloader: Optional dataloader to obtain the untransformed image. :param shuffle: Whether to shuffle sample data :param max_samples: Max number of sample data to create :param data_split_cb: Optional callback function to split data sample into a tuple (features,labels). If not provided will assume dataloader returns a tuple (features,labels). :param label_mapping_cb: Optional callback function to mapping dataset label to other formats. :param dataset_wrapper: Wrapper function for the dataset to add original data to each sample. If set to None will default to use the 'iter_dataset_with_orig_wrapper' function. :param trace_script: If true, creates torchscript via tracing. Otherwise, creates the torchscripe via scripting. :param fail_on_torchscript_failure: If true, fails if torchscript is unable to export model. :param export_entire_model: Exports entire file instead of state_dict """ sample_batches = [] sample_labels = [] sample_originals = None if original_dataloader is not None: sample_originals = [] for originals in original_dataloader: sample_originals.append(originals) if len(sample_originals) == max_samples: break for sample in dataloader: if data_split_cb is not None: features, labels = data_split_cb(sample) else: features, labels = sample if label_mapping_cb: labels = label_mapping_cb(labels) sample_batches.append(features) sample_labels.append(labels) if len(sample_batches) == max_samples: break self.export_onnx(sample_batch=sample_batches[0]) self.export_pytorch(export_entire_model=export_entire_model) try: if trace_script: self.export_torchscript(sample_batch=sample_batches[0]) else: self.export_torchscript() except Exception as e: if fail_on_torchscript_failure: raise e else: _LOGGER.warn( f"Unable to create torchscript file. Following error occurred: {e}" ) self.export_samples( sample_batches, sample_labels=sample_labels, sample_originals=sample_originals, )
[docs] @classmethod def get_output_names(cls, out: Any): """ Get name of output tensors. Derived exporters specific to frameworks could override this method :param out: outputs of the model :return: list of names """ return _get_output_names(out)
[docs] def export_onnx( self, sample_batch: Any, name: str = "model.onnx", opset: int = DEFAULT_ONNX_OPSET, disable_bn_fusing: bool = True, convert_qat: bool = False, **export_kwargs, ): """ Export an onnx file for the current module and for a sample batch. Sample batch used to feed through the model to freeze the graph for a particular execution. :param sample_batch: the batch to export an onnx for, handles creating the static graph for onnx as well as setting dimensions :param name: name of the onnx file to save :param opset: onnx opset to use for exported model. Default is 11, if torch version is 1.2 or below, default is 9 :param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm fusing during torch export. Default and suggested setting is True. Batch norm fusing will change the exported parameter names as well as affect sensitivity analyses of the exported graph. Additionally, the DeepSparse inference engine, and other engines, perform batch norm fusing at model compilation. :param convert_qat: if True and quantization aware training is detected in the module being exported, the resulting QAT ONNX model will be converted to a fully quantized ONNX model using `quantize_torch_qat_export`. Default is False. :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api call. Useful to pass in dyanmic_axes, input_names, output_names, etc. See more on the torch.onnx.export api spec in the PyTorch docs: https://pytorch.org/docs/stable/onnx.html """ if not export_kwargs: export_kwargs = {} if "output_names" not in export_kwargs: sample_batch = tensors_to_device(sample_batch, "cpu") module = deepcopy(self._module).cpu() module.eval() with torch.no_grad(): out = tensors_module_forward( sample_batch, module, check_feat_lab_inp=False ) export_kwargs["output_names"] = self.get_output_names(out) export_onnx( module=self._module, sample_batch=sample_batch, file_path=os.path.join(self._output_dir, name), opset=opset, disable_bn_fusing=disable_bn_fusing, convert_qat=convert_qat, **export_kwargs, )
[docs] def export_torchscript( self, name: str = "model.pts", sample_batch: Optional[Any] = None, ): """ Export the torchscript into a pts file within a framework directory. If a sample batch is provided, will create torchscript model in trace mode. Otherwise uses script to create torchscript. :param name: name of the torchscript file to save :param sample_batch: If provided, will create torchscript model via tracing using the sample_batch """ path = os.path.join(self._output_dir, "framework", name) create_parent_dirs(path) if sample_batch: trace_model(path, self._module, sample_batch) else: script_model(path, self._module)
[docs] def export_pytorch( self, optimizer: Optional[Optimizer] = None, recipe: Optional[str] = None, epoch: Optional[int] = None, name: str = "model.pth", use_zipfile_serialization_if_available: bool = True, include_modifiers: bool = False, export_entire_model: bool = False, arch_key: Optional[str] = None, ): """ Export the pytorch state dicts into pth file within a pytorch framework directory. :param optimizer: optional optimizer to export along with the module :param recipe: the recipe used to obtain the model :param epoch: optional epoch to export along with the module :param name: name of the pytorch file to save :param use_zipfile_serialization_if_available: for torch >= 1.6.0 only exports the Module's state dict using the new zipfile serialization :param include_modifiers: if True, and a ScheduledOptimizer is provided as the optimizer, the associated ScheduledModifierManager and its Modifiers will be exported under the 'manager' key. Default is False :param export_entire_model: Exports entire file instead of state_dict :param arch_key: if provided, the `arch_key` will be saved in the checkpoint """ pytorch_path = os.path.join(self._output_dir, "framework") pth_path = os.path.join(pytorch_path, name) create_parent_dirs(pth_path) if export_entire_model: torch.save(self._module, pth_path) else: save_model( pth_path, self._module, optimizer, recipe, epoch, use_zipfile_serialization_if_available=( use_zipfile_serialization_if_available ), include_modifiers=include_modifiers, arch_key=arch_key, )
[docs] def export_samples( self, sample_batches: List[Any], sample_labels: Optional[List[Any]] = None, sample_originals: Optional[List[Any]] = None, exp_counter: int = 0, ): """ Export a set list of sample batches as inputs and outputs through the model. :param sample_batches: a list of the sample batches to feed through the module for saving inputs and outputs :param sample_labels: an optional list of sample labels that correspond to the the batches for saving :param exp_counter: the counter to start exporting the tensor files at """ sample_batches = [tensors_to_device(batch, "cpu") for batch in sample_batches] inputs_dir = os.path.join(self._output_dir, "sample-inputs") outputs_dir = os.path.join(self._output_dir, "sample-outputs") labels_dir = os.path.join(self._output_dir, "sample-labels") originals_dir = os.path.join(self._output_dir, "sample-originals") with torch.no_grad(): for batch, lab, orig in zip( sample_batches, sample_labels if sample_labels else [None for _ in sample_batches], sample_originals if sample_originals else [None for _ in sample_batches], ): out = tensors_module_forward(batch, self._module) exported_input = tensors_export( batch, inputs_dir, name_prefix="inp", counter=exp_counter, break_batch=True, ) if isinstance(out, dict): new_out = [] for key in out: new_out.append(out[key]) out = new_out exported_output = tensors_export( out, outputs_dir, name_prefix="out", counter=exp_counter, break_batch=True, ) if lab is not None: tensors_export( lab, labels_dir, "lab", counter=exp_counter, break_batch=True ) if orig is not None: tensors_export( orig, originals_dir, "orig", counter=exp_counter, break_batch=True, ) assert len(exported_input) == len(exported_output) exp_counter += len(exported_input)
[docs]def export_onnx( module: Module, sample_batch: Any, file_path: str, opset: int = DEFAULT_ONNX_OPSET, disable_bn_fusing: bool = True, convert_qat: bool = False, dynamic_axes: Union[str, Dict[str, List[int]]] = None, skip_input_quantize: bool = False, **export_kwargs, ): """ Export an onnx file for the current module and for a sample batch. Sample batch used to feed through the model to freeze the graph for a particular execution. :param module: torch Module object to export :param sample_batch: the batch to export an onnx for, handles creating the static graph for onnx as well as setting dimensions :param file_path: path to the onnx file to save :param opset: onnx opset to use for exported model. Default is 11, if torch version is 1.2 or below, default is 9 :param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm fusing during torch export. Default and suggested setting is True. Batch norm fusing will change the exported parameter names as well as affect sensitivity analyses of the exported graph. Additionally, the DeepSparse inference engine, and other engines, perform batch norm fusing at model compilation. :param convert_qat: if True and quantization aware training is detected in the module being exported, the resulting QAT ONNX model will be converted to a fully quantized ONNX model using `quantize_torch_qat_export`. Default is False. :param dynamic_axes: dictionary of input or output names to list of dimensions of those tensors that should be exported as dynamic. May input 'batch' to set the first dimension of all inputs and outputs to dynamic. Default is an empty dict :param skip_input_quantize: if True, the export flow will attempt to delete the first Quantize Linear Nodes(s) immediately after model input and set the model input type to UINT8. Default is False :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api call. Useful to pass in dyanmic_axes, input_names, output_names, etc. See more on the torch.onnx.export api spec in the PyTorch docs: https://pytorch.org/docs/stable/onnx.html """ if not export_kwargs: export_kwargs = {} if isinstance(sample_batch, Dict) and not isinstance( sample_batch, collections.OrderedDict ): warnings.warn( "Sample inputs passed into the ONNX exporter should be in " "the same order defined in the model forward function. " "Consider using OrderedDict for this purpose.", UserWarning, ) sample_batch = tensors_to_device(sample_batch, "cpu") create_parent_dirs(file_path) module = deepcopy(module).cpu() module.eval() with torch.no_grad(): out = tensors_module_forward(sample_batch, module, check_feat_lab_inp=False) if "input_names" not in export_kwargs: if isinstance(sample_batch, Tensor): export_kwargs["input_names"] = ["input"] elif isinstance(sample_batch, Dict): export_kwargs["input_names"] = list(sample_batch.keys()) sample_batch = tuple( [sample_batch[f] for f in export_kwargs["input_names"]] ) elif isinstance(sample_batch, Iterable): export_kwargs["input_names"] = [ "input_{}".format(index) for index, _ in enumerate(iter(sample_batch)) ] if isinstance(sample_batch, List): sample_batch = tuple(sample_batch) # torch.onnx.export requires tuple if "output_names" not in export_kwargs: export_kwargs["output_names"] = _get_output_names(out) if dynamic_axes == "batch": dynamic_axes = { tensor_name: {0: "batch"} for tensor_name in ( export_kwargs["input_names"] + export_kwargs["output_names"] ) } # disable active quantization observers because they cannot be exported disabled_observers = [] for submodule in module.modules(): if ( hasattr(submodule, "observer_enabled") and submodule.observer_enabled[0] == 1 ): submodule.observer_enabled[0] = 0 disabled_observers.append(submodule) is_quant_module = any( hasattr(submodule, "qconfig") and submodule.qconfig for submodule in module.modules() ) batch_norms_wrapped = False if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing: # prevent batch norm fusing by adding a trivial operation before every # batch norm layer batch_norms_wrapped = _wrap_batch_norms(module) torch.onnx.export( module, sample_batch, file_path, strip_doc_string=True, verbose=False, opset_version=opset, dynamic_axes=dynamic_axes, **export_kwargs, ) # re-enable disabled quantization observers for submodule in disabled_observers: submodule.observer_enabled[0] = 1 # onnx file fixes onnx_model = onnx.load(file_path) # fix changed batch norm names _fix_batch_norm_names(onnx_model) if batch_norms_wrapped: # clean up graph from any injected / wrapped operations _delete_trivial_onnx_adds(onnx_model) onnx.save(onnx_model, file_path) if convert_qat and is_quant_module: # overwrite exported model with fully quantized version # import here to avoid cyclic dependency from sparseml.pytorch.sparsification.quantization import ( quantize_torch_qat_export, ) use_qlinearconv = hasattr(module, "export_with_qlinearconv") and ( module.export_with_qlinearconv ) quantize_torch_qat_export( model=file_path, output_file_path=file_path, use_qlinearconv=use_qlinearconv, ) if skip_input_quantize: try: # import here to avoid cyclic dependency from sparseml.pytorch.sparsification.quantization import ( skip_onnx_input_quantize, ) skip_onnx_input_quantize(file_path, file_path) except Exception as e: _LOGGER.warning( f"Unable to skip input QuantizeLinear op with exception {e}" )
def _get_output_names(out: Any): """ Get name of output tensors :param out: outputs of the model :return: list of names """ output_names = None if isinstance(out, Tensor): output_names = ["output"] elif hasattr(out, "keys") and callable(out.keys): output_names = list(out.keys()) elif isinstance(out, Iterable): output_names = ["output_{}".format(index) for index, _ in enumerate(iter(out))] return output_names class _AddNoOpWrapper(Module): # trivial wrapper to break-up Conv-BN blocks def __init__(self, module: Module): super().__init__() self.module = module def forward(self, inp): inp = inp + 0 # no-op return self.module(inp) def _get_submodule(module: Module, path: List[str]) -> Module: if not path: return module return _get_submodule(getattr(module, path[0]), path[1:]) def _wrap_batch_norms(module: Module) -> bool: # wrap all batch norm layers in module with a trivial wrapper # to prevent BN fusing during export batch_norms_wrapped = False for name, submodule in module.named_modules(): if ( isinstance(submodule, torch.nn.BatchNorm1d) or isinstance(submodule, torch.nn.BatchNorm2d) or isinstance(submodule, torch.nn.BatchNorm3d) ): submodule_path = name.split(".") parent_module = _get_submodule(module, submodule_path[:-1]) setattr(parent_module, submodule_path[-1], _AddNoOpWrapper(submodule)) batch_norms_wrapped = True return batch_norms_wrapped def _delete_trivial_onnx_adds(model: onnx.ModelProto): # delete all add nodes in the graph with second inputs as constant nodes set to 0 add_nodes = [node for node in model.graph.node if node.op_type == "Add"] for add_node in add_nodes: try: add_const_node = [ node for node in model.graph.node if node.output[0] == add_node.input[1] ][0] add_const_val = numpy_helper.to_array(add_const_node.attribute[0].t) if numpy.all(add_const_val == 0.0): # update graph edges parent_node = [ node for node in model.graph.node if add_node.input[0] in node.output ] if not parent_node: continue parent_node[0].output[0] = add_node.output[0] # remove node and constant model.graph.node.remove(add_node) model.graph.node.remove(add_const_node) except Exception: # skip node on any error continue def _fix_batch_norm_names(model: onnx.ModelProto): name_to_inits = {init.name: init for init in model.graph.initializer} for node in model.graph.node: if node.op_type != "BatchNormalization": continue for idx in range(len(node.input)): init_name = node.input[idx] name_parts = init_name.split(".") if ( init_name not in name_to_inits or len(name_parts) < 2 or (name_parts[-2] != "module") ): continue del name_parts[-2] new_name = ".".join(name_parts) if new_name not in name_to_inits: init = name_to_inits[init_name] del name_to_inits[init_name] init.name = new_name node.input[idx] = new_name name_to_inits[new_name] = init