Source code for sparseml.onnx.utils.graph_editor

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Helper functions to edit ONNX Graphs.
"""

from collections import defaultdict
from typing import Iterable, List, Optional, Union

import numpy
import onnx
from onnx import ModelProto, NodeProto, TensorProto, numpy_helper
from toposort import toposort_flatten

from sparseml.onnx.utils.helpers import get_node_params


__all__ = [
    "ONNXGraph",
    "update_model_param",
    "swap_node_output",
    "remove_node_and_params_from_graph",
    "override_model_batch_size",
    "prune_unstructured",
    "prune_model_one_shot",
    "prune_model_one_shot_iter",
]


[docs]class ONNXGraph(object): """ Class for quick look-up of ONNX graph nodes and initializers. If graph state changes outside of ONNXGraph class functions, update() should be called. :param model: the ONNX graph to represent """ def __init__(self, model: ModelProto): self._model = model self._output_id_to_node = {} self._input_id_to_nodes = defaultdict(list) self._name_to_initializer = {} self.update() @property def nodes(self) -> Iterable[NodeProto]: """ :return: ordered collection of nodes in this graph """ return self._model.graph.node
[docs] def update(self, model: Optional[ModelProto] = None): """ Update the graph state based on the model this graph represents or the given model. :param model: model to represent. defaults to current loaded model state """ self._model = model or self._model # nodes self._output_id_to_node = {} self._input_id_to_nodes = defaultdict(list) for node in self._model.graph.node: self._store_node_edges(node) # initializers self._name_to_initializer = { init.name: init for init in self._model.graph.initializer }
[docs] def get_init_by_name( self, name: str, allow_optional: bool = True, ) -> Optional[TensorProto]: """ :param name: name of initializer :param allow_optional: if True and the given name is not found as an initializer, None will be returned. Otherwise a KeyError will be raised :return: tensor of initializer with given name, returns None if the name does not exist in the cached graph """ init = self._name_to_initializer.get(name, None) if not allow_optional and init is None: raise KeyError(f"Unable to find initializer {name} in ONNX model") return init
[docs] def get_node_by_output_id(self, id: str) -> Optional[TensorProto]: """ :param id: name of output id of node :return: the associated node if it is present in the graph, None otherwise """ return self._output_id_to_node.get(id)
[docs] def get_node_parents( self, node: NodeProto ) -> List[Union[NodeProto, TensorProto, None]]: """ :param node: node to get the input objects for :return: input nodes or tensors of this node in order. if an input does not exist, None will be returned in its place """ inputs = [] for input_id in node.input: inp = None if input_id in self._output_id_to_node: inp = self._output_id_to_node[input_id] elif input_id in self._name_to_initializer: inp = self._name_to_initializer[input_id] inputs.append(inp) return inputs
[docs] def get_node_single_parent( self, node: NodeProto, index: int ) -> Union[NodeProto, None]: """ :param node: the node to get the parent node of :param index: choose which input to search :return: parent of node if it only has one parent, otherwise None """ input_id = node.input[index] if input_id not in self._output_id_to_node: return None return self._output_id_to_node[input_id]
[docs] def get_node_children(self, node: NodeProto) -> List[NodeProto]: """ :param node: the node to get the children node of :return: list of nodes that include this node as an output """ children = [] for output_id in node.output: children.extend(self._input_id_to_nodes[output_id]) return children
[docs] def get_node_single_child(self, node: NodeProto) -> Union[NodeProto, None]: """ :param node: the node to get the child node of :return: child of node if it only has one child, otherwise None """ children = self.get_node_children(node) return children[0] if len(children) == 1 else None
[docs] def add_node(self, node: NodeProto): """ Adds the given node to the model and graph state :param node: node to add to the model """ self._model.graph.node.append(node) self._store_node_edges(node)
[docs] def update_node_input( self, node: NodeProto, input_id: str, input_idx: Optional[int] = None ): """ :param node: node to update the inputs of :param input_id: new input_id to attach to the node :param input_idx: optional index of the node input list to update, if none is given, the new input id will be appended to the input list """ if input_idx is not None: if node in self._input_id_to_nodes[node.input[input_idx]]: self._input_id_to_nodes[node.input[input_idx]].remove(node) node.input[input_idx] = input_id else: node.input.append(input_id) self._input_id_to_nodes[input_id].append(node)
[docs] def delete_node(self, node: NodeProto): """ deletes the given node from the graph :param node: node to delete """ self._model.graph.node.remove(node) self._delete_node_edges(node)
[docs] def delete_nodes(self, nodes: List[NodeProto]): """ deletes the given nodes from the graph :param nodes: list of nodes to delete """ node_ouptut_ids_to_delete = {node.output[0] for node in nodes} nodes_to_keep = [] for node in self._model.graph.node: if node.output[0] in node_ouptut_ids_to_delete: self._delete_node_edges(node) else: nodes_to_keep.append(node) self._model.graph.ClearField("node") self._model.graph.node.extend(nodes_to_keep)
[docs] def delete_initializers(self, initializers: List[Union[str, TensorProto]]): """ deletes the given initializers from the model :param initializers: list of initializers or initializer names to delete """ inits_to_delete = { init if isinstance(init, str) else init.name for init in initializers } inits_to_keep = [] for init in self._model.graph.initializer: if init.name in inits_to_delete: # keep edge reference if nodes in the graph still point to the # initializer name if not self._input_id_to_nodes[init.name]: del self._input_id_to_nodes[init.name] del self._name_to_initializer[init.name] else: inits_to_keep.append(init) self._model.graph.ClearField("initializer") self._model.graph.initializer.extend(inits_to_keep)
[docs] def delete_unused_initializers(self): """ deletes tensors in the initializer list that are not listed as inputs to any node in the current graph state or directly passed as model outputs """ output_names = {out.name for out in self._model.graph.output} self.delete_initializers( [ init for init in self._model.graph.initializer if not self._input_id_to_nodes[init.name] and (init.name not in output_names) ] ) # delete inits that have no edge
[docs] def sort_nodes_topologically(self): """ Sorts the order of the graph Node repeated field in place in topological order as per the ONNX Model proto specifications """ # build toposort DAG input and sort model_dag = defaultdict(set) # node_id -> dependencies for parent_node_id, child_nodes in self._input_id_to_nodes.items(): if parent_node_id not in self._output_id_to_node: continue # parent is an initializer, not node # standardize all references to nodes by their first output id parent_node_id = self._output_id_to_node[parent_node_id].output[0] for child_node in child_nodes: model_dag[child_node.output[0]].add(parent_node_id) sorted_node_ids = toposort_flatten(model_dag) # deduplicate any nodes from the sorted list updated_node_list = [] seen_ids = set() for node_id in sorted_node_ids: if node_id in seen_ids: continue # a node could have multiple ids, all ids will be updated node = self._output_id_to_node[node_id] updated_node_list.append(node) seen_ids.update(node.output) # update model node list with topo sorted list assert len(updated_node_list) == len(self._model.graph.node) self._model.graph.ClearField("node") self._model.graph.node.extend(updated_node_list)
def _store_node_edges(self, node: NodeProto): for output_id in node.output: self._output_id_to_node[output_id] = node for input_id in node.input: self._input_id_to_nodes[input_id].append(node) def _delete_node_edges(self, node: NodeProto): # remove node edges from cache for output_id in node.output: del self._output_id_to_node[output_id] for input_id in node.input: self._input_id_to_nodes[input_id].remove(node)
[docs]def update_model_param( model: ModelProto, param_name: str, val: numpy.ndarray, ) -> None: """ Removes the parameter with name param_name from the model Creates a new parameter using val Adds val to the model with name param_name as an update :param model: The model to update :param param_name: The parameter name in the model to update :param val: The new value of the parameter """ param_matches = [ param for param in model.graph.initializer if param.name == param_name ] if param_matches: model.graph.initializer.remove(param_matches[0]) new_param = numpy_helper.from_array(val, param_name) model.graph.initializer.append(new_param)
[docs]def swap_node_output(node: onnx.NodeProto, output: str) -> None: """ Deletes the current output of the node and replaces it with the provided value Assumes that the node only has one output :param node: Node to change the output of :param output: New output value """ node.output.pop() node.output.append(output)
[docs]def remove_node_and_params_from_graph( model: ModelProto, node: onnx.NodeProto, keep_params: Iterable[str] = None, ) -> None: """ Deletes a node from the mdoel graph as well as its parameters listed in node.input :param model: Model to delete from :param node: Node to delete :param keep_params: Names of node input initializers not to remove from graph default is None. """ keep_params = keep_params or [] for param in model.graph.initializer: if param.name not in keep_params and param.name in node.input: model.graph.initializer.remove(param) model.graph.node.remove(node)
def _override_tensor_batch_dim(model, tensor, batch_size): for init in model.graph.initializer: if init.name == tensor.name: # This tensor is actually an initializer => skip return shape = tensor.type.tensor_type.shape # skip tensors with variable batch sizes if not shape.dim[0].dim_param and shape.dim[0].dim_value > 0: shape.dim[0].dim_value = batch_size
[docs]def override_model_batch_size(model: ModelProto, batch_size: int) -> ModelProto: """ Rewrites any positive batch dimensions in the model inputs or outputs to the given batch_size :param model: Model to modify :param batch_size: Batch size to enforce :return: the given model with inputs and outputs set to batch_size if the batch dimensions are not -1. """ for tensor in model.graph.input: # This may not work for ONNX graphs that have hard-coded reshape nodes _override_tensor_batch_dim(model, tensor, batch_size) # Do the same for outputs for tensor in model.graph.output: # Ignore augmented _Reduce nodes if "_Reduce" not in tensor.name: _override_tensor_batch_dim(model, tensor, batch_size)
[docs]def prune_unstructured(array: numpy.ndarray, sparsity: float) -> numpy.ndarray: """ Prune a numpy array with unstructured sparsity according to magnitude pruning :param array: the array to prune (introduce zeros), will remove the lowest absolute values in the array :param sparsity: the sparsity value, as a decimal, to impose in the array :return: the pruned array """ array = numpy.array(array) # make a copy because arrays from onnx are read only sparse_index = int(round(sparsity * array.size) - 1) if sparse_index < 0: return array sorted_array = numpy.sort(numpy.abs(array.flatten())) sparse_thresh = sorted_array[sparse_index] array[numpy.abs(array) < sparse_thresh] = 0 return array
[docs]def prune_model_one_shot( model: ModelProto, nodes: List[NodeProto], sparsity: Union[float, List[float]] ): """ Prune a model in-place with one shot pruning (no retraining) according to magnitude pruning. Does so in an unstructured way currently :param model: the model to apply pruning to :param nodes: the nodes within the model to prune to the desired sparsities :param sparsity: the sparsity level to prune all nodes to if a float, or the sparsity level to prune each node to if a list of floats :return: the new, pruned model """ if not isinstance(sparsity, Iterable): tmp = float(sparsity) sparsity = [tmp for _ in range(len(nodes))] if len(nodes) != len(sparsity): raise ValueError( "len(nodes) {} does not match len(sparsity) {}".format( len(nodes), len(sparsity) ) ) for node, sparsity in zip(nodes, sparsity): weight, bias = get_node_params(model, node) pruned_weight_val = prune_unstructured(weight.val, sparsity) update_model_param(model, weight.name, pruned_weight_val)
[docs]def prune_model_one_shot_iter( model: ModelProto, nodes: List[NodeProto], sparsity: Union[float, List[float]] ): """ Iteratively prune a model in-place with one shot pruning (no retraining) according to magnitude pruning. Does so in an unstructured way currently :param model: the model to apply pruning to :param nodes: the nodes within the model to prune to the desired sparsities :param sparsity: the sparsity level to prune all nodes to if a float, or the sparsity level to prune each node to if a list of floats """ if not isinstance(sparsity, Iterable): tmp = float(sparsity) sparsity = [tmp for _ in range(len(nodes))] if len(nodes) != len(sparsity): raise ValueError( "len(nodes) {} does not match len(sparsity) {}".format( len(nodes), len(sparsity) ) ) for index, (node, sparsity) in enumerate(zip(nodes, sparsity)): weight, bias = get_node_params(model, node) pruned_weight_val = prune_unstructured(weight.val, sparsity) update_model_param(model, weight.name, pruned_weight_val) yield (index + 1) / len(nodes)