Source code for datasets.transforms.denormalization

# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


[docs] class DeNormalize(object): def __init__(self, mean, std): """ Initializes a DeNormalize object. Args: mean (list): List of mean values for each channel. std (list): List of standard deviation values for each channel. """ self.mean = mean self.std = std def __call__(self, tensor): """ Applies denormalization to the input tensor. Args: tensor (Tensor): Tensor image of size (C, H, W) to be denormalized. Returns: Tensor: Denormalized image. """ for t, m, s in zip(tensor, self.mean, self.std): t.mul_(s).add_(m) return tensor