Vision Transformers (ViT)

> Cut the image into patches, treat each patch as a word, run a standard transformer. Don't look back.

Type: Build

Languages: Python

Prerequisites: Phase 7 Lesson 02 (Self-Attention), Phase 4 Lesson 04 (Image Classification)

Time: ~45 minutes

Learning Objectives

The Problem

For a decade, convolution was synonymous with computer vision. CNNs had strong inductive biases — locality, translation equivariance — that nobody thought you could replace. Then Dosovitskiy et al. (2020) showed that a plain transformer applied to flattened image patches, with no convolutional machinery at all, could match or beat the best CNNs at scale.

The catch was "at scale." ViT on ImageNet-1k lost to ResNet. ViT pretrained on ImageNet-21k or JFT-300M then fine-tuned on ImageNet-1k beat it. The conclusion was that transformers lacked useful priors but could learn them from enough data. Subsequent work (DeiT, MAE, DINO) showed that with the right training recipes — strong augmentation, self-supervised pretraining, distillation — ViTs train fine on small data too.

By 2026, pure CNNs are still competitive on edge devices (ConvNeXt is the strongest), but transformers dominate everything else: segmentation (Mask2Former, SegFormer), detection (DETR, RT-DETR), multimodal (CLIP, SigLIP), video (VideoMAE, VJEPA). The ViT block structure is the one to know.

The Concept

The pipeline

flowchart LR IMG["Image
(3, 224, 224)"] --> PATCH["Patch embedding
conv 16x16 s=16
-> (768, 14, 14)"] PATCH --> FLAT["Flatten to
(196, 768) tokens"] FLAT --> CAT["Prepend
[CLS] token"] CAT --> POS["Add learned
positional embed"] POS --> ENC["N transformer
encoder blocks"] ENC --> CLS["Take [CLS]
token output"] CLS --> HEAD["MLP classifier"] style PATCH fill:#dbeafe,stroke:#2563eb style ENC fill:#fef3c7,stroke:#d97706 style HEAD fill:#dcfce7,stroke:#16a34a

Seven steps. Patches -> tokens -> attention -> classifier. Every variant (DeiT, Swin, ConvNeXt, MAE pretraining) changes one or two of the seven and leaves the rest alone.

Patch embedding

The first conv is the secret. Kernel size 16, stride 16, so a 224x224 image becomes a 14x14 grid of 16x16 patches, each projected to a 768-dim embedding. That single conv both patchifies and linearly projects.

Input:  (3, 224, 224)
Conv (3 -> 768, k=16, s=16, no padding):
Output: (768, 14, 14)
Flatten spatial: (196, 768)

196 patches = 196 tokens. Each token's feature dimension is 768 (ViT-B), 1024 (ViT-L), or 1280 (ViT-H).

Class token

A single learned vector prepended to the sequence:

tokens = [CLS; patch_1; patch_2; ...; patch_196]   shape (197, 768)

After N transformer blocks, the [CLS] output is the global image representation. Classification head reads only this one vector.

Positional embedding

Transformers have no built-in notion of spatial position. Add a learned vector to every token:

tokens = tokens + learned_pos_embedding   (also shape (197, 768))

The embedding is a parameter of the model; gradient-based training adapts it to 2D image structure. Sinusoidal 2D alternatives exist but are rarely used in practice.

Transformer encoder block

Standard. Multi-head self-attention, MLP, residual connections, pre-LayerNorm.

x = x + MSA(LN(x))
x = x + MLP(LN(x))

MLP is two-layer with GELU: Linear(d -> 4d) -> GELU -> Linear(4d -> d)

ViT-B/16 stacks 12 of these blocks, each with 12 attention heads, totalling 86M parameters.

Why pre-LN

Early transformers used post-LN (x = LN(x + sublayer(x))) and struggled to train past 6-8 layers without warmup. Pre-LN (x = x + sublayer(LN(x))) trains deeper networks stably without warmup. Every ViT and every modern LLM uses pre-LN.

Patch size trade-off

Bigger patches = fewer tokens = faster but less spatial detail. SwinV2 uses 4x4 patches in hierarchical windows.

DeiT's recipe for training ViT on ImageNet-1k

The original ViT needed JFT-300M to beat CNNs. DeiT (Touvron et al., 2020) trained ViT-B to 81.8% top-1 on ImageNet-1k alone with four changes:

  1. Heavy augmentation: RandAugment, Mixup, CutMix, Random Erasing.
  2. Stochastic depth (drop entire blocks at random during training).
  3. Repeated augmentation (same image sampled 3 times per batch).
  4. Distillation from a CNN teacher (optional, lifts accuracy further).

Every modern ViT training recipe descends from DeiT.

Swin vs ConvNeXt

In 2026, ConvNeXt-V2 and Swin-V2 are both production-grade; the right choice depends on your inference stack (ConvNeXt compiles better for edge) and pretraining corpus.

MAE pretraining

Masked Autoencoder (He et al., 2022): mask 75% of patches at random, train the encoder to process only the visible 25%, train a small decoder to reconstruct the masked patches from the encoder's output. After pretraining, discard the decoder and fine-tune the encoder.

MAE makes ViT trainable on ImageNet-1k alone, hits SOTA, and is the current default self-supervised recipe.

Build It

Step 1: Patch embedding

import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, dim=192, image_size=64):
        super().__init__()
        assert image_size % patch_size == 0
        self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (image_size // patch_size) ** 2
        self.num_patches = num_patches

    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)

One conv, one flatten, one transpose. That is the entire image-to-tokens step.

Step 2: Transformer block

Pre-LN, multi-head self-attention, MLP with GELU, residual connections.

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mlp_ratio, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        a, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x

nn.MultiheadAttention handles the splitting into heads, the scaled dot-product, and the output projection. batch_first=True so shapes are (N, seq, dim).

Step 3: The ViT

class ViT(nn.Module):
    def __init__(self, image_size=64, patch_size=16, in_channels=3,
                 num_classes=10, dim=192, depth=6, num_heads=3, mlp_ratio=4):
        super().__init__()
        self.patch = PatchEmbedding(in_channels, patch_size, dim, image_size)
        num_patches = self.patch.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])
        self.ln = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        x = self.patch(x)
        cls = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        x = self.ln(x[:, 0])
        return self.head(x)

vit = ViT(image_size=64, patch_size=16, num_classes=10, dim=192, depth=6, num_heads=3)
x = torch.randn(2, 3, 64, 64)
print(f"output: {vit(x).shape}")
print(f"params: {sum(p.numel() for p in vit.parameters()):,}")

About 2.8M parameters — a tiny ViT tractable on CPU. Real ViT-B is 86M; same class definition with dim=768, depth=12, num_heads=12.

Step 4: Sanity check — single image inference

logits = vit(torch.randn(1, 3, 64, 64))
print(f"logits: {logits}")
print(f"probs:  {logits.softmax(-1)}")

Should run without error. Probabilities sum to 1.

Use It

timm ships every ViT variant with ImageNet pretrained weights. One line:

import timm

model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=10)

timm is the production default for vision transformers in 2026. Supports ViT, DeiT, Swin, Swin-V2, ConvNeXt, ConvNeXt-V2, MaxViT, MViT, EfficientFormer, and dozens of others under the same API.

For multi-modal work (image + text), transformers ships CLIP, SigLIP, BLIP-2, LLaVA. The image encoder in all of those is a ViT variant.

Ship It

This lesson produces:

Exercises

  1. (Easy) Print the shapes of every intermediate tensor for a forward pass through the tiny ViT above. Confirm: input (N, 3, 64, 64) -> patches (N, 16, 192) -> with CLS (N, 17, 192) -> classifier input (N, 192) -> output (N, num_classes).
  2. (Medium) Fine-tune a pretrained timm ViT-S/16 on the synthetic-CIFAR dataset from Lesson 4. Compare against ResNet-18 fine-tuning on the same data. Report training time and final accuracy.
  3. (Hard) Implement MAE pretraining for the tiny ViT: mask 75% of patches, train the encoder + a small decoder to reconstruct the masked patches. Evaluate linear-probe accuracy on the synthetic data before and after pretraining.

Key Terms

Term What people say What it actually means
Patch embedding "The first conv" A conv with kernel size = stride = patch size; turns the image into a grid of token embeddings
Class token "[CLS]" A learned vector prepended to the token sequence; its final output is the global image representation
Positional embedding "Learned pos" A learned vector added to every token so the transformer knows where each patch came from
Pre-LN "LayerNorm before sublayer" The stable transformer variant: x + sublayer(LN(x)) instead of LN(x + sublayer(x))
Multi-head attention "Parallel attention" Standard transformer attention split into num_heads independent subspaces, concatenated afterwards
ViT-B/16 "Base, patch 16" The canonical size: dim=768, depth=12, heads=12, patch_size=16, image=224; ~86M params
DeiT "Data-efficient ViT" ViT trained on ImageNet-1k alone with strong augmentation; proved large pretraining datasets are not strictly required
MAE "Masked autoencoder" Self-supervised pretraining: mask 75% of patches, reconstruct; the dominant ViT pretraining recipe

Further Reading