sparseml.pytorch.models package

Submodules

sparseml.pytorch.models.registry module

Code related to the PyTorch model registry for easily creating models.

class sparseml.pytorch.models.registry. ModelRegistry [source]

Bases: object

Registry class for creating models

static available_keys ( ) List [ str ] [source]
Returns

the keys (models) currently available in the registry

static create ( key : str , pretrained : Union [ bool , str ] = False , pretrained_path : Optional [ str ] = None , pretrained_dataset : Optional [ str ] = None , load_strict : bool = True , ignore_error_tensors : Optional [ List [ str ] ] = None , ** kwargs ) torch.nn.modules.module.Module [source]

Create a new model for the given key

Parameters
  • key – the model key (name) to create

  • pretrained – True to load pretrained weights; to load a specific version give a string with the name of the version (pruned-moderate, base). Default None

  • pretrained_path – A model file path to load into the created model

  • pretrained_dataset – The dataset to load for the model

  • load_strict – True to make sure all states are found in and loaded in model, False otherwise; default True

  • ignore_error_tensors – tensors to ignore if there are errors in loading

  • kwargs – any keyword args to supply to the model constructor

Returns

the instantiated model

static create_zoo_model ( key : str , pretrained : Union [ bool , str ] = True , pretrained_dataset : Optional [ str ] = None ) sparsezoo.objects.model.Model [source]

Create a sparsezoo Model for the desired model in the zoo

Parameters
  • key – the model key (name) to retrieve

  • pretrained – True to load pretrained weights; to load a specific version give a string with the name of the version (optim, optim-perf), default True

  • pretrained_dataset – The dataset to load for the model

Returns

the sparsezoo Model reference for the given model

static input_shape ( key : str ) Any [source]
Parameters

key – the model key (name) to create

Returns

the specified input shape for the model

static register ( key : Union [ str , List [ str ] ] , input_shape : Any , domain : str , sub_domain : str , architecture : str , sub_architecture : str , default_dataset : str , default_desc : str , repo_source : str = 'sparseml' , def_ignore_error_tensors : Optional [ List [ str ] ] = None , desc_args : Optional [ Dict [ str , Tuple [ str , Any ] ] ] = None ) [source]

Register a model with the registry. Should be used as a decorator

Parameters
  • key – the model key (name) to create

  • input_shape – the specified input shape for the model

  • domain – the domain the model belongs to; ex: cv, nlp, etc

  • sub_domain – the sub domain the model belongs to; ex: classification, detection, etc

  • architecture – the architecture the model belongs to; ex: resnet, mobilenet, etc

  • sub_architecture – the sub architecture the model belongs to; ex: 50, 101, etc

  • default_dataset – the dataset to use by default for loading pretrained if not supplied

  • default_desc – the description to use by default for loading pretrained if not supplied

  • repo_source – the source repo for the model, default is sparseml

  • def_ignore_error_tensors – tensors to ignore if there are errors in loading

  • desc_args – args that should be changed based on the description

Returns

the decorator

static register_wrapped_model_constructor ( wrapped_constructor : Callable , key : Union [ str , List [ str ] ] , input_shape : Any , domain : str , sub_domain : str , architecture : str , sub_architecture : str , default_dataset : str , default_desc : str , repo_source : str , def_ignore_error_tensors : Optional [ List [ str ] ] = None , desc_args : Optional [ Dict [ str , Tuple [ str , Any ] ] ] = None ) [source]

Register a model with the registry from a model constructor or provider function

Parameters
  • wrapped_constructor – Model constructor wrapped to be compatible by call from ModelRegistry.create should have pretrained, pretrained_path, pretrained_dataset, load_strict, ignore_error_tensors, and kwargs as arguments

  • key – the model key (name) to create

  • input_shape – the specified input shape for the model

  • domain – the domain the model belongs to; ex: cv, nlp, etc

  • sub_domain – the sub domain the model belongs to; ex: classification, detection, etc

  • architecture – the architecture the model belongs to; ex: resnet, mobilenet, etc

  • sub_architecture – the sub architecture the model belongs to; ex: 50, 101, etc

  • default_dataset – the dataset to use by default for loading pretrained if not supplied

  • default_desc – the description to use by default for loading pretrained if not supplied

  • repo_source – the source repo for the model; ex: sparseml, torchvision

  • def_ignore_error_tensors – tensors to ignore if there are errors in loading

  • desc_args – args that should be changed based on the description

Returns

The constructor wrapper registered with the registry

Module contents

Code for creating and loading models in PyTorch