Source code for sparseml.pytorch.models.classification.mnist

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Simple PyTorch implementations for the MNIST dataset.
"""

from torch import Tensor
from torch.nn import (
    AdaptiveAvgPool2d,
    BatchNorm2d,
    Conv2d,
    Linear,
    Module,
    Sequential,
    Sigmoid,
    Softmax,
)

from sparseml.pytorch.models.registry import ModelRegistry
from sparseml.pytorch.nn import ReLU


__all__ = ["MnistNet", "mnist_net"]


class _ConvBNRelu(Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        padding: int,
    ):
        super().__init__()
        self.conv = Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        )
        self.bn = BatchNorm2d(out_channels)
        self.act = ReLU(num_channels=out_channels, inplace=True)

    def forward(self, inp: Tensor):
        out = self.conv(inp)
        out = self.bn(out)
        out = self.act(out)

        return out


class _Classifier(Module):
    def __init__(self, in_channels: int, classes: int, class_type: str):
        super().__init__()
        self.avgpool = AdaptiveAvgPool2d(1)
        self.fc = Linear(in_channels, classes)

        if class_type == "single":
            self.softmax = Softmax(dim=1)
        elif class_type == "multi":
            self.softmax = Sigmoid()
        else:
            raise ValueError("unknown class_type given of {}".format(class_type))

    def forward(self, inp: Tensor):
        out = self.avgpool(inp)
        out = out.view(out.size(0), -1)
        logits = self.fc(out)
        classes = self.softmax(logits)

        return logits, classes


[docs]class MnistNet(Module): """ A simple convolutional model created for the MNIST dataset :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single """ def __init__( self, num_classes: int = 10, class_type: str = "single", ): super().__init__() self.blocks = Sequential( _ConvBNRelu( in_channels=1, out_channels=16, kernel_size=5, padding=2, stride=1 ), _ConvBNRelu( in_channels=16, out_channels=32, kernel_size=5, padding=2, stride=2 ), _ConvBNRelu( in_channels=32, out_channels=64, kernel_size=5, padding=2, stride=1 ), _ConvBNRelu( in_channels=64, out_channels=128, kernel_size=5, padding=2, stride=2 ), ) self.classifier = _Classifier( in_channels=128, classes=num_classes, class_type=class_type )
[docs] def forward(self, inp: Tensor): out = self.blocks(inp) logits, classes = self.classifier(out) return logits, classes
[docs]@ModelRegistry.register( key=["mnistnet"], input_shape=(1, 28, 28), domain="cv", sub_domain="classification", architecture="mnistnet", sub_architecture=None, default_dataset="mnist", default_desc="base", ) def mnist_net(num_classes: int = 10, class_type: str = "single") -> MnistNet: """ MnistNet implementation; expected input shape is (B, 1, 28, 28) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MnistNet Module """ return MnistNet(num_classes, class_type)