Source code for utils.args

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

if __name__ == '__main__':
    import os
    import sys
    mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(mammoth_path)

from argparse import ArgumentParser
from datasets import get_dataset_names
from models import get_all_models
from models.utils.continual_model import ContinualModel
from utils import custom_str_underscore


[docs] def add_experiment_args(parser: ArgumentParser) -> None: """ Adds the arguments used by all the models. Args: parser: the parser instance Returns: None """ parser.add_argument('--dataset', type=str, required=True, choices=get_dataset_names(), help='Which dataset to perform experiments on.') parser.add_argument('--model', type=custom_str_underscore, required=True, help='Model name.', choices=list(get_all_models().keys())) parser.add_argument('--lr', type=float, required=True, help='Learning rate.') parser.add_argument('--optimizer', type=str, default='sgd', choices=ContinualModel.AVAIL_OPTIMS, help='Optimizer.') parser.add_argument('--optim_wd', type=float, default=0., help='optimizer weight decay.') parser.add_argument('--optim_mom', type=float, default=0., help='optimizer momentum.') parser.add_argument('--optim_nesterov', type=int, default=0, help='optimizer nesterov momentum.') parser.add_argument('--lr_scheduler', type=str, help='Learning rate scheduler.') parser.add_argument('--lr_milestones', type=int, nargs='+', default=[], help='Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`).') parser.add_argument('--sched_multistep_lr_gamma', type=float, default=0.1, help='Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`).') parser.add_argument('--n_epochs', type=int, help='Number of epochs.') parser.add_argument('--batch_size', type=int, help='Batch size.') parser.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp'], help='Enable distributed training?') parser.add_argument('--savecheck', action='store_true', help='Save checkpoint?') parser.add_argument('--loadcheck', type=str, default=None, help='Path of the checkpoint to load (.pt file for the specific task)') parser.add_argument('--ckpt_name', type=str, required=False, help='(optional) checkpoint save name.') parser.add_argument('--start_from', type=int, default=None, help="Task to start from") parser.add_argument('--stop_after', type=int, default=None, help="Task limit") parser.add_argument('--joint', type=int, choices=[0, 1], default=0, help='Train model on Joint (single task)?') parser.add_argument('--label_perc', type=float, default=1, help='Percentage in (0-1] of labeled examples per task.')
[docs] def add_management_args(parser: ArgumentParser) -> None: """ Adds the management arguments. Args: parser: the parser instance Returns: None """ parser.add_argument('--seed', type=int, default=None, help='The random seed.') parser.add_argument('--permute_classes', type=int, choices=[0, 1], default=0, help='Permute classes before splitting tasks (applies seed before permute if seed is present)?') parser.add_argument('--base_path', type=str, default="./data/", help='The base path where to save datasets, logs, results.') parser.add_argument('--notes', type=str, default=None, help='Notes for this run.') parser.add_argument('--wandb_name', type=str, default=None, help='Wandb name for this run. Overrides the default name (`args.model`).') parser.add_argument('--non_verbose', default=0, choices=[0, 1], type=int, help='Make progress bars non verbose') parser.add_argument('--disable_log', default=0, choices=[0, 1], type=int, help='Disable logging?') parser.add_argument('--num_workers', type=int, default=None, help='Number of workers for the dataloaders (default=infer from number of cpus).') parser.add_argument('--validation', type=int, help='Percentage of validation set drawn from the training set.') parser.add_argument('--enable_other_metrics', default=0, choices=[0, 1], type=int, help='Enable computing additional metrics: forward and backward transfer.') parser.add_argument('--debug_mode', type=int, default=0, choices=[0, 1], help='Run only a few forward steps per epoch') parser.add_argument('--wandb_entity', type=str, help='Wandb entity') parser.add_argument('--wandb_project', type=str, default='mammoth', help='Wandb project name') parser.add_argument('--eval_epochs', type=int, default=None, help='Perform inference intra-task at every `eval_epochs`.') parser.add_argument('--inference_only', action="store_true", help='Perform inference only for each task (no training).')
[docs] def add_rehearsal_args(parser: ArgumentParser) -> None: """ Adds the arguments used by all the rehearsal-based methods Args: parser: the parser instance Returns: None """ parser.add_argument('--buffer_size', type=int, required=True, help='The size of the memory buffer.') parser.add_argument('--minibatch_size', type=int, help='The batch size of the memory buffer.')
class _DocsArgs: """ This class is used to generate the documentation of the arguments. """ def __init__(self, name: str, type_: str, choices: str, default: str, help_: str): self.name = name self.type = type_ self.choices = choices self.default = default self.help = help_ def parse_choices(self) -> str: if self.choices is None: return '' return ', '.join([c.keys() if isinstance(c, dict) else str(c) for c in self.choices]) def __str__(self): tb = '\t' return f"""**\\-\\-{self.name}** : {self.type} *Help*: {self.help}\n - Default: {self.default}\n - Choices: {self.parse_choices() if self.choices is not None else ''}""" if __name__ == '__main__': print("Generating documentation for the arguments...") os.chdir(mammoth_path) parser = ArgumentParser() add_experiment_args(parser) docs_args = [] for action in parser._actions: if action.dest == 'help': continue docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) with open('docs/utils/args.rst', 'w') as f: f.write('.. _module-args:\n\n') f.write('Arguments\n') f.write('=========\n\n') f.write('.. rubric:: EXPERIMENT-RELATED ARGS\n\n') for arg in docs_args: f.write(str(arg) + '\n\n') parser = ArgumentParser() add_management_args(parser) docs_args = [] for action in parser._actions: if action.dest == 'help': continue docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) with open('docs/utils/args.rst', 'a') as f: f.write('.. rubric:: MANAGEMENT ARGS\n\n') for arg in docs_args: f.write(str(arg) + '\n\n') parser = ArgumentParser() add_rehearsal_args(parser) docs_args = [] for action in parser._actions: if action.dest == 'help': continue docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) with open('docs/utils/args.rst', 'a') as f: f.write('.. rubric:: REEHARSAL-ONLY ARGS\n\n') for arg in docs_args: f.write(str(arg) + '\n\n') print("Saving documentation in docs/utils/args.rst") print("Done!")