gem#

Classes

class models.gem.Gem(backbone, loss, args, transform)[source]#

Bases: ContinualModel

COMPATIBILITY: List[str] = ['class-il', 'domain-il', 'task-il']#
NAME: str = 'gem'#
end_task(dataset)[source]#
static get_parser()[source]#
Return type:

ArgumentParser

observe(inputs, labels, not_aug_inputs, epoch=None)[source]#

Functions

models.gem.overwrite_grad(params, newgrad, grad_dims)[source]#

This is used to overwrite the gradients with a new gradient vector, whenever violations occur. pp: parameters newgrad: corrected gradient grad_dims: list storing number of parameters at each layer

models.gem.project2cone2(gradient, memories, margin=0.5, eps=0.001)[source]#

Solves the GEM dual QP described in the paper given a proposed gradient “gradient”, and a memory of task gradients “memories”. Overwrites “gradient” with the final projected update.

input: gradient, p-vector input: memories, (t * p)-vector output: x, p-vector

models.gem.store_grad(params, grads, grad_dims)[source]#

This stores parameter gradients of past tasks. pp: parameters grads: gradients grad_dims: list with number of parameters per layers