Source code for datasets

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys

mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(mammoth_path)
os.chdir(mammoth_path)

import importlib
import inspect
from argparse import Namespace

from datasets.utils.continual_dataset import ContinualDataset
from utils.conf import warn_once


[docs] def get_all_datasets(): """Returns the list of all the available datasets in the datasets folder.""" return [model.split('.')[0] for model in os.listdir('datasets') if not model.find('__') > -1 and 'py' in model]
[docs] def get_dataset(args: Namespace) -> ContinualDataset: """ Creates and returns a continual dataset among those that are available. If an error was detected while loading the available datasets, it raises the appropriate error message. Args: args (Namespace): the arguments which contains the hyperparameters Exceptions: AssertError: if the dataset is not available Exception: if an error is detected in the dataset Returns: the continual dataset instance """ names = get_dataset_names() assert args.dataset in names return get_dataset_class(args)(args)
[docs] def get_dataset_class(args: Namespace) -> ContinualDataset: """ Return the class of the selected continual dataset among those that are available. If an error was detected while loading the available datasets, it raises the appropriate error message. Args: args (Namespace): the arguments which contains the `--dataset` attribute Exceptions: AssertError: if the dataset is not available Exception: if an error is detected in the dataset Returns: the continual dataset class """ names = get_dataset_names() assert args.dataset in names if isinstance(names[args.dataset], Exception): raise names[args.dataset] return names[args.dataset]
[docs] def get_dataset_names(): """ Return the names of the selected continual dataset among those that are available. If an error was detected while loading the available datasets, it raises the appropriate error message. Args: args (Namespace): the arguments which contains the `--dataset` attribute Exceptions: AssertError: if the dataset is not available Exception: if an error is detected in the dataset Returns: the continual dataset class names """ def _dataset_names(): names = {} for dataset in get_all_datasets(): try: mod = importlib.import_module('datasets.' + dataset) dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'ContinualDataset' in str(inspect.getmro(getattr(mod, x))[1:]) and 'GCLDataset' not in str(inspect.getmro(getattr(mod, x)))] for d in dataset_classes_name: c = getattr(mod, d) names[c.NAME] = c gcl_dataset_classes_name = [x for x in mod.__dir__() if 'type' in str(type(getattr(mod, x))) and 'GCLDataset' in str(inspect.getmro(getattr(mod, x))[1:])] for d in gcl_dataset_classes_name: c = getattr(mod, d) names[c.NAME] = c except Exception as e: warn_once(f'Error in dataset {dataset}') warn_once(e) names[dataset.replace('_', '-')] = e return names if not hasattr(get_dataset_names, 'names'): setattr(get_dataset_names, 'names', _dataset_names()) return getattr(get_dataset_names, 'names')