gem#
Classes
- class models.gem.Gem(backbone, loss, args, transform)[source]#
Bases:
ContinualModel
Functions
- models.gem.overwrite_grad(params, newgrad, grad_dims)[source]#
This is used to overwrite the gradients with a new gradient vector, whenever violations occur. pp: parameters newgrad: corrected gradient grad_dims: list storing number of parameters at each layer
- models.gem.project2cone2(gradient, memories, margin=0.5, eps=0.001)[source]#
Solves the GEM dual QP described in the paper given a proposed gradient “gradient”, and a memory of task gradients “memories”. Overwrites “gradient” with the final projected update.
input: gradient, p-vector input: memories, (t * p)-vector output: x, p-vector