Source code for sparseml.optim.modifier

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Code related to modifiers that is shared across frameworks.
Modifiers allow modifying the training process of a model; ex to perform model pruning.
"""

import hashlib
import re
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Union

import yaml

from sparseml.optim.helpers import evaluate_recipe_yaml_str_equations
from sparseml.utils import ALL_TOKEN, validate_str_iterable


__all__ = [
    "BaseProp",
    "ModifierProp",
    "BaseObject",
    "BaseModifier",
    "BaseScheduled",
    "BaseUpdate",
    "ModifierYAML",
]


[docs]class BaseProp(ABC): """ BaseProp class meant to be implemented by any property decorators """ @abstractmethod def __get__(self, obj, obj_type=None): """ :param obj: the object to get the attribute from :param obj_type: unused :return: The retrieved value from the obj """ pass @abstractmethod def __set__(self, obj, value): """ :param obj: the object to get the attribute from :param value: the value to set """ pass def __delete__(self, obj): """ Override to not allow deletes for modifier properties :param obj: the object """ raise AttributeError("can't delete attribute")
[docs] @abstractmethod def getter(self, func_get: Callable): """ :param func_get: the getter function :return: the recreated instance with the new getter function """ pass
[docs] @abstractmethod def setter(self, func_set: Callable): """ :param func_set: the setter function :return: the recreated instance with the new setter function """ pass
[docs]class ModifierProp(BaseProp): """ Property used to decorate a modifier. Use for creating getters and setters in a modifier. Handles making sure props cannot be changed after a certain point; ex after initialized. Also, marks the properties so they can be easily collected and serialized later. :param serializable: True if the property should be serialized (ex in yaml), False otherwise. Default True :param restrict_initialized: True to keep the property from being set after initialized, False otherwise. Default True :param restrict_enabled: True to keep the property from being set after enabled, False otherwise. Default False :param restrict_extras: extra attributes to check, if any are truthy then keep from being set. Default None :param no_serialize_val: If prop is equal to this value, will not serialize the prop :param func_get: The function getter :param func_set: The function setter :param doc: The docs function """ def __init__( self, serializable: bool = True, restrict_initialized: bool = True, restrict_enabled: bool = False, restrict_extras: List[str] = None, no_serialize_val: Any = None, func_get: Callable = None, func_set: Callable = None, doc: Callable = None, ): self._func_get = func_get self._func_set = func_set self._serializable = serializable self._restrictions = [] self._no_serialize_val = no_serialize_val if restrict_initialized: self._restrictions.append("_initialized") if restrict_enabled: self._restrictions.append("_enabled") if restrict_extras is not None: self._restrictions.extend(restrict_extras) if doc is None and self._func_get is not None: doc = self._func_get.__doc__ self.__doc__ = doc def __call__(self, getter: Callable) -> BaseProp: """ :param getter: the annotated getter to use to get the attribute :return: the current property instance """ self._func_get = getter return self def __get__(self, obj, obj_type=None): """ Get the attribute from the current given object for the current modifier :param obj: the object to get the attribute from :param obj_type: unused :return: The retrieved value from the obj """ if obj is None: return self if self._func_get is None: raise AttributeError("unreadable attribute") return self._func_get(obj) def __set__(self, obj, value): """ Set the attribute in the current given object for the current modifier. If the attribute can't be set because of the current modifiers state, (ex: initialized) then will raise a AttributeError :param obj: the object to get the attribute from :param value: the value to set """ if self._func_set is None: raise AttributeError("can't set attribute") if self.restrictions: for rest in self.restrictions: if hasattr(obj, rest) and getattr(obj, rest): raise AttributeError( "Cannot change {} after initializing {}".format( self._func_get.__name__, obj.__class__.__name__ ) ) self._func_set(obj, value) @property def serializable(self) -> bool: """ :return: True if the property should be serialized (ex in yaml), False otherwise """ return self._serializable @property def restrictions(self) -> List[str]: """ :return: The attributes to check for restricting when the attribute can be set """ return self._restrictions @property def no_serialize_val(self) -> Any: """ :return: a value that if the prop is equal to, will not serialize the prop """ return self._no_serialize_val
[docs] def getter(self, func_get: Callable) -> BaseProp: """ Create a ModifierProp based off the current instance with the getter function :param func_get: the getter function :return: the recreated instance with the new getter function """ return type(self)( **self._creator_kwargs(), func_get=func_get, func_set=self._func_set, doc=self.__doc__, )
[docs] def setter(self, func_set: Callable) -> BaseProp: """ Create a ModifierProp based off the current instance with the setter function :param func_set: the setter function :return: the recreated instance with the new setter function """ return type(self)( **self._creator_kwargs(), func_get=self._func_get, func_set=func_set, doc=self.__doc__, )
def _creator_kwargs(self) -> Dict: return { "serializable": self._serializable, "restrict_initialized": False, "restrict_enabled": False, "restrict_extras": self._restrictions, }
[docs]class BaseObject(object): """ BaseObject to accept kwargs so multiple inheritance will work with the modifier classes. kwargs param must be empty by the time this class is called. :param kwargs: standard key word args, used to support multi inheritance """ def __init__(self, **kwargs): super().__init__() if len(kwargs) != 0: raise ValueError( ( "kwargs must be empty at _BaseObject, " "extras passed in of {} for {}" ).format(kwargs, self.__class__.__name__) )
[docs]class BaseModifier(BaseObject): """ Parent class meant to be used for all modifiers. Handles base implementations for properties and methods. :param log_types: the loggers that can be used by the modifier instance :param kwargs: standard key word args, used to support multi inheritance """ @staticmethod def _convert_to_framework_modifiers(yaml_str: str, framework: str) -> str: pattern = re.compile(r"!(?P<mod_class>(?!.*\.)[a-zA-Z_][a-zA-Z^._0-9]+)") yaml_str = pattern.sub(r"!{}.\g<mod_class>".format(framework), yaml_str) return yaml_str
[docs] @staticmethod def load_framework_list(yaml_str: str, framework: str): """ :param yaml_str: a string representation of the yaml syntax to load modifiers :param framework: the framework to load the modifiers for :return: the loaded modifiers list """ yaml_str = evaluate_recipe_yaml_str_equations(yaml_str) yaml_str = BaseModifier._convert_to_framework_modifiers(yaml_str, framework) container = yaml.safe_load(yaml_str) if isinstance(container, BaseModifier): modifiers = [container] elif isinstance(container, List): modifiers = container else: # Dict modifiers = [] for name, item in container.items(): if "modifiers" in name and isinstance(item, List): modifiers.extend(item) elif isinstance(item, BaseModifier): modifiers.append(item) elif isinstance(item, List) and any( isinstance(element, BaseModifier) for element in item ): modifier_type = type( [mod for mod in item if isinstance(mod, BaseModifier)][0] ) raise ValueError( "Invalid modifier location. Grouped modifiers in recipes must " "be listed in lists with 'modifiers' in its name. A modifier " f"of type {modifier_type} was found in recipe list {name}" ) return modifiers
[docs] @staticmethod def load_framework_obj(yaml_str: str, framework: str): """ :param yaml_str: a string representation of the yaml syntax to load a modifier :param framework: the framework to load the modifier for :return: the loaded modifier object """ yaml_str = BaseModifier._convert_to_framework_modifiers(yaml_str, framework) modifier = yaml.safe_load(yaml_str) return modifier
[docs] @staticmethod def yaml_key(clazz, framework: Union[str, None] = None): """ create a key for loading in yaml from the class and the framework :param clazz: the class representation to create the key for :param framework: the string representing the ML framework the modifier class is for. Default is None. :return: the formatted key; ex: !{framework}.{clazz.__name__} """ if framework is None: return "!{}".format(clazz.__name__) return "!{}.{}".format(framework, clazz.__name__)
[docs] @staticmethod def comparator(one, two) -> int: """ Comparator implementation for Modifiers. Compares first on end_epoch, next on start_epoch, and finally on identifier. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ # compare first on end epoch compare = BaseModifier.comparator_ends(one, two) if compare == 0: # if ends equal, compare next on start compare = BaseModifier.comparator_starts(one, two) if compare == 0: # if still equal, compare on identifier compare = BaseModifier.comparator_identifiers(one, two) return compare
[docs] @staticmethod def comparator_ends(one, two) -> int: """ Comparator implementation for Modifiers based on end_epoch. Modifiers with ends greater than another will come out higher. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ one_end = getattr(one, "end_epoch") if hasattr(one, "end_epoch") else None two_end = getattr(two, "end_epoch") if hasattr(two, "end_epoch") else None if one_end == two_end: return 0 if one_end is None or two_end == -1: # if no end for one and two has one or two never ends, return one before two return -1 if two_end is None or one_end == -1: # if no end for two and one has one or one never ends, return one after two return 1 if one_end < two_end: return -1 return 1
[docs] @staticmethod def comparator_starts(one, two) -> int: """ Comparator implementation for Modifiers based on start_epoch. Modifiers with starts greater than another will come out higher. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ one_start = getattr(one, "start_epoch") if hasattr(one, "start_epoch") else None two_start = getattr(two, "start_epoch") if hasattr(two, "start_epoch") else None if one_start == two_start: return 0 if one_start is None: # if no start for one and two has one, return one before two return -1 if two_start is None: # if no start for two and one has one, return one after two return 1 if one_start > two_start: # if one starts after two, return after return 1 return -1
[docs] @staticmethod def comparator_identifiers(one, two) -> int: """ Comparator implementation for Modifiers based on identifier. Modifiers with ends greater than another will come out higher. :param one: first modifier to compare :param two: second modifier to compare :return: int representing where one is in relation to two """ one_id = one.identifier() two_id = two.identifier() if one_id < two_id: return -1 if one_id > two_id: return 1 return 0
def __init__(self, log_types: Union[str, List[str]] = ALL_TOKEN, **kwargs): super().__init__(**kwargs) self._log_types = ( validate_str_iterable( log_types, "log_types for {}".format(self.__class__.__name__) ) if log_types else None ) self._initialized = False self._enabled = True def __str__(self): formatted = [ " {}".format("{}: {}".format(key, val)) for key, val in self.props(only_serializable=True, format_str=True).items() ] return "{}\n{}".format( BaseModifier.yaml_key(self.__class__), "\n".join(formatted) ) def __repr__(self): return "{}({})".format( self.__class__.__name__, self.props(only_serializable=False, format_repr=True), )
[docs] @ModifierProp(serializable=True) def log_types(self) -> Union[None, str, List[str]]: """ :return: the loggers that can be used by the modifier instance """ return self._log_types
[docs] @ModifierProp(serializable=False, restrict_initialized=False) def initialized(self) -> bool: """ :return: True if the modifier has gone through the initialized life cycle, False otherwise """ return self._initialized
@ModifierProp(serializable=False, restrict_initialized=False) def enabled(self) -> bool: """ :return: True if the modifier is currently enabled and making updates, False otherwise """ return self._enabled
[docs] @enabled.setter def enabled(self, value: bool): """ :param value: True to allow the modifier to make updates, False otherwise """ self._enabled = value
[docs] def identifier(self, extra: Any = "") -> str: """ :param extra: any extra identifier to append to the end of the string :return: generate an identifier for the current modifier based on its class name and params """ props = self.props(only_serializable=True, format_str=True) props_list = [f"{key}:{val}" for key, val in props.items()] props_list.sort() # convert to list and sort to make deterministic hash_str = hashlib.md5(str(props_list).encode()).hexdigest() return f"{self.__class__.__name__}-{hash_str}{f'-{extra}' if extra else ''}"
[docs] def props( self, only_serializable: bool, format_str: bool = False, format_repr: bool = False, ) -> Dict[str, Any]: """ Gather all the ModifierProps for the current instance into a dictionary collection. :param only_serializable: True if only props marked as serializable should be collected, False otherwise :param format_str: True to format the values properly for a str. Ex: None values are formatted to null and otherwise str is called :param format_repr: True to format the values properly for a repr. :return: the collected properties with names mapping to values """ if format_str and format_repr: raise ValueError( "only format_str or format_repr can be True, both are currently True" ) props = {} for attr in dir(self): if attr.startswith("_"): continue func = getattr(self.__class__, attr) if not isinstance(func, ModifierProp) or ( only_serializable and not func.serializable ): continue val = getattr(self, attr) no_serialize_val = func.no_serialize_val if val == no_serialize_val: continue if format_str: props[attr] = str(val) if val is not None else "null" elif format_repr: props[attr] = repr(val) else: props[attr] = val return props
[docs]class BaseScheduled(BaseObject): """ Abstract class meant to be used for all scheduled modifiers. :py:func `~Modifier` is also meant to be inherited alongside this class. Handles base implementations for scheduled properties and methods to allow a schedule to be enforced. :param start_epoch: the epoch to start the scheduled modifier at :param min_start: the minimum value start_epoch can be, otherwise will raise a ValueError :param end_epoch: the epoch to end the scheduled modifier at :param min_end: the minimum value end_epoch can be, otherwise will raise a ValueError :param end_comparator: integer value representing how the end_epoch should be compared to start_epoch. if == None, then end_epoch can only be set to what its initial value was. if == -1, then end_epoch can be less than, equal, or greater than start_epoch. if == 0, then end_epoch can be equal to or greater than start_epoch. if == 1, then end_epoch can only be greater than start_epoch. :param kwargs: standard key word args, used to support multi inheritance """ def __init__( self, start_epoch: float = -1.0, min_start: float = -1.0, end_epoch: float = -1.0, min_end: float = -1.0, end_comparator: Union[int, None] = 0, **kwargs, ): super().__init__(**kwargs) self._start_epoch = start_epoch self._init_start = start_epoch self._min_start = min_start self._end_epoch = end_epoch self._init_end = end_epoch self._min_end = min_end self._end_comparator = end_comparator self.validate_schedule() @ModifierProp() def start_epoch(self) -> float: """ :return: The epoch to start the modifier at (set to -1.0 so it starts immediately) """ return self._start_epoch
[docs] @start_epoch.setter def start_epoch(self, value: float): """ :param value: The epoch to start the modifier at (set to -1.0 so it starts immediately) """ self._start_epoch = value self.validate_schedule()
@ModifierProp() def end_epoch(self) -> float: """ :return: The epoch to end the modifier at (set to -1.0 so it never ends) """ return self._end_epoch
[docs] @end_epoch.setter def end_epoch(self, value: float): """ :param value: The epoch to end the modifier at (set to -1.0 so it never ends) """ self._end_epoch = value self.validate_schedule()
[docs] def validate_schedule(self): """ Validate the schedule values of the params for the current instance are valid """ if self._start_epoch < self._min_start: raise ValueError( "start_epoch of {} must be greater than or equal to {} for {}".format( self._start_epoch, self._min_start, self.__class__.__name__ ) ) if self._end_epoch < self._min_end: raise ValueError( "end_epoch of {} must be greater than or equal to {} for {}".format( self._end_epoch, self._min_end, self.__class__.__name__ ) ) if self._end_comparator is None and self._end_epoch != self._init_end: raise ValueError( "end_epoch of {} must be equal the init value of {} for {}".format( self._end_epoch, self._init_end, self.__class__.__name__ ) ) if self._end_comparator == 0 and self._start_epoch > self._end_epoch: raise ValueError( ( "end_epoch of {} must be greater than" " or equal to start_epoch of {} for {}" ).format(self._end_epoch, self._start_epoch, self.__class__.__name__) ) if self._end_comparator == 1 and self._start_epoch >= self._end_epoch: raise ValueError( "end_epoch of {} must be greater than start_epoch of {} for {}".format( self._end_epoch, self._start_epoch, self.__class__.__name__ ) )
[docs]class BaseUpdate(BaseObject): """ Abstract class meant to be used for all update modifiers. :py:func `~Modifier` and :py:func `~ScheduledModifier` are also meant to be inherited alongside this class. Handles base implementations for scheduled properties and methods to allow updates to be enforced. :param update_frequency: The number of epochs or fraction of epochs to update at between start and end :param min_frequency: The minimum acceptable value for update_frequency, default -1 :param kwargs: standard key word args, used to support multi inheritance """ def __init__(self, update_frequency: float, min_frequency: float, **kwargs): super().__init__(**kwargs) self._update_frequency = update_frequency self._min_frequency = min_frequency self.validate_update() @ModifierProp() def update_frequency(self) -> float: """ :return: The number of epochs or fraction of epochs to update at between start and end """ return self._update_frequency
[docs] @update_frequency.setter def update_frequency(self, value: float): """ :param value: The number of epochs or fraction of epochs to update at between start and end """ self._update_frequency = value self.validate_update()
[docs] def validate_update(self): """ Validate the update schedule values of the params for the current instance are valid """ if self._update_frequency < self._min_frequency: raise ValueError( ( "update_frequency of {} must be greater than or " "equal to {} for {}" ).format( self._update_frequency, self._min_frequency, self.__class__.__name__ ) )
[docs]class ModifierYAML(object): """ A decorator to handle making a modifier class YAML ready. IE it can be loaded in through the yaml plugin easily. :param framework: the string representing the ML framework the modifier should be stored under """ def __init__(self, framework: str): if not framework: raise ValueError("a framework is required") self._framework = framework def __call__(self, clazz): """ :param clazz: the class to create yaml constructors for :return: the class after yaml constructors have been added """ yaml_key = "{}".format(BaseModifier.yaml_key(clazz, self._framework)) def constructor(loader, node): instance = clazz.__new__(clazz) yield instance state = loader.construct_mapping(node, deep=True) instance.__init__(**state) yaml.add_constructor(yaml_key, constructor) yaml.add_constructor( yaml_key, constructor, yaml.SafeLoader, ) return clazz