Source code for backbone

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn as nn


[docs] def xavier(m: nn.Module) -> None: """ Applies Xavier initialization to linear modules. Args: m: the module to be initialized Example:: >>> net = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) >>> net.apply(xavier) """ if m.__class__.__name__ == 'Linear': fan_in = m.weight.data.size(1) fan_out = m.weight.data.size(0) std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out)) a = math.sqrt(3.0) * std m.weight.data.uniform_(-a, a) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def num_flat_features(x: torch.Tensor) -> int: """ Computes the total number of items except the first (batch) dimension. Args: x: input tensor Returns: number of item from the second dimension onward """ size = x.size()[1:] num_features = 1 for ff in size: num_features *= ff return num_features
[docs] class MammothBackbone(nn.Module): """ A backbone module for the Mammoth model. Args: **kwargs: additional keyword arguments Methods: forward: Compute a forward pass. features: Get the features of the input tensor (same as forward but with returnt='features'). get_params: Returns all the parameters concatenated in a single tensor. set_params: Sets the parameters to a given value. get_grads: Returns all the gradients concatenated in a single tensor. get_grads_list: Returns a list containing the gradients (a tensor for each layer). """ def __init__(self, **kwargs) -> None: super(MammothBackbone, self).__init__()
[docs] def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: """ Compute a forward pass. Args: x: input tensor (batch_size, *input_shape) returnt: return type (a string among `out`, `features`, `both`, or `all`) Returns: output tensor """ raise NotImplementedError
[docs] def features(self, x: torch.Tensor) -> torch.Tensor: """ Compute the features of the input tensor. Args: x: input tensor Returns: features tensor """ return self.forward(x, returnt='features')
[docs] def get_params(self) -> torch.Tensor: """ Returns all the parameters concatenated in a single tensor. Returns: parameters tensor """ params = [] for pp in list(self.parameters()): params.append(pp.view(-1)) return torch.cat(params)
[docs] def set_params(self, new_params: torch.Tensor) -> None: """ Sets the parameters to a given value. Args: new_params: concatenated values to be set """ assert new_params.size() == self.get_params().size() progress = 0 for pp in list(self.parameters()): cand_params = new_params[progress: progress + torch.tensor(pp.size()).prod()].view(pp.size()) progress += torch.tensor(pp.size()).prod() pp.data = cand_params
[docs] def get_grads(self) -> torch.Tensor: """ Returns all the gradients concatenated in a single tensor. Returns: gradients tensor """ return torch.cat(self.get_grads_list())
[docs] def get_grads_list(self): """ Returns a list containing the gradients (a tensor for each layer). Returns: gradients list """ grads = [] for pp in list(self.parameters()): grads.append(pp.grad.view(-1)) return grads