Source code for sparseml.keras.datasets.helpers

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
General utilities for dataset implementations for Keras
"""

from typing import Tuple

import tensorflow


__all__ = [
    "random_scaling_crop",
]


[docs]def random_scaling_crop( scale_range: Tuple[int, int] = (0.8, 1.0), ratio_range: Tuple[int, int] = (3.0 / 4.0, 4.0 / 3.0), ): """ Random crop implementation which also randomly scales the crop taken as well as the aspect ratio of the crop. :param scale_range: the (min, max) of the crop scales to take from the orig image :param ratio_range: the (min, max) of the aspect ratios to take from the orig image :return: the callable function for random scaling crop op, takes in the image and outputs randomly cropped image """ def rand_crop(img: tensorflow.Tensor): orig_shape = tensorflow.shape(img) scale = tensorflow.random.uniform( shape=[1], minval=scale_range[0], maxval=scale_range[1] )[0] ratio = tensorflow.random.uniform( shape=[1], minval=ratio_range[0], maxval=ratio_range[1] )[0] height = tensorflow.minimum( tensorflow.cast( tensorflow.round( tensorflow.cast(orig_shape[0], dtype=tensorflow.float32) * scale / ratio ), tensorflow.int32, ), orig_shape[0], ) width = tensorflow.minimum( tensorflow.cast( tensorflow.round( tensorflow.cast(orig_shape[1], dtype=tensorflow.float32) * scale ), tensorflow.int32, ), orig_shape[1], ) img = tensorflow.image.random_crop(img, [height, width, orig_shape[2]]) return img return rand_crop