ONNX Export¶
ONNX is a generic representation for neural network graphs that most ML frameworks can be converted to. Some inference engines such as DeepSparse natively take in ONNX for deployment pipelines, so convenience functions for conversion and export are provided for the supported frameworks.
Exporting PyTorch to ONNX¶
ONNX is built into the PyTorch system natively.
The ModuleExporter
class under the sparseml.pytorch.utils
package features an export_onnx
function built on top of this native support.
Example code:
import os
import torch
from sparseml.pytorch.models import mnist_net
from sparseml.pytorch.utils import ModuleExporter
model = mnist_net()
exporter = ModuleExporter(model, output_dir=os.path.join(".", "onnx-export"))
exporter.export_onnx(sample_batch=torch.randn(1, 1, 28, 28))
Exporting Keras to ONNX¶
ONNX is not built into the Keras system, but is supported through an ONNX official tool keras2onnx. The ModelExporter
class under the sparseml.keras.utils
package features an export_onnx
function built on top of keras2onnx.
Example code:
import os
from sparseml.keras.utils import ModelExporter
model = None # fill in with your model
exporter = ModelExporter(model, output_dir=os.path.join(".", "onnx-export"))
exporter.export_onnx()
Exporting TensorFlow V1 to ONNX¶
ONNX is not built into the TensorFlow system, but it is supported through an ONNX official tool
tf2onnx.
The GraphExporter
class under the sparseml.tensorflow_v1.utils
package features an
export_onnx
function built on top of tf2onnx.
Note that the ONNX file is created from the protobuf graph representation, so export_pb
must be called first.
Example code:
import os
from sparseml.tensorflow_v1.utils import tf_compat, GraphExporter
from sparseml.tensorflow_v1.models import mnist_net
exporter = GraphExporter(output_dir=os.path.join(".", "mnist-tf-export"))
with tf_compat.Graph().as_default() as graph:
inputs = tf_compat.placeholder(
tf_compat.float32, [None, 28, 28, 1], name="inputs"
)
logits = mnist_net(inputs)
input_names = [inputs.name]
output_names = [logits.name]
with tf_compat.Session() as sess:
sess.run(tf_compat.global_variables_initializer())
exporter.export_pb(outputs=[logits])
exporter.export_onnx(inputs=input_names, outputs=output_names)