Source code for sparseml.keras.utils.callbacks

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Built-in callbacks for Keras
"""

from typing import List, Union

from tensorflow import Tensor

from sparseml.keras.utils.compat import keras
from sparseml.keras.utils.logger import KerasLogger, LoggingMode


__all__ = [
    "LoggerSettingCallback",
    "LossesAndMetricsLoggingCallback",
]


[docs]class LoggerSettingCallback(keras.callbacks.Callback): """ Class to help correctly set logging modes for callbacks that rely on KerasLogger. All callbacks using KerasLogger should derive from this class. :param loggers: logger or list of loggers """ def __init__(self, loggers: Union[KerasLogger, List[KerasLogger]]): self._loggers = loggers if isinstance(loggers, list) else [loggers]
[docs] def on_epoch_begin(self, epoch, logs=None): """ Called at the begin of a training epoch :param epoch: epoch index :param logs: dictionary of logs (see Keras Callback doc) """ super().on_epoch_begin(epoch, logs) self._set_logging_mode(LoggingMode.TRAIN)
[docs] def on_epoch_end(self, epoch, logs=None): """ Called at the end of a training epoch :param epoch: epoch index :param logs: dictionary of logs (see Keras Callback doc) """ super().on_epoch_end(epoch, logs) self._set_logging_mode(LoggingMode.TRAIN)
[docs] def on_predict_batch_begin(self, batch, logs=None): """ Called at the begin of a batch in prediction :param batch: batch index in current epoch :param logs: dictionary of logs (see Keras Callback doc) """ super().on_predict_batch_begin(batch, logs) self._set_logging_mode(LoggingMode.PREDICT)
[docs] def on_predict_batch_end(self, batch, logs=None): """ Called at the end of a batch in prediction :param batch: batch index in current epoch :param logs: dictionary of logs (see Keras Callback doc) """ super().on_predict_batch_end(batch, logs) self._set_logging_mode(LoggingMode.PREDICT)
[docs] def on_predict_begin(self, logs=None): """ Called at the begin of prediction :param logs: dictionary of logs (see Keras Callback doc) """ super().on_predict_begin(logs) self._set_logging_mode(LoggingMode.PREDICT)
[docs] def on_predict_end(self, logs=None): """ Called at the end of prediction :param logs: dictionary of logs (see Keras Callback doc) """ super().on_predict_end(logs) self._set_logging_mode(LoggingMode.PREDICT)
[docs] def on_test_batch_begin(self, batch, logs=None): """ Called at the begin of a batch in evaluation :param batch: batch index in current epoch :param logs: dictionary of logs (see Keras Callback doc) """ super().on_test_batch_begin(batch, logs) self._set_logging_mode(LoggingMode.TEST)
[docs] def on_test_batch_end(self, batch, logs=None): """ Called at the end of a batch in evaluation :param batch: batch index in current epoch :param logs: dictionary of logs (see Keras Callback doc) """ super().on_test_batch_end(batch, logs) self._set_logging_mode(LoggingMode.TEST)
[docs] def on_test_begin(self, logs=None): """ Called at the begin of evaluation :param logs: dictionary of logs (see Keras Callback doc) """ super().on_test_begin(logs) self._set_logging_mode(LoggingMode.TEST)
[docs] def on_test_end(self, logs=None): """ Called at the end of evaluation :param logs: dictionary of logs (see Keras Callback doc) """ super().on_test_end(logs) self._set_logging_mode(LoggingMode.TEST)
[docs] def on_train_batch_begin(self, batch, logs=None): """ Called at the begin of a batch in training :param batch: batch index in current epoch :param logs: dictionary of logs (see Keras Callback doc) """ super().on_train_batch_begin(batch, logs) self._set_logging_mode(LoggingMode.TRAIN)
[docs] def on_train_batch_end(self, batch, logs=None): """ Called at the end of a batch in training :param batch: batch index in current epoch :param logs: dictionary of logs (see Keras Callback doc) """ super().on_train_batch_end(batch, logs) self._set_logging_mode(LoggingMode.TRAIN)
[docs] def on_train_begin(self, logs=None): """ Called at the begin of training :param logs: dictionary of logs (see Keras Callback doc) """ super().on_train_begin(logs) self._set_logging_mode(LoggingMode.TRAIN)
[docs] def on_train_end(self, logs=None): """ Called at the end of training :param logs: dictionary of logs (see Keras Callback doc) """ super().on_train_end(logs) self._set_logging_mode(LoggingMode.TRAIN)
def _set_logging_mode(self, mode: LoggingMode): for logger in self._loggers: logger.mode = mode
[docs]class LossesAndMetricsLoggingCallback(LoggerSettingCallback): """ Callback to log all losses and metrics :param loggers: logger or list of loggers :param start_step: a start step tensor when this callback starts to take effect """ def __init__( self, loggers: Union[KerasLogger, List[KerasLogger]], start_step: Union[Tensor, int] = 0, ): super().__init__(loggers) self._start_step = start_step self._step = None
[docs] def on_train_begin(self, logs=None): """ Called at the begin of training :param logs: dictionary of logs (see Keras Callback doc) """ super().on_train_begin(logs) self._step = keras.backend.get_value(self._start_step)
[docs] def on_epoch_end(self, epoch, logs=None): """ Called at the end of a training epoch :param epoch: epoch index :param logs: dictionary of logs (see Keras Callback doc) """ super().on_epoch_end(epoch, logs) if logs is None: return for logger in self._loggers: assert logger.mode == LoggingMode.TRAIN for tag, value in logs.items(): logger.log_scalar("epoch_{}".format(tag), value, step=epoch)
[docs] def on_train_batch_end(self, batch, logs=None): """ Called at the end of a batch in training :param batch: batch index in current epoch :param logs: dictionary of logs (see Keras Callback doc) """ super().on_train_batch_end(batch, logs) if logs is None: return for logger in self._loggers: assert logger.mode == LoggingMode.TRAIN if logger.update_freq == "batch" or ( isinstance(logger.update_freq, int) and self._step % logger.update_freq == 0 ): for tag, value in logs.items(): logger.log_scalar("batch_{}".format(tag), value, step=self._step) self._step += 1
[docs] def on_test_end(self, logs=None): """ Called at the end of evaluation :param logs: dictionary of logs (see Keras Callback doc) """ super().on_test_end(logs) if logs is None: return for logger in self._loggers: assert logger.mode == LoggingMode.TEST for tag, value in logs.items(): logger.log_scalar("val_{}".format(tag), value, step=self._step)