Source code for deepsparse.utils.onnx

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import logging
import os
import tempfile
from typing import List, Union

import numpy
import onnx


try:
    from sparsezoo import Zoo
    from sparsezoo.objects import File, Model

    sparsezoo_import_error = None
except Exception as sparsezoo_err:
    Zoo = None
    Model = object
    File = object
    sparsezoo_import_error = sparsezoo_err


__all__ = [
    "ONNX_TENSOR_TYPE_MAP",
    "model_to_path",
    "get_external_inputs",
    "get_external_outputs",
    "get_input_names",
    "get_output_names",
    "generate_random_inputs",
    "override_onnx_batch_size",
    "override_onnx_input_shapes",
]

_LOGGER = logging.getLogger(__name__)

ONNX_TENSOR_TYPE_MAP = {
    1: numpy.float32,
    2: numpy.uint8,
    3: numpy.int8,
    4: numpy.uint16,
    5: numpy.int16,
    6: numpy.int32,
    7: numpy.int64,
    9: numpy.bool_,
    10: numpy.float16,
    11: numpy.float64,
    12: numpy.uint32,
    13: numpy.uint64,
    14: numpy.complex64,
    15: numpy.complex128,
}


def translate_onnx_type_to_numpy(tensor_type: int):
    """
    Translates ONNX types to numpy types
    :param tensor_type: Integer representing a type in ONNX spec
    :return: Corresponding numpy type
    """
    if tensor_type not in ONNX_TENSOR_TYPE_MAP:
        raise Exception("Unknown ONNX tensor type = {}".format(tensor_type))
    return ONNX_TENSOR_TYPE_MAP[tensor_type]


[docs]def model_to_path(model: Union[str, Model, File]) -> str: """ Deals with the various forms a model can take. Either an ONNX file, a SparseZoo model stub prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File object that defines the neural network """ if not model: raise ValueError("model must be a path, sparsezoo.Model, or sparsezoo.File") if isinstance(model, str) and model.startswith("zoo:"): # load SparseZoo Model from stub if sparsezoo_import_error is not None: raise sparsezoo_import_error model = Zoo.load_model_from_stub(model) if Model is not object and isinstance(model, Model): # default to the main onnx file for the model model = model.onnx_file.downloaded_path() elif File is not object and isinstance(model, File): # get the downloaded_path -- will auto download if not on local system model = model.downloaded_path() if not isinstance(model, str): raise ValueError("unsupported type for model: {}".format(type(model))) if not os.path.exists(model): raise ValueError("model path must exist: given {}".format(model)) return model
[docs]def get_external_inputs(onnx_filepath: str) -> List: """ Gather external inputs of ONNX model :param onnx_filepath: File path to ONNX model :return: List of input objects """ model = onnx.load(onnx_filepath) all_inputs = model.graph.input initializer_input_names = [node.name for node in model.graph.initializer] external_inputs = [ input for input in all_inputs if input.name not in initializer_input_names ] return external_inputs
[docs]def get_external_outputs(onnx_filepath: str) -> List: """ Gather external outputs of ONNX model :param onnx_filepath: File path to ONNX model :return: List of output objects """ model = onnx.load(onnx_filepath) return [output for output in model.graph.output]
[docs]def get_input_names(onnx_filepath: str) -> List[str]: """ Gather names of all external inputs of ONNX model :param onnx_filepath: File path to ONNX model :return: List of string names """ return [input.name for input in get_external_inputs(onnx_filepath)]
[docs]def get_output_names(onnx_filepath: str) -> List[str]: """ Gather names of all external outputs of ONNX model :param onnx_filepath: File path to ONNX model :return: List of string names """ return [output.name for output in get_external_outputs(onnx_filepath)]
[docs]def generate_random_inputs( onnx_filepath: str, batch_size: int = None ) -> List[numpy.array]: """ Generate random data that matches the type and shape of ONNX model, with a batch size override :param onnx_filepath: File path to ONNX model :param batch_size: If provided, override for the batch size dimension :return: List of random tensors """ input_data_list = [] for i, external_input in enumerate(get_external_inputs(onnx_filepath)): input_tensor_type = external_input.type.tensor_type elem_type = translate_onnx_type_to_numpy(input_tensor_type.elem_type) in_shape = [int(d.dim_value) for d in input_tensor_type.shape.dim] if batch_size is not None: in_shape[0] = batch_size _LOGGER.info( "Generating input '{}', type = {}, shape = {}".format( external_input.name, numpy.dtype(elem_type).name, in_shape ) ) input_data_list.append(numpy.random.rand(*in_shape).astype(elem_type)) return input_data_list
[docs]@contextlib.contextmanager def override_onnx_batch_size(onnx_filepath: str, batch_size: int) -> str: """ Rewrite batch sizes of ONNX model, saving the modified model and returning its path :param onnx_filepath: File path to ONNX model :param batch_size: Override for the batch size dimension :return: File path to modified ONNX model """ model = onnx.load(onnx_filepath) all_inputs = model.graph.input initializer_input_names = [node.name for node in model.graph.initializer] external_inputs = [ input for input in all_inputs if input.name not in initializer_input_names ] for external_input in external_inputs: external_input.type.tensor_type.shape.dim[0].dim_value = batch_size # Save modified model, this will be cleaned up when context is exited shaped_model = tempfile.NamedTemporaryFile(mode="w", delete=False) onnx.save(model, shaped_model.name) try: yield shaped_model.name finally: os.unlink(shaped_model.name) shaped_model.close()
[docs]@contextlib.contextmanager def override_onnx_input_shapes( onnx_filepath: str, input_shapes: Union[List[int], List[List[int]]] ) -> str: """ Rewrite input shapes of ONNX model, saving the modified model and returning its path :param onnx_filepath: File path to ONNX model :param input_shapes: Override for model's input shapes :return: File path to modified ONNX model """ if input_shapes is None: return onnx_filepath model = onnx.load(onnx_filepath) all_inputs = model.graph.input initializer_input_names = [node.name for node in model.graph.initializer] external_inputs = [ input for input in all_inputs if input.name not in initializer_input_names ] # Input shapes should be a list of lists, even if there is only one input assert all(isinstance(inp, list) for inp in input_shapes) # If there is a single input shape given and multiple inputs, # duplicate for all inputs to apply the same shape if len(input_shapes) == 1 and len(external_inputs) > 1: input_shapes.extend([input_shapes[0] for _ in range(1, len(external_inputs))]) # Make sure that input shapes can map to the ONNX model assert len(external_inputs) == len( input_shapes ), "Mismatch of number of model inputs ({}) and override shapes ({})".format( len(external_inputs), len(input_shapes) ) # Overwrite the input shapes of the model for input_idx, external_input in enumerate(external_inputs): assert len(external_input.type.tensor_type.shape.dim) == len( input_shapes[input_idx] ), "Input '{}' shape doesn't match shape override: {} vs {}".format( external_input.name, external_input.type.tensor_type.shape.dim, input_shapes[input_idx], ) for dim_idx, dim in enumerate(external_input.type.tensor_type.shape.dim): dim.dim_value = input_shapes[input_idx][dim_idx] # Save modified model, this will be cleaned up when context is exited shaped_model = tempfile.NamedTemporaryFile(mode="w", delete=False) onnx.save(model, shaped_model.name) try: yield shaped_model.name finally: os.unlink(shaped_model.name) shaped_model.close()