Source code for sparseml.pytorch.models.classification.vgg

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
PyTorch VGG implementations.
Further info can be found in the paper `here <https://arxiv.org/abs/1409.1556>`__.
"""

from typing import List

from torch import Tensor
from torch.nn import (
    BatchNorm2d,
    Conv2d,
    Dropout,
    Linear,
    MaxPool2d,
    Module,
    Sequential,
    Sigmoid,
    Softmax,
    init,
)

from sparseml.pytorch.models.registry import ModelRegistry
from sparseml.pytorch.nn import ReLU


__all__ = [
    "VGG",
    "vgg11",
    "vgg11bn",
    "vgg13",
    "vgg13bn",
    "vgg16",
    "vgg16bn",
    "vgg19bn",
    "vgg19",
]


def _init_conv(conv: Conv2d):
    init.kaiming_normal_(conv.weight, mode="fan_out", nonlinearity="relu")

    if conv.bias is not None:
        init.constant_(conv.bias, 0)


def _init_batch_norm(norm: BatchNorm2d):
    init.constant_(norm.weight, 1.0)
    init.constant_(norm.bias, 0.0)


def _init_linear(linear: Linear):
    init.normal_(linear.weight, 0, 0.01)
    init.constant_(linear.bias, 0)


class _Block(Module):
    def __init__(self, in_channels: int, out_channels: int, batch_norm: bool):
        super().__init__()
        self.conv = Conv2d(
            in_channels, out_channels, kernel_size=3, padding=1, stride=1
        )
        self.bn = BatchNorm2d(out_channels) if batch_norm else None
        self.act = ReLU(num_channels=out_channels, inplace=True)

        self.initialize()

    def forward(self, inp: Tensor):
        out = self.conv(inp)

        if self.bn is not None:
            out = self.bn(out)

        out = self.act(out)

        return out

    def initialize(self):
        _init_conv(self.conv)

        if self.bn is not None:
            _init_batch_norm(self.bn)


class _Classifier(Module):
    def __init__(self, in_channels: int, num_classes: int, class_type: str = "single"):
        super().__init__()
        self.mlp = Sequential(
            Linear(in_channels * 7 * 7, 4096),
            Dropout(),
            ReLU(num_channels=4096, inplace=True),
            Linear(4096, 4096),
            Dropout(),
            ReLU(num_channels=4096, inplace=True),
            Linear(4096, num_classes),
        )

        if class_type == "single":
            self.softmax = Softmax(dim=1)
        elif class_type == "multi":
            self.softmax = Sigmoid()
        else:
            raise ValueError("unknown class_type given of {}".format(class_type))

    def forward(self, inp: Tensor):
        out = inp.view(inp.size(0), -1)
        logits = self.mlp(out)
        classes = self.softmax(logits)

        return logits, classes


class VGGSectionSettings(object):
    """
    Settings to describe how to put together a VGG architecture
    using user supplied configurations.

    :param num_blocks: the number of blocks to put in the section (conv [bn] relu)
    :param in_channels: the number of input channels to the section
    :param out_channels: the number of output channels from the section
    :param use_batchnorm: True to put batchnorm after each conv, False otherwise
    """

    def __init__(
        self, num_blocks: int, in_channels: int, out_channels: int, use_batchnorm: bool
    ):
        self.num_blocks = num_blocks
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_batchnorm = use_batchnorm


[docs]class VGG(Module): """ VGG implementation :param sec_settings: the settings for each section in the vgg model :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single """ def __init__( self, sec_settings: List[VGGSectionSettings], num_classes: int, class_type: str, ): super(VGG, self).__init__() self.sections = Sequential( *[VGG.create_section(settings) for settings in sec_settings] ) self.classifier = _Classifier( sec_settings[-1].out_channels, num_classes, class_type )
[docs] def forward(self, inp): out = self.sections(inp) logits, classes = self.classifier(out) return logits, classes
[docs] @staticmethod def create_section(settings: VGGSectionSettings) -> Sequential: blocks = [] in_channels = settings.in_channels for _ in range(settings.num_blocks): blocks.append( _Block(in_channels, settings.out_channels, settings.use_batchnorm) ) in_channels = settings.out_channels blocks.append(MaxPool2d(kernel_size=2, stride=2)) return Sequential(*blocks)
[docs]@ModelRegistry.register( key=["vgg11", "vgg_11", "vgg-11"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="11", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg11(num_classes: int = 1000, class_type: str = "single") -> VGG: """ Standard VGG 11; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=1, in_channels=3, out_channels=64, use_batchnorm=False ), VGGSectionSettings( num_blocks=1, in_channels=64, out_channels=128, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=128, out_channels=256, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=256, out_channels=512, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=512, out_channels=512, use_batchnorm=False ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )
[docs]@ModelRegistry.register( key=["vgg11bn", "vgg_11bn", "vgg-11bn"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="11_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg11bn(num_classes: int = 1000, class_type: str = "single") -> VGG: """ VGG 11 with batch norm added; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=1, in_channels=3, out_channels=64, use_batchnorm=True ), VGGSectionSettings( num_blocks=1, in_channels=64, out_channels=128, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=128, out_channels=256, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=256, out_channels=512, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=512, out_channels=512, use_batchnorm=True ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )
[docs]@ModelRegistry.register( key=["vgg13", "vgg_13", "vgg-13"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="13", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg13(num_classes: int = 1000, class_type: str = "single") -> VGG: """ Standard VGG 13; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=2, in_channels=3, out_channels=64, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=64, out_channels=128, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=128, out_channels=256, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=256, out_channels=512, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=512, out_channels=512, use_batchnorm=False ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )
[docs]@ModelRegistry.register( key=["vgg13bn", "vgg_13bn", "vgg-13bn"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="13_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg13bn(num_classes: int = 1000, class_type: str = "single") -> VGG: """ VGG 13 with batch norm added; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=2, in_channels=3, out_channels=64, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=64, out_channels=128, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=128, out_channels=256, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=256, out_channels=512, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=512, out_channels=512, use_batchnorm=True ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )
[docs]@ModelRegistry.register( key=["vgg16", "vgg_16", "vgg-16"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="16", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg16(num_classes: int = 1000, class_type: str = "single") -> VGG: """ Standard VGG 16; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=2, in_channels=3, out_channels=64, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=64, out_channels=128, use_batchnorm=False ), VGGSectionSettings( num_blocks=3, in_channels=128, out_channels=256, use_batchnorm=False ), VGGSectionSettings( num_blocks=3, in_channels=256, out_channels=512, use_batchnorm=False ), VGGSectionSettings( num_blocks=3, in_channels=512, out_channels=512, use_batchnorm=False ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )
[docs]@ModelRegistry.register( key=["vgg16bn", "vgg_16bn", "vgg-16bn"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="16_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg16bn(num_classes: int = 1000, class_type: str = "single") -> VGG: """ VGG 16 with batch norm added; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=2, in_channels=3, out_channels=64, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=64, out_channels=128, use_batchnorm=True ), VGGSectionSettings( num_blocks=3, in_channels=128, out_channels=256, use_batchnorm=True ), VGGSectionSettings( num_blocks=3, in_channels=256, out_channels=512, use_batchnorm=True ), VGGSectionSettings( num_blocks=3, in_channels=512, out_channels=512, use_batchnorm=True ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )
[docs]@ModelRegistry.register( key=["vgg19", "vgg_19", "vgg-19"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="19", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg19(num_classes: int = 1000, class_type: str = "single") -> VGG: """ Standard VGG 19; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=2, in_channels=3, out_channels=64, use_batchnorm=False ), VGGSectionSettings( num_blocks=2, in_channels=64, out_channels=128, use_batchnorm=False ), VGGSectionSettings( num_blocks=4, in_channels=128, out_channels=256, use_batchnorm=False ), VGGSectionSettings( num_blocks=4, in_channels=256, out_channels=512, use_batchnorm=False ), VGGSectionSettings( num_blocks=4, in_channels=512, out_channels=512, use_batchnorm=False ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )
[docs]@ModelRegistry.register( key=["vgg19bn", "vgg_19bn", "vgg-19bn"], input_shape=(3, 224, 224), domain="cv", sub_domain="classification", architecture="vgg", sub_architecture="19_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], ) def vgg19bn(num_classes: int = 1000, class_type: str = "single") -> VGG: """ VGG 19 with batch norm added; expected input shape is (B, 3, 224, 224) :param num_classes: the number of classes to classify :param class_type: one of [single, multi] to support multi class training; default single :return: The created MobileNet Module """ sec_settings = [ VGGSectionSettings( num_blocks=2, in_channels=3, out_channels=64, use_batchnorm=True ), VGGSectionSettings( num_blocks=2, in_channels=64, out_channels=128, use_batchnorm=True ), VGGSectionSettings( num_blocks=4, in_channels=128, out_channels=256, use_batchnorm=True ), VGGSectionSettings( num_blocks=4, in_channels=256, out_channels=512, use_batchnorm=True ), VGGSectionSettings( num_blocks=4, in_channels=512, out_channels=512, use_batchnorm=True ), ] return VGG( sec_settings=sec_settings, num_classes=num_classes, class_type=class_type )