Source code for utils.conf
"""
This module contains utility functions for configuration settings.
"""
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import random
import torch
import numpy as np
[docs]
def warn_once(*msg):
"""
Prints a warning message only once.
Args:
msg: the message to be printed
"""
msg = ' '.join([str(m) for m in msg])
if not hasattr(warn_once, 'warned'):
warn_once.warned = set()
if msg not in warn_once.warned:
warn_once.warned.add(msg)
print(msg, file=sys.stderr)
[docs]
def get_device() -> torch.device:
"""
Returns the least used GPU device if available else MPS or CPU.
"""
def _get_device():
# get least used gpu by used memory
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
gpu_memory = []
for i in range(torch.cuda.device_count()):
gpu_memory.append(torch.cuda.memory_allocated(i))
device = torch.device(f'cuda:{np.argmin(gpu_memory)}')
print(f'Using device {device}')
return device
try:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
print("WARNING: MSP support is still experimental. Use at your own risk!")
return torch.device("mps")
except BaseException:
print("WARNING: Something went wrong with MPS. Using CPU.")
return torch.device("cpu")
# Permanently store the chosen device
if not hasattr(get_device, 'device'):
get_device.device = _get_device()
print(f'Using device {get_device.device}')
return get_device.device
[docs]
def base_path(override=None) -> str:
"""
Returns the base bath where to log accuracies and tensorboard data.
Args:
override: the path to override the default one. Once set, it is stored and used for all the next calls.
Returns:
the base path (default: `./data/`)
"""
if override is not None:
if not os.path.exists(override):
os.makedirs(override)
if not override.endswith('/'):
override += '/'
setattr(base_path, 'path', override)
if not hasattr(base_path, 'path'):
setattr(base_path, 'path', './data/')
return getattr(base_path, 'path')
[docs]
def set_random_seed(seed: int) -> None:
"""
Sets the seeds at a certain value.
Args:
seed: the value to be set
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
try:
torch.cuda.manual_seed_all(seed)
except BaseException:
print('Could not set cuda seed.')
[docs]
def set_random_seed_worker(worker_id) -> None:
"""
Sets the seeds for a worker of a dataloader.
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
[docs]
def create_seeded_dataloader(args, dataset, **dataloader_args) -> torch.utils.data.DataLoader:
"""
Creates a dataloader object from a dataset, setting the seeds for the workers (if `--seed` is set).
Args:
args: the arguments of the program
dataset: the dataset to be loaded
dataloader_args: external arguments of the dataloader
Returns:
the dataloader object
"""
n_cpus = 4 if not hasattr(os, 'sched_getaffinity') else len(os.sched_getaffinity(0))
num_workers = n_cpus if args.num_workers is None else args.num_workers
dataloader_args['num_workers'] = num_workers if 'num_workers' not in dataloader_args else dataloader_args['num_workers']
if args.seed is not None:
worker_generator = torch.Generator()
worker_generator.manual_seed(args.seed)
else:
worker_generator = None
dataloader_args['generator'] = worker_generator if 'generator' not in dataloader_args else dataloader_args['generator']
dataloader_args['worker_init_fn'] = set_random_seed_worker if 'worker_init_fn' not in dataloader_args else dataloader_args['worker_init_fn']
return torch.utils.data.DataLoader(dataset, **dataloader_args)