Source code for sparseml.optim.modifier

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Code related to modifiers that is shared across frameworks.
Modifiers allow modifying the training process of a model; ex to perform model pruning.
"""

import hashlib
import re
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Union

import yaml
from yaml import ScalarNode

from sparseml.optim.helpers import evaluate_recipe_yaml_str_equations
from sparseml.sparsification.types import SparsificationTypes


__all__ = [
    "BaseProp",
    "ModifierProp",
    "BaseObject",
    "BaseModifier",
    "BaseScheduled",
    "BaseUpdate",
    "ModifierYAML",
]


[docs]class BaseProp(ABC): """ BaseProp class meant to be implemented by any property decorators """ @abstractmethod def __get__(self, obj, obj_type=None): """ :param obj: the object to get the attribute from :param obj_type: unused :return: The retrieved value from the obj """ pass @abstractmethod def __set__(self, obj, value): """ :param obj: the object to get the attribute from :param value: the value to set """ pass def __delete__(self, obj): """ Override to not allow deletes for modifier properties :param obj: the object """ raise AttributeError("can't delete attribute")
[docs] @abstractmethod def getter(self, func_get: Callable): """ :param func_get: the getter function :return: the recreated instance with the new getter function """ pass
[docs] @abstractmethod def setter(self, func_set: Callable): """ :param func_set: the setter function :return: the recreated instance with the new setter function """ pass
[docs]class ModifierProp(BaseProp): """ Property used to decorate a modifier. Use for creating getters and setters in a modifier. Handles making sure props cannot be changed after a certain point; ex after initialized. Also, marks the properties so they can be easily collected and serialized later. :param serializable: True if the property should be serialized (ex in yaml), False otherwise. Default True :param restrict_initialized: True to keep the property from being set after initialized, False otherwise. Default True :param restrict_enabled: True to keep the property from being set after enabled, False otherwise. Default False :param restrict_extras: extra attributes to check, if any are truthy then keep from being set. Default None :param no_serialize_val: If prop is equal to this value, will not serialize the prop :param func_get: The function getter :param func_set: The function setter :param doc: The docs function """ def __init__( self, serializable: bool = True, restrict_initialized: bool = True, restrict_enabled: bool = False, restrict_extras: List[str] = None, no_serialize_val: Any = None, func_get: Callable = None, func_set: Callable = None, doc: Callable = None, ): self._func_get = func_get self._func_set = func_set self._serializable = serializable self._restrictions = [] self._no_serialize_val = no_serialize_val if restrict_initialized: self._restrictions.append("_initialized") if restrict_enabled: self._restrictions.append("_enabled") if restrict_extras is not None: self._restrictions.extend(restrict_extras) if doc is None and self._func_get is not None: doc = self._func_get.__doc__ self.__doc__ = doc def __call__(self, getter: Callable) -> BaseProp: """ :param getter: the annotated getter to use to get the attribute :return: the current property instance """ self._func_get = getter return self def __get__(self, obj, obj_type=None): """ Get the attribute from the current given object for the current modifier :param obj: the object to get the attribute from :param obj_type: unused :return: The retrieved value from the obj """ if obj is None: return self if self._func_get is None: raise AttributeError("unreadable attribute") return self._func_get(obj) def __set__(self, obj, value): """ Set the attribute in the current given object for the current modifier. If the attribute can't be set because of the current modifiers state, (ex: initialized) then will raise a AttributeError :param obj: the object to get the attribute from :param value: the value to set """ if self._func_set is None: raise AttributeError("can't set attribute") if self.restrictions: for rest in self.restrictions: if hasattr(obj, rest) and getattr(obj, rest): raise AttributeError( "Cannot change {} after initializing {}".format( self._func_get.__name__, obj.__class__.__name__ ) ) self._func_set(obj, value) @property def serializable(self) -> bool: """ :return: True if the property should be serialized (ex in yaml), False otherwise """ return self._serializable @property def restrictions(self) -> List[str]: """ :return: The attributes to check for restricting when the attribute can be set """ return self._restrictions @property def no_serialize_val(self) -> Any: """ :return: a value that if the prop is equal to, will not serialize the prop """ return self._no_serialize_val
[docs] def getter(self, func_get: Callable) -> BaseProp: """ Create a ModifierProp based off the current instance with the getter function :param func_get: the getter function :return: the recreated instance with the new getter function """ return type(self)( **self._creator_kwargs(), func_get=func_get, func_set=self._func_set, doc=self.__doc__, )
[docs] def setter(self, func_set: Callable) -> BaseProp: """ Create a ModifierProp based off the current instance with the setter function :param func_set: the setter function :return: the recreated instance with the new setter function """ return type(self)( **self._creator_kwargs(), func_get=self._func_get, func_set=func_set, doc=self.__doc__, )
def _creator_kwargs(self) -> Dict: return { "serializable": self._serializable, "restrict_initialized": False, "restrict_enabled": False, "restrict_extras": self._restrictions, }
[docs]class BaseObject(object): """ BaseObject to accept kwargs so multiple inheritance will work with the modifier classes. kwargs param must be empty by the time this class is called. :param kwargs: standard key word args, used to support multi inheritance """ def __init__(self, **kwargs): super().__init__() if len(kwargs) != 0: raise ValueError( ( "kwargs must be empty at _BaseObject, " "extras passed in of {} for {}" ).format(kwargs, self.__class__.__name__) )
[docs]class BaseModifier(BaseObject): """ Parent class meant to be used for all modifiers. Handles base implementations for properties and methods. :param kwargs: standard key word args, used to support multi inheritance """ @staticmethod def _convert_to_framework_modifiers(yaml_str: str, framework: str) -> str: pattern = re.compile(r"!(?P<mod_class>(?!.*\.)[a-zA-Z_][a-zA-Z^._0-9]+)") yaml_str = pattern.sub(r"!{}.\g<mod_class>".format(framework), yaml_str) return yaml_str
[docs] @staticmethod def load_framework_list(yaml_str: str, framework: str): """ :param yaml_str: a string representation of the yaml syntax to load modifiers :param framework: the framework to load the modifiers for :return: the loaded modifiers list or dictionary of stage name to stage modifiers list if given a yaml string of a staged recipe """ def _load_stage_modifiers(stage_container): stage_modifiers = [] # type: List[BaseModifier] for name, item in stage_container.items(): if "modifiers" in name and isinstance(item, List): stage_modifiers.extend(item) elif isinstance(item, BaseModifier): stage_modifiers.append(item) elif isinstance(item, List) and any( isinstance(element, BaseModifier) for element in item ): # invalid modifier group name modifier_type = type( [mod for mod in item if isinstance(mod, BaseModifier)][0] ) raise ValueError( "Invalid modifier location. Grouped modifiers in recipes must " "be listed in lists with 'modifiers' in its name. A modifier " f"of type {modifier_type} was found in recipe list {name}" ) return stage_modifiers # evaluate recipe equations and load into yaml container object yaml_str = evaluate_recipe_yaml_str_equations(yaml_str) yaml_str = BaseModifier._convert_to_framework_modifiers(yaml_str, framework) container = yaml.safe_load(yaml_str) if isinstance(container, BaseModifier): modifiers = [container] elif isinstance(container, List): modifiers = container else: # Dict if any("modifiers" in key for key in container): # non-staged recipe, treat entire recipe as stage modifiers = _load_stage_modifiers(container) else: # staged recipe, return dict of stage_name -> modifiers modifiers = {} for stage_name, stage_item in container.items(): if not isinstance(stage_item, Dict): continue # stages must be represented as a Dict if any("modifiers" in key for key in stage_item): modifiers[stage_name] = _load_stage_modifiers(stage_item) if not modifiers: raise ValueError( "Unable to find any modifiers in given recipe. Modifiers must be " "listed as lists under yaml keys that include 'modifiers' in their " "name. Those keys and lists may also be nested under an extra key for " "staged recipes." ) return modifiers
[docs] @staticmethod def load_framework_obj(yaml_str: str, framework: str): """ :param yaml_str: a string representation of the yaml syntax to load a modifier :param framework: the framework to load the modifier for :return: the loaded modifier object """ yaml_str = BaseModifier._convert_to_framework_modifiers(yaml_str, framework) modifier = yaml.safe_load(yaml_str) return modifier
[docs] @staticmethod def yaml_key(clazz, framework: Union[str, None] = None): """ create a key for loading in yaml from the class and the framework :param clazz: the class representation to create the key for :param framework: the string representing the ML framework the modifier class is for. Default is None. :return: the formatted key; ex: !{framework}.{clazz.__name__} """ if framework is None: return "!{}".format(clazz.__name__) return "!{}.{}".format(framework, clazz.__name__)
[docs] @staticmethod def comparator(one, two) -> int: """ Comparator implementation for Modifiers. Compares first on end_epoch, next on start_epoch, and finally on identifier. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ # compare first on end epoch compare = BaseModifier.comparator_ends(one, two) if compare == 0: # if ends equal, compare next on start compare = BaseModifier.comparator_starts(one, two) if compare == 0: # if still equal, compare on identifier compare = BaseModifier.comparator_identifiers(one, two) return compare
[docs] @staticmethod def comparator_lists(one: List["BaseModifier"], two: List["BaseModifier"]) -> int: """ Comparator for list of modifiers, compares the max end, min start epochs of either lists and then the maximal identifiers of either :param one: first list of modifiers to compare :param two: second list of modifiers to compare :return: int representing where one is in relation to two """ # compare first on end epoch compare = BaseModifier.comparator_ends( max(one, key=lambda mod: mod.end_epoch), max(two, key=lambda mod: mod.end_epoch), ) if compare == 0: # if ends equal, compare next on start compare = BaseModifier.comparator_starts( min(one, key=lambda mod: mod.start_epoch), min(two, key=lambda mod: mod.start_epoch), ) if compare == 0: # if still equal, compare on identifier compare = BaseModifier.comparator_identifiers( max(one, key=lambda mod: mod.identifier()), max(two, key=lambda mod: mod.identifier()), ) return compare
[docs] @staticmethod def comparator_ends(one, two) -> int: """ Comparator implementation for Modifiers based on end_epoch. Modifiers with ends greater than another will come out higher. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ one_end = getattr(one, "end_epoch") if hasattr(one, "end_epoch") else None two_end = getattr(two, "end_epoch") if hasattr(two, "end_epoch") else None if one_end == two_end: return 0 if one_end is None or two_end == -1: # if no end for one and two has one or two never ends, return one before two return -1 if two_end is None or one_end == -1: # if no end for two and one has one or one never ends, return one after two return 1 if one_end < two_end: return -1 return 1
[docs] @staticmethod def comparator_starts(one, two) -> int: """ Comparator implementation for Modifiers based on start_epoch. Modifiers with starts greater than another will come out higher. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ one_start = getattr(one, "start_epoch") if hasattr(one, "start_epoch") else None two_start = getattr(two, "start_epoch") if hasattr(two, "start_epoch") else None if one_start == two_start: return 0 if one_start is None: # if no start for one and two has one, return one before two return -1 if two_start is None: # if no start for two and one has one, return one after two return 1 if one_start > two_start: # if one starts after two, return after return 1 return -1
[docs] @staticmethod def comparator_identifiers(one, two) -> int: """ Comparator implementation for Modifiers based on identifier. Modifiers with ends greater than another will come out higher. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ one_id = one.identifier() two_id = two.identifier() if one_id < two_id: return -1 if one_id > two_id: return 1 return 0
def __init__(self, **kwargs): super().__init__(**kwargs) self._initialized = False self._enabled = True def __str__(self): formatted = [ " {}".format("{}: {}".format(key, val)) for key, val in self.props(only_serializable=True, format_str=True).items() ] return "{}\n{}".format( BaseModifier.yaml_key(self.__class__), "\n".join(formatted) ) def __repr__(self): return "{}({})".format( self.__class__.__name__, self.props(only_serializable=False, format_repr=True), )
[docs] @ModifierProp(serializable=False) def sparsification_types(self) -> List[SparsificationTypes]: """ :return: the sparsification types this modifier instance will apply """ return []
[docs] @ModifierProp(serializable=False, restrict_initialized=False) def initialized(self) -> bool: """ :return: True if the modifier has gone through the initialized life cycle, False otherwise """ return self._initialized
@ModifierProp(serializable=False, restrict_initialized=False) def enabled(self) -> bool: """ :return: True if the modifier is currently enabled and making updates, False otherwise """ return self._enabled
[docs] @enabled.setter def enabled(self, value: bool): """ :param value: True to allow the modifier to make updates, False otherwise """ self._enabled = value
[docs] def identifier(self, extra: Any = "") -> str: """ :param extra: any extra identifier to append to the end of the string :return: generate an identifier for the current modifier based on its class name and params """ props = self.props(only_serializable=True, format_str=True) props_list = [f"{key}:{val}" for key, val in props.items()] props_list.sort() # convert to list and sort to make deterministic hash_str = hashlib.md5(str(props_list).encode()).hexdigest() return f"{self.__class__.__name__}-{hash_str}{f'-{extra}' if extra else ''}"
[docs] def props( self, only_serializable: bool, format_str: bool = False, format_repr: bool = False, ) -> Dict[str, Any]: """ Gather all the ModifierProps for the current instance into a dictionary collection. :param only_serializable: True if only props marked as serializable should be collected, False otherwise :param format_str: True to format the values properly for a str. Ex: None values are formatted to null and otherwise str is called :param format_repr: True to format the values properly for a repr. :return: the collected properties with names mapping to values """ if format_str and format_repr: raise ValueError( "only format_str or format_repr can be True, both are currently True" ) props = {} for attr in dir(self): if attr.startswith("_"): continue func = getattr(self.__class__, attr) if not isinstance(func, ModifierProp) or ( only_serializable and not func.serializable ): continue val = getattr(self, attr) no_serialize_val = func.no_serialize_val if val == no_serialize_val: continue if format_str: props[attr] = str(val) if val is not None else "null" elif format_repr: props[attr] = repr(val) else: props[attr] = val return props
[docs]class BaseScheduled(BaseObject): """ Abstract class meant to be used for all scheduled modifiers. :py:func `~Modifier` is also meant to be inherited alongside this class. Handles base implementations for scheduled properties and methods to allow a schedule to be enforced. :param start_epoch: the epoch to start the scheduled modifier at :param min_start: the minimum value start_epoch can be, otherwise will raise a ValueError :param end_epoch: the epoch to end the scheduled modifier at :param min_end: the minimum value end_epoch can be, otherwise will raise a ValueError :param end_comparator: integer value representing how the end_epoch should be compared to start_epoch. if == None, then end_epoch can only be set to what its initial value was. if == -1, then end_epoch can be -1, equal to, or greater than start_epoch. if == 0, then end_epoch can be equal to or greater than start_epoch. if == 1, then end_epoch can only be greater than start_epoch. :param kwargs: standard key word args, used to support multi inheritance """ def __init__( self, start_epoch: float = -1.0, min_start: float = -1.0, end_epoch: float = -1.0, min_end: float = -1.0, end_comparator: Union[int, None] = 0, **kwargs, ): super().__init__(**kwargs) self._start_epoch = start_epoch self._init_start = start_epoch self._min_start = min_start self._end_epoch = end_epoch self._init_end = end_epoch self._min_end = min_end self._end_comparator = end_comparator self.validate_schedule() @ModifierProp() def start_epoch(self) -> float: """ :return: The epoch to start the modifier at (set to -1.0 so it starts immediately) """ return self._start_epoch
[docs] @start_epoch.setter def start_epoch(self, value: float): """ :param value: The epoch to start the modifier at (set to -1.0 so it starts immediately) """ self._start_epoch = value self.validate_schedule()
@ModifierProp() def end_epoch(self) -> float: """ :return: The epoch to end the modifier at (set to -1.0 so it never ends) """ return self._end_epoch
[docs] @end_epoch.setter def end_epoch(self, value: float): """ :param value: The epoch to end the modifier at (set to -1.0 so it never ends) """ self._end_epoch = value self.validate_schedule()
[docs] def validate_schedule(self): """ Validate the schedule values of the params for the current instance are valid """ if self._start_epoch < self._min_start: raise ValueError( "start_epoch of {} must be greater than or equal to {} for {}".format( self._start_epoch, self._min_start, self.__class__.__name__ ) ) if self._end_epoch < self._min_end: raise ValueError( "end_epoch of {} must be greater than or equal to {} for {}".format( self._end_epoch, self._min_end, self.__class__.__name__ ) ) if self._end_comparator is None and self._end_epoch != self._init_end: raise ValueError( "end_epoch of {} must be equal the init value of {} for {}".format( self._end_epoch, self._init_end, self.__class__.__name__ ) ) if ( self._end_comparator == -1 and self._end_epoch < self._start_epoch and (self._end_epoch != -1) ): raise ValueError( ( "end_epoch of {} must be greater than" " or equal to start_epoch of {} for {} or equal to -1" ).format(self._end_epoch, self._start_epoch, self.__class__.__name__) ) if self._end_comparator == 0 and self._start_epoch > self._end_epoch: raise ValueError( ( "end_epoch of {} must be greater than" " or equal to start_epoch of {} for {}" ).format(self._end_epoch, self._start_epoch, self.__class__.__name__) ) if self._end_comparator == 1 and self._start_epoch >= self._end_epoch: raise ValueError( "end_epoch of {} must be greater than start_epoch of {} for {}".format( self._end_epoch, self._start_epoch, self.__class__.__name__ ) )
[docs]class BaseUpdate(BaseObject): """ Abstract class meant to be used for all update modifiers. :py:func `~Modifier` and :py:func `~ScheduledModifier` are also meant to be inherited alongside this class. Handles base implementations for scheduled properties and methods to allow updates to be enforced. :param update_frequency: The number of epochs or fraction of epochs to update at between start and end :param min_frequency: The minimum acceptable value for update_frequency, default -1 :param kwargs: standard key word args, used to support multi inheritance """ def __init__(self, update_frequency: float, min_frequency: float, **kwargs): super().__init__(**kwargs) self._update_frequency = update_frequency self._min_frequency = min_frequency self.validate_update() @ModifierProp() def update_frequency(self) -> float: """ :return: The number of epochs or fraction of epochs to update at between start and end """ return self._update_frequency
[docs] @update_frequency.setter def update_frequency(self, value: float): """ :param value: The number of epochs or fraction of epochs to update at between start and end """ self._update_frequency = value self.validate_update()
[docs] def validate_update(self): """ Validate the update schedule values of the params for the current instance are valid """ if self._update_frequency < self._min_frequency: raise ValueError( ( "update_frequency of {} must be greater than or " "equal to {} for {}" ).format( self._update_frequency, self._min_frequency, self.__class__.__name__ ) )
[docs]class ModifierYAML(object): """ A decorator to handle making a modifier class YAML ready. IE it can be loaded in through the yaml plugin easily. :param framework: the string representing the ML framework the modifier should be stored under """ def __init__(self, framework: str): if not framework: raise ValueError("a framework is required") self._framework = framework def __call__(self, clazz): """ :param clazz: the class to create yaml constructors for :return: the class after yaml constructors have been added """ yaml_key = "{}".format(BaseModifier.yaml_key(clazz, self._framework)) def constructor(loader, node): instance = clazz.__new__(clazz) yield instance state = ( loader.construct_mapping(node, deep=True) if not isinstance(node, ScalarNode) else {} ) # ignore the log_types arg in recipes to maintain backwards compatability # while recipes are updated if "log_types" in state: del state["log_types"] instance.__init__(**state) yaml.add_constructor(yaml_key, constructor) yaml.add_constructor( yaml_key, constructor, yaml.SafeLoader, ) return clazz