Source code for sparseml.tensorflow_v1.models.registry

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Code related to the PyTorch model registry for easily creating models.
"""

import re
from typing import Any, Callable, Dict, List, Optional, Union

from sparseml.tensorflow_v1.models.estimator import EstimatorModelFn
from sparseml.tensorflow_v1.utils import tf_compat
from sparseml.utils import TENSORFLOW_V1_FRAMEWORK, parse_optimization_str
from sparsezoo import Zoo
from sparsezoo.objects import Model


__all__ = ["ModelRegistry"]


class _ModelAttributes(object):
    def __init__(
        self,
        input_shape: Any,
        domain: str,
        sub_domain: str,
        architecture: str,
        sub_architecture: str,
        default_dataset: str,
        default_desc: str,
        default_model_fn_creator: EstimatorModelFn,
        base_name_scope: str,
        tl_ignore_tens: List[str],
        repo_source: str,
    ):
        self.input_shape = input_shape
        self.domain = domain
        self.sub_domain = sub_domain
        self.architecture = architecture
        self.sub_architecture = sub_architecture
        self.default_dataset = default_dataset
        self.default_desc = default_desc
        self.default_model_fn_creator = default_model_fn_creator
        self.base_name_scope = base_name_scope
        self.tl_ignore_tens = tl_ignore_tens
        self.repo_source = repo_source


[docs]class ModelRegistry(object): """ Registry class for creating models """ _CONSTRUCTORS = {} # type: Dict[str, Callable] _ATTRIBUTES = {} # type: Dict[str, _ModelAttributes]
[docs] @staticmethod def available_keys() -> List[str]: """ :return: the keys (models) currently available in the registry """ return list(ModelRegistry._CONSTRUCTORS.keys())
[docs] @staticmethod def create(key: str, *args, **kwargs) -> Any: """ Create a new model for the given key :param key: the model key (name) to create :param args: any args to supply to the graph constructor :param kwargs: any keyword args to supply to the graph constructor :return: the outputs from the created graph """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) return ModelRegistry._CONSTRUCTORS[key](*args, **kwargs)
[docs] @staticmethod def create_estimator( key: str, model_dir: str, model_fn_params: Optional[Dict[str, Any]], run_config: tf_compat.estimator.RunConfig, *args, **kwargs, ) -> tf_compat.estimator.Estimator: """ Create Estimator for a model given the key and extra parameters :param key: the key that the model was registered with :param model_dir: directory to save results :param model_fn_params: parameters for model function :param run_config: RunConfig used by the estimator during training :param args: additional positional arguments to pass into model constructor :param kwargs: additional keyword arguments to pass into model constructor :return: an Estimator instance """ model_const = ModelRegistry._CONSTRUCTORS[key] attributes = ModelRegistry._ATTRIBUTES[key] model_fn_creator = attributes.default_model_fn_creator() model_fn = model_fn_creator.create(model_const, *args, **kwargs) model_fn_params = {} if model_fn_params is None else model_fn_params classifier = tf_compat.estimator.Estimator( config=run_config, model_dir=model_dir, model_fn=model_fn, params=model_fn_params, ) return classifier
[docs] @staticmethod def create_zoo_model( key: str, pretrained: Union[bool, str] = True, pretrained_dataset: str = None, ) -> Model: """ Create a sparsezoo Model for the desired model in the zoo :param key: the model key (name) to retrieve :param pretrained: True to load pretrained weights; to load a specific version give a string with the name of the version (pruned-moderate, base), default True :param pretrained_dataset: The dataset to load for the model :return: the sparsezoo Model reference for the given model """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) attributes = ModelRegistry._ATTRIBUTES[key] sparse_name, sparse_category, sparse_target = parse_optimization_str( pretrained if isinstance(pretrained, str) else attributes.default_desc ) return Zoo.load_model( attributes.domain, attributes.sub_domain, attributes.architecture, attributes.sub_architecture, TENSORFLOW_V1_FRAMEWORK, attributes.repo_source, attributes.default_dataset if pretrained_dataset is None else pretrained_dataset, None, sparse_name, sparse_category, sparse_target, )
[docs] @staticmethod def load_pretrained( key: str, pretrained: Union[bool, str] = True, pretrained_dataset: str = None, pretrained_path: str = None, remove_dynamic_tl_vars: bool = False, sess: tf_compat.Session = None, saver: tf_compat.train.Saver = None, ): """ Load pre-trained variables for a given model into a session. Uses a Saver object from TensorFlow to restore the variables from an index and data file. :param key: the model key (name) to create :param pretrained: True to load the default pretrained variables, a string to load a specific pretrained graph (ex: base, optim, optim-perf), or False to not load any pretrained weights :param pretrained_dataset: The dataset to load pretrained weights for (ex: imagenet, mnist, etc). If not supplied will default to the one preconfigured for the model. :param pretrained_path: A path to the pretrained variables to load, if provided will override the pretrained param :param remove_dynamic_tl_vars: True to remove the vars that are used for transfer learning (have a different shape and should not be restored), False to keep all vars in the Saver. Only used if saver is None :param sess: The session to load the model variables into if pretrained_path or pretrained is supplied. If not supplied and required, then will use the default session :param saver: The Saver instance to use to restore the variables for the graph if pretrained_path or pretrained is supplied. If not supplied and required, then will create one using the ModelRegistry.saver function """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) if not sess and (pretrained_path or pretrained): sess = tf_compat.get_default_session() if not saver and (pretrained_path or pretrained): saver = ModelRegistry.saver(key, remove_dynamic_tl_vars) if isinstance(pretrained, str): if pretrained.lower() == "true": pretrained = True elif pretrained.lower() in ["false", "none"]: pretrained = False if pretrained_path: saver.restore(sess, pretrained_path) elif pretrained: zoo_model = ModelRegistry.create_zoo_model( key, pretrained, pretrained_dataset ) try: paths = zoo_model.download_framework_files() index_path = [path for path in paths if path.endswith(".index")] index_path = index_path[0] model_path = index_path[:-6] saver.restore(sess, model_path) except Exception: # try one more time with overwrite on in case files were corrupted paths = zoo_model.download_framework_files(overwrite=True) index_path = [path for path in paths if path.endswith(".index")] if len(index_path) != 1: raise FileNotFoundError( "could not find .index file for {}".format(zoo_model.root_path) ) index_path = index_path[0] model_path = index_path[:-6] saver.restore(sess, model_path)
[docs] @staticmethod def input_shape(key: str): """ :param key: the model key (name) to create :return: the specified input shape for the model """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) return ModelRegistry._ATTRIBUTES[key].input_shape
[docs] @staticmethod def saver(key: str, remove_dynamic_tl_vars: bool = False) -> tf_compat.train.Saver: """ Get a tf compat saver that contains only the variables for the desired architecture specified by key. Note, the architecture must have been created in the current graph already to work. :param key: the model key (name) to get a saver instance for :param remove_dynamic_tl_vars: True to remove the vars that are used for transfer learning (have a different shape and should not be restored), False to keep all vars in the Saver :return: a Saver object with the appropriate vars for the model to restore """ if key not in ModelRegistry._CONSTRUCTORS: raise ValueError( "key {} is not in the model registry; available: {}".format( key, ModelRegistry._CONSTRUCTORS ) ) base_name = ModelRegistry._ATTRIBUTES[key].base_name_scope saver_vars = [ var for var in tf_compat.get_collection(tf_compat.GraphKeys.TRAINABLE_VARIABLES) if base_name in var.name ] saver_vars.extend( [ var for var in tf_compat.global_variables() if ("moving_mean" in var.name or "moving_variance" in var.name) and base_name in var.name ] ) if remove_dynamic_tl_vars: tl_ignore_tens = ModelRegistry._ATTRIBUTES[key].tl_ignore_tens def _check_ignore(var: tf_compat.Variable) -> bool: for ignore in tl_ignore_tens: if re.match(ignore, var.name): return True return False saver_vars = [var for var in saver_vars if not _check_ignore(var)] saver = tf_compat.train.Saver(saver_vars) return saver
[docs] @staticmethod def register( key: Union[str, List[str]], input_shape: Any, domain: str, sub_domain: str, architecture: str, sub_architecture: str, default_dataset: str, default_desc: str, default_model_fn_creator: EstimatorModelFn, base_name_scope: str, tl_ignore_tens: List[str], repo_source: str = "sparseml", ): """ Register a model with the registry. Should be used as a decorator :param key: the model key (name) to create :param input_shape: the specified input shape for the model :param domain: the domain the model belongs to; ex: cv, nlp, etc :param sub_domain: the sub domain the model belongs to; ex: classification, detection, etc :param architecture: the architecture the model belongs to; ex: resnet, mobilenet, etc :param sub_architecture: the sub architecture the model belongs to; ex: 50, 101, etc :param default_dataset: the dataset to use by default for loading pretrained if not supplied :param default_desc: the description to use by default for loading pretrained if not supplied :param default_model_fn_creator: default model creator to use when creating estimator instance :param base_name_scope: the base string used to create the graph under :param tl_ignore_tens: a list of tensors to ignore restoring for if transfer learning :param repo_source: the source repo for the model, default is sparseml :return: the decorator """ if not isinstance(key, List): key = [key] def decorator(const_func): for r_key in key: if r_key in ModelRegistry._CONSTRUCTORS: raise ValueError("key {} is already registered".format(key)) ModelRegistry._CONSTRUCTORS[r_key] = const_func ModelRegistry._ATTRIBUTES[r_key] = _ModelAttributes( input_shape, domain, sub_domain, architecture, sub_architecture, default_dataset, default_desc, default_model_fn_creator, base_name_scope, tl_ignore_tens, repo_source, ) return const_func return decorator