Source code for sparseml.pytorch.datasets.generic

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Tuple, Union

import torch
from torch.utils.data import Dataset


__all__ = [
    "EarlyStopDataset",
    "NoisyDataset",
    "RandNDataset",
    "CacheableDataset",
]


[docs]class EarlyStopDataset(Dataset): """ Dataset that handles applying an early stop when iterating through the dataset IE will allow indexing between [0, early_stop) :param original: the original dataset to apply an early stop to :param early_stop: the total number of data items to run through, if -1 then will go through whole dataset """ def __init__(self, original: Dataset, early_stop: int): self._original = original self._early_stop = early_stop if self._early_stop > len(self._original): raise ValueError( ( "Cannot apply early stop of {}, " "its greater than length of dataset {}" ).format(self._early_stop, len(self._original)) ) def __getitem__(self, index): return self._original.__getitem__(index) def __len__(self): return self._early_stop if self._early_stop > 0 else self._original.__len__() def __repr__(self): rep = self._original.__str__() rep = re.sub( r"Number of datapoints:[ ]+[0-9]+", "Number of datapoints: {}".format(self.__len__()), rep, ) return rep
[docs]class NoisyDataset(Dataset): """ Add random noise from a standard distribution mean(0) and stdev(intensity) on top of a dataset :param original: the dataset to add noise on top of :param intensity: the level of noise to add (creates the noise with this standard deviation) """ def __init__(self, original: Dataset, intensity: float): self._original = original self._intensity = intensity def __getitem__(self, index): x_tens, y_tens = self._original.__getitem__(index) noise = torch.zeros(x_tens.size()).normal_(mean=0, std=self._intensity) x_tens += noise return x_tens, y_tens def __len__(self): return self._original.__len__()
[docs]class RandNDataset(Dataset): """ Generates a random dataset :param length: the number of random items to create in the dataset :param shape: the shape of the data to create :param normalize: Normalize the data according to imagenet distribution (shape must match 3,x,x) """ def __init__( self, length: int, shape: Union[int, Tuple[int, ...]], normalize: bool ): if isinstance(shape, int): shape = (3, shape, shape) self._data = [] for _ in range(length): tens = torch.randn(*shape) self._data.append(tens) def __getitem__(self, index): return self._data[index], torch.tensor(1) def __len__(self): return len(self._data)
[docs]class CacheableDataset(Dataset): """ Generates a cacheable dataset, ie stores the data in a cache in cpu memory so it doesn't have to be loaded from disk every time. Note, this can only be used with a data loader that has num_workers=0 :param original: the original dataset to cache """ def __init__(self, original: Dataset): self._original = original self._cache = {} def __getitem__(self, index): if index not in self._cache: self._cache[index] = self._original[index] return self._cache[index] def __len__(self): return self._original.__len__()