continual_model#

This is the base class for all models. It provides some useful methods and defines the interface of the models.

The observe method is the most important one: it is called at each training iteration and it is responsible for computing the loss and updating the model’s parameters.

The begin_task and end_task methods are called before and after each task, respectively.

The get_parser method returns the parser of the model. Additional model-specific hyper-parameters can be added by overriding this method.

The get_debug_iters method returns the number of iterations to be used for debugging. Default: 3.

The get_optimizer method returns the optimizer to be used for training. Default: SGD.

The load_buffer method is called when a buffer is loaded. Default: do nothing.

The meta_observe, meta_begin_task and meta_end_task methods are wrappers for observe, begin_task and end_task methods, respectively. They take care of updating the internal counters and of logging to wandb if installed.

The autolog_wandb method is used to automatically log to wandb all variables starting with “_wandb_” or “loss” in the observe function. It is called by meta_observe if wandb is installed. It can be overridden to add custom logging.

Classes

class models.utils.continual_model.ContinualModel(backbone, loss, args, transform)[source]#

Bases: Module

Continual learning model.

AVAIL_OPTIMS = ['sgd', 'adam', 'adamw']#
COMPATIBILITY: List[str]#
NAME: str#
autolog_wandb(locals, extra=None)[source]#

All variables starting with “_wandb_” or “loss” in the observe function are automatically logged to wandb upon return if wandb is installed.

begin_task(dataset)[source]#

Prepares the model for the current task. Executed before each task.

property cpt#

Returns the raw number of classes per task. Warning: return value might be either an integer or a list of integers.

property current_task#

Returns the index of current task.

end_task(dataset)[source]#

Prepares the model for the next task. Executed after each task.

forward(x)[source]#

Computes a forward pass.

Parameters:
  • x (Tensor) – batch of inputs

  • task_label – some models require the task label

Returns:

the result of the computation

Return type:

Tensor

get_debug_iters()[source]#

Returns the number of iterations to be used for debugging. Default: 3

get_optimizer()[source]#
get_parameters()[source]#

Returns the parameters of the model.

static get_parser()[source]#

Returns the parser of the model.

Additional model-specific hyper-parameters can be added by overriding this method.

Returns:

the parser of the model

Return type:

Namespace

load_buffer(buffer)[source]#

Default way to handle load buffer.

meta_begin_task(dataset)[source]#

Wrapper for begin_task method.

Takes care of updating the internal counters.

Parameters:

dataset – the current task’s dataset

meta_end_task(dataset)[source]#

Wrapper for end_task method.

Takes care of updating the internal counters.

Parameters:

dataset – the current task’s dataset

meta_observe(*args, **kwargs)[source]#

Wrapper for observe method.

Takes care of dropping unlabeled data if not supported by the model and of logging to wandb if installed.

Parameters:
  • inputs – batch of inputs

  • labels – batch of labels

  • not_aug_inputs – batch of inputs without augmentation

  • kwargs – some methods could require additional parameters

Returns:

the value of the loss function

property n_classes_current_task#

Returns the number of classes in the current task. Returns -1 if task has not been initialized yet.

property n_past_classes#

Returns the number of classes seen up to the PAST task. Returns -1 if task has not been initialized yet.

property n_remaining_classes#

Returns the number of classes remaining to be seen. Returns -1 if task has not been initialized yet.

property n_seen_classes#

Returns the number of classes seen so far. Returns -1 if task has not been initialized yet.

abstract observe(inputs, labels, not_aug_inputs, epoch=None)[source]#

Compute a training step over a given batch of examples.

Parameters:
  • inputs (Tensor) – batch of examples

  • labels (Tensor) – ground-truth labels

  • kwargs – some methods could require additional parameters

Returns:

the value of the loss function

Return type:

float

to(device)[source]#

Captures the device to be used for training.