GAN Training Patterns in Python — Deep Dive

Advanced GAN training requires understanding loss function design, progressive training strategies, modern architectures like StyleGAN, evaluation metrics, and distributed training patterns. This guide covers production-grade implementations in PyTorch.

Wasserstein GAN with gradient penalty (WGAN-GP)

The original GAN loss uses binary cross-entropy, which can saturate and produce vanishing gradients. Wasserstein loss provides a smoother training signal by measuring the Earth Mover’s distance between real and generated distributions:

import torch
import torch.nn as nn
import torch.autograd as autograd

class WGANCritic(nn.Module):
    """Critic (not discriminator) — outputs unbounded scores, not probabilities."""
    def __init__(self, img_channels=3, features=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, features, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features, features * 2, 4, 2, 1),
            nn.InstanceNorm2d(features * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features * 2, features * 4, 4, 2, 1),
            nn.InstanceNorm2d(features * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features * 4, features * 8, 4, 2, 1),
            nn.InstanceNorm2d(features * 8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features * 8, 1, 4, 1, 0),
            # No sigmoid — output is unbounded
        )
    
    def forward(self, x):
        return self.net(x).view(-1)

def gradient_penalty(critic, real, fake, device="cuda", lambda_gp=10):
    """Enforces 1-Lipschitz constraint via interpolation penalty."""
    batch_size = real.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    
    interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    critic_interpolated = critic(interpolated)
    
    gradients = autograd.grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_interpolated),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    return penalty

def train_wgan_gp(
    dataloader, gen, critic,
    epochs=100, latent_dim=100,
    n_critic=5,  # Train critic more often than generator
    lr=1e-4,
):
    opt_g = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.9))
    opt_c = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0.0, 0.9))
    
    for epoch in range(epochs):
        for i, (real, _) in enumerate(dataloader):
            real = real.to("cuda")
            batch_size = real.size(0)
            
            # --- Train Critic (n_critic times per generator step) ---
            for _ in range(n_critic):
                noise = torch.randn(batch_size, latent_dim, device="cuda")
                fake = gen(noise).detach()
                
                critic_real = critic(real).mean()
                critic_fake = critic(fake).mean()
                gp = gradient_penalty(critic, real, fake)
                
                loss_c = critic_fake - critic_real + gp
                
                opt_c.zero_grad()
                loss_c.backward()
                opt_c.step()
            
            # --- Train Generator ---
            noise = torch.randn(batch_size, latent_dim, device="cuda")
            fake = gen(noise)
            loss_g = -critic(fake).mean()
            
            opt_g.zero_grad()
            loss_g.backward()
            opt_g.step()

Key differences from vanilla GAN: no sigmoid in critic, Wasserstein loss (mean instead of BCE), gradient penalty for constraint enforcement, critic trains 5x per generator step.

Progressive growing

ProGAN starts training at low resolution (4×4) and progressively adds layers to both networks, gradually increasing to the target resolution. This stabilizes training by starting with easy, low-frequency patterns:

