PyTorch Lightning Training — Deep Dive
Complete LightningModule Example
A production-ready image classifier with all the recommended methods:
import torch
import torch.nn as nn
import lightning as L
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchmetrics import Accuracy, F1Score
class ImageClassifier(L.LightningModule):
def __init__(self, num_classes: int, lr: float = 1e-3,
weight_decay: float = 0.01):
super().__init__()
self.save_hyperparameters()
# Pretrained backbone
backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
in_features = backbone.classifier[1].in_features
backbone.classifier = nn.Identity()
self.backbone = backbone
# Custom head
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, num_classes),
)
# Metrics
self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
self.val_f1 = F1Score(task="multiclass", num_classes=num_classes,
average="macro")
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def training_step(self, batch, batch_idx):
images, labels = batch
logits = self(images)
loss = self.criterion(logits, labels)
preds = logits.argmax(dim=1)
self.train_acc(preds, labels)
self.log("train/loss", loss, prog_bar=True)
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
logits = self(images)
loss = self.criterion(logits, labels)
preds = logits.argmax(dim=1)
self.val_acc(preds, labels)
self.val_f1(preds, labels)
self.log("val/loss", loss, prog_bar=True)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
self.log("val/f1", self.val_f1, on_step=False, on_epoch=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.trainer.max_epochs
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
Key details: save_hyperparameters() stores constructor arguments for reproducibility and checkpoint loading. torchmetrics handles metric computation correctly across distributed processes (automatic sync at epoch end).
Production DataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
class ImageDataModule(L.LightningDataModule):
def __init__(self, data_dir: str, batch_size: int = 32,
num_workers: int = 4):
super().__init__()
self.save_hyperparameters()
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
self.val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
def setup(self, stage=None):
full_dataset = ImageFolder(self.hparams.data_dir)
n = len(full_dataset)
train_size = int(0.8 * n)
val_size = n - train_size
self.train_ds, self.val_ds = random_split(
full_dataset, [train_size, val_size],
generator=torch.Generator().manual_seed(42),
)
# Apply different transforms
self.train_ds.dataset.transform = self.train_transform
self.val_ds.dataset.transform = self.val_transform
def train_dataloader(self):
return DataLoader(
self.train_ds, batch_size=self.hparams.batch_size,
shuffle=True, num_workers=self.hparams.num_workers,
pin_memory=True, persistent_workers=True,
)
def val_dataloader(self):
return DataLoader(
self.val_ds, batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=True, persistent_workers=True,
)
persistent_workers=True keeps worker processes alive between epochs, avoiding the overhead of respawning them. Essential for datasets with expensive initialization.
Custom Callbacks
Gradient Monitoring
class GradientMonitor(L.Callback):
"""Log gradient statistics to detect vanishing/exploding gradients."""
def on_after_backward(self, trainer, pl_module):
if trainer.global_step % 100 != 0:
return
for name, param in pl_module.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm(2).item()
pl_module.log(f"gradients/{name}_norm", grad_norm)
if grad_norm > 100:
trainer.print(f"WARNING: Large gradient in {name}: "
f"{grad_norm:.2f}")
Dynamic Batch Size
class DynamicBatchSize(L.Callback):
"""Increase batch size during training for faster convergence."""
def __init__(self, initial: int = 16, max_size: int = 256,
increase_every: int = 5):
self.initial = initial
self.max_size = max_size
self.increase_every = increase_every
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
new_size = min(
self.initial * (2 ** (epoch // self.increase_every)),
self.max_size,
)
trainer.datamodule.hparams.batch_size = new_size
Trainer Configuration for Production
from lightning.pytorch.callbacks import (
ModelCheckpoint, EarlyStopping, LearningRateMonitor,
RichProgressBar,
)
from lightning.pytorch.loggers import WandbLogger
trainer = L.Trainer(
# Hardware
accelerator="gpu",
devices=4,
strategy="ddp",
precision="16-mixed",
# Training
max_epochs=100,
gradient_clip_val=1.0,
accumulate_grad_batches=4,
# Callbacks
callbacks=[
ModelCheckpoint(
monitor="val/f1",
mode="max",
save_top_k=3,
filename="{epoch}-{val/f1:.3f}",
),
EarlyStopping(
monitor="val/loss",
patience=10,
mode="min",
),
LearningRateMonitor(logging_interval="step"),
GradientMonitor(),
RichProgressBar(),
],
# Logging
logger=WandbLogger(project="image-classification"),
# Debugging (disable for production)
# fast_dev_run=True, # Run 1 batch to verify pipeline
# overfit_batches=10, # Overfit on 10 batches to verify model capacity
)
Advanced Training Strategies
FSDP for Large Models
from lightning.pytorch.strategies import FSDPStrategy
strategy = FSDPStrategy(
auto_wrap_policy={nn.TransformerEncoderLayer},
activation_checkpointing_policy={nn.TransformerEncoderLayer},
sharding_strategy="FULL_SHARD",
mixed_precision=None, # Use Trainer's precision instead
)
trainer = L.Trainer(
strategy=strategy,
devices=8,
precision="bf16-mixed",
)
DeepSpeed for Extreme Scale
from lightning.pytorch.strategies import DeepSpeedStrategy
strategy = DeepSpeedStrategy(
stage=3, # Full parameter + gradient + optimizer sharding
offload_optimizer=True, # Offload optimizer to CPU
offload_parameters=False,
allgather_bucket_size=5e8,
reduce_bucket_size=5e8,
)
Profiling and Debugging
Lightning’s built-in profiler identifies bottlenecks:
from lightning.pytorch.profilers import AdvancedProfiler
profiler = AdvancedProfiler(dirpath="./profiler", filename="perf")
trainer = L.Trainer(profiler=profiler, max_epochs=2)
trainer.fit(model, datamodule)
The output shows time spent in each method — immediately revealing whether your bottleneck is data loading, forward pass, backward pass, or optimizer steps.
For quick sanity checks:
# Verify pipeline works end-to-end (1 train + 1 val batch)
trainer = L.Trainer(fast_dev_run=True)
trainer.fit(model, datamodule)
# Verify model can overfit (proves learning capacity)
trainer = L.Trainer(overfit_batches=10, max_epochs=50)
trainer.fit(model, datamodule)
Resuming and Fault Tolerance
Resume from a checkpoint after crash or preemption:
# Automatic: Lightning saves last.ckpt by default
trainer = L.Trainer(
default_root_dir="./checkpoints",
max_epochs=100,
)
trainer.fit(model, datamodule, ckpt_path="./checkpoints/last.ckpt")
For cloud training with spot/preemptible instances, Lightning’s fault-tolerant mode automatically saves and restores the exact training state (dataloader position, RNG states, optimizer state) so training resumes without wasted computation.
Testing and Inference
# Test with best checkpoint
trainer.test(model, datamodule, ckpt_path="best")
# Prediction
predictions = trainer.predict(model, datamodule)
# Load for standalone inference (no Trainer needed)
loaded_model = ImageClassifier.load_from_checkpoint(
"checkpoints/epoch=45-val/f1=0.942.ckpt"
)
loaded_model.eval()
with torch.no_grad():
output = loaded_model(sample_input)
load_from_checkpoint works because save_hyperparameters() stored all constructor arguments — the class reconstructs itself completely from the checkpoint file.
The one thing to remember: Lightning’s real power isn’t the training loop abstraction — it’s the ecosystem of callbacks, strategies, and profilers that let you go from a working prototype to a production multi-node training pipeline by changing configuration, not code.
See Also
- Python Tensorflow Custom Layers How to teach TensorFlow new tricks by building your own custom layers — explained with a cookie cutter analogy.
- Python Tensorflow Data Pipelines How TensorFlow feeds data to your model without wasting time — explained like a restaurant kitchen that never stops cooking.
- Python Tensorflow Keras Api Why Keras is TensorFlow's friendly front door — and how it turns complex math into simple building blocks anyone can stack together.
- Python Tensorflow Model Optimization Why making a trained model smaller and faster matters — explained like packing a suitcase for a trip.
- Python Tensorflow Tensorboard How TensorBoard lets you watch your model learn in real time — explained like a fitness tracker for your AI.