vits#
Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in:
- ‘An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale’
- How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers
The official jax code is released and available at https://github.com/google-research/vision_transformer
DeiT model defs and weights from https://github.com/facebookresearch/deit, paper DeiT: Data-efficient Image Transformers - https://arxiv.org/abs/2012.12877
Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang’s https://github.com/lucidrains/vit-pytorch … check it out for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy’s https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020, Ross Wightman
Classes
- class models.slca_utils.convs.vits.Attention(dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
Module
- class models.slca_utils.convs.vits.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
Module
- class models.slca_utils.convs.vits.VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, representation_size=None, distilled=False, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, embed_layer=<class 'timm.layers.patch_embed.PatchEmbed'>, norm_layer=None, act_layer=None, weight_init='', with_adapter=False, global_pool=False)[source]#
Bases:
ModuleVision Transformer
- A PyTorch impl ofAn Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- Includes distillation token & head support for DeiT: Data-efficient Image Transformers
Functions
- models.slca_utils.convs.vits.checkpoint_filter_fn(state_dict, model)[source]#
convert patch embedding weight from manual patchify + linear proj to conv
- models.slca_utils.convs.vits.vit_base_patch16_224_in21k(pretrained=False, adapter=False, **kwargs)[source]#
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
- models.slca_utils.convs.vits.vit_base_patch16_224_mocov3(pretrained=False, adapter=False, **kwargs)[source]#
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer