Source code for utils.buffer

# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn

from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from utils.augmentations import apply_transform
from utils.conf import get_device


[docs] def icarl_replay(self: ContinualModel, dataset, val_set_split=0): """ Merge the replay buffer with the current task data. Optionally split the replay buffer into a validation set. Args: self: the model instance dataset: the dataset val_set_split: the fraction of the replay buffer to be used as validation set """ if self.current_task > 0: buff_val_mask = torch.rand(len(self.buffer)) < val_set_split val_train_mask = torch.zeros(len(dataset.train_loader.dataset.data)).bool() val_train_mask[torch.randperm(len(dataset.train_loader.dataset.data))[:buff_val_mask.sum()]] = True if val_set_split > 0: self.val_dataset = deepcopy(dataset.train_loader.dataset) data_concatenate = torch.cat if isinstance(dataset.train_loader.dataset.data, torch.Tensor) else np.concatenate need_aug = hasattr(dataset.train_loader.dataset, 'not_aug_transform') if not need_aug: def refold_transform(x): return x.cpu() else: data_shape = len(dataset.train_loader.dataset.data[0].shape) if data_shape == 3: def refold_transform(x): return (x.cpu() * 255).permute([0, 2, 3, 1]).numpy().astype(np.uint8) elif data_shape == 2: def refold_transform(x): return (x.cpu() * 255).squeeze(1).type(torch.uint8) # REDUCE AND MERGE TRAINING SET dataset.train_loader.dataset.targets = np.concatenate([ dataset.train_loader.dataset.targets[~val_train_mask], self.buffer.labels.cpu().numpy()[:len(self.buffer)][~buff_val_mask] ]) dataset.train_loader.dataset.data = data_concatenate([ dataset.train_loader.dataset.data[~val_train_mask], refold_transform((self.buffer.examples)[:len(self.buffer)][~buff_val_mask]) ]) if val_set_split > 0: # REDUCE AND MERGE VALIDATION SET self.val_dataset.targets = np.concatenate([ self.val_dataset.targets[val_train_mask], self.buffer.labels.cpu().numpy()[:len(self.buffer)][buff_val_mask] ]) self.val_dataset.data = data_concatenate([ self.val_dataset.data[val_train_mask], refold_transform((self.buffer.examples)[:len(self.buffer)][buff_val_mask]) ]) self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers)
[docs] def reservoir(num_seen_examples: int, buffer_size: int) -> int: """ Reservoir sampling algorithm. Args: num_seen_examples: the number of seen examples buffer_size: the maximum buffer size Returns: the target index if the current image is sampled, else -1 """ if num_seen_examples < buffer_size: return num_seen_examples rand = np.random.randint(0, num_seen_examples + 1) if rand < buffer_size: return rand else: return -1
[docs] class Buffer: """ The memory buffer of rehearsal method. """ def __init__(self, buffer_size, device="cpu"): """ Initialize a reservoir-based Buffer object. Args: buffer_size (int): The maximum size of the buffer. device (str, optional): The device to store the buffer on. Defaults to "cpu". Note: If during the `get_data` the transform is PIL, data will be moved to cpu and then back to the device. This is why the device is set to cpu by default. """ self.buffer_size = buffer_size self.device = device self.num_seen_examples = 0 self.attributes = ['examples', 'labels', 'logits', 'task_labels'] self.attention_maps = [None] * buffer_size
[docs] def to(self, device): """ Move the buffer and its attributes to the specified device. Args: device: The device to move the buffer and its attributes to. Returns: The buffer instance with the updated device and attributes. """ self.device = device for attr_str in self.attributes: if hasattr(self, attr_str): setattr(self, attr_str, getattr(self, attr_str).to(device)) return self
def __len__(self): """ Returns the number items in the buffer. """ return min(self.num_seen_examples, self.buffer_size)
[docs] def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, logits: torch.Tensor, task_labels: torch.Tensor) -> None: """ Initializes just the required tensors. Args: examples: tensor containing the images labels: tensor containing the labels logits: tensor containing the outputs of the network task_labels: tensor containing the task labels """ for attr_str in self.attributes: attr = eval(attr_str) if attr is not None and not hasattr(self, attr_str): typ = torch.int64 if attr_str.endswith('els') else torch.float32 setattr(self, attr_str, torch.zeros((self.buffer_size, *attr.shape[1:]), dtype=typ, device=self.device))
@property def used_attributes(self): """ Returns a list of attributes that are currently being used by the object. """ return [attr_str for attr_str in self.attributes if hasattr(self, attr_str)]
[docs] def add_data(self, examples, labels=None, logits=None, task_labels=None, attention_maps=None): """ Adds the data to the memory buffer according to the reservoir strategy. Args: examples: tensor containing the images labels: tensor containing the labels logits: tensor containing the outputs of the network task_labels: tensor containing the task labels Note: Only the examples are required. The other tensors are initialized only if they are provided. """ if not hasattr(self, 'examples'): self.init_tensors(examples, labels, logits, task_labels) for i in range(examples.shape[0]): index = reservoir(self.num_seen_examples, self.buffer_size) self.num_seen_examples += 1 if index >= 0: self.examples[index] = examples[i].to(self.device) if labels is not None: self.labels[index] = labels[i].to(self.device) if logits is not None: self.logits[index] = logits[i].to(self.device) if task_labels is not None: self.task_labels[index] = task_labels[i].to(self.device) if attention_maps is not None: self.attention_maps[index] = [at[i].byte().to(self.device) for at in attention_maps]
[docs] def get_data(self, size: int, transform: nn.Module = None, return_index=False, device=None, mask_task_out=None, cpt=None) -> Tuple: """ Random samples a batch of size items. Args: size: the number of requested items transform: the transformation to be applied (data augmentation) return_index: if True, returns the indexes of the sampled items mask_task: if not None, masks OUT the examples from the given task cpt: the number of classes per task (required if mask_task is not None and task_labels are not present) Returns: a tuple containing the requested items. If return_index is True, the tuple contains the indexes as first element. """ target_device = self.device if device is None else device if mask_task_out is not None: assert hasattr(self, 'task_labels') or cpt is not None assert hasattr(self, 'task_labels') or hasattr(self, 'labels') samples_mask = (self.task_labels != mask_task_out) if hasattr(self, 'task_labels') else self.labels // cpt != mask_task_out num_avail_samples = self.examples.shape[0] if mask_task_out is None else samples_mask.sum().item() num_avail_samples = min(self.num_seen_examples, num_avail_samples) if size > min(num_avail_samples, self.examples.shape[0]): size = min(num_avail_samples, self.examples.shape[0]) choice = np.random.choice(num_avail_samples, size=size, replace=False) if transform is None: def transform(x): return x selected_samples = self.examples[choice] if mask_task_out is None else self.examples[samples_mask][choice] ret_tuple = (torch.stack([transform(ee) for ee in selected_samples.cpu()]).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) selected_attr = attr[choice] if mask_task_out is None else attr[samples_mask][choice] ret_tuple += (selected_attr.to(target_device),) if not return_index: return ret_tuple else: return (torch.tensor(choice).to(target_device), ) + ret_tuple
[docs] def get_data_by_index(self, indexes, transform: nn.Module = None, device=None) -> Tuple: """ Returns the data by the given index. Args: index: the index of the item transform: the transformation to be applied (data augmentation) Returns: a tuple containing the requested items. The returned items depend on the attributes stored in the buffer from previous calls to `add_data`. """ target_device = self.device if device is None else device if transform is None: def transform(x): return x ret_tuple = (apply_transform(self.examples[:len(self)], transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str).to(target_device) ret_tuple += (attr[indexes],) return ret_tuple
[docs] def is_empty(self) -> bool: """ Returns true if the buffer is empty, false otherwise. """ if self.num_seen_examples == 0: return True else: return False
[docs] def get_all_data(self, transform: nn.Module = None, device=None) -> Tuple: """ Return all the items in the memory buffer. Args: transform: the transformation to be applied (data augmentation) Returns: a tuple with all the items in the memory buffer """ target_device = self.device if device is None else device if transform is None: def transform(x): return x ret_tuple = (apply_transform(self.examples[:len(self)], transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str)[:len(self)].to(target_device) ret_tuple += (attr,) return ret_tuple
[docs] def empty(self) -> None: """ Set all the tensors to None. """ for attr_str in self.attributes: if hasattr(self, attr_str): delattr(self, attr_str) self.num_seen_examples = 0
[docs] @torch.no_grad() def fill_buffer(buffer: Buffer, dataset: ContinualDataset, t_idx: int, net: ContinualModel = None, use_herding=False, required_attributes: List[str] = None) -> None: """ Adds examples from the current task to the memory buffer. Supports images, labels, task_labels, and logits. Args: buffer: the memory buffer dataset: the dataset from which take the examples t_idx: the task index net: (optional) the model instance. Used if logits are in buffer. If provided, adds logits. use_herding: (optional) if True, uses herding strategy. Otherwise, random sampling. required_attributes: (optional) the attributes to be added to the buffer. If None and buffer is empty, adds only examples and labels. """ if net is not None: mode = net.training net.eval() else: assert not use_herding, "Herding strategy requires a model instance" device = net.device if net is not None else get_device() n_seen_classes = dataset.N_CLASSES_PER_TASK * (t_idx + 1) if isinstance(dataset.N_CLASSES_PER_TASK, int) else \ sum(dataset.N_CLASSES_PER_TASK[:t_idx + 1]) n_past_classes = dataset.N_CLASSES_PER_TASK * t_idx if isinstance(dataset.N_CLASSES_PER_TASK, int) else \ sum(dataset.N_CLASSES_PER_TASK[:t_idx]) samples_per_class = buffer.buffer_size // n_seen_classes # Check for requirs attributes required_attributes = required_attributes or ['examples', 'labels'] assert all([attr in buffer.used_attributes for attr in required_attributes]) or len(buffer) == 0, \ "Required attributes not in buffer: {}".format([attr for attr in required_attributes if attr not in buffer.used_attributes]) if t_idx > 0: # 1) First, subsample prior classes buf_data = buffer.get_all_data() buf_y = buf_data[1] buffer.empty() for _y in buf_y.unique(): idx = (buf_y == _y) _buf_data_idx = {attr_name: _d[idx][:samples_per_class] for attr_name, _d in zip(required_attributes, buf_data)} buffer.add_data(**_buf_data_idx) # 2) Then, fill with current tasks loader = dataset.train_loader norm_trans = dataset.get_normalization_transform() if norm_trans is None: def norm_trans(x): return x if 'logits' in buffer.used_attributes: assert net is not None, "Logits in buffer require a model instance" # 2.1 Extract all features a_x, a_y, a_f, a_l = [], [], [], [] for x, y, not_norm_x in loader: mask = (y >= n_past_classes) & (y < n_seen_classes) x, y, not_norm_x = x[mask], y[mask], not_norm_x[mask] if not x.size(0): continue a_x.append(not_norm_x.cpu()) a_y.append(y.cpu()) if net is not None: feats = net(norm_trans(not_norm_x.to(device)), returnt='features') outs = net.classifier(feats) a_f.append(feats.cpu()) a_l.append(torch.sigmoid(outs).cpu()) a_x, a_y = torch.cat(a_x), torch.cat(a_y) if net is not None: a_f, a_l = torch.cat(a_f), torch.cat(a_l) # 2.2 Compute class means for _y in a_y.unique(): idx = (a_y == _y) _x, _y = a_x[idx], a_y[idx] if use_herding: _l = a_l[idx] feats = a_f[idx] mean_feat = feats.mean(0, keepdim=True) running_sum = torch.zeros_like(mean_feat) i = 0 while i < samples_per_class and i < feats.shape[0]: cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1) idx_min = cost.argmin().item() buffer.add_data( examples=_x[idx_min:idx_min + 1].to(device), labels=_y[idx_min:idx_min + 1].to(device), logits=_l[idx_min:idx_min + 1].to(device) if 'logits' in required_attributes else None, task_labels=torch.ones(len(_x[idx_min:idx_min + 1])).to(device) * t_idx if 'task_labels' in required_attributes else None ) running_sum += feats[idx_min:idx_min + 1] feats[idx_min] = feats[idx_min] + 1e6 i += 1 else: idx = torch.randperm(len(_x))[:samples_per_class] buffer.add_data( examples=_x[idx].to(device), labels=_y[idx].to(device), logits=_l[idx].to(device) if 'logits' in required_attributes else None, task_labels=torch.ones(len(_x[idx])).to(device) * t_idx if 'task_labels' in required_attributes else None ) assert len(buffer.examples) <= buffer.buffer_size assert buffer.num_seen_examples <= buffer.buffer_size if net is not None: net.train(mode)