Transfer Learning & Fine-Tuning
> Somebody else spent a million GPU hours teaching a network what edges, textures, and object parts look like. You should borrow those features before training your own.
Type: Build
Languages: Python
Prerequisites: Phase 4 Lesson 03 (CNNs), Phase 4 Lesson 04 (Image Classification)
Time: ~75 minutes
Learning Objectives
- Distinguish feature extraction from fine-tuning and pick the right one based on dataset size, domain distance, and compute budget
- Load a pretrained backbone, replace its classifier head, and train only the head to a working baseline in under 20 lines
- Progressively unfreeze layers with discriminative learning rates so early generic features get smaller updates than late task-specific ones
- Diagnose the three common failures: feature drift from too-high LR on unfrozen blocks, BN statistics collapse on tiny datasets, and catastrophic forgetting
The Problem
Training a ResNet-50 on ImageNet costs around 2,000 GPU-hours. Very few teams have that budget for every task they ship. What almost every team actually ships is a pretrained backbone with a new head trained on a few hundred or few thousand task-specific images.
This is not a shortcut. The first conv block of any ImageNet-trained CNN learns edges and Gabor-like filters. The next few blocks learn textures and simple motifs. The middle blocks learn object parts. The final blocks learn combinations that start to look like the 1,000 ImageNet categories. The first 90% of that hierarchy transfers almost unchanged to medical imaging, industrial inspection, satellite data, and every other vision task — because nature has a limited vocabulary of edges and textures. The last 10% is what you actually train.
Getting transfer right has three bugs waiting for you: destroying pretrained features with a too-high learning rate, starving the model of information by freezing too much, and letting BatchNorm's running statistics drift toward a tiny dataset that the rest of the network never learnt from. This lesson walks each of them on purpose.
The Concept
Feature extraction vs fine-tuning
Two regimes, picked by how much you trust the pretrained features and how much data you have.
(no gradient)"] --> FE2["New head
(trained)"] end subgraph FT["Fine-tuning — end-to-end"] FT1["Pretrained backbone
(tiny LR)"] --> FT2["New head
(normal LR)"] end style FE1 fill:#e5e7eb,stroke:#6b7280 style FE2 fill:#dcfce7,stroke:#16a34a style FT1 fill:#fef3c7,stroke:#d97706 style FT2 fill:#dcfce7,stroke:#16a34a
Rules of thumb:
| Dataset size | Domain distance | Recipe |
|---|---|---|
| < 1k images | close to ImageNet | Freeze backbone, train head only |
| 1k-10k | close | Freeze first 2-3 stages, fine-tune the rest |
| 10k-100k | any | Fine-tune end-to-end with discriminative LR |
| 100k+ | far | Fine-tune everything; consider training from scratch if domain is far enough |
"Close to ImageNet" roughly means natural RGB photos with object-like content. Medical CT scans, overhead satellite imagery, and microscopy are far domains — the features still help, but you will need to let more layers adapt.
Why freezing works at all
The ImageNet features a CNN learns are not specialised to the 1,000 categories. They are specialised to the statistics of natural images: edges at specific orientations, textures, contrast patterns, shape primitives. Those statistics are stable across almost every visual domain a human can name. That is why a model trained on ImageNet and evaluated zero-shot on CIFAR-10 with just a new linear head (no fine-tuning of the backbone) reaches 80%+ accuracy. The head is learning which of the already-learnt features to weight for this task.
Discriminative learning rates
When you do unfreeze, early layers should train slower than late layers. Early layers encode generic features that you want to preserve; late layers encode task-specific structure that you need to move a lot.
Typical recipe:
stage 0 (stem + first group): lr = base_lr / 100 (mostly fixed)
stage 1: lr = base_lr / 10
stage 2: lr = base_lr / 3
stage 3 (last backbone group): lr = base_lr
head: lr = base_lr (or slightly higher)
In PyTorch this is just a list of parameter groups passed to the optimizer. One model, five learning rates, zero extra code.
The BatchNorm problem
BN layers hold running_mean and running_var buffers that were computed on ImageNet. If your task has a different pixel distribution — different lighting, different sensor, different colour space — those buffers are wrong. Three options in order of preference:
- Fine-tune with BN in train mode. Let BN update its running statistics along with everything else. Default choice when the task dataset is medium-sized (>= 5k examples).
- Freeze BN in eval mode. Keep the ImageNet statistics and train only the weights. Correct when your dataset is small enough that BN's moving average would be noisy.
- Replace BN with GroupNorm. Removes the moving-average problem entirely. Used in detection and segmentation backbones where batch size per GPU is tiny.
Getting this wrong silently tanks accuracy by 5-15%.
Head design
The classifier head is 1-3 linear layers plus an optional dropout. Every torchvision backbone ships a default head that you replace:
backbone.fc = nn.Linear(backbone.fc.in_features, num_classes) # ResNet
backbone.classifier[1] = nn.Linear(..., num_classes) # EfficientNet, MobileNet
backbone.heads.head = nn.Linear(..., num_classes) # torchvision ViT
For small datasets, a single linear layer is usually enough. Adding a hidden layer (Linear -> ReLU -> Dropout -> Linear) helps when the task distribution is farther from the backbone's training distribution.
Layer-wise LR decay
A smoother version of discriminative LR used in modern fine-tuning (BEiT, DINOv2, ViT-B fine-tunes). Instead of grouping layers into stages, give every layer a slightly smaller LR than the one above it:
lr_layer_k = base_lr * decay^(L - k)
With decay = 0.75 and L = 12 transformer blocks, the first block trains at 0.75^11 ≈ 0.04x the head's LR. Matters more for transformer fine-tunes than for CNNs, where stage-grouped LRs are usually enough.
What to evaluate
Transfer-learning runs need two numbers you would not track on a scratch run:
- Pretrained-only accuracy — the head's accuracy with the backbone frozen. This is your floor.
- Fine-tuned accuracy — the same model after end-to-end training. This is your ceiling.
If fine-tuned is less than pretrained-only, you have a learning-rate or BN bug. Always print both.
Build It
Step 1: Load a pretrained backbone and inspect it
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
print(backbone)
print()
print("classifier head:", backbone.fc)
print("feature dim:", backbone.fc.in_features)
ResNet18 has four stages (layer1..layer4) plus a stem and a fc head. Every torchvision classification backbone has an analogous structure.
Step 2: Feature extraction — freeze everything, replace the head
def make_feature_extractor(num_classes=10):
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
for p in model.parameters():
p.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
model = make_feature_extractor(num_classes=10)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"trainable: {trainable:>10,}")
print(f"frozen: {frozen:>10,}")
Only model.fc is trainable. The backbone is a frozen feature extractor.
Step 3: Discriminative fine-tuning
A utility that builds parameter groups with stage-specific learning rates.
def discriminative_param_groups(model, base_lr=1e-3, decay=0.3):
stages = [
["conv1", "bn1"],
["layer1"],
["layer2"],
["layer3"],
["layer4"],
["fc"],
]
groups = []
for i, names in enumerate(stages):
lr = base_lr * (decay ** (len(stages) - 1 - i))
params = [p for n, p in model.named_parameters()
if any(n.startswith(k) for k in names)]
if params:
groups.append({"params": params, "lr": lr, "name": "_".join(names)})
return groups
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 10)
for p in model.parameters():
p.requires_grad = True
groups = discriminative_param_groups(model)
for g in groups:
print(f"{g['name']:>10s} lr={g['lr']:.2e} params={sum(p.numel() for p in g['params']):>8,}")
decay=0.3 means each stage trains at 30% of the rate of the next one. fc gets base_lr, layer4 gets 0.3 * base_lr, conv1 gets 0.3^5 * base_lr ≈ 0.00243 * base_lr. Extreme sounding; empirically it works.
Step 4: BatchNorm handling
Helper to freeze BN running statistics without freezing its weights.
def freeze_bn_stats(model):
for m in model.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.eval()
for p in m.parameters():
p.requires_grad = False
return model
Call it after you set model.train() at the start of every epoch. model.train() flips everything to training mode; this reverses it only for BN layers.
Step 5: A minimal end-to-end fine-tuning loop
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
def fine_tune(model, train_loader, val_loader, device, epochs=5, base_lr=1e-3, freeze_bn=False):
model = model.to(device)
groups = discriminative_param_groups(model, base_lr=base_lr)
optimizer = SGD(groups, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(epochs):
model.train()
if freeze_bn:
freeze_bn_stats(model)
tr_loss, tr_correct, tr_total = 0.0, 0, 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits, y, label_smoothing=0.1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
tr_loss += loss.item() * x.size(0)
tr_total += x.size(0)
tr_correct += (logits.argmax(-1) == y).sum().item()
scheduler.step()
model.eval()
va_total, va_correct = 0, 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(-1)
va_total += x.size(0)
va_correct += (pred == y).sum().item()
print(f"epoch {epoch} train {tr_loss/tr_total:.3f}/{tr_correct/tr_total:.3f} "
f"val {va_correct/va_total:.3f}")
return model
Five epochs with the above recipe on CIFAR-10 takes ResNet18-IMAGENET1K_V1 from ~70% zero-shot linear-probe accuracy to ~93% fine-tuned accuracy. The head alone would plateau around 86% without ever touching the backbone.
Step 6: Progressive unfreezing
A schedule that unfreezes one stage per epoch from the end toward the beginning. Mitigates feature drift at the cost of some extra epochs.
def progressive_unfreeze_schedule(model):
stages = ["layer4", "layer3", "layer2", "layer1"]
yielded = set()
def start():
for p in model.parameters():
p.requires_grad = False
for p in model.fc.parameters():
p.requires_grad = True
def unfreeze(epoch):
if epoch < len(stages):
name = stages[epoch]
yielded.add(name)
for n, p in model.named_parameters():
if n.startswith(name):
p.requires_grad = True
return name
return None
return start, unfreeze
Call start() once before the first epoch. Call unfreeze(epoch) at the start of each epoch. Rebuild the optimizer whenever the set of trainable parameters changes, otherwise the frozen params still hold cached moments that confuse it.
Use It
For most real tasks, torchvision.models + three lines is enough. The heavier machinery above matters when you run into the problems that library defaults cannot fix.
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, num_classes)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
Two other production-grade defaults:
timmships ~800 pretrained vision backbones with a consistent API (timm.create_model("resnet50", pretrained=True, num_classes=10)). For any fine-tune beyond the torchvision zoo, it is the standard.- For transformers,
transformers.AutoModelForImageClassification.from_pretrained(name, num_labels=N)gives you ViT / BEiT / DeiT with the same loading semantics as text models.
Ship It
This lesson produces:
outputs/prompt-fine-tune-planner.md— a prompt that picks feature-extraction vs progressive vs end-to-end fine-tuning based on dataset size, domain distance, and compute budget.outputs/skill-freeze-inspector.md— a skill that, given a PyTorch model, reports which parameters are trainable, which BatchNorm layers are in eval mode, and whether the optimizer is actually being fed the trainable parameters.
Exercises
- (Easy) Train a
ResNet18as a linear probe (backbone frozen) and as a full fine-tune on the same synthetic-CIFAR dataset. Report both accuracies side by side. Explain which gap tells you the features transfer well and which tells you they do not. - (Medium) Introduce a bug on purpose: set
base_lr = 1e-1on the backbone stage instead of the head. Show the training loss explode, then recover by applying thediscriminative_param_groupshelper. Record the LR at which each stage starts diverging. - (Hard) Take a medical imaging dataset (e.g. CheXpert-small, PatchCamelyon, or HAM10000) and compare three regimes: (a) ImageNet-pretrained frozen backbone + linear head; (b) ImageNet-pretrained fine-tune end-to-end; (c) scratch training. Report accuracy and compute cost for each. At what dataset size does scratch training become competitive?
Key Terms
| Term | What people say | What it actually means |
|---|---|---|
| Feature extraction | "Freeze and train head" | Backbone parameters frozen, only the new classifier head receives gradient |
| Fine-tuning | "Retrain end-to-end" | All parameters trainable, usually with much smaller LR than scratch training |
| Discriminative LR | "Smaller LR for early layers" | Optimizer parameter groups where early-stage LR is a fraction of late-stage LR |
| Layer-wise LR decay | "Smooth LR gradient" | Per-layer LR multiplied by decay^(L - k); common in transformer fine-tunes |
| Catastrophic forgetting | "The model lost ImageNet" | A too-high LR overwrites pretrained features before the new task signal is learnt |
| BN statistics drift | "Running mean is wrong" | BatchNorm running_mean/var computed on a different distribution than the current task, silently hurting accuracy |
| Linear probe | "Frozen backbone + linear head" | Evaluation of pretrained features — accuracy of the best linear classifier on top of the frozen representation |
| Catastrophic collapse | "Everything predicts one class" | Happens when fine-tuning with an LR high enough to destroy features before gradients from the head can stabilise |
Further Reading
- How transferable are features in deep neural networks? (Yosinski et al., 2014) — the paper that quantified feature transferability across layers
- Universal Language Model Fine-tuning (ULMFiT, Howard & Ruder, 2018) — the original discriminative LR / progressive unfreezing recipe; the ideas transfer directly to vision
- timm documentation — the reference for modern vision backbones and the exact fine-tune defaults they were trained with
- A Simple Framework for Linear-Probe Evaluation (Kornblith et al., 2019) — why linear-probe accuracy matters and how to report it correctly