You can export a model to the ONNX format for use with DeepSparse.
ONNX is a generic representation for neural network graphs to which most ML frameworks can be converted. Some inference engines such as DeepSparse natively take in ONNX for deployment pipelines, so convenience functions for conversion and export are provided for the supported frameworks.
See the SparseML installation page for installation requirements of each integration.
ONNX is built into the PyTorch system natively.
The ModuleExporter
class under the sparseml.pytorch.utils
package features an export_onnx
function built on this native support.
Example code is:
1import os2import torch3from sparseml.pytorch.models import mnist_net4from sparseml.pytorch.utils import ModuleExporter56model = mnist_net()7exporter = ModuleExporter(model, output_dir=os.path.join(".", "onnx-export"))8exporter.export_onnx(sample_batch=torch.randn(1, 1, 28, 28))
ONNX is not built into the Keras system, but is supported through an ONNX official tool, keras2onnx. The ModelExporter
class under the sparseml.keras.utils
package features an export_onnx
function built on top of keras2onnx.
Example code is:
1import os2from sparseml.keras.utils import ModelExporter34model = None # fill in with your model5exporter = ModelExporter(model, output_dir=os.path.join(".", "onnx-export"))6exporter.export_onnx()
ONNX is not built into the TensorFlow system, but is supported through an ONNX official tool,
tf2onnx.
The GraphExporter
class under the sparseml.tensorflow_v1.utils
package features an
export_onnx
function built on top of tf2onnx.
Note that the ONNX file is created from the protobuf graph representation, so export_pb
must be called first.
Example code is:
1import os2from sparseml.tensorflow_v1.utils import tf_compat, GraphExporter3from sparseml.tensorflow_v1.models import mnist_net45exporter = GraphExporter(output_dir=os.path.join(".", "mnist-tf-export"))67with tf_compat.Graph().as_default() as graph:8 inputs = tf_compat.placeholder(9 tf_compat.float32, [None, 28, 28, 1], name="inputs"10 )11 logits = mnist_net(inputs)12 input_names = [inputs.name]13 output_names = [logits.name]1415 with tf_compat.Session() as sess:16 sess.run(tf_compat.global_variables_initializer())17 exporter.export_pb(outputs=[logits])1819exporter.export_onnx(inputs=input_names, outputs=output_names)