GAN Training Patterns in Python — Core Concepts

Generative Adversarial Networks (GANs) train two neural networks simultaneously in a minimax game: the Generator creates synthetic samples, and the Discriminator classifies samples as real or generated. The training dynamics are notoriously unstable, making training patterns and stabilization techniques essential knowledge.

Basic GAN architecture

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3, features=64):
        super().__init__()
        self.net = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, features * 8, 4, 1, 0),
            nn.BatchNorm2d(features * 8),
            nn.ReLU(True),
            # State: features*8 x 4 x 4
            nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1),
            nn.BatchNorm2d(features * 4),
            nn.ReLU(True),
            # State: features*4 x 8 x 8
            nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1),
            nn.BatchNorm2d(features * 2),
            nn.ReLU(True),
            # State: features*2 x 16 x 16
            nn.ConvTranspose2d(features * 2, features, 4, 2, 1),
            nn.BatchNorm2d(features),
            nn.ReLU(True),
            # State: features x 32 x 32
            nn.ConvTranspose2d(features, img_channels, 4, 2, 1),
            nn.Tanh(),
            # Output: img_channels x 64 x 64
        )
    
    def forward(self, z):
        return self.net(z.view(-1, z.size(1), 1, 1))

class Discriminator(nn.Module):
    def __init__(self, img_channels=3, features=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, features, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(features, features * 2, 4, 2, 1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(features * 2, features * 4, 4, 2, 1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(features * 4, features * 8, 4, 2, 1),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(features * 8, 1, 4, 1, 0),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.net(x).view(-1)

The training loop

The Generator and Discriminator alternate updates. The Discriminator trains on both real and fake samples; the Generator trains only through the Discriminator’s feedback:

def train_gan(dataloader, gen, disc, epochs=100, latent_dim=100, lr=2e-4):
    opt_g = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_d = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
    criterion = nn.BCELoss()
    
    for epoch in range(epochs):
        for real_images, _ in dataloader:
            batch_size = real_images.size(0)
            real = real_images.to("cuda")
            
            # --- Train Discriminator ---
            noise = torch.randn(batch_size, latent_dim, device="cuda")
            fake = gen(noise).detach()
            
            d_real = disc(real)
            d_fake = disc(fake)
            
            loss_d = (
                criterion(d_real, torch.ones_like(d_real)) +
                criterion(d_fake, torch.zeros_like(d_fake))
            ) / 2
            
            opt_d.zero_grad()
            loss_d.backward()
            opt_d.step()
            
            # --- Train Generator ---
            noise = torch.randn(batch_size, latent_dim, device="cuda")
            fake = gen(noise)
            d_fake = disc(fake)
            
            loss_g = criterion(d_fake, torch.ones_like(d_fake))
            
            opt_g.zero_grad()
            loss_g.backward()
            opt_g.step()

Mode collapse

The most common GAN failure is mode collapse — the Generator learns to produce a small set of outputs that fool the Discriminator, ignoring the full diversity of the training data. Instead of generating varied faces, it might produce the same three faces repeatedly.

Signs of mode collapse:

  • Generator loss drops rapidly and stays low
  • Generated samples lack variety
  • Discriminator accuracy oscillates without settling

Stabilization techniques

Spectral normalization: Constrains the Lipschitz constant of the Discriminator, preventing it from becoming too sharp a critic too fast:

disc = Discriminator()
for module in disc.modules():
    if isinstance(module, nn.Conv2d):
        nn.utils.spectral_norm(module)

Label smoothing: Use 0.9 instead of 1.0 for “real” labels, preventing the Discriminator from becoming overconfident:

real_labels = torch.ones_like(d_real) * 0.9

Two-timescale update rule (TTUR): Use different learning rates — typically a slower rate for the Generator:

opt_g = torch.optim.Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(disc.parameters(), lr=4e-4, betas=(0.5, 0.999))

Common misconception

GANs do not memorize training images. The Generator never sees real images directly — it only receives gradient signals indicating whether the Discriminator was fooled. This indirect learning process creates new samples from learned statistical patterns. You can verify this by computing nearest-neighbor distances between generated and training images — well-trained GANs produce novel content, not copies.

GAN variants at a glance

VariantKey innovationBest for
DCGANConvolutional architectureBaseline experiments
WGAN-GPWasserstein loss + gradient penaltyStable training
StyleGANStyle-based generatorHigh-quality faces
Pix2PixPaired image translationPaired transformations
CycleGANUnpaired image translationDomain transfer
ProGANProgressive growingHigh resolution

One thing to remember: GAN training is a balancing act between Generator and Discriminator — too much advantage on either side causes collapse or stagnation — and techniques like spectral normalization, label smoothing, and careful learning rate tuning keep the competition productive.

pythongangenerative-aideep-learning

See Also

  • Diffusion Models Stable Diffusion and DALL-E don't 'draw' your images — they unspoil a scrambled mess until a picture emerges. Here's the surprisingly simple idea behind it.
  • Python Controlnet Image Control Find out how ControlNet lets you boss around an AI artist by giving it sketches, poses, and outlines to follow.
  • Python Image Generation Pipelines Discover how Python chains together multiple steps to turn your ideas into polished AI-generated images, like a factory assembly line for pictures.
  • Python Image Inpainting Learn how Python can magically fill in missing parts of a photo, like erasing something and having the picture fix itself.
  • Python Lora Fine Tuning Learn how LoRA lets you teach an AI new tricks without replacing its entire brain, using tiny add-on lessons instead.