class ProgressiveGenerator(nn.Module):
    def __init__(self, latent_dim=512, max_resolution=256):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Initial block: 4x4
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.LeakyReLU(0.2),
        )
        
        # Progressive blocks (each doubles resolution)
        self.blocks = nn.ModuleList([
            self._make_block(512, 512),  # 4->8
            self._make_block(512, 256),  # 8->16
            self._make_block(256, 128),  # 16->32
            self._make_block(128, 64),   # 32->64
            self._make_block(64, 32),    # 64->128
            self._make_block(32, 16),    # 128->256
        ])
        
        # To-RGB layers for each resolution
        self.to_rgb = nn.ModuleList([
            nn.Conv2d(512, 3, 1),  # 4x4
            nn.Conv2d(512, 3, 1),  # 8x8
            nn.Conv2d(256, 3, 1),  # 16x16
            nn.Conv2d(128, 3, 1),  # 32x32
            nn.Conv2d(64, 3, 1),   # 64x64
            nn.Conv2d(32, 3, 1),   # 128x128
            nn.Conv2d(16, 3, 1),   # 256x256
        ])
        
        self.current_depth = 0
        self.alpha = 1.0  # Blending factor for fade-in
    
    def _make_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_ch, out_ch, 3, 1, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_ch, out_ch, 3, 1, 1),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self, z):
        x = self.initial(z.view(-1, self.latent_dim, 1, 1))
        
        if self.current_depth == 0:
            return self.to_rgb[0](x)
        
        for i in range(self.current_depth - 1):
            x = self.blocks[i](x)
        
        # Fade-in: blend new block with upsampled previous
        upsampled = nn.functional.interpolate(x, scale_factor=2, mode="nearest")
        upsampled_rgb = self.to_rgb[self.current_depth - 1](upsampled)
        
        new_features = self.blocks[self.current_depth - 1](x)
        new_rgb = self.to_rgb[self.current_depth](new_features)
        
        return self.alpha * new_rgb + (1 - self.alpha) * upsampled_rgb
    
    def grow(self):
        """Add a new resolution level."""
        self.current_depth += 1
        self.alpha = 0.0  # Start faded out, gradually increase

StyleGAN patterns

StyleGAN replaces the traditional generator with a style-based architecture using a mapping network and adaptive instance normalization:

class MappingNetwork(nn.Module):
    """Maps latent z to intermediate space w."""
    def __init__(self, latent_dim=512, w_dim=512, num_layers=8):
        super().__init__()
        layers = []
        for i in range(num_layers):
            layers.extend([
                nn.Linear(latent_dim if i == 0 else w_dim, w_dim),
                nn.LeakyReLU(0.2),
            ])
        self.net = nn.Sequential(*layers)
    
    def forward(self, z):
        return self.net(z)

class StyleBlock(nn.Module):
    """Single block with style modulation."""
    def __init__(self, in_ch, out_ch, w_dim=512):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.style = nn.Linear(w_dim, out_ch * 2)  # scale and shift
        self.noise_scale = nn.Parameter(torch.zeros(1))
    
    def forward(self, x, w, noise=None):
        x = self.conv(x)
        
        # Add per-pixel noise for stochastic detail
        if noise is None:
            noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device)
        x = x + self.noise_scale * noise
        
        # Style modulation via AdaIN
        style = self.style(w).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, dim=1)
        
        x = nn.functional.instance_norm(x)
        return x * (1 + gamma) + beta

Evaluation metrics

Fréchet Inception Distance (FID)

FID compares the distribution of generated images to real images in Inception feature space:

from torchvision.models import inception_v3
import numpy as np
from scipy import linalg

class FIDCalculator:
    def __init__(self, device="cuda"):
        self.model = inception_v3(pretrained=True, transform_input=False)
        self.model.fc = nn.Identity()  # Remove classification head
        self.model = self.model.to(device).eval()
        self.device = device
    
    @torch.no_grad()
    def extract_features(self, images: torch.Tensor) -> np.ndarray:
        # Resize to Inception's expected 299x299
        resized = nn.functional.interpolate(images, size=299, mode="bilinear")
        features = self.model(resized.to(self.device))
        return features.cpu().numpy()
    
    def calculate_fid(self, real_features, fake_features) -> float:
        mu_real = np.mean(real_features, axis=0)
        mu_fake = np.mean(fake_features, axis=0)
        sigma_real = np.cov(real_features, rowvar=False)
        sigma_fake = np.cov(fake_features, rowvar=False)
        
        diff = mu_real - mu_fake
        covmean = linalg.sqrtm(sigma_real @ sigma_fake)
        
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        
        fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
        return float(fid)

FID below 10 indicates high quality; below 5 is state-of-the-art for face generation.

Inception Score (IS)

Measures both quality (confident classifications) and diversity (spread across classes):

def inception_score(images, splits=10):
    model = inception_v3(pretrained=True).eval().to("cuda")
    
    with torch.no_grad():
        preds = torch.softmax(model(images.to("cuda")), dim=1).cpu().numpy()
    
    scores = []
    chunk_size = len(preds) // splits
    for i in range(splits):
        chunk = preds[i * chunk_size:(i + 1) * chunk_size]
        p_y = chunk.mean(axis=0, keepdims=True)
        kl_div = chunk * (np.log(chunk + 1e-10) - np.log(p_y + 1e-10))
        scores.append(np.exp(kl_div.sum(axis=1).mean()))
    
    return float(np.mean(scores)), float(np.std(scores))

Training monitoring dashboard

from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils

class GANTrainingMonitor:
    def __init__(self, log_dir="runs/gan"):
        self.writer = SummaryWriter(log_dir)
        self.fixed_noise = torch.randn(64, 100, device="cuda")
    
    def log_step(self, step, gen, disc, loss_g, loss_d, real_batch):
        self.writer.add_scalar("loss/generator", loss_g, step)
        self.writer.add_scalar("loss/discriminator", loss_d, step)
        
        # Discriminator accuracy on real vs fake
        with torch.no_grad():
            real_pred = disc(real_batch).mean().item()
            fake = gen(self.fixed_noise[:real_batch.size(0)])
            fake_pred = disc(fake).mean().item()
        
        self.writer.add_scalar("disc/real_score", real_pred, step)
        self.writer.add_scalar("disc/fake_score", fake_pred, step)
        
        # Generate sample grid periodically
        if step % 500 == 0:
            with torch.no_grad():
                samples = gen(self.fixed_noise)
                grid = vutils.make_grid(samples, normalize=True, nrow=8)
                self.writer.add_image("generated", grid, step)
        
        # Gradient norms (training health indicator)
        g_grad = self._gradient_norm(gen)
        d_grad = self._gradient_norm(disc)
        self.writer.add_scalar("gradients/generator", g_grad, step)
        self.writer.add_scalar("gradients/discriminator", d_grad, step)
    
    @staticmethod
    def _gradient_norm(model):
        total = 0
        for p in model.parameters():
            if p.grad is not None:
                total += p.grad.data.norm(2).item() ** 2
        return total ** 0.5

Distributed training

For high-resolution GANs requiring multi-GPU:

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def train_distributed(rank, world_size, dataloader):
    setup_distributed(rank, world_size)
    
    gen = Generator().to(rank)
    disc = Discriminator().to(rank)
    
    gen = DDP(gen, device_ids=[rank])
    disc = DDP(disc, device_ids=[rank])
    
    # Training loop same as single-GPU
    # DDP handles gradient synchronization automatically
    
    # Only save checkpoints from rank 0
    if rank == 0:
        torch.save(gen.module.state_dict(), "generator.pth")

Practical training recipes

ResolutionArchitectureBatch sizeTraining time (1× A100)
64×64DCGAN + WGAN-GP1284–8 hours
256×256ProGAN16–322–4 days
512×512StyleGAN28–161–2 weeks
1024×1024StyleGAN24–82–4 weeks

One thing to remember: Stable GAN training requires the right loss function (Wasserstein + gradient penalty for reliability), proper training ratio (critic trains more than generator), progressive resolution growth for high-quality output, and continuous monitoring of both gradient norms and discriminator scores — the training signal quality determines everything.

pythongangenerative-aideep-learning

See Also

  • Diffusion Models Stable Diffusion and DALL-E don't 'draw' your images — they unspoil a scrambled mess until a picture emerges. Here's the surprisingly simple idea behind it.
  • Python Controlnet Image Control Find out how ControlNet lets you boss around an AI artist by giving it sketches, poses, and outlines to follow.
  • Python Image Generation Pipelines Discover how Python chains together multiple steps to turn your ideas into polished AI-generated images, like a factory assembly line for pictures.
  • Python Image Inpainting Learn how Python can magically fill in missing parts of a photo, like erasing something and having the picture fix itself.
  • Python Lora Fine Tuning Learn how LoRA lets you teach an AI new tricks without replacing its entire brain, using tiny add-on lessons instead.