Neural Magic LogoNeural Magic Logo
Products
menu-icon
Products
DeepSparse EngineSparseMLSparseZoo
User Guides
ONNX Export

Exporting to the ONNX Format

You can export a model to the ONNX format for use with DeepSparse.

ONNX is a generic representation for neural network graphs to which most ML frameworks can be converted. Some inference engines such as DeepSparse natively take in ONNX for deployment pipelines, so convenience functions for conversion and export are provided for the supported frameworks.

Installation Requirements

See the SparseML installation page for installation requirements of each integration.

Exporting PyTorch to ONNX

ONNX is built into the PyTorch system natively. The ModuleExporter class under the sparseml.pytorch.utils package features an export_onnx function built on this native support. Example code is:

1import os
2import torch
3from sparseml.pytorch.models import mnist_net
4from sparseml.pytorch.utils import ModuleExporter
5
6model = mnist_net()
7exporter = ModuleExporter(model, output_dir=os.path.join(".", "onnx-export"))
8exporter.export_onnx(sample_batch=torch.randn(1, 1, 28, 28))

Exporting Keras to ONNX

ONNX is not built into the Keras system, but is supported through an ONNX official tool, keras2onnx. The ModelExporter class under the sparseml.keras.utils package features an export_onnx function built on top of keras2onnx. Example code is:

1import os
2from sparseml.keras.utils import ModelExporter
3
4model = None # fill in with your model
5exporter = ModelExporter(model, output_dir=os.path.join(".", "onnx-export"))
6exporter.export_onnx()

Exporting TensorFlow V1 to ONNX

ONNX is not built into the TensorFlow system, but is supported through an ONNX official tool, tf2onnx. The GraphExporter class under the sparseml.tensorflow_v1.utils package features an export_onnx function built on top of tf2onnx. Note that the ONNX file is created from the protobuf graph representation, so export_pb must be called first. Example code is:

1import os
2from sparseml.tensorflow_v1.utils import tf_compat, GraphExporter
3from sparseml.tensorflow_v1.models import mnist_net
4
5exporter = GraphExporter(output_dir=os.path.join(".", "mnist-tf-export"))
6
7with tf_compat.Graph().as_default() as graph:
8 inputs = tf_compat.placeholder(
9 tf_compat.float32, [None, 28, 28, 1], name="inputs"
10 )
11 logits = mnist_net(inputs)
12 input_names = [inputs.name]
13 output_names = [logits.name]
14
15 with tf_compat.Session() as sess:
16 sess.run(tf_compat.global_variables_initializer())
17 exporter.export_pb(outputs=[logits])
18
19exporter.export_onnx(inputs=input_names, outputs=output_names)
Enabling Pipelines to work with SparseML Recipes
User Guides for DeepSparse Engine