# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
import math
import sys
from argparse import Namespace
from typing import Tuple
import torch
from datasets import get_dataset
from datasets.utils.continual_dataset import ContinualDataset
from datasets.utils.gcl_dataset import GCLDataset
from models.utils.continual_model import ContinualModel
from utils import random_id
from utils.checkpoints import mammoth_load_checkpoint
from utils.loggers import *
from utils.status import ProgressBar
try:
import wandb
except ImportError:
wandb = None
[docs]
def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> None:
"""
Given the output tensor, the dataset at hand and the current task,
masks the former by setting the responses for the other tasks at -inf.
It is used to obtain the results for the task-il setting.
Args:
outputs: the output tensor
dataset: the continual dataset
k: the task index
"""
outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf')
outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK:
dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf')
[docs]
@torch.no_grad()
def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[list, list]:
"""
Evaluates the accuracy of the model for each past task.
The accuracy is evaluated for all the tasks up to the current one, only for the total number of classes seen so far.
Args:
model: the model to be evaluated
dataset: the continual dataset at hand
Returns:
a tuple of lists, containing the class-il and task-il accuracy for each task
"""
status = model.net.training
model.net.eval()
accs, accs_mask_classes = [], []
n_classes = dataset.get_offsets()[1]
for k, test_loader in enumerate(dataset.test_loaders):
if last and k < len(dataset.test_loaders) - 1:
continue
correct, correct_mask_classes, total = 0.0, 0.0, 0.0
test_iter = iter(test_loader)
i = 0
while True:
try:
data = next(test_iter)
except StopIteration:
break
if model.args.debug_mode and i > model.get_debug_iters():
break
inputs, labels = data
inputs, labels = inputs.to(model.device), labels.to(model.device)
if 'class-il' not in model.COMPATIBILITY and 'general-continual' not in model.COMPATIBILITY:
outputs = model(inputs, k)
else:
outputs = model(inputs)
_, pred = torch.max(outputs[:, :n_classes].data, 1)
correct += torch.sum(pred == labels).item()
total += labels.shape[0]
i += 1
if dataset.SETTING == 'class-il':
mask_classes(outputs, dataset, k)
_, pred = torch.max(outputs.data, 1)
correct_mask_classes += torch.sum(pred == labels).item()
accs.append(correct / total * 100
if 'class-il' in model.COMPATIBILITY or 'general-continual' in model.COMPATIBILITY else 0)
accs_mask_classes.append(correct_mask_classes / total * 100)
model.net.train(status)
return accs, accs_mask_classes
[docs]
def initialize_wandb(args: Namespace) -> None:
"""
Initializes wandb, if installed.
Args:
args: the arguments of the current execution
"""
assert wandb is not None, "Wandb not installed, please install it or run without wandb"
run_name = args.wandb_name if args.wandb_name is not None else args.model
run_id = random_id(5)
name = f'{run_name}_{run_id}'
wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), name=name)
args.wandb_url = wandb.run.get_url()
[docs]
def train(model: ContinualModel, dataset: ContinualDataset,
args: Namespace) -> None:
"""
The training process, including evaluations and loggers.
Args:
model: the module to be trained
dataset: the continual dataset at hand
args: the arguments of the current execution
"""
print(args)
if not args.nowand:
initialize_wandb(args)
model.net.to(model.device)
results, results_mask_classes = [], []
if not args.disable_log:
logger = Logger(dataset.SETTING, dataset.NAME, model.NAME)
if args.start_from is not None:
for i in range(args.start_from):
train_loader, _ = dataset.get_data_loaders()
model.meta_begin_task(dataset)
model.meta_end_task(dataset)
if args.loadcheck is not None:
model, past_res = mammoth_load_checkpoint(args, model)
if not args.disable_log and past_res is not None:
(results, results_mask_classes, csvdump) = past_res
logger.load(csvdump)
print('Checkpoint Loaded!')
progress_bar = ProgressBar(joint=args.joint, verbose=not args.non_verbose)
if args.enable_other_metrics:
dataset_copy = get_dataset(args)
for t in range(dataset.N_TASKS):
model.net.train()
_, _ = dataset_copy.get_data_loaders()
if model.NAME != 'icarl' and model.NAME != 'pnn':
random_results_class, random_results_task = evaluate(model, dataset_copy)
print(file=sys.stderr)
start_task = 0 if args.start_from is None else args.start_from
end_task = dataset.N_TASKS if args.stop_after is None else args.stop_after
torch.cuda.empty_cache()
for t in range(start_task, end_task):
model.net.train()
train_loader, test_loader = dataset.get_data_loaders()
model.meta_begin_task(dataset)
if not args.inference_only:
if t and args.enable_other_metrics:
accs = evaluate(model, dataset, last=True)
results[t - 1] = results[t - 1] + accs[0]
if dataset.SETTING == 'class-il':
results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1]
scheduler = dataset.get_scheduler(model, args) if not hasattr(model, 'scheduler') else model.scheduler
for epoch in range(model.args.n_epochs):
train_iter = iter(train_loader)
data_len = None
if not isinstance(dataset, GCLDataset):
data_len = len(train_loader)
i = 0
while True:
try:
data = next(train_iter)
except StopIteration:
break
if args.debug_mode and i > model.get_debug_iters():
break
if hasattr(dataset.train_loader.dataset, 'logits'):
inputs, labels, not_aug_inputs, logits = data
inputs = inputs.to(model.device)
labels = labels.to(model.device, dtype=torch.long)
not_aug_inputs = not_aug_inputs.to(model.device)
logits = logits.to(model.device)
loss = model.meta_observe(inputs, labels, not_aug_inputs, logits, epoch=epoch)
else:
inputs, labels, not_aug_inputs = data
inputs, labels = inputs.to(model.device), labels.to(model.device, dtype=torch.long)
not_aug_inputs = not_aug_inputs.to(model.device)
loss = model.meta_observe(inputs, labels, not_aug_inputs, epoch=epoch)
assert not math.isnan(loss)
progress_bar.prog(i, data_len, epoch, t, loss)
i += 1
if scheduler is not None:
scheduler.step()
if args.eval_epochs is not None and epoch % args.eval_epochs == 0 and epoch < model.args.n_epochs - 1:
epoch_accs = evaluate(model, dataset)
log_accs(args, logger, epoch_accs, t, dataset.SETTING, epoch=epoch)
model.meta_end_task(dataset)
accs = evaluate(model, dataset)
results.append(accs[0])
results_mask_classes.append(accs[1])
log_accs(args, logger, accs, t, dataset.SETTING)
if args.savecheck:
save_obj = {
'model': model.state_dict(),
'args': args,
'results': [results, results_mask_classes, logger.dump()],
'optimizer': model.opt.state_dict() if hasattr(model, 'opt') else None,
'scheduler': scheduler.state_dict() if scheduler is not None else None,
}
if 'buffer_size' in model.args:
save_obj['buffer'] = deepcopy(model.buffer).to('cpu')
# Saving model checkpoint
checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_{t}.pt'
torch.save(save_obj, checkpoint_name)
if args.validation:
del dataset
args.validation = None
final_dataset = get_dataset(args)
for _ in range(final_dataset.N_TASKS):
final_dataset.get_data_loaders()
accs = evaluate(model, final_dataset)
log_accs(args, logger, accs, t, final_dataset.SETTING, prefix="FINAL")
if not args.disable_log and args.enable_other_metrics:
logger.add_bwt(results, results_mask_classes)
logger.add_forgetting(results, results_mask_classes)
if model.NAME != 'icarl' and model.NAME != 'pnn':
logger.add_fwt(results, random_results_class,
results_mask_classes, random_results_task)
if not args.disable_log:
logger.write(vars(args))
if not args.nowand:
d = logger.dump()
d['wandb_url'] = wandb.run.get_url()
wandb.log(d)
if not args.nowand:
wandb.finish()