Source code for models.coda_prompt_utils.model

import torch
import torch.nn as nn
from timm.models.vision_transformer import vit_base_patch16_224
from models.coda_prompt_utils.vit import VisionTransformer
import copy


[docs] class CodaPrompt(nn.Module): def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768): super().__init__() self.task_count = 0 self.emb_d = emb_d self.key_d = key_dim self.n_tasks = n_tasks self._init_smart(emb_d, prompt_param) # e prompt init for e in self.e_layers: # for model saving/loading simplicity, we init the full paramaters here # however, please note that we reinit the new components at each task # in the "spirit of continual learning", as we don't know how many tasks # we will encounter at the start of the task sequence # # in the original paper, we used ortho init at the start - this modification is more # fair in the spirit of continual learning and has little affect on performance e_l = self.e_p_length p = tensor_prompt(self.e_pool_size, e_l, emb_d) k = tensor_prompt(self.e_pool_size, self.key_d) a = tensor_prompt(self.e_pool_size, self.key_d) p = self.gram_schmidt(p) k = self.gram_schmidt(k) a = self.gram_schmidt(a) setattr(self, f'e_p_{e}', p) setattr(self, f'e_k_{e}', k) setattr(self, f'e_a_{e}', a) def _init_smart(self, emb_d, prompt_param): # prompt basic param self.e_pool_size = int(prompt_param[0]) self.e_p_length = int(prompt_param[1]) self.e_layers = [0, 1, 2, 3, 4] # strenth of ortho penalty self.ortho_mu = prompt_param[2]
[docs] def process_task_count(self): self.task_count += 1 # in the spirit of continual learning, we will reinit the new components # for the new task with Gram Schmidt # # in the original paper, we used ortho init at the start - this modification is more # fair in the spirit of continual learning and has little affect on performance # # code for this function is modified from: # https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py for e in self.e_layers: K = getattr(self, f'e_k_{e}') A = getattr(self, f'e_a_{e}') P = getattr(self, f'e_p_{e}') k = self.gram_schmidt(K) a = self.gram_schmidt(A) p = self.gram_schmidt(P) setattr(self, f'e_p_{e}', p) setattr(self, f'e_k_{e}', k) setattr(self, f'e_a_{e}', a)
# code for this function is modified from: # https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py
[docs] def gram_schmidt(self, vv): def projection(u, v): denominator = (u * u).sum() if denominator < 1e-8: return None else: return (v * u).sum() / denominator * u # check if the tensor is 3D and flatten the last two dimensions if necessary is_3d = len(vv.shape) == 3 if is_3d: shape_2d = copy.deepcopy(vv.shape) vv = vv.view(vv.shape[0], -1) # swap rows and columns vv = vv.T # process matrix size nk = vv.size(1) uu = torch.zeros_like(vv, device=vv.device) # get starting point pt = int(self.e_pool_size / (self.n_tasks)) s = int(self.task_count * pt) f = int((self.task_count + 1) * pt) if s > 0: uu[:, 0:s] = vv[:, 0:s].clone() for k in range(s, f): redo = True while redo: redo = False vk = torch.randn_like(vv[:, k]).to(vv.device) uk = 0 for j in range(0, k): if not redo: uj = uu[:, j].clone() proj = projection(uj, vk) if proj is None: redo = True print('restarting!!!') else: uk = uk + proj if not redo: uu[:, k] = vk - uk for k in range(s, f): uk = uu[:, k].clone() uu[:, k] = uk / (uk.norm()) # undo swapping of rows and columns uu = uu.T # return from 2D if is_3d: uu = uu.view(shape_2d) return torch.nn.Parameter(uu)
[docs] def forward(self, x_querry, l, x_block, train=False, task_id=None): # e prompts e_valid = False if l in self.e_layers: e_valid = True B, C = x_querry.shape K = getattr(self, f'e_k_{l}') A = getattr(self, f'e_a_{l}') p = getattr(self, f'e_p_{l}') pt = int(self.e_pool_size / (self.n_tasks)) s = int(self.task_count * pt) f = int((self.task_count + 1) * pt) # freeze/control past tasks if train: if self.task_count > 0: K = torch.cat((K[:s].detach().clone(), K[s:f]), dim=0) A = torch.cat((A[:s].detach().clone(), A[s:f]), dim=0) p = torch.cat((p[:s].detach().clone(), p[s:f]), dim=0) else: K = K[s:f] A = A[s:f] p = p[s:f] else: K = K[0:f] A = A[0:f] p = p[0:f] # with attention and cosine sim # (b x 1 x d) * soft([1 x k x d]) = (b x k x d) -> attention = k x d a_querry = torch.einsum('bd,kd->bkd', x_querry, A) # # (b x k x d) - [1 x k x d] = (b x k) -> key = k x d n_K = nn.functional.normalize(K, dim=1) q = nn.functional.normalize(a_querry, dim=2) aq_k = torch.einsum('bkd,kd->bk', q, n_K) # (b x 1 x k x 1) * [1 x plen x k x d] = (b x plen x d) -> prompt = plen x k x d P_ = torch.einsum('bk,kld->bld', aq_k, p) # select prompts i = int(self.e_p_length / 2) Ek = P_[:, :i, :] Ev = P_[:, i:, :] # ortho penalty if train and self.ortho_mu > 0: loss = ortho_penalty(K) * self.ortho_mu loss += ortho_penalty(A) * self.ortho_mu loss += ortho_penalty(p.view(p.shape[0], -1)) * self.ortho_mu else: loss = 0 else: loss = 0 # combine prompts for prefix tuning if e_valid: p_return = [Ek, Ev] else: p_return = None # return return p_return, loss, x_block
[docs] def ortho_penalty(t): return ((t @ t.T - torch.eye(t.shape[0]).to(t.device))**2).mean()
[docs] def tensor_prompt(a, b, c=None, ortho=False): if c is None: p = torch.nn.Parameter(torch.FloatTensor(a, b), requires_grad=True) else: p = torch.nn.Parameter(torch.FloatTensor(a, b, c), requires_grad=True) if ortho: nn.init.orthogonal_(p) else: nn.init.uniform_(p) return p
[docs] class Model(nn.Module): def __init__(self, num_classes=10, pt=False, prompt_param=None): super().__init__() self.task_id = None # get feature encoder vit_model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, ckpt_layer=0, drop_path_rate=0) if pt: load_dict = vit_base_patch16_224(pretrained=True).state_dict() del load_dict['head.weight'] del load_dict['head.bias'] vit_model.load_state_dict(load_dict) # classifier self.last = nn.Linear(768, num_classes) self.prompt = CodaPrompt(768, prompt_param[0], prompt_param[1]) # feature encoder changes if transformer vs resnet self.feat = vit_model # pen: get penultimate features
[docs] def forward(self, x, pen=False, train=False): if self.prompt is not None: with torch.no_grad(): q, _ = self.feat(x) q = q[:, 0, :] out, prompt_loss = self.feat(x, prompt=self.prompt, q=q, train=train, task_id=self.task_id) out = out[:, 0, :] else: out, _ = self.feat(x) out = out[:, 0, :] out = out.view(out.size(0), -1) if not pen: out = self.last(out) if self.prompt is not None and train: return out, prompt_loss else: return out