This page explains how to apply a recipe to a custom model. For more details on the concepts of pruning/quantization as well as how to create recipes, see Sparsifying a Model for SparseML Integrations.
In addition to supported integrations described on the prior page, SparseML is set to enable easy integration in custom training pipelines. This flexibility enables easy sparsification for any neural network architecture for custom models and use cases. Once SparseML is installed, the necessary code can be plugged into most PyTorch/Keras training pipelines with only a few lines of code.
This section requires SparseML Torchvision Install to run the Apply the Recipe section.
To enable sparsification of models with recipes, a few edits to the training pipeline code need to be made.
Specifically, a ScheduledModifierManager
instance is used to take over and inject the desired sparsification algorithms into the training process.
To do this properly in PyTorch, the ScheduledModifierManager
requires the instance of the model
to modify, the optimizer
used for training,
and the number of steps_per_epoch
to ensure algorithms are applied at the right time.
For the integration, the following code illustrates all that is needed:
1from sparseml.pytorch.optim import ScheduledModifierManager2manager = ScheduledModifierManager.from_yaml(recipe_path)3optimizer = manager.modify(model, optimizer, steps_per_epoch)45# your typical training loop, using model/optimizer as usual67manager.finalize(model)
Walking through this code:
ScheduledModifierManager
is imported from the SparseML Python package.ScheduledModifierManager
is created from a recipe stored as a local file or on the SparseZoo.ScheduledModifierManager
so that the recipe will be applied while training.
A wrapped instance of the training optimizer is returned.ScheduledModifierManager
to release all resources.A simple training example utilizing PyTorch and Torchvision with this SparseML integration is provided below:
1import torch2from torch.nn import Linear3from torch.utils.data import DataLoader4from torch.nn import CrossEntropyLoss5from torch.optim import SGD67from sparseml.pytorch.models import resnet508from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize9from sparseml.pytorch.optim import ScheduledModifierManager1011# Model creation12NUM_CLASSES = 10 # number of Imagenette classes13model = resnet50(pretrained=True, num_classes=NUM_CLASSES)1415# Dataset creation16batch_size = 6417train_dataset = ImagenetteDataset(train=True, dataset_size=ImagenetteSize.s320, image_size=224)18train_loader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=True, num_workers=8)1920# Device setup21device = "cuda" if torch.cuda.is_available() else "cpu"22model.to(device)2324# Loss setup25criterion = CrossEntropyLoss()26optimizer = SGD(model.parameters(), lr=10e-6, momentum=0.9)2728# Recipe - in this case, we pull down a recipe from the SparseZoo for ResNet-5029# This can be a be a path to a local file30recipe_path = "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95_quant-none?recipe_type=original"3132# SparseML Integration33manager = ScheduledModifierManager.from_yaml(recipe_path)34optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))3536# Training Loop37for epoch in range(manager.max_epochs):38 running_loss = 0.039 running_corrects = 0.040 for inputs, labels in train_loader:41 inputs = inputs.to(device)42 labels = labels.to(device)43 optimizer.zero_grad()44 with torch.set_grad_enabled(True):45 outputs, _ = model(inputs)46 loss = criterion(outputs, labels)47 _, preds = torch.max(outputs, 1)48 loss.backward()49 optimizer.step()50 running_loss += loss.item() * inputs.size(0)51 running_corrects += torch.sum(preds == labels.data)5253 epoch_loss = running_loss / len(train_loader.dataset)54 epoch_acc = running_corrects.double() / len(train_loader.dataset)55 print("Training Loss: {:.4f} Acc: {:.4f}".format(epoch_loss, epoch_acc))5657manager.finalize(model)
To dive into the details of this recipe and how to edit it, visit Supported Integrations. The resulting recipe is included here for easy integration and testing.
1modifiers:2 - !GlobalMagnitudePruningModifier3 init_sparsity: 0.054 final_sparsity: 0.85 start_epoch: 0.06 end_epoch: 30.07 update_frequency: 1.08 params: __ALL_PRUNABLE__910 - !SetLearningRateModifier11 start_epoch: 0.012 learning_rate: 0.051314 - !LearningRateFunctionModifier15 start_epoch: 30.016 end_epoch: 50.017 lr_func: cosine18 init_lr: 0.0519 final_lr: 0.0012021 - !QuantizationModifier22 start_epoch: 50.023 freeze_bn_stats_epoch: 53.02425 - !SetLearningRateModifier26 start_epoch: 50.027 learning_rate: 10e-62829 - !EpochRangeModifier30 start_epoch: 0.031 end_epoch: 55.0
The pipeline is ready to sparsify a model with the integration and recipe setup.
To begin sparsifying, save the recipe as a local file called recipe.yaml
.
Next, pass in the path to the recipe to the training script for the recipe_path
argument for the ScheduledModifierManager.from_yaml(recipe_path)
line.
With that completed, start the training pipeline, and the result will be a sparsified model.