import copy
import torch
from torch import nn
from models.slca_utils.convs.cifar_resnet import resnet32
from models.slca_utils.convs.resnet import resnet18, resnet34, resnet50
from models.slca_utils.convs.linears import SimpleContinualLinear
from models.slca_utils.convs.vits import vit_base_patch16_224_in21k, vit_base_patch16_224_mocov3
import torch.nn.functional as F
[docs]
def get_convnet(feature_extractor_type, pretrained=False):
name = feature_extractor_type.lower()
if name == 'resnet32':
return resnet32()
elif name == 'resnet18':
return resnet18(pretrained=pretrained)
elif name == 'resnet18_cifar':
return resnet18(pretrained=pretrained, cifar=True)
elif name == 'resnet18_cifar_cos':
return resnet18(pretrained=pretrained, cifar=True, no_last_relu=True)
elif name == 'resnet34':
return resnet34(pretrained=pretrained)
elif name == 'resnet50':
return resnet50(pretrained=pretrained)
elif name == 'vit-b-p16':
print("Using ViT-B/16 pretrained on ImageNet21k (NO FINETUNE ON IN1K)")
return vit_base_patch16_224_in21k(pretrained=pretrained)
elif name == 'vit-b-p16-mocov3':
return vit_base_patch16_224_mocov3(pretrained=True)
else:
raise NotImplementedError('Unknown type {}'.format(feature_extractor_type))
[docs]
class BaseNet(nn.Module):
def __init__(self, feature_extractor_type, pretrained):
super(BaseNet, self).__init__()
self.convnet = get_convnet(feature_extractor_type, pretrained)
self.fc = None
@property
def feature_dim(self):
return self.convnet.out_dim
[docs]
def forward(self, x):
x = self.convnet(x)
out = self.fc(x['features'])
'''
{
'fmaps': [x_1, x_2, ..., x_n],
'features': features
'logits': logits
}
'''
out.update(x)
return out
[docs]
def update_fc(self, nb_classes):
pass
[docs]
def generate_fc(self, in_dim, out_dim):
pass
[docs]
def copy(self):
return copy.deepcopy(self)
[docs]
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
[docs]
class FinetuneIncrementalNet(BaseNet):
def __init__(self, feature_extractor_type, pretrained, fc_with_ln=False):
super().__init__(feature_extractor_type, pretrained)
self.old_fc = None
self.fc_with_ln = fc_with_ln
[docs]
def update_fc(self, nb_classes, freeze_old=True):
if self.fc is None:
self.fc = self.generate_fc(self.feature_dim, nb_classes)
else:
self.fc.update(nb_classes, freeze_old=freeze_old)
[docs]
def save_old_fc(self):
if self.old_fc is None:
self.old_fc = copy.deepcopy(self.fc)
else:
self.old_fc.heads.append(copy.deepcopy(self.fc.heads[-1]))
[docs]
def generate_fc(self, in_dim, out_dim):
fc = SimpleContinualLinear(in_dim, out_dim)
return fc
[docs]
def forward(self, x, bcb_no_grad=False, fc_only=False):
if fc_only:
fc_out = self.fc(x)
if self.old_fc is not None:
old_fc_logits = self.old_fc(x)['logits']
fc_out['old_logits'] = old_fc_logits
return fc_out
if bcb_no_grad:
with torch.no_grad():
x = self.convnet(x)
else:
x = self.convnet(x)
out = self.fc(x['features'])
out.update(x)
return out