vision_transformer#

Vision Transformer (ViT) in PyTorch

A clone of ViT from timm’s implementation, with dualprompt implementation.

Copyright 2020, Ross Wightman # —————————————— # Modification: # Added code for dualprompt implementation # – Jaeho Lee, dlwogh9344@khu.ac.kr # ——————————————

Classes

class models.dualprompt_utils.vision_transformer.Attention(dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

forward(x, *args)[source]#
class models.dualprompt_utils.vision_transformer.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, init_values=None, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, attn_layer=<class 'models.dualprompt_utils.vision_transformer.Attention'>)[source]#

Bases: Module

forward(x, prompt=None)[source]#
class models.dualprompt_utils.vision_transformer.LayerScale(dim, init_values=1e-05, inplace=False)[source]#

Bases: Module

forward(x)[source]#
class models.dualprompt_utils.vision_transformer.ParallelBlock(dim, num_heads, num_parallel=2, mlp_ratio=4.0, qkv_bias=False, init_values=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

forward(x)[source]#
class models.dualprompt_utils.vision_transformer.ResPostBlock(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, init_values=None, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

forward(x)[source]#
init_weights()[source]#
class models.dualprompt_utils.vision_transformer.VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, init_values=None, class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, weight_init='', embed_layer=<class 'timm.layers.patch_embed.PatchEmbed'>, norm_layer=None, act_layer=None, block_fn=<class 'models.dualprompt_utils.vision_transformer.Block'>, prompt_length=None, embedding_key='cls', prompt_init='uniform', prompt_pool=False, prompt_key=False, pool_size=None, top_k=None, batchwise_prompt=False, prompt_key_init='uniform', head_type='token', use_prompt_mask=False, use_g_prompt=False, g_prompt_length=None, g_prompt_layer_idx=None, use_prefix_tune_for_g_prompt=False, use_e_prompt=False, e_prompt_layer_idx=None, use_prefix_tune_for_e_prompt=False, same_key_value=False)[source]#

Bases: Module

Vision Transformer A PyTorch impl of : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

forward(x, task_id=-1, cls_features=None, train=False)[source]#
forward_features(x, task_id=-1, cls_features=None, train=False)[source]#
forward_head(res, pre_logits=False)[source]#
get_classifier()[source]#
group_matcher(coarse=False)[source]#
init_weights(mode='')[source]#
load_pretrained(checkpoint_path, prefix='')[source]#
no_weight_decay()[source]#
reset_classifier(num_classes, global_pool=None)[source]#
set_grad_checkpointing(enable=True)[source]#

Functions

models.dualprompt_utils.vision_transformer.checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False)[source]#

convert patch embedding weight from manual patchify + linear proj to conv

models.dualprompt_utils.vision_transformer.get_init_weights_vit(mode='jax', head_bias=0.0)[source]#
models.dualprompt_utils.vision_transformer.init_weights_vit_jax(module, name='', head_bias=0.0)[source]#

ViT weight initialization, matching JAX (Flax) impl

models.dualprompt_utils.vision_transformer.init_weights_vit_moco(module, name='')[source]#

ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed

models.dualprompt_utils.vision_transformer.init_weights_vit_timm(module, name='')[source]#

ViT weight initialization, original timm impl (for reproducibility)

models.dualprompt_utils.vision_transformer.resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=())[source]#
models.dualprompt_utils.vision_transformer.vit_base_patch16_224_dualprompt(pretrained=False, **kwargs)[source]#

ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.