# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code related to the PyTorch model registry for easily creating models.
"""
import re
from typing import Any, Callable, Dict, List, Optional, Union
from sparseml.tensorflow_v1.models.estimator import EstimatorModelFn
from sparseml.tensorflow_v1.utils import tf_compat
from sparseml.utils import TENSORFLOW_V1_FRAMEWORK, parse_optimization_str
from sparsezoo import Zoo
from sparsezoo.objects import Model
__all__ = ["ModelRegistry"]
class _ModelAttributes(object):
def __init__(
self,
input_shape: Any,
domain: str,
sub_domain: str,
architecture: str,
sub_architecture: str,
default_dataset: str,
default_desc: str,
default_model_fn_creator: EstimatorModelFn,
base_name_scope: str,
tl_ignore_tens: List[str],
repo_source: str,
):
self.input_shape = input_shape
self.domain = domain
self.sub_domain = sub_domain
self.architecture = architecture
self.sub_architecture = sub_architecture
self.default_dataset = default_dataset
self.default_desc = default_desc
self.default_model_fn_creator = default_model_fn_creator
self.base_name_scope = base_name_scope
self.tl_ignore_tens = tl_ignore_tens
self.repo_source = repo_source
[docs]class ModelRegistry(object):
"""
Registry class for creating models
"""
_CONSTRUCTORS = {} # type: Dict[str, Callable]
_ATTRIBUTES = {} # type: Dict[str, _ModelAttributes]
[docs] @staticmethod
def available_keys() -> List[str]:
"""
:return: the keys (models) currently available in the registry
"""
return list(ModelRegistry._CONSTRUCTORS.keys())
[docs] @staticmethod
def create(key: str, *args, **kwargs) -> Any:
"""
Create a new model for the given key
:param key: the model key (name) to create
:param args: any args to supply to the graph constructor
:param kwargs: any keyword args to supply to the graph constructor
:return: the outputs from the created graph
"""
if key not in ModelRegistry._CONSTRUCTORS:
raise ValueError(
"key {} is not in the model registry; available: {}".format(
key, ModelRegistry._CONSTRUCTORS
)
)
return ModelRegistry._CONSTRUCTORS[key](*args, **kwargs)
[docs] @staticmethod
def create_estimator(
key: str,
model_dir: str,
model_fn_params: Optional[Dict[str, Any]],
run_config: tf_compat.estimator.RunConfig,
*args,
**kwargs,
) -> tf_compat.estimator.Estimator:
"""
Create Estimator for a model given the key and extra parameters
:param key: the key that the model was registered with
:param model_dir: directory to save results
:param model_fn_params: parameters for model function
:param run_config: RunConfig used by the estimator during training
:param args: additional positional arguments to pass into model constructor
:param kwargs: additional keyword arguments to pass into model constructor
:return: an Estimator instance
"""
model_const = ModelRegistry._CONSTRUCTORS[key]
attributes = ModelRegistry._ATTRIBUTES[key]
model_fn_creator = attributes.default_model_fn_creator()
model_fn = model_fn_creator.create(model_const, *args, **kwargs)
model_fn_params = {} if model_fn_params is None else model_fn_params
classifier = tf_compat.estimator.Estimator(
config=run_config,
model_dir=model_dir,
model_fn=model_fn,
params=model_fn_params,
)
return classifier
[docs] @staticmethod
def create_zoo_model(
key: str,
pretrained: Union[bool, str] = True,
pretrained_dataset: str = None,
) -> Model:
"""
Create a sparsezoo Model for the desired model in the zoo
:param key: the model key (name) to retrieve
:param pretrained: True to load pretrained weights; to load a specific version
give a string with the name of the version (pruned-moderate, base),
default True
:param pretrained_dataset: The dataset to load for the model
:return: the sparsezoo Model reference for the given model
"""
if key not in ModelRegistry._CONSTRUCTORS:
raise ValueError(
"key {} is not in the model registry; available: {}".format(
key, ModelRegistry._CONSTRUCTORS
)
)
attributes = ModelRegistry._ATTRIBUTES[key]
sparse_name, sparse_category, sparse_target = parse_optimization_str(
pretrained if isinstance(pretrained, str) else attributes.default_desc
)
return Zoo.load_model(
attributes.domain,
attributes.sub_domain,
attributes.architecture,
attributes.sub_architecture,
TENSORFLOW_V1_FRAMEWORK,
attributes.repo_source,
attributes.default_dataset
if pretrained_dataset is None
else pretrained_dataset,
None,
sparse_name,
sparse_category,
sparse_target,
)
[docs] @staticmethod
def load_pretrained(
key: str,
pretrained: Union[bool, str] = True,
pretrained_dataset: str = None,
pretrained_path: str = None,
remove_dynamic_tl_vars: bool = False,
sess: tf_compat.Session = None,
saver: tf_compat.train.Saver = None,
):
"""
Load pre-trained variables for a given model into a session.
Uses a Saver object from TensorFlow to restore the variables
from an index and data file.
:param key: the model key (name) to create
:param pretrained: True to load the default pretrained variables,
a string to load a specific pretrained graph
(ex: base, optim, optim-perf),
or False to not load any pretrained weights
:param pretrained_dataset: The dataset to load pretrained weights for
(ex: imagenet, mnist, etc).
If not supplied will default to the one preconfigured for the model.
:param pretrained_path: A path to the pretrained variables to load,
if provided will override the pretrained param
:param remove_dynamic_tl_vars: True to remove the vars that are used for
transfer learning (have a different shape and should not be restored),
False to keep all vars in the Saver.
Only used if saver is None
:param sess: The session to load the model variables into
if pretrained_path or pretrained is supplied.
If not supplied and required, then will use the default session
:param saver: The Saver instance to use to restore the variables
for the graph if pretrained_path or pretrained is supplied.
If not supplied and required, then will create one using the
ModelRegistry.saver function
"""
if key not in ModelRegistry._CONSTRUCTORS:
raise ValueError(
"key {} is not in the model registry; available: {}".format(
key, ModelRegistry._CONSTRUCTORS
)
)
if not sess and (pretrained_path or pretrained):
sess = tf_compat.get_default_session()
if not saver and (pretrained_path or pretrained):
saver = ModelRegistry.saver(key, remove_dynamic_tl_vars)
if isinstance(pretrained, str):
if pretrained.lower() == "true":
pretrained = True
elif pretrained.lower() in ["false", "none"]:
pretrained = False
if pretrained_path:
saver.restore(sess, pretrained_path)
elif pretrained:
zoo_model = ModelRegistry.create_zoo_model(
key, pretrained, pretrained_dataset
)
try:
paths = zoo_model.download_framework_files()
index_path = [path for path in paths if path.endswith(".index")]
index_path = index_path[0]
model_path = index_path[:-6]
saver.restore(sess, model_path)
except Exception:
# try one more time with overwrite on in case files were corrupted
paths = zoo_model.download_framework_files(overwrite=True)
index_path = [path for path in paths if path.endswith(".index")]
if len(index_path) != 1:
raise FileNotFoundError(
"could not find .index file for {}".format(zoo_model.root_path)
)
index_path = index_path[0]
model_path = index_path[:-6]
saver.restore(sess, model_path)
[docs] @staticmethod
def saver(key: str, remove_dynamic_tl_vars: bool = False) -> tf_compat.train.Saver:
"""
Get a tf compat saver that contains only the variables for the desired
architecture specified by key.
Note, the architecture must have been created in the current graph already
to work.
:param key: the model key (name) to get a saver instance for
:param remove_dynamic_tl_vars: True to remove the vars that are used for
transfer learning (have a different shape and should not be restored),
False to keep all vars in the Saver
:return: a Saver object with the appropriate vars for the model to restore
"""
if key not in ModelRegistry._CONSTRUCTORS:
raise ValueError(
"key {} is not in the model registry; available: {}".format(
key, ModelRegistry._CONSTRUCTORS
)
)
base_name = ModelRegistry._ATTRIBUTES[key].base_name_scope
saver_vars = [
var
for var in tf_compat.get_collection(tf_compat.GraphKeys.TRAINABLE_VARIABLES)
if base_name in var.name
]
saver_vars.extend(
[
var
for var in tf_compat.global_variables()
if ("moving_mean" in var.name or "moving_variance" in var.name)
and base_name in var.name
]
)
if remove_dynamic_tl_vars:
tl_ignore_tens = ModelRegistry._ATTRIBUTES[key].tl_ignore_tens
def _check_ignore(var: tf_compat.Variable) -> bool:
for ignore in tl_ignore_tens:
if re.match(ignore, var.name):
return True
return False
saver_vars = [var for var in saver_vars if not _check_ignore(var)]
saver = tf_compat.train.Saver(saver_vars)
return saver
[docs] @staticmethod
def register(
key: Union[str, List[str]],
input_shape: Any,
domain: str,
sub_domain: str,
architecture: str,
sub_architecture: str,
default_dataset: str,
default_desc: str,
default_model_fn_creator: EstimatorModelFn,
base_name_scope: str,
tl_ignore_tens: List[str],
repo_source: str = "sparseml",
):
"""
Register a model with the registry. Should be used as a decorator
:param key: the model key (name) to create
:param input_shape: the specified input shape for the model
:param domain: the domain the model belongs to; ex: cv, nlp, etc
:param sub_domain: the sub domain the model belongs to;
ex: classification, detection, etc
:param architecture: the architecture the model belongs to;
ex: resnet, mobilenet, etc
:param sub_architecture: the sub architecture the model belongs to;
ex: 50, 101, etc
:param default_dataset: the dataset to use by default for loading
pretrained if not supplied
:param default_desc: the description to use by default for loading
pretrained if not supplied
:param default_model_fn_creator: default model creator to use when creating
estimator instance
:param base_name_scope: the base string used to create the graph under
:param tl_ignore_tens: a list of tensors to ignore restoring for
if transfer learning
:param repo_source: the source repo for the model, default is sparseml
:return: the decorator
"""
if not isinstance(key, List):
key = [key]
def decorator(const_func):
for r_key in key:
if r_key in ModelRegistry._CONSTRUCTORS:
raise ValueError("key {} is already registered".format(key))
ModelRegistry._CONSTRUCTORS[r_key] = const_func
ModelRegistry._ATTRIBUTES[r_key] = _ModelAttributes(
input_shape,
domain,
sub_domain,
architecture,
sub_architecture,
default_dataset,
default_desc,
default_model_fn_creator,
base_name_scope,
tl_ignore_tens,
repo_source,
)
return const_func
return decorator