# Source code for sparseml.pytorch.utils.loss

```
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code related to convenience functions for controlling the calculation of losses and
metrics.
Additionally adds in support for knowledge distillation
"""
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as TF
from torch import Tensor
from torch.nn import Module
from sparseml.pytorch.utils.helpers import tensors_module_forward
from sparseml.pytorch.utils.yolo_helpers import (
box_giou,
build_targets,
get_output_grid_shapes,
yolo_v3_anchor_groups,
)
__all__ = [
"TEACHER_LOSS_KEY",
"DEFAULT_LOSS_KEY",
"LossWrapper",
"BinaryCrossEntropyLossWrapper",
"CrossEntropyLossWrapper",
"InceptionCrossEntropyLossWrapper",
"KDSettings",
"KDLossWrapper",
"SSDLossWrapper",
"YoloLossWrapper",
"Accuracy",
"TopKAccuracy",
]
TEACHER_LOSS_KEY = "__teacher_loss__"
DEFAULT_LOSS_KEY = "__loss__"
[docs]class LossWrapper(object):
"""
Generic loss class for controlling how to feed inputs and compare
with predictions for standard loss functions and metrics.
:param loss_fn: the loss function to calculate on forward call of this object,
accessible in the returned Dict at DEFAULT_LOSS_KEY
:param extras: extras representing other metrics that should be calculated
in addition to the loss
:param deconstruct_tensors: True to break the tensors up into expected
predictions and labels, False to pass the tensors as is to loss and extras
"""
def __init__(
self,
loss_fn: Callable[[Any, Any], Tensor],
extras: Union[None, Dict[str, Callable]] = None,
deconstruct_tensors: bool = True,
):
super(LossWrapper, self).__init__()
self._loss_fn = loss_fn
self._extras = extras
self._deconstruct_tensors = deconstruct_tensors
def __call__(self, data: Any, pred: Any) -> Dict[str, Tensor]:
return self.forward(data, pred)
def __repr__(self):
def _create_repr(_obj: Any) -> str:
if hasattr(_obj, "__name__"):
return _obj.__name__
if hasattr(_obj, "__class__"):
return _obj.__class__.__name__
return str(_obj)
extras = (
[_create_repr(extra) for extra in self._extras.values()]
if self._extras is not None
else []
)
return "{}(Loss: {}; Extras: {})".format(
self.__class__.__name__, _create_repr(self._loss_fn), ",".join(extras)
)
@property
def available_losses(self) -> Tuple[str, ...]:
"""
:return: a collection of all the loss and metrics keys available
for this instance
"""
return (DEFAULT_LOSS_KEY, *list(self._extras.keys()))
[docs] def forward(self, data: Any, pred: Any) -> Dict[str, Tensor]:
"""
:param data: the input data to the model, expected to contain the labels
:param pred: the predicted output from the model
:return: a dictionary containing all calculated losses and metrics with
the loss from the loss_fn at DEFAULT_LOSS_KEY
"""
calculated = {
DEFAULT_LOSS_KEY: self._loss_fn(
self.get_preds(data, pred, DEFAULT_LOSS_KEY),
self.get_labels(data, pred, DEFAULT_LOSS_KEY),
)
}
if self._extras:
for extra, func in self._extras.items():
calculated[extra] = func(
self.get_preds(data, pred, extra),
self.get_labels(data, pred, extra),
)
return calculated
[docs] def get_preds(self, data: Any, pred: Any, name: str) -> Any:
"""
overridable function that is responsible for extracting the predictions
from a model's output
:param data: data from a data loader
:param pred: the prediction from the model, if it is a tensor returns this,
if it is an iterable returns first
:param name: the name of the loss function that is asking for the
information for calculation
:return: the predictions from the model for the loss function
"""
if isinstance(pred, Tensor) or not self._deconstruct_tensors:
return pred
# assume that the desired prediction for loss is in the first instance
if isinstance(pred, Iterable):
for tens in pred:
return tens
raise TypeError(
"unsupported type of pred given of {}".format(pred.__class__.__name__)
)
[docs] def get_labels(self, data: Any, pred: Any, name: str) -> Any:
"""
overridable function that is responsible for extracting the labels
for the loss calculation from the input data to the model
:param data: data from a data loader, expected to contain a tuple of
(features, labels)
:param pred: the predicted output from a model
:param name: the name of the loss function that is asking for the
information for calculation
:return: the label for the data
"""
if isinstance(data, Iterable) and not isinstance(data, Tensor):
labels = None
for tens in data:
labels = tens
if labels is not None:
return labels
raise TypeError(
"unsupported type of data given of {}".format(data.__class__.__name__)
)
[docs]class BinaryCrossEntropyLossWrapper(LossWrapper):
"""
Convenience class for doing binary cross entropy loss calculations,
ie the default loss function is TF.binary_cross_entropy_with_logits.
:param extras: extras representing other metrics that should be calculated
in addition to the loss
"""
def __init__(
self,
extras: Union[None, Dict] = None,
):
super().__init__(
TF.binary_cross_entropy_with_logits,
extras,
)
[docs]class CrossEntropyLossWrapper(LossWrapper):
"""
Convenience class for doing cross entropy loss calculations,
ie the default loss function is TF.cross_entropy.
:param extras: extras representing other metrics that should be calculated
in addition to the loss
"""
def __init__(
self,
extras: Union[None, Dict] = None,
):
super().__init__(TF.cross_entropy, extras)
[docs]class InceptionCrossEntropyLossWrapper(LossWrapper):
"""
Loss wrapper for training an inception model that has an aux output
with cross entropy.
Defines the loss in the following way:
aux_weight * cross_entropy(aux_pred, lab) + cross_entropy(pred, lab)
Additionally adds cross_entropy into the extras.
:param extras: extras representing other metrics that should be calculated
in addition to the loss
:param aux_weight: the weight to use for the cross_entropy value
calculated from the aux output
"""
def __init__(
self,
extras: Union[None, Dict] = None,
aux_weight: float = 0.4,
):
if extras is None:
extras = {}
extras["cross_entropy"] = TF.cross_entropy
self._aux_weight = aux_weight
super().__init__(self.loss, extras)
[docs] def loss(self, preds: Tuple[Tensor, Tensor], labels: Tensor):
"""
Loss function for inception to combine the overall outputs from the model
along with the the auxiliary loss from an earlier point in the model
:param preds: the predictions tuple containing [aux output, output]
:param labels: the labels to compare to
:return: the combined cross entropy value
"""
aux_loss = TF.cross_entropy(preds[0], labels)
loss = TF.cross_entropy(preds[1], labels)
return loss + self._aux_weight * aux_loss
[docs] def get_preds(
self, data: Any, pred: Tuple[Tensor, Tensor, Tensor], name: str
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""
Override get_preds for the inception training output.
Specifically expects the pred from the model to be a three tensor tuple:
(aux logits, logits, classes)
For the loss function returns a tuple containing (aux logits, logits),
for all other extras returns the logits tensor
:param data: data from a data loader
:param pred: the prediction from an inception model,
expected to be a tuple containing (aux logits, logits, classes)
:param name: the name of the loss function that is asking for the
information for calculation
:return: the predictions from the model for the loss function;
a tuple containing (aux logits, logits),
for all other extras returns the logits tensor
"""
if name == DEFAULT_LOSS_KEY:
return pred[0], pred[1] # return aux, logits for loss function
return pred[1] # return logits for other calculations
[docs]class KDSettings(object):
"""
properties class for settings for applying knowledge distillation as
part of the loss calculation.
:param teacher: the teacher that provides targets for the student to learn from
:param temp_student: temperature coefficient for the student
:param temp_teacher: temperature coefficient for the teacher
:param weight: the weight for how much of the kd loss to use in proportion
with the original loss
:param contradict_hinton: in hinton's original paper they included T^2
as a scaling factor some implementations dropped this factor
so contradicting hinton does not scale by T^2
"""
def __init__(
self,
teacher: Module,
temp_student: float = 5.0,
temp_teacher: float = 5.0,
weight: float = 0.5,
contradict_hinton: bool = False,
):
self._teacher = teacher
self._temp_student = temp_student
self._temp_teacher = temp_teacher
self._weight = weight
self._contradict_hinton = contradict_hinton
@property
def teacher(self) -> Module:
"""
:return: the teacher that provides targets for the student to learn from
"""
return self._teacher
@property
def temp_student(self) -> float:
"""
:return: temperature coefficient for the student
"""
return self._temp_student
@property
def temp_teacher(self) -> float:
"""
:return: temperature coefficient for the teacher
"""
return self._temp_teacher
@property
def weight(self) -> float:
"""
:return: the weight for how much of the kd loss to use in proportion
with the original loss
"""
return self._weight
@property
def contradict_hinton(self) -> bool:
"""
:return: in hinton's original paper they included T^2 as a scaling factor
some implementations dropped this factor so contradicting hinton
does not scale by T^2
"""
return self._contradict_hinton
[docs]class KDLossWrapper(LossWrapper):
"""
Special case of the loss wrapper that allows knowledge distillation.
Makes some assumptions specifically for image classification tasks,
so may not work out of the box for everything.
:param loss_fn: the loss function to calculate on forward call of this object,
accessible in the returned Dict at DEFAULT_LOSS_KEY
:param extras: extras representing other metrics that should be
calculated in addition to the loss
:param deconstruct_tensors: True to break the tensors up into expected
predictions and labels, False to pass the tensors as is to loss and extras
:param kd_settings: the knowledge distillation settings that guide
how to calculate the total loss
"""
def __init__(
self,
loss_fn: Callable[[Any, Any], Tensor],
extras: Union[None, Dict[str, Callable]] = None,
deconstruct_tensors: bool = True,
kd_settings: Union[None, KDSettings] = None,
):
super().__init__(loss_fn, extras, deconstruct_tensors)
self._kd_settings = kd_settings # type: KDSettings
[docs] def get_inputs(self, data: Any, pred: Any, name: str) -> Any:
"""
overridable function that is responsible for extracting the inputs to the model
from the input data to the model and the output from the model
:param data: data from a data loader, expected to contain a tuple of
(features, labels)
:param pred: the predicted output from a model
:param name: the name of the loss function that is asking for the information
for calculation
:return: the input data for the model
"""
if isinstance(data, Tensor):
return data
if isinstance(data, Iterable):
for tens in data:
return tens
raise TypeError(
"unsupported type of data given of {}".format(data.__class__.__name__)
)
[docs] def forward(self, data: Any, pred: Any) -> Dict[str, Tensor]:
"""
override to calculate the knowledge distillation loss if kd_settings
is supplied and not None
:param data: the input data to the model, expected to contain the labels
:param pred: the predicted output from the model
:return: a dictionary containing all calculated losses and metrics with
the loss from the loss_fn at DEFAULT_LOSS_KEY
"""
losses = super().forward(data, pred)
if self._kd_settings is not None:
with torch.no_grad():
teacher = self._kd_settings.teacher # type: Module
preds_teacher = tensors_module_forward(
self.get_inputs(data, pred, TEACHER_LOSS_KEY), teacher.eval()
)
preds_teacher = self.get_preds(data, preds_teacher, TEACHER_LOSS_KEY)
soft_log_probs = TF.log_softmax(
self.get_preds(data, pred, DEFAULT_LOSS_KEY)
/ self._kd_settings.temp_student,
dim=1,
)
soft_targets = TF.softmax(
preds_teacher / self._kd_settings.temp_teacher, dim=1
)
distill_loss = (
TF.kl_div(soft_log_probs, soft_targets, size_average=False)
/ soft_targets.shape[0]
)
if not self._kd_settings.contradict_hinton:
# in hinton's original paper they included T^2 as a scaling factor
# some implementations dropped this factor
# so contradicting hinton does not scale by T^2
distill_loss = (
(self._kd_settings.temp_student + self._kd_settings.temp_teacher)
/ 2
) ** 2 * distill_loss
losses[DEFAULT_LOSS_KEY] = (
self._kd_settings.weight * distill_loss
+ (1 - self._kd_settings.weight) * losses[DEFAULT_LOSS_KEY]
)
return losses
[docs]class SSDLossWrapper(LossWrapper):
"""
Loss wrapper for SSD models. Implements the loss as the sum of:
1. Confidence Loss: All labels, with hard negative mining
2. Localization Loss: Only on positive labels
:param extras: extras representing other metrics that should be
calculated in addition to the loss
"""
def __init__(
self,
extras: Union[None, Dict] = None,
):
if extras is None:
extras = {}
self._localization_loss = nn.SmoothL1Loss(reduction="none")
self._confidence_loss = nn.CrossEntropyLoss(reduction="none")
super().__init__(self.loss, extras)
[docs] def loss(self, preds: Tuple[Tensor, Tensor], labels: Tuple[Tensor, Tensor, Tensor]):
"""
Calculates the loss for a multibox SSD output as the sum of the confidence
and localization loss for the positive samples in the predictor with hard
negative mining.
:param preds: the predictions tuple containing [predicted_boxes,
predicted_lables].
:param labels: the labels to compare to
:return: the combined location and confidence losses
"""
# extract predicted and ground truth boxes / labels
predicted_boxes, predicted_scores = preds
ground_boxes, ground_labels, _ = labels
# create positive label mask and count positive samples
positive_mask = ground_labels > 0
num_pos_labels = positive_mask.sum(dim=1) # shape: BATCH_SIZE,1
# sum loss on localization values, and mask out negative results
loc_loss = self._localization_loss(predicted_boxes, ground_boxes).sum(dim=1)
loc_loss = (positive_mask.float() * loc_loss).sum(dim=1)
# confidence loss with hard negative mining
con_loss_init = self._confidence_loss(predicted_scores, ground_labels)
# create mask to select 3 negative samples for every positive sample per image
con_loss_neg_vals = con_loss_init.clone()
con_loss_neg_vals[positive_mask] = 0 # clear positive sample values
_, neg_sample_sorted_idx = con_loss_neg_vals.sort(dim=1, descending=True)
_, neg_sample_rank = neg_sample_sorted_idx.sort(dim=1) # ascending value rank
neg_threshold = torch.clamp( # set threshold to 3x number of positive samples
3 * num_pos_labels, max=positive_mask.size(1)
).unsqueeze(-1)
neg_mask = neg_sample_rank < neg_threshold # select samples with highest loss
con_loss = (con_loss_init * (positive_mask.float() + neg_mask.float())).sum(
dim=1
)
# take average total loss over number of positive samples
# sets loss to 0 for images with no objects
total_loss = loc_loss + con_loss
pos_label_mask = (num_pos_labels > 0).float()
num_pos_labels = num_pos_labels.float().clamp(min=1e-6)
return (total_loss * pos_label_mask / num_pos_labels).mean(dim=0)
[docs] def get_preds(
self, data: Any, pred: Tuple[Tensor, Tensor], name: str
) -> Tuple[Tensor, Tensor]:
"""
Override get_preds for SSD model output.
:param data: data from a data loader
:param pred: the prediction from an ssd model: two tensors
representing object location and object label respectively
:param name: the name of the loss function that is asking for the
information for calculation
:return: the predictions from the model without any changes
"""
return pred[0], pred[1] # predicted locations, predicted labels
[docs]class YoloLossWrapper(LossWrapper):
"""
Loss wrapper for Yolo models. Implements the loss as a sum of class loss,
objectness loss, and GIoU
:param extras: extras representing other metrics that should be
calculated in addition to the loss
:param anchor_groups: List of n,2 tensors of the Yolo model's anchor points
for each output group
"""
def __init__(
self,
extras: Union[None, Dict] = None,
anchor_groups: List[Tensor] = None,
):
if extras is None:
extras = {}
self.anchor_groups = anchor_groups or yolo_v3_anchor_groups()
self.class_loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([1.0]))
self.obj_loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([1.0]))
super().__init__(self.loss, extras)
[docs] def loss(self, preds: List[Tensor], labels: Tuple[Tensor, Tensor]):
"""
Calculates the loss for a Yolo model output as the sum of the box, object,
and class losses
:param preds: the predictions list containing objectness, class, and location
values for each detector in the Yolo model.
:param labels: the labels to compare to
:return: the combined box, object, and class losses
"""
targets, _ = labels
grid_shapes = get_output_grid_shapes(preds)
target_classes, target_boxes, target_indices, anchors = build_targets(
targets, self.anchor_groups, grid_shapes
)
device = targets.device
self.class_loss_fn = self.class_loss_fn.to(device)
self.obj_loss_fn = self.obj_loss_fn.to(device)
class_loss = torch.zeros(1, device=device)
box_loss = torch.zeros(1, device=device)
object_loss = torch.zeros(1, device=device)
num_targets = 0
object_loss_balance = [4.0, 1.0, 0.4, 0.1] # usually only first 3 used
for i, pred in enumerate(preds):
image, anchor, grid_x, grid_y = target_indices[i]
target_object = torch.zeros_like(pred[..., 0], device=device)
if image.shape[0]:
num_targets += image.shape[0]
# filter for predictions on actual objects
predictions = pred[image, anchor, grid_x, grid_y]
# box loss
predictions_xy = predictions[:, :2].sigmoid() * 2.0 - 0.5
predictions_wh = (predictions[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
predictions_box = torch.cat((predictions_xy, predictions_wh), 1).to(
device
)
giou = box_giou(predictions_box.T, target_boxes[i].T)
box_loss += (1.0 - giou).mean()
# giou
target_object[image, anchor, grid_x, grid_y] = (
giou.detach().clamp(0).type(target_object.dtype)
)
# class loss
target_class_mask = torch.zeros_like(predictions[:, 5:], device=device)
target_class_mask[range(image.shape[0]), target_classes[i]] = 1.0
class_loss += self.class_loss_fn(predictions[:, 5:], target_class_mask)
object_loss += (
self.obj_loss_fn(pred[..., 4], target_object) * object_loss_balance[i]
)
# scale losses
box_loss *= 0.05
class_loss *= 0.5
batch_size = preds[0].shape[0]
loss = batch_size * (box_loss + object_loss + class_loss)
return loss, box_loss, object_loss, class_loss
[docs] def forward(self, data: Any, pred: Any) -> Dict[str, Tensor]:
"""
:param data: the input data to the model, expected to contain the labels
:param pred: the predicted output from the model
:return: a dictionary containing all calculated losses (default, giou,
object, and class) and metrics with the loss from the loss_fn at
DEFAULT_LOSS_KEY
"""
loss, box_loss, object_loss, class_loss = self._loss_fn(
self.get_preds(data, pred, DEFAULT_LOSS_KEY),
self.get_labels(data, pred, DEFAULT_LOSS_KEY),
)
calculated = {
DEFAULT_LOSS_KEY: loss,
"giou": box_loss,
"object": object_loss,
"classification": class_loss,
}
if self._extras:
for extra, func in self._extras.items():
calculated[extra] = func(
self.get_preds(data, pred, extra),
self.get_labels(data, pred, extra),
)
return calculated
[docs] def get_preds(self, data: Any, pred: List[Tensor], name: str) -> List[Tensor]:
"""
Override get_preds for SSD model output.
:param data: data from a data loader
:param pred: the prediction from an ssd model: two tensors
representing object location and object label respectively
:param name: the name of the loss function that is asking for the
information for calculation
:return: the predictions from the model without any changes
"""
return pred
[docs]class Accuracy(Module):
"""
Class for calculating the accuracy for a given prediction and the labels
for comparison.
Expects the inputs to be from a range of 0 to 1 and sets a crossing threshold at 0.5
the labels are similarly rounded.
"""
[docs] def forward(self, pred: Tensor, lab: Tensor) -> Tensor:
"""
:param pred: the models prediction to compare with
:param lab: the labels for the data to compare to
:return: the calculated accuracy
"""
return Accuracy.calculate(pred, lab)
[docs] @staticmethod
def calculate(pred: Tensor, lab: Tensor):
"""
:param pred: the models prediction to compare with
:param lab: the labels for the data to compare to
:return: the calculated accuracy
"""
pred = pred >= 0.5
lab = lab >= 0.5
correct = (pred == lab).sum()
total = lab.numel()
acc = correct.float() / total * 100.0
return acc
[docs]class TopKAccuracy(Module):
"""
Class for calculating the top k accuracy for a given prediction and the labels for
comparison; ie the top1 or top5 accuracy. top1 is equivalent to the Accuracy class
:param topk: the numbers of buckets the model is considered to be correct within
"""
def __init__(self, topk: int = 1):
super(TopKAccuracy, self).__init__()
self._topk = topk
[docs] def forward(self, pred: Tensor, lab: Tensor) -> Tensor:
"""
:param pred: the models prediction to compare with
:param lab: the labels for the data to compare to
:return: the calculated topk accuracy
"""
return TopKAccuracy.calculate(pred, lab, self._topk)
[docs] @staticmethod
def calculate(pred: Tensor, lab: Tensor, topk: int):
"""
:param pred: the models prediction to compare with
:param lab: the labels for the data to compare to
:param topk: the number of bins to be within for the correct label
:return: the calculated topk accuracy
"""
with torch.no_grad():
batch_size = lab.size(0)
_, pred = pred.topk(topk, 1, True, True)
pred = pred.t()
correct = pred.eq(lab.view(1, -1).expand_as(pred))
correct_k = (
correct[:topk].contiguous().view(-1).float().sum(0, keepdim=True)
)
correct_k = correct_k.mul_(100.0 / batch_size)
return correct_k
```