Source code for sparseml.pytorch.models.registry

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Code related to the PyTorch model registry for easily creating models.
"""

from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union

from torch.nn import Module

from merge_args import merge_args
from sparseml.pytorch.utils import load_model
from sparseml.utils import parse_optimization_str, wrapper_decorator
from sparseml.utils.frameworks import PYTORCH_FRAMEWORK
from sparsezoo import Zoo
from sparsezoo.objects import Model


__all__ = [
    "ModelRegistry",
]


"""
Simple named tuple object to store model info
"""
_ModelAttributes = NamedTuple(
    "_ModelAttributes",
    [
        ("input_shape", Any),
        ("domain", str),
        ("sub_domain", str),
        ("architecture", str),
        ("sub_architecture", str),
        ("default_dataset", str),
        ("default_desc", str),
        ("repo_source", str),
        ("ignore_error_tensors", List[str]),
        ("args", Dict[str, Tuple[str, Any]]),
    ],
)


[docs]class ModelRegistry(object): """ Registry class for creating models """ _CONSTRUCTORS = {} # type: Dict[str, Callable] _ATTRIBUTES = {} # type: Dict[str, _ModelAttributes]
[docs] @staticmethod def available_keys() -> List[str]: """ :return: the keys (models) currently available in the registry """ return list(ModelRegistry._CONSTRUCTORS.keys())
[docs] @staticmethod def create( key: str, pretrained: Union[bool, str] = False, pretrained_path: str = None, pretrained_dataset: str = None, load_strict: bool = True, ignore_error_tensors: List[str] = None, **kwargs, ) -> Module: """ Create a new model for the given key :param key: the model key (name) to create :param pretrained: True to load pretrained weights; to load a specific version give a string with the name of the version (pruned-moderate, base). Default None :param pretrained_path: A model file path to load into the created model :param pretrained_dataset: The dataset to load for the model :param load_strict: True to make sure all states are found in and loaded in model, False otherwise; default True :param ignore_error_tensors: tensors to ignore if there are errors in loading :param kwargs: any keyword args to supply to the model constructor :return: the instantiated model """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) return ModelRegistry._CONSTRUCTORS[key]( pretrained=pretrained, pretrained_path=pretrained_path, pretrained_dataset=pretrained_dataset, load_strict=load_strict, ignore_error_tensors=ignore_error_tensors, **kwargs, )
[docs] @staticmethod def create_zoo_model( key: str, pretrained: Union[bool, str] = True, pretrained_dataset: str = None, ) -> Model: """ Create a sparsezoo Model for the desired model in the zoo :param key: the model key (name) to retrieve :param pretrained: True to load pretrained weights; to load a specific version give a string with the name of the version (optim, optim-perf), default True :param pretrained_dataset: The dataset to load for the model :return: the sparsezoo Model reference for the given model """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) attributes = ModelRegistry._ATTRIBUTES[key] sparse_name, sparse_category, sparse_target = parse_optimization_str( pretrained if isinstance(pretrained, str) else attributes.default_desc ) return Zoo.load_model( attributes.domain, attributes.sub_domain, attributes.architecture, attributes.sub_architecture, PYTORCH_FRAMEWORK, attributes.repo_source, attributes.default_dataset if pretrained_dataset is None else pretrained_dataset, None, sparse_name, sparse_category, sparse_target, )
[docs] @staticmethod def input_shape(key: str) -> Any: """ :param key: the model key (name) to create :return: the specified input shape for the model """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) return ModelRegistry._ATTRIBUTES[key].input_shape
[docs] @staticmethod def register( key: Union[str, List[str]], input_shape: Any, domain: str, sub_domain: str, architecture: str, sub_architecture: str, default_dataset: str, default_desc: str, repo_source: str = "sparseml", def_ignore_error_tensors: List[str] = None, desc_args: Dict[str, Tuple[str, Any]] = None, ): """ Register a model with the registry. Should be used as a decorator :param key: the model key (name) to create :param input_shape: the specified input shape for the model :param domain: the domain the model belongs to; ex: cv, nlp, etc :param sub_domain: the sub domain the model belongs to; ex: classification, detection, etc :param architecture: the architecture the model belongs to; ex: resnet, mobilenet, etc :param sub_architecture: the sub architecture the model belongs to; ex: 50, 101, etc :param default_dataset: the dataset to use by default for loading pretrained if not supplied :param default_desc: the description to use by default for loading pretrained if not supplied :param repo_source: the source repo for the model, default is sparseml :param def_ignore_error_tensors: tensors to ignore if there are errors in loading :param desc_args: args that should be changed based on the description :return: the decorator """ if not isinstance(key, List): key = [key] def decorator(const_func): wrapped_constructor = ModelRegistry._registered_wrapper(key[0], const_func) ModelRegistry.register_wrapped_model_constructor( wrapped_constructor, key, input_shape, domain, sub_domain, architecture, sub_architecture, default_dataset, default_desc, repo_source, def_ignore_error_tensors, desc_args, ) return wrapped_constructor return decorator
[docs] @staticmethod def register_wrapped_model_constructor( wrapped_constructor: Callable, key: Union[str, List[str]], input_shape: Any, domain: str, sub_domain: str, architecture: str, sub_architecture: str, default_dataset: str, default_desc: str, repo_source: str, def_ignore_error_tensors: List[str] = None, desc_args: Dict[str, Tuple[str, Any]] = None, ): """ Register a model with the registry from a model constructor or provider function :param wrapped_constructor: Model constructor wrapped to be compatible by call from ModelRegistry.create should have pretrained, pretrained_path, pretrained_dataset, load_strict, ignore_error_tensors, and kwargs as arguments :param key: the model key (name) to create :param input_shape: the specified input shape for the model :param domain: the domain the model belongs to; ex: cv, nlp, etc :param sub_domain: the sub domain the model belongs to; ex: classification, detection, etc :param architecture: the architecture the model belongs to; ex: resnet, mobilenet, etc :param sub_architecture: the sub architecture the model belongs to; ex: 50, 101, etc :param default_dataset: the dataset to use by default for loading pretrained if not supplied :param default_desc: the description to use by default for loading pretrained if not supplied :param repo_source: the source repo for the model; ex: sparseml, torchvision :param def_ignore_error_tensors: tensors to ignore if there are errors in loading :param desc_args: args that should be changed based on the description :return: The constructor wrapper registered with the registry """ if not isinstance(key, List): key = [key] for r_key in key: if r_key in ModelRegistry._CONSTRUCTORS: raise ValueError("key {} is already registered".format(key)) ModelRegistry._CONSTRUCTORS[r_key] = wrapped_constructor ModelRegistry._ATTRIBUTES[r_key] = _ModelAttributes( input_shape, domain, sub_domain, architecture, sub_architecture, default_dataset, default_desc, repo_source, def_ignore_error_tensors, desc_args, )
@staticmethod def _registered_wrapper( key: str, const_func: Callable, ): @merge_args(const_func) @wrapper_decorator(const_func) def wrapper( pretrained_path: str = None, pretrained: Union[bool, str] = False, pretrained_dataset: str = None, load_strict: bool = True, ignore_error_tensors: List[str] = None, *args, **kwargs, ): """ :param pretrained_path: A path to the pretrained weights to load, if provided will override the pretrained param. May also be a SparseZoo stub path preceded by 'zoo:' with the optional `?recipe_type=` argument. If given a recipe type, the base model weights for that recipe will be loaded :param pretrained: True to load the default pretrained weights, a string to load a specific pretrained weight (ex: base, optim, optim-perf), or False to not load any pretrained weights :param pretrained_dataset: The dataset to load pretrained weights for (ex: imagenet, mnist, etc). If not supplied will default to the one preconfigured for the model. :param load_strict: True to raise an error on issues with state dict loading from pretrained_path or pretrained, False to ignore :param ignore_error_tensors: Tensors to ignore while checking the state dict for weights loaded from pretrained_path or pretrained """ attributes = ModelRegistry._ATTRIBUTES[key] if attributes.args and pretrained in attributes.args: kwargs[attributes.args[pretrained][0]] = attributes.args[pretrained][1] model = const_func(*args, **kwargs) ignore = [] if ignore_error_tensors: ignore.extend(ignore_error_tensors) elif attributes.ignore_error_tensors: ignore.extend(attributes.ignore_error_tensors) if isinstance(pretrained, str): if pretrained.lower() == "true": pretrained = True elif pretrained.lower() in ["false", "none"]: pretrained = False if pretrained_path: load_model(pretrained_path, model, load_strict, ignore) elif pretrained: zoo_model = ModelRegistry.create_zoo_model( key, pretrained, pretrained_dataset ) try: paths = zoo_model.download_framework_files(extensions=[".pth"]) load_model(paths[0], model, load_strict, ignore) except Exception: # try one more time with overwrite on in case file was corrupted paths = zoo_model.download_framework_files( overwrite=True, extensions=[".pth"] ) load_model(paths[0], model, load_strict, ignore) return model return wrapper