# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility / helper functions
"""
import logging
from collections import OrderedDict
from copy import deepcopy
from functools import reduce
from typing import Any, Dict, List, NamedTuple, Tuple, Union
import numpy
import onnx
import onnxruntime
from onnx import ModelProto, NodeProto, TensorProto, numpy_helper
from onnx.helper import get_attribute_value, make_empty_tensor_value_info
from sparseml.utils import clean_path
_LOGGER = logging.getLogger(__name__)
__all__ = [
"validate_onnx_file",
"check_load_model",
"extract_node_id",
"get_node_by_id",
"get_nodes_by_input_id",
"get_nodes_by_output_id",
"extract_shape",
"get_numpy_dtype",
"NodeShape",
"extract_nodes_shapes_ort",
"extract_nodes_shapes_shape_inference",
"extract_node_shapes",
"get_init_by_name",
"get_attr_float_val_for_node",
"NodeParam",
"conv_node_params",
"gemm_node_params",
"matmul_node_params",
"get_node_params",
"BatchNormParams",
"get_batch_norm_params",
"get_node_attributes",
"get_node_inputs",
"get_node_outputs",
"get_node_input_nodes",
"get_node_output_nodes",
"is_prunable_node",
"is_foldable_node",
"get_prunable_node_from_foldable",
"get_prunable_nodes",
"SparsityMeasurement",
"onnx_nodes_sparsities",
"model_inputs",
"model_outputs",
"get_kernel_shape",
"calculate_flops",
"get_quantize_parent_for_dequantize_node",
"get_tensor_dim_shape",
"set_tensor_dim_shape",
]
[docs]def validate_onnx_file(path: str):
"""
Validate that a file at a given path is a valid ONNX model
:param path: the path of the file to validate
:raise ValueError: if not a valid ONNX model
"""
try:
onnx_model = check_load_model(path)
if not onnx_model.opset_import:
raise ValueError("could not parse opset_import")
except Exception as err:
raise ValueError("file at {} is not a valid onnx model: {}".format(path, err))
[docs]def check_load_model(model: Union[str, ModelProto]) -> ModelProto:
"""
Load an ONNX model from a given file path if supplied.
If already a model proto, then returns.
:param model: the model proto or path to the model ONNX file to check for loading
:return: the loaded ONNX ModelProto
"""
if isinstance(model, ModelProto):
return model
if isinstance(model, str):
return onnx.load(clean_path(model))
raise ValueError("unknown type given for model: {}".format(model))
[docs]def get_node_by_id(model: ModelProto, node_id: str) -> Union[NodeProto, None]:
"""
Get a node from a model by the node_id generated from extract_node_id
:param model: the model proto loaded from the ONNX file
:param node_id: id of the node to get from the model
:return: the retrieved node or None if no node found
"""
for node in model.graph.node:
if extract_node_id(node) == node_id:
return node
return None
[docs]def get_nodes_by_output_id(model: ModelProto, output_id: str) -> List[NodeProto]:
"""
Get all the nodes in a model that have a given id as one of the outputs
:param model: the model proto loaded from the ONNX file
:param output_id: id of the output to get nodes by
:return: the retrieved nodes
"""
nodes = []
for node in model.graph.node:
outputs = get_node_outputs(model, node)
if output_id in outputs:
nodes.append(node)
return nodes
[docs]def get_numpy_dtype(tensor: onnx.TensorProto) -> Union[None, numpy.dtype]:
"""
Extract the NumPy dtype of an ONNX tensor.
Returns None if there is not a direct mapping from the ONNX data type
to a NumPy dtype.
:param tensor: the tensor to get the dtype of
:return: a NumPy dtype for the tensor if available otherwise None
"""
data_type_enum = tensor.type.tensor_type.elem_type # type: int
data_type = onnx.TensorProto.DataType.Name(data_type_enum).lower() # type: str
if hasattr(numpy, data_type):
dtype = getattr(numpy, data_type)
if dtype is numpy.float: # default to float32
dtype = numpy.float32
return dtype
return None
"""
Tuple containing a node id and its input and output shapes
"""
NodeShape = NamedTuple(
"NodeShape",
[
("id", str),
("input_shapes", Union[List[List[int]], None]),
("output_shapes", Union[List[List[int]], None]),
],
)
[docs]def get_init_by_name(model: ModelProto, init_name: str) -> Union[Any, None]:
"""
Get an initializer by name from the ONNX model proto graph
:param model: the model proto loaded from the ONNX file
:param init_name: the name of the initializer to retrieve
:return: the initializer retrieved by name from the model
"""
matching_inits = [
init for init in model.graph.initializer if init.name == init_name
]
if len(matching_inits) == 0:
return None
if len(matching_inits) > 1:
raise ValueError(
"found duplicate inits in the ONNX graph for name {} in {}".format(
init_name, model
)
)
return matching_inits[0]
def _get_node_param_init_by_idx(
model: ModelProto, node: onnx.NodeProto, idx: int
) -> Union[numpy.ndarray, None]:
if idx < len(node.input):
initializer = get_init_by_name(model, node.input[idx])
if initializer is not None:
return numpy_helper.to_array(initializer)
return None
[docs]def get_attr_float_val_for_node(node: onnx.NodeProto, attr: str) -> Union[float, None]:
"""
:param node: Node to get the attribute value of
:param attr: Attribute name to match in the node
:return: The value of the attribute if the attribute found in the node and is
a float type. Otherwise returns None
"""
attr = attr.lower()
node_attr_matches = [att for att in node.attribute if attr in att.name]
if not node_attr_matches:
return None
node_attr = node_attr_matches[0]
return node_attr.f if node_attr.type == node_attr.FLOAT else None
"""
Simple named tuple for mapping a node value to the init name it came from
"""
NodeParam = NamedTuple("NodeParam", [("name", str), ("val", numpy.ndarray)])
_TRIVIAL_OP_TYPES = {"Reshape", "Transpose"}
def _get_init_by_name_nested(
model, weight_name
) -> Tuple[Union[str, None], Union[TensorProto, None]]:
# traverses graph if weights are reshaped / transposed before becoming layer input
init = get_init_by_name(model, weight_name)
if init is not None:
return weight_name, init
nested = None, None
parent_nodes = get_nodes_by_output_id(model, weight_name)
if not parent_nodes:
return nested
for parent in parent_nodes:
if parent.op_type not in _TRIVIAL_OP_TYPES:
continue
nested = _get_init_by_name_nested(model, parent.input[0])
if nested[0] is not None and nested[1] is not None:
break
return nested
[docs]def conv_node_params(
model: ModelProto, node: NodeProto, include_values: bool = True
) -> Tuple[NodeParam, Union[NodeParam, None]]:
"""
Get the params (weight and bias) for a conv node in an ONNX ModelProto
:param model: the model proto loaded from the ONNX file
:param node: the conv node to get the params for
:param include_values: True to include the param values as NumPy arrays
in the returned NodeParam objects.
False to not load the values -- in this event NodeParam.val will be None
:return: a tuple containing the weight, bias (if it is present)
"""
node_id = extract_node_id(node)
if str(node.op_type).lower() != "conv":
raise ValueError("node_id of {} is not a conv: {}".format(node_id, node))
weight_name, weight_init = _get_init_by_name_nested(model, node.input[1])
weight = NodeParam(
weight_name, numpy_helper.to_array(weight_init) if include_values else None
)
if len(node.input) > 2:
bias_name, bias_init = _get_init_by_name_nested(model, node.input[2])
bias = NodeParam(
bias_name, numpy_helper.to_array(bias_init) if include_values else None
)
else:
bias = None
return weight, bias
def _get_matmul_gemm_weight(
model: ModelProto, node: NodeProto, include_values: bool = True
) -> NodeParam:
node_id = extract_node_id(node)
if str(node.op_type).lower() not in ["gemm", "matmul"]:
raise ValueError(
"node id of {} is not a gemm or matmul: {}".format(node_id, node)
)
# for gemm, the positions of weights are not explicit in the definition
weight_inits = [
_get_init_by_name_nested(model, node.input[0]),
_get_init_by_name_nested(model, node.input[1]),
]
# putting this here in case it's changed in the future since the else case below
# falls to expecting only 2
assert len(weight_inits) == 2
if weight_inits[0][1] is None and weight_inits[1][1] is None:
raise ValueError(
"could not find weight for gemm / matmul node with id {}: {}".format(
node_id, node
)
)
elif weight_inits[0][1] is not None and weight_inits[1][1] is not None:
raise ValueError(
"found too many weight inputs for gemm / matmul node with id {}: {}".format(
node_id, node
)
)
elif weight_inits[0][1] is not None:
weight = NodeParam(
weight_inits[0][0],
numpy_helper.to_array(weight_inits[0][1]) if include_values else None,
)
else:
weight = NodeParam(
weight_inits[1][0],
numpy_helper.to_array(weight_inits[1][1]) if include_values else None,
)
return weight
[docs]def gemm_node_params(
model: ModelProto, node: NodeProto, include_values: bool = True
) -> Tuple[NodeParam, Union[NodeParam, None]]:
"""
Get the params (weight and bias) for a gemm node in an ONNX ModelProto
:param model: the model proto loaded from the ONNX file
:param node: the conv node to get the params for
:param include_values: True to include the param values as NumPy arrays
in the returned NodeParam objects.
False to not load the values -- in this event NodeParam.val will be None
:return: a tuple containing the weight, bias (if it is present)
"""
node_id = extract_node_id(node)
if str(node.op_type).lower() != "gemm":
raise ValueError("node_id of {} is not a gemm: {}".format(node_id, node))
weight = _get_matmul_gemm_weight(model, node, include_values)
if len(node.input) > 2:
bias_name, bias_init = _get_init_by_name_nested(model, node.input[2])
bias = NodeParam(
bias_name, numpy_helper.to_array(bias_init) if include_values else None
)
else:
bias = None
return weight, bias
[docs]def matmul_node_params(
model: ModelProto, node: NodeProto, include_values: bool = True
) -> Tuple[NodeParam, Union[NodeParam, None]]:
"""
Get the params (weight) for a matmul node in an ONNX ModelProto.
In the future will retrieve a following bias addition as the bias for the matmul.
:param model: the model proto loaded from the ONNX file
:param node: the conv node to get the params for
:param include_values: True to include the param values as NumPy arrays
in the returned NodeParam objects.
False to not load the values -- in this event NodeParam.val will be None
:return: a tuple containing the weight, bias (if it is present)
"""
# todo, expand this to grab a bias add if one occurs after the matmul for fcs
node_id = extract_node_id(node)
if str(node.op_type).lower() != "matmul":
raise ValueError("node_id of {} is not a matmul: {}".format(node_id, node))
weight = _get_matmul_gemm_weight(model, node, include_values)
bias = None
return weight, bias
[docs]def get_node_params(
model: ModelProto, node: NodeProto, include_values: bool = True
) -> Tuple[NodeParam, Union[NodeParam, None]]:
"""
Get the params (weight and bias) for a node in an ONNX ModelProto.
Must be an op type of one of [conv, gemm, matmul]
:param model: the model proto loaded from the ONNX file
:param node: the conv node to get the params for
:param include_values: True to include the param values as NumPy arrays
in the returned NodeParam objects.
False to not load the values -- in this event NodeParam.val will be None
:return: a tuple containing the weight, bias (if it is present)
"""
node_id = extract_node_id(node)
if str(node.op_type).lower() == "conv":
return conv_node_params(model, node, include_values)
if str(node.op_type).lower() == "gemm":
return gemm_node_params(model, node, include_values)
if str(node.op_type).lower() == "matmul":
return matmul_node_params(model, node, include_values)
raise ValueError(
(
"node_id of {} is not a supported node (conv, gemm, matmul) "
"for params: {}"
).format(node_id, node)
)
"""
Named tuple for defining the paramters of a batch normalization operator
"""
BatchNormParams = NamedTuple(
"BatchNormParams",
[
("epsilon", float),
("momentum", numpy.ndarray),
("scale", numpy.ndarray),
("bias", numpy.ndarray),
("mean", numpy.ndarray),
("var", numpy.ndarray),
],
)
[docs]def get_batch_norm_params(
model: onnx.ModelProto, bn_node: onnx.NodeProto
) -> BatchNormParams:
"""
Get the params and relevant attributes of a batch normalization operator.
Following the ONNX operators spec, will default epsilon and momentum to
1e-5 and 0.9 respectively when not defined.
:param model: the model proto loaded from the ONNX file
:param bn_node: the batch normalization node to get the params for
:return: a BatchNormParams named tuple
"""
bn_attributes = get_node_attributes(bn_node)
return BatchNormParams(
epsilon=(bn_attributes["epsilon"] if "epsilon" in bn_attributes else 1e-5),
momentum=(bn_attributes["momentum"] if "momentum" in bn_attributes else 0.9),
scale=_get_node_param_init_by_idx(model, bn_node, 1),
bias=_get_node_param_init_by_idx(model, bn_node, 2),
mean=_get_node_param_init_by_idx(model, bn_node, 3),
var=_get_node_param_init_by_idx(model, bn_node, 4),
)
[docs]def get_node_attributes(node: NodeProto) -> Dict[str, Any]:
"""
:param node: the ONNX node to get the attibutes for
:return: a dictionary containing all attributes for the node
"""
attributes = reduce(
lambda accum, attribute: accum.update(
{attribute.name: get_attribute_value(attribute)}
)
or accum,
node.attribute,
{},
)
for key in list(attributes.keys()):
val = attributes[key]
if not (
isinstance(val, int)
or isinstance(val, float)
or isinstance(val, str)
or isinstance(val, list)
or isinstance(val, dict)
):
attributes[key] = None
return attributes
[docs]def get_node_outputs(model: ModelProto, node: NodeProto) -> List[str]:
"""
:param model: the model the node is from
:param node: the node to get all outputs (non initializers) for
:return: the names of all the outputs to the node that are not initializers
"""
outputs = []
for out in node.output:
if get_init_by_name(model, out) is None:
outputs.append(out)
return outputs
[docs]def get_node_output_nodes(model: ModelProto, node: NodeProto) -> List[NodeProto]:
"""
Get all of the nodes that share an input edge for the outputs from a given node
:param model: the model the node is from
:param node: the node to get all output nodes for
:return: the list of nodes that share an input edge
for the outputs from the given node
"""
nodes = []
for output_id in get_node_outputs(model, node):
nodes.extend(get_nodes_by_input_id(model, output_id))
return nodes
[docs]def is_prunable_node(model: ModelProto, node: NodeProto) -> bool:
"""
:param model: the model the node is from
:param node: an ONNX node or op_type string
:return: True if the given node is prunable, False otherwise
"""
prunable_types = ["conv", "gemm", "matmul"]
if str(node.op_type).lower() not in prunable_types:
return False
try:
# try to get the weight param, if this fails then
# it's not a trainable version of the node and therefore not prunable
get_node_params(model, node, include_values=False)
except Exception:
return False
return True
[docs]def is_foldable_node(node: Union[str, NodeProto]) -> bool:
"""
Foldable nodes as defined by ONNX Runtime and what it supports layerwise folding
in the ONNX graphs. More info can be found in their docs:
https://www.onnxruntime.ai/docs/resources/graph-optimizations.html
:param node: the node or node type to check if it is foldable or not
according to the ONNX Runtime specs
:return: True if the node is foldable and therefore can be combined with other
nodes, False otherwise
"""
return (node.lower() if isinstance(node, str) else str(node.op_type).lower()) in [
"batchnormalization",
"add",
"mul",
]
[docs]def get_prunable_node_from_foldable(
model: ModelProto,
foldable_node: Union[str, NodeProto],
traverse_previous: bool = True,
max_node_distance: int = 3,
) -> Union[None, NodeProto]:
"""
Get a prunable node that is attached by foldable nodes to a given foldable node.
Returns None if nothing could be found.
Ex: get the convolution that would be folded for an attached BatchNormalization
:param model: the model the node is from
:param foldable_node: the foldable node or node id to find prunable node from
:param traverse_previous: True to only search for previous prunable nodes that the
foldable node could have been attached to for Conv -> BN patterns.
False to only search for following prunable nodes that the foldable node
could have been attached to for BN -> Conv patterns.
:param max_node_distance: The maximum distance
(and therefore number of foldable nodes) the prunable node must be within
to match. Ex: max_node_distance = 3, the prunable node must be within 3
other foldable nodes of the foldable node passed in to match
:return: the found prunable node
"""
if isinstance(foldable_node, str):
foldable_node = get_node_by_id(model, foldable_node)
if not is_foldable_node(foldable_node):
raise ValueError(
"non foldable node passed in for foldable_node: {}".format(
extract_node_id(foldable_node)
)
)
prunable_node = foldable_node
num_steps = 0
while (
prunable_node is not None
and not is_prunable_node(model, prunable_node)
and is_foldable_node(prunable_node)
and num_steps < max_node_distance
):
next_nodes = (
get_node_input_nodes(model, prunable_node)
if traverse_previous
else get_node_output_nodes(model, prunable_node)
)
num_steps += 1
prunable_node = next_nodes[0] if next_nodes else None
return (
None
if prunable_node is None or not is_prunable_node(model, prunable_node)
else prunable_node
)
[docs]def get_prunable_nodes(model: Union[str, ModelProto]) -> List[Any]:
"""
Get the prunable nodes in an ONNX model proto.
Prunable nodes are defined as any conv, gemm, or matmul
:param model: the model proto loaded from the ONNX file
:return: a list of nodes from the model proto
"""
model = check_load_model(model)
prunable_nodes = []
for node in model.graph.node:
if is_prunable_node(model, node):
prunable_nodes.append(node)
return prunable_nodes
"""
A measurement of the sparsity for a given ONNX node or model
"""
SparsityMeasurement = NamedTuple(
"SparsityMeasurement",
[
("node_id", str),
("params_count", int),
("params_zero_count", int),
("sparsity", float),
("density", float),
],
)
[docs]def onnx_nodes_sparsities(
model: Union[str, ModelProto],
) -> Tuple[SparsityMeasurement, Dict[str, SparsityMeasurement]]:
"""
Retrieve the sparsities for each Conv or Gemm op in an ONNX graph
for the associated weight inputs.
:param model: ONNX model to use
:return: a tuple containing the overall sparsity measurement for the model,
each conv or gemm node found in the model
"""
model = check_load_model(model)
node_inp_sparsities = OrderedDict() # type: Dict[str, SparsityMeasurement]
params_count = 0
params_zero_count = 0
for node in get_prunable_nodes(model):
node_id = extract_node_id(node)
node_key = "{}(id={})".format(node.op_type, node_id)
weight, bias = get_node_params(model, node)
zeros = weight.val.size - numpy.count_nonzero(weight.val)
sparsity = float(zeros) / float(weight.val.size)
density = 1.0 - sparsity
node_inp_sparsities[
"{}_inp={}".format(node_key, weight.name)
] = SparsityMeasurement(node_id, weight.val.size, zeros, sparsity, density)
params_count += weight.val.size
params_zero_count += zeros
return (
SparsityMeasurement(
"ModelProto",
params_count,
params_zero_count,
float(params_zero_count) / float(params_count),
float(params_count - params_zero_count) / float(params_count),
),
node_inp_sparsities,
)
[docs]def model_outputs(model: Union[str, ModelProto]) -> List:
"""
Get the output from an ONNX model
:param model: the loaded model or a file path to the ONNX model
to get the model outputs for
:return: the output from the model
"""
model = check_load_model(model)
outputs = [node for node in model.graph.output]
return outputs
[docs]def get_kernel_shape(attributes: Dict[str, Any]) -> Union[List[float], None]:
"""
Get the kernel shape from a dictionary of a model's attributes
:param attributes: a dictionary of a model's attributes
:return: the kernel shape if attribute contains either the kernel or
kernel_shape field, otherwise None
"""
if "kernel" in attributes:
return attributes["kernel"]
elif "kernel_shape" in attributes:
return attributes["kernel_shape"]
else:
return None
[docs]def calculate_flops(
op_type: str,
input_shape: Union[List[List], None] = None,
output_shape: Union[List[List], None] = None,
weight_shape: Union[List, None] = None,
kernel_shape: Union[List, None] = None,
bias_shape: Union[List, None] = None,
attributes: Union[None, Dict[str, Any]] = None,
) -> Union[float, None]:
"""
Calculate flops based on operation type and shape of certain attributes.
If any fields necessary in operation are set to None, will return None
:param op_type: Operation type of flop calculation
:param input_shape: List of input shapes of operation
:param output_shape: List of output shapes of operation
:param weight_shape: Shape of weights in operation if any, else None
:param kernel_shape: Shape of kernel in operation if any, else None
:param bias_shape: Shape of bias in operation if any, else None
:param attributes: The node attributes if any, else None
:return: The amount of floating point operations in the operation
"""
input_shape = _array_as_numeric(input_shape)
output_shape = _array_as_numeric(output_shape)
weight_shape = _array_as_numeric(weight_shape)
kernel_shape = _array_as_numeric(kernel_shape)
bias_shape = _array_as_numeric(bias_shape)
if (
op_type == "Add"
or op_type == "Mul"
or op_type == "Div"
or op_type == "Sub"
or op_type == "Clip"
):
flops = _numpy_prod_with_none_check(output_shape)
elif (
op_type == "Relu"
or op_type == "LeakyRelu"
or op_type == "Sigmoid"
or op_type == "Tanh"
or op_type == "BatchNormalization"
):
flops = _numpy_prod_with_none_check(output_shape)
elif op_type == "GlobalAveragePool" or op_type == "GlobalMaxPool":
flops = _numpy_prod_with_none_check(input_shape)
elif op_type == "MaxPool" or op_type == "AveragePool":
flops = (
numpy.prod(output_shape) * numpy.prod(kernel_shape)
if output_shape is not None and kernel_shape is not None
else None
)
elif op_type == "MatMul":
flops = _calculate_flops_matmul(
op_type,
input_shape=input_shape,
output_shape=output_shape,
weight_shape=weight_shape,
)
elif op_type == "Gemm":
flops = _numpy_prod_with_none_check(weight_shape)
flops = flops * 2 if flops is not None else None
elif op_type == "Conv":
flops = (
numpy.prod(kernel_shape) * weight_shape[1] * numpy.prod(output_shape)
if kernel_shape is not None
and weight_shape is not None
and output_shape is not None
else None
)
if (
flops
and attributes
and "group" in attributes
and attributes["group"]
and attributes["group"] > 1
):
# adjust flops for group / depthwise convolutions
flops = flops / attributes["group"]
else:
flops = None
if flops is not None and bias_shape is not None:
if op_type == "Conv":
flops += numpy.prod(bias_shape) * output_shape[0][-1] * output_shape[0][-2]
else:
flops += numpy.prod(bias_shape)
return flops
def _calculate_flops_matmul(
op_type: str,
input_shape: Union[List[List], None] = None,
output_shape: Union[List[List], None] = None,
weight_shape: Union[List, None] = None,
) -> Union[float, None]:
"""
Calculates flops in an ONNX MatMul operation.
If input shape only contains 1 input, in otherwords the value of the
first index is 1, then the matrix operation is treated as a Gemm operation.
Otherwise the operation is treated like a NumPy operation.
Will return none if any required value is set to None
:param op_type: Operation type of flop calculation
:param input_shape: List of input shapes of operation
:param output_shape: List of output shapes of operation
:param weight_shape: Shape of weights in operation if any, else None
:return: The amount of floating point operations in the operation
"""
flops = None
if (
input_shape is not None
and output_shape is not None
and len(input_shape) > 1
and input_shape[0][-1] == input_shape[1][-2]
):
matrix_ops = (
input_shape[0][-2] * input_shape[1][-1] * (2 * input_shape[0][-1] - 1)
)
flops = numpy.prod(output_shape[0][:-2]) * matrix_ops
elif input_shape is not None and len(input_shape) == 1:
flops = _numpy_prod_with_none_check(weight_shape)
flops = flops * 2 if flops is not None else None
return flops
def _numpy_prod_with_none_check(array: Union[List, None]) -> Union[float, None]:
"""
:param array: an array like list
:return: the product of the array if array is not None otherwise return None
"""
return numpy.prod(array) if array is not None else None
def _attempt_cast_as_float(value: Any) -> float:
"""
:param vale: a value
:return: the value as a float if casting is possible, otherwise return 1
"""
try:
return float(value)
except (ValueError, TypeError):
return 1.0
def _array_as_numeric(array: Union[List, None]) -> Union[List, None]:
"""
:param array: an array like list
:return: the array with any non numeric or None values replaced with 1
if array itself is not None, otherwise return None
"""
if array is None:
return None
array = numpy.array(array, dtype=object)
# Check if the array datatype is a number
if numpy.issubdtype(array.dtype, numpy.number):
return array
else:
to_float = numpy.vectorize(_attempt_cast_as_float)
return to_float(array)
[docs]def get_quantize_parent_for_dequantize_node(
quantized_model: ModelProto, dequantize_node: NodeProto
) -> Union[NodeProto, None]:
"""
Returns the first quantize node found by traversing the first node input of the
given de-quantize node's ancestors. All inputs to de-quantize nodes should have
a quantize node ancestor.
:param quantized_model: the model the de-quantize node is from
:param dequantize_node: the node to get an associated quantize node for
:return: the first quantize node found by traversing the first node input of the
given de-quantize node's ancestors. If no quantize node is found, returns None
"""
curr_node = dequantize_node
while curr_node is not None and curr_node.op_type != "QuantizeLinear":
input_nodes = get_node_input_nodes(quantized_model, curr_node)
curr_node = input_nodes[0] if input_nodes else None
return curr_node
[docs]def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int:
"""
:param tensor: ONNX tensor to get the shape of a dimension of
:param dim: dimension index of the tensor to get the shape of
:return: shape of the tensor at the given dimension
"""
return tensor.type.tensor_type.shape.dim[dim].dim_value
[docs]def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int):
"""
Sets the shape of the tensor at the given dimension to the given value
:param tensor: ONNX tensor to modify the shape of
:param dim: dimension index of the tensor to modify the shape of
:param value: new shape for the given dimension
"""
tensor.type.tensor_type.shape.dim[dim].dim_value = value