continual_model#
This is the base class for all models. It provides some useful methods and defines the interface of the models.
The observe method is the most important one: it is called at each training iteration and it is responsible for computing the loss and updating the model’s parameters.
The begin_task and end_task methods are called before and after each task, respectively.
The get_parser method returns the parser of the model. Additional model-specific hyper-parameters can be added by overriding this method.
The get_debug_iters method returns the number of iterations to be used for debugging. Default: 3.
The get_optimizer method returns the optimizer to be used for training. Default: SGD.
The load_buffer method is called when a buffer is loaded. Default: do nothing.
The meta_observe, meta_begin_task and meta_end_task methods are wrappers for observe, begin_task and end_task methods, respectively. They take care of updating the internal counters and of logging to wandb if installed.
The autolog_wandb method is used to automatically log to wandb all variables starting with “_wandb_” or “loss” in the observe function. It is called by meta_observe if wandb is installed. It can be overridden to add custom logging.
Classes
- class models.utils.continual_model.ContinualModel(backbone, loss, args, transform)[source]#
Bases:
ModuleContinual learning model.
- AVAIL_OPTIMS = ['sgd', 'adam', 'adamw']#
- autolog_wandb(locals, extra=None)[source]#
All variables starting with “_wandb_” or “loss” in the observe function are automatically logged to wandb upon return if wandb is installed.
- property cpt#
Returns the raw number of classes per task. Warning: return value might be either an integer or a list of integers.
- property current_task#
Returns the index of current task.
- static get_parser()[source]#
Returns the parser of the model.
Additional model-specific hyper-parameters can be added by overriding this method.
- Returns:
the parser of the model
- Return type:
- meta_begin_task(dataset)[source]#
Wrapper for begin_task method.
Takes care of updating the internal counters.
- Parameters:
dataset – the current task’s dataset
- meta_end_task(dataset)[source]#
Wrapper for end_task method.
Takes care of updating the internal counters.
- Parameters:
dataset – the current task’s dataset
- meta_observe(*args, **kwargs)[source]#
Wrapper for observe method.
Takes care of dropping unlabeled data if not supported by the model and of logging to wandb if installed.
- Parameters:
inputs – batch of inputs
labels – batch of labels
not_aug_inputs – batch of inputs without augmentation
kwargs – some methods could require additional parameters
- Returns:
the value of the loss function
- property n_classes_current_task#
Returns the number of classes in the current task. Returns -1 if task has not been initialized yet.
- property n_past_classes#
Returns the number of classes seen up to the PAST task. Returns -1 if task has not been initialized yet.
- property n_remaining_classes#
Returns the number of classes remaining to be seen. Returns -1 if task has not been initialized yet.
- property n_seen_classes#
Returns the number of classes seen so far. Returns -1 if task has not been initialized yet.