Source code for utils.main

"""
This script is the main entry point for the Mammoth project. It contains the main function `main()` that orchestrates the training process.

The script performs the following tasks:
- Imports necessary modules and libraries.
- Sets up the necessary paths and configurations.
- Parses command-line arguments.
- Initializes the dataset, model, and other components.
- Trains the model using the `train()` function.

To run the script, execute it directly or import it as a module and call the `main()` function.
"""
# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# needed (don't change it)
import numpy  # noqa
import time
import importlib
import os
import socket
import sys
import datetime
import uuid
from argparse import ArgumentParser
import torch

mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(mammoth_path)
sys.path.append(mammoth_path + '/datasets')
sys.path.append(mammoth_path + '/backbone')
sys.path.append(mammoth_path + '/models')

from utils import create_if_not_exists, custom_str_underscore
from utils.args import add_management_args, add_experiment_args
from utils.conf import base_path
from utils.distributed import make_dp
from utils.best_args import best_args
from utils.conf import set_random_seed


[docs] def lecun_fix(): # Yann moved his website to CloudFlare. You need this now from six.moves import urllib # pyright: ignore opener = urllib.request.build_opener() opener.addheaders = [('User-agent', 'Mozilla/5.0')] urllib.request.install_opener(opener)
[docs] def parse_args(): """ Parse command line arguments for the mammoth program and sets up the `args` object. Returns: args (argparse.Namespace): Parsed command line arguments. """ from models import get_all_models, get_model_class from datasets import get_dataset_names, get_dataset_class parser = ArgumentParser(description='mammoth', allow_abbrev=False, add_help=False) parser.add_argument('--model', type=custom_str_underscore, help='Model name.', choices=list(get_all_models().keys())) parser.add_argument('--load_best_args', action='store_true', help='Loads the best arguments for each method, ' 'dataset and memory buffer.') args = parser.parse_known_args()[0] models_dict = get_all_models() if args.model is None: print('No model specified. Please specify a model with --model to see all other options.') print('Available models are: {}'.format(list(models_dict.keys()))) sys.exit(1) mod = importlib.import_module('models.' + models_dict[args.model]) if args.load_best_args: parser.add_argument('--dataset', type=str, required=True, choices=get_dataset_names(), help='Which dataset to perform experiments on.') if hasattr(mod, 'Buffer'): parser.add_argument('--buffer_size', type=int, required=True, help='The size of the memory buffer.') args = parser.parse_args() if args.model == 'joint': best = best_args[args.dataset]['sgd'] else: best = best_args[args.dataset][args.model] if hasattr(mod, 'Buffer'): best = best[args.buffer_size] else: best = best[-1] parser = get_model_class(args).get_parser() add_management_args(parser) add_experiment_args(parser) to_parse = sys.argv[1:] + ['--' + k + '=' + str(v) for k, v in best.items()] to_parse.remove('--load_best_args') args = parser.parse_args(to_parse) if args.model == 'joint' and args.dataset == 'mnist-360': args.model = 'joint_gcl' else: parser = get_model_class(args).get_parser() add_management_args(parser) add_experiment_args(parser) args = parser.parse_args() tmp_dset_class = get_dataset_class(args) n_epochs = tmp_dset_class.get_epochs() if args.n_epochs is None: args.n_epochs = n_epochs else: if args.n_epochs != n_epochs: print('Warning: n_epochs set to {} instead of {}.'.format(args.n_epochs, n_epochs), file=sys.stderr) args.model = models_dict[args.model] if args.lr_scheduler is not None: print('Warning: lr_scheduler set to {}, overrides default from dataset.'.format(args.lr_scheduler), file=sys.stderr) if args.seed is not None: set_random_seed(args.seed) if args.savecheck: assert args.inference_only == 0, "Should not save checkpoint in inference only mode" if not os.path.isdir('checkpoints'): create_if_not_exists("checkpoints") now = time.strftime("%Y%m%d-%H%M%S") extra_ckpt_name = "" if args.ckpt_name is None else f"{args.ckpt_name}_" args.ckpt_name = f"{extra_ckpt_name}{args.model}_{args.dataset}_{args.buffer_size if hasattr(args, 'buffer_size') else 0}_{args.n_epochs}_{str(now)}" args.ckpt_name_replace = f"{extra_ckpt_name}{args.model}_{args.dataset}_{'{}'}_{args.buffer_size if hasattr(args, 'buffer_size') else 0}__{args.n_epochs}_{str(now)}" print("Saving checkpoint into", args.ckpt_name, file=sys.stderr) if args.joint: assert args.start_from is None and args.stop_after is None, "Joint training does not support start_from and stop_after" assert args.enable_other_metrics == 0, "Joint training does not support other metrics" assert 0 < args.label_perc <= 1, "label_perc must be in (0, 1]" return args
[docs] def main(args=None): from models import get_model from datasets import ContinualDataset, get_dataset from utils.training import train lecun_fix() if args is None: args = parse_args() # set base path base_path(args.base_path) os.putenv("MKL_SERVICE_FORCE_INTEL", "1") os.putenv("NPY_MKL_FORCE_INTEL", "1") # Add uuid, timestamp and hostname for logging args.conf_jobnum = str(uuid.uuid4()) args.conf_timestamp = str(datetime.datetime.now()) args.conf_host = socket.gethostname() dataset = get_dataset(args) if args.n_epochs is None and isinstance(dataset, ContinualDataset): args.n_epochs = dataset.get_epochs() if args.batch_size is None: args.batch_size = dataset.get_batch_size() if hasattr(importlib.import_module('models.' + args.model), 'Buffer') and (not hasattr(args, 'minibatch_size') or args.minibatch_size is None): args.minibatch_size = dataset.get_minibatch_size() else: args.minibatch_size = args.batch_size backbone = dataset.get_backbone() loss = dataset.get_loss() model = get_model(args, backbone, loss, dataset.get_transform()) if args.distributed == 'dp': if args.batch_size < torch.cuda.device_count(): raise Exception(f"Batch too small for DataParallel (Need at least {torch.cuda.device_count()}).") model.net = make_dp(model.net) model.to('cuda:0') args.conf_ngpus = torch.cuda.device_count() elif args.distributed == 'ddp': # DDP breaks the buffer, it has to be synchronized. raise NotImplementedError('Distributed Data Parallel not supported yet.') if args.debug_mode: print('Debug mode enabled: running only a few forward steps per epoch with W&B disabled.') args.nowand = 1 if args.wandb_entity is None or args.wandb_project is None: print('Warning: wandb_entity and wandb_project not set. Disabling wandb.') args.nowand = 1 else: print('Logging to wandb: {}/{}'.format(args.wandb_entity, args.wandb_project)) args.nowand = 0 try: import setproctitle # set job name setproctitle.setproctitle('{}_{}_{}'.format(args.model, args.buffer_size if 'buffer_size' in args else 0, args.dataset)) except Exception: pass train(model, dataset, args)
if __name__ == '__main__': main()