PyTorch Lightning Training — Deep Dive

Complete LightningModule Example

A production-ready image classifier with all the recommended methods:

import torch
import torch.nn as nn
import lightning as L
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchmetrics import Accuracy, F1Score

class ImageClassifier(L.LightningModule):
    def __init__(self, num_classes: int, lr: float = 1e-3,
                 weight_decay: float = 0.01):
        super().__init__()
        self.save_hyperparameters()

        # Pretrained backbone
        backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        in_features = backbone.classifier[1].in_features
        backbone.classifier = nn.Identity()
        self.backbone = backbone

        # Custom head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes),
        )

        # Metrics
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_f1 = F1Score(task="multiclass", num_classes=num_classes,
                              average="macro")

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = self.criterion(logits, labels)

        preds = logits.argmax(dim=1)
        self.train_acc(preds, labels)

        self.log("train/loss", loss, prog_bar=True)
        self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = self.criterion(logits, labels)

        preds = logits.argmax(dim=1)
        self.val_acc(preds, labels)
        self.val_f1(preds, labels)

        self.log("val/loss", loss, prog_bar=True)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
        self.log("val/f1", self.val_f1, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.trainer.max_epochs
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

Key details: save_hyperparameters() stores constructor arguments for reproducibility and checkpoint loading. torchmetrics handles metric computation correctly across distributed processes (automatic sync at epoch end).

Production DataModule

from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder

class ImageDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 32,
                 num_workers: int = 4):
        super().__init__()
        self.save_hyperparameters()

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.2, 0.2, 0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                               [0.229, 0.224, 0.225]),
        ])
        self.val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                               [0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        full_dataset = ImageFolder(self.hparams.data_dir)
        n = len(full_dataset)
        train_size = int(0.8 * n)
        val_size = n - train_size
        self.train_ds, self.val_ds = random_split(
            full_dataset, [train_size, val_size],
            generator=torch.Generator().manual_seed(42),
        )
        # Apply different transforms
        self.train_ds.dataset.transform = self.train_transform
        self.val_ds.dataset.transform = self.val_transform

    def train_dataloader(self):
        return DataLoader(
            self.train_ds, batch_size=self.hparams.batch_size,
            shuffle=True, num_workers=self.hparams.num_workers,
            pin_memory=True, persistent_workers=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds, batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=True, persistent_workers=True,
        )

persistent_workers=True keeps worker processes alive between epochs, avoiding the overhead of respawning them. Essential for datasets with expensive initialization.

Custom Callbacks

Gradient Monitoring

class GradientMonitor(L.Callback):
    """Log gradient statistics to detect vanishing/exploding gradients."""

    def on_after_backward(self, trainer, pl_module):
        if trainer.global_step % 100 != 0:
            return

        for name, param in pl_module.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm(2).item()
                pl_module.log(f"gradients/{name}_norm", grad_norm)

                if grad_norm > 100:
                    trainer.print(f"WARNING: Large gradient in {name}: "
                                  f"{grad_norm:.2f}")

Dynamic Batch Size

class DynamicBatchSize(L.Callback):
    """Increase batch size during training for faster convergence."""

    def __init__(self, initial: int = 16, max_size: int = 256,
                 increase_every: int = 5):
        self.initial = initial
        self.max_size = max_size
        self.increase_every = increase_every

    def on_train_epoch_start(self, trainer, pl_module):
        epoch = trainer.current_epoch
        new_size = min(
            self.initial * (2 ** (epoch // self.increase_every)),
            self.max_size,
        )
        trainer.datamodule.hparams.batch_size = new_size

Trainer Configuration for Production

from lightning.pytorch.callbacks import (
    ModelCheckpoint, EarlyStopping, LearningRateMonitor,
    RichProgressBar,
)
from lightning.pytorch.loggers import WandbLogger

trainer = L.Trainer(
    # Hardware
    accelerator="gpu",
    devices=4,
    strategy="ddp",
    precision="16-mixed",

    # Training
    max_epochs=100,
    gradient_clip_val=1.0,
    accumulate_grad_batches=4,

    # Callbacks
    callbacks=[
        ModelCheckpoint(
            monitor="val/f1",
            mode="max",
            save_top_k=3,
            filename="{epoch}-{val/f1:.3f}",
        ),
        EarlyStopping(
            monitor="val/loss",
            patience=10,
            mode="min",
        ),
        LearningRateMonitor(logging_interval="step"),
        GradientMonitor(),
        RichProgressBar(),
    ],

    # Logging
    logger=WandbLogger(project="image-classification"),

    # Debugging (disable for production)
    # fast_dev_run=True,  # Run 1 batch to verify pipeline
    # overfit_batches=10,  # Overfit on 10 batches to verify model capacity
)

Advanced Training Strategies

FSDP for Large Models

from lightning.pytorch.strategies import FSDPStrategy

strategy = FSDPStrategy(
    auto_wrap_policy={nn.TransformerEncoderLayer},
    activation_checkpointing_policy={nn.TransformerEncoderLayer},
    sharding_strategy="FULL_SHARD",
    mixed_precision=None,  # Use Trainer's precision instead
)

trainer = L.Trainer(
    strategy=strategy,
    devices=8,
    precision="bf16-mixed",
)

DeepSpeed for Extreme Scale

from lightning.pytorch.strategies import DeepSpeedStrategy

strategy = DeepSpeedStrategy(
    stage=3,  # Full parameter + gradient + optimizer sharding
    offload_optimizer=True,  # Offload optimizer to CPU
    offload_parameters=False,
    allgather_bucket_size=5e8,
    reduce_bucket_size=5e8,
)

Profiling and Debugging

Lightning’s built-in profiler identifies bottlenecks:

from lightning.pytorch.profilers import AdvancedProfiler

profiler = AdvancedProfiler(dirpath="./profiler", filename="perf")
trainer = L.Trainer(profiler=profiler, max_epochs=2)
trainer.fit(model, datamodule)

The output shows time spent in each method — immediately revealing whether your bottleneck is data loading, forward pass, backward pass, or optimizer steps.

For quick sanity checks:

# Verify pipeline works end-to-end (1 train + 1 val batch)
trainer = L.Trainer(fast_dev_run=True)
trainer.fit(model, datamodule)

# Verify model can overfit (proves learning capacity)
trainer = L.Trainer(overfit_batches=10, max_epochs=50)
trainer.fit(model, datamodule)

Resuming and Fault Tolerance

Resume from a checkpoint after crash or preemption:

# Automatic: Lightning saves last.ckpt by default
trainer = L.Trainer(
    default_root_dir="./checkpoints",
    max_epochs=100,
)
trainer.fit(model, datamodule, ckpt_path="./checkpoints/last.ckpt")

For cloud training with spot/preemptible instances, Lightning’s fault-tolerant mode automatically saves and restores the exact training state (dataloader position, RNG states, optimizer state) so training resumes without wasted computation.

Testing and Inference

# Test with best checkpoint
trainer.test(model, datamodule, ckpt_path="best")

# Prediction
predictions = trainer.predict(model, datamodule)

# Load for standalone inference (no Trainer needed)
loaded_model = ImageClassifier.load_from_checkpoint(
    "checkpoints/epoch=45-val/f1=0.942.ckpt"
)
loaded_model.eval()
with torch.no_grad():
    output = loaded_model(sample_input)

load_from_checkpoint works because save_hyperparameters() stored all constructor arguments — the class reconstructs itself completely from the checkpoint file.

The one thing to remember: Lightning’s real power isn’t the training loop abstraction — it’s the ecosystem of callbacks, strategies, and profilers that let you go from a working prototype to a production multi-node training pipeline by changing configuration, not code.

pythonmachine-learningpytorch

See Also