Source code for sparseml.onnx.framework.info

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functionality related to detecting and getting information for
support and sparsification in the ONNX/ONNXRuntime framework.
"""

import logging
from typing import Any

from sparseml.base import Framework, get_version
from sparseml.framework import FrameworkInferenceProviderInfo, FrameworkInfo
from sparseml.onnx.base import check_onnx_install, check_onnxruntime_install
from sparseml.onnx.sparsification import sparsification_info
from sparseml.sparsification import SparsificationInfo


__all__ = ["is_supported", "detect_framework", "framework_info"]


_LOGGER = logging.getLogger(__name__)


[docs]def is_supported(item: Any) -> bool: """ :param item: The item to detect the support for :type item: Any :return: True if the item is supported by onnx/onnxruntime, False otherwise :rtype: bool """ framework = detect_framework(item) return framework == Framework.onnx
[docs]def detect_framework(item: Any) -> Framework: """ Detect the supported ML framework for a given item specifically for the onnx/onnxruntime package. Supported input types are the following: - A Framework enum - A string of any case representing the name of the framework (deepsparse, onnx, keras, pytorch, tensorflow_v1) - A supported file type within the framework such as model files: (onnx, pth, h5, pb) - An object from a supported ML framework such as a model instance If the framework cannot be determined, will return Framework.unknown :param item: The item to detect the ML framework for :type item: Any :return: The detected framework from the given item :rtype: Framework """ framework = Framework.unknown if isinstance(item, Framework): _LOGGER.debug("framework detected from Framework instance") framework = item elif isinstance(item, str) and item.lower().strip() in Framework.__members__: _LOGGER.debug("framework detected from Framework string instance") framework = Framework[item.lower().strip()] elif isinstance(item, str) and "onnx" in item.lower().strip(): _LOGGER.debug("framework detected from onnx text") # string, check if it's a string saying onnx first framework = Framework.onnx elif isinstance(item, str) and ".onnx" in item.lower().strip(): _LOGGER.debug("framework detected from .onnx") # string, check if it's a file url or path that ends with onnx extension framework = Framework.onnx elif check_onnx_install(raise_on_error=False): from onnx import ModelProto if isinstance(item, ModelProto): _LOGGER.debug("framework detected from ONNX instance") # onnx native support framework = Framework.onnx return framework
[docs]def framework_info() -> FrameworkInfo: """ Detect the information for the onnx/onnxruntime framework such as package versions, availability for core actions such as training and inference, sparsification support, and inference provider support. :return: The framework info for onnx/onnxruntime :rtype: FrameworkInfo """ all_providers = [] available_providers = [] if check_onnxruntime_install(raise_on_error=False): from onnxruntime import get_all_providers, get_available_providers available_providers = get_available_providers() all_providers = get_all_providers() cpu_provider = FrameworkInferenceProviderInfo( name="cpu", description="Base CPU provider within ONNXRuntime", device="cpu", supported_sparsification=SparsificationInfo(), # TODO: fill in when available available=( check_onnx_install(raise_on_error=False) and check_onnxruntime_install(raise_on_error=False) and "CPUExecutionProvider" in available_providers ), properties={}, warnings=[], ) gpu_provider = FrameworkInferenceProviderInfo( name="cuda", description="Base GPU CUDA provider within ONNXRuntime", device="gpu", supported_sparsification=SparsificationInfo(), # TODO: fill in when available available=( check_onnx_install(raise_on_error=False) and check_onnxruntime_install(raise_on_error=False) and "CUDAExecutionProvider" in available_providers ), properties={}, warnings=[], ) return FrameworkInfo( framework=Framework.onnx, package_versions={ "onnx": get_version(package_name="onnx", raise_on_error=False), "onnxruntime": ( get_version(package_name="onnxruntime", raise_on_error=False) ), "sparsezoo": get_version( package_name="sparsezoo", raise_on_error=False, alternate_package_names=["sparsezoo-nightly"], ), "sparseml": get_version( package_name="sparseml", raise_on_error=False, alternate_package_names=["sparseml-nightly"], ), }, sparsification=sparsification_info(), inference_providers=[cpu_provider, gpu_provider], properties={ "available_providers": available_providers, "all_providers": all_providers, }, training_available=False, sparsification_available=True, exporting_onnx_available=True, inference_available=True, )