# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

General dataset implementations for Keras

from abc import ABCMeta, abstractmethod

import tensorflow

__all__ = [

[docs]class Dataset(metaclass=ABCMeta): """ Generic dataset implementation for Keras. Expected to work with the APIs """ @abstractmethod def __len__(self): raise NotImplementedError()
[docs] def build( self, batch_size: int, repeat_count: int = None, shuffle_buffer_size: int = None, prefetch_buffer_size: int = None, num_parallel_calls: int = None, ) -> """ Create the dataset in the current graph using APIs :param batch_size: the batch size to create the dataset for :param repeat_count: the number of times to repeat the dataset, if unset or None, will repeat indefinitely :param shuffle_buffer_size: None if not shuffling, otherwise the size of the buffer to use for shuffling data :param prefetch_buffer_size: None if not prefetching, otherwise the size of the buffer to use for buffering :param num_parallel_calls: the number of parallel calls to run the processor function with :return: a instance """ dataset = self.creator() if shuffle_buffer_size and shuffle_buffer_size > 0: dataset = dataset.shuffle( shuffle_buffer_size, reshuffle_each_iteration=True ) dataset =, num_parallel_calls=num_parallel_calls) # Together with shuffling above, putting batch after repeat yields # batches that straddle epoch boundaries dataset = dataset.repeat(repeat_count) dataset = dataset.batch(batch_size) if prefetch_buffer_size and prefetch_buffer_size > 0: dataset = dataset.prefetch(prefetch_buffer_size) return dataset
[docs] @abstractmethod def creator(self) -> """ Implemented by sub classes to create a dataset for the given impl. :return: a created dataset """ raise NotImplementedError()
[docs] @abstractmethod def processor(self, *args, **kwargs): """ Implemented by sub classes to parallelize and map processing functions for loading the data of the dataset into memory. :param args: generic inputs for processing :param kwargs: generic inputs for processing :return: the processed tensors """ raise NotImplementedError()