GAN Training Patterns in Python — Deep Dive
Advanced GAN training requires understanding loss function design, progressive training strategies, modern architectures like StyleGAN, evaluation metrics, and distributed training patterns. This guide covers production-grade implementations in PyTorch.
Wasserstein GAN with gradient penalty (WGAN-GP)
The original GAN loss uses binary cross-entropy, which can saturate and produce vanishing gradients. Wasserstein loss provides a smoother training signal by measuring the Earth Mover’s distance between real and generated distributions:
import torch
import torch.nn as nn
import torch.autograd as autograd
class WGANCritic(nn.Module):
"""Critic (not discriminator) — outputs unbounded scores, not probabilities."""
def __init__(self, img_channels=3, features=64):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(img_channels, features, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(features, features * 2, 4, 2, 1),
nn.InstanceNorm2d(features * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(features * 2, features * 4, 4, 2, 1),
nn.InstanceNorm2d(features * 4),
nn.LeakyReLU(0.2),
nn.Conv2d(features * 4, features * 8, 4, 2, 1),
nn.InstanceNorm2d(features * 8),
nn.LeakyReLU(0.2),
nn.Conv2d(features * 8, 1, 4, 1, 0),
# No sigmoid — output is unbounded
)
def forward(self, x):
return self.net(x).view(-1)
def gradient_penalty(critic, real, fake, device="cuda", lambda_gp=10):
"""Enforces 1-Lipschitz constraint via interpolation penalty."""
batch_size = real.size(0)
alpha = torch.rand(batch_size, 1, 1, 1, device=device)
interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
critic_interpolated = critic(interpolated)
gradients = autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_interpolated),
create_graph=True,
retain_graph=True,
)[0]
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
return penalty
def train_wgan_gp(
dataloader, gen, critic,
epochs=100, latent_dim=100,
n_critic=5, # Train critic more often than generator
lr=1e-4,
):
opt_g = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.9))
opt_c = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0.0, 0.9))
for epoch in range(epochs):
for i, (real, _) in enumerate(dataloader):
real = real.to("cuda")
batch_size = real.size(0)
# --- Train Critic (n_critic times per generator step) ---
for _ in range(n_critic):
noise = torch.randn(batch_size, latent_dim, device="cuda")
fake = gen(noise).detach()
critic_real = critic(real).mean()
critic_fake = critic(fake).mean()
gp = gradient_penalty(critic, real, fake)
loss_c = critic_fake - critic_real + gp
opt_c.zero_grad()
loss_c.backward()
opt_c.step()
# --- Train Generator ---
noise = torch.randn(batch_size, latent_dim, device="cuda")
fake = gen(noise)
loss_g = -critic(fake).mean()
opt_g.zero_grad()
loss_g.backward()
opt_g.step()
Key differences from vanilla GAN: no sigmoid in critic, Wasserstein loss (mean instead of BCE), gradient penalty for constraint enforcement, critic trains 5x per generator step.
Progressive growing
ProGAN starts training at low resolution (4×4) and progressively adds layers to both networks, gradually increasing to the target resolution. This stabilizes training by starting with easy, low-frequency patterns:
class ProgressiveGenerator(nn.Module):
def __init__(self, latent_dim=512, max_resolution=256):
super().__init__()
self.latent_dim = latent_dim
# Initial block: 4x4
self.initial = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, 3, 1, 1),
nn.LeakyReLU(0.2),
)
# Progressive blocks (each doubles resolution)
self.blocks = nn.ModuleList([
self._make_block(512, 512), # 4->8
self._make_block(512, 256), # 8->16
self._make_block(256, 128), # 16->32
self._make_block(128, 64), # 32->64
self._make_block(64, 32), # 64->128
self._make_block(32, 16), # 128->256
])
# To-RGB layers for each resolution
self.to_rgb = nn.ModuleList([
nn.Conv2d(512, 3, 1), # 4x4
nn.Conv2d(512, 3, 1), # 8x8
nn.Conv2d(256, 3, 1), # 16x16
nn.Conv2d(128, 3, 1), # 32x32
nn.Conv2d(64, 3, 1), # 64x64
nn.Conv2d(32, 3, 1), # 128x128
nn.Conv2d(16, 3, 1), # 256x256
])
self.current_depth = 0
self.alpha = 1.0 # Blending factor for fade-in
def _make_block(self, in_ch, out_ch):
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(out_ch, out_ch, 3, 1, 1),
nn.LeakyReLU(0.2),
)
def forward(self, z):
x = self.initial(z.view(-1, self.latent_dim, 1, 1))
if self.current_depth == 0:
return self.to_rgb[0](x)
for i in range(self.current_depth - 1):
x = self.blocks[i](x)
# Fade-in: blend new block with upsampled previous
upsampled = nn.functional.interpolate(x, scale_factor=2, mode="nearest")
upsampled_rgb = self.to_rgb[self.current_depth - 1](upsampled)
new_features = self.blocks[self.current_depth - 1](x)
new_rgb = self.to_rgb[self.current_depth](new_features)
return self.alpha * new_rgb + (1 - self.alpha) * upsampled_rgb
def grow(self):
"""Add a new resolution level."""
self.current_depth += 1
self.alpha = 0.0 # Start faded out, gradually increase
StyleGAN patterns
StyleGAN replaces the traditional generator with a style-based architecture using a mapping network and adaptive instance normalization:
class MappingNetwork(nn.Module):
"""Maps latent z to intermediate space w."""
def __init__(self, latent_dim=512, w_dim=512, num_layers=8):
super().__init__()
layers = []
for i in range(num_layers):
layers.extend([
nn.Linear(latent_dim if i == 0 else w_dim, w_dim),
nn.LeakyReLU(0.2),
])
self.net = nn.Sequential(*layers)
def forward(self, z):
return self.net(z)
class StyleBlock(nn.Module):
"""Single block with style modulation."""
def __init__(self, in_ch, out_ch, w_dim=512):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
self.style = nn.Linear(w_dim, out_ch * 2) # scale and shift
self.noise_scale = nn.Parameter(torch.zeros(1))
def forward(self, x, w, noise=None):
x = self.conv(x)
# Add per-pixel noise for stochastic detail
if noise is None:
noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device)
x = x + self.noise_scale * noise
# Style modulation via AdaIN
style = self.style(w).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, dim=1)
x = nn.functional.instance_norm(x)
return x * (1 + gamma) + beta
Evaluation metrics
Fréchet Inception Distance (FID)
FID compares the distribution of generated images to real images in Inception feature space:
from torchvision.models import inception_v3
import numpy as np
from scipy import linalg
class FIDCalculator:
def __init__(self, device="cuda"):
self.model = inception_v3(pretrained=True, transform_input=False)
self.model.fc = nn.Identity() # Remove classification head
self.model = self.model.to(device).eval()
self.device = device
@torch.no_grad()
def extract_features(self, images: torch.Tensor) -> np.ndarray:
# Resize to Inception's expected 299x299
resized = nn.functional.interpolate(images, size=299, mode="bilinear")
features = self.model(resized.to(self.device))
return features.cpu().numpy()
def calculate_fid(self, real_features, fake_features) -> float:
mu_real = np.mean(real_features, axis=0)
mu_fake = np.mean(fake_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
sigma_fake = np.cov(fake_features, rowvar=False)
diff = mu_real - mu_fake
covmean = linalg.sqrtm(sigma_real @ sigma_fake)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
return float(fid)
FID below 10 indicates high quality; below 5 is state-of-the-art for face generation.
Inception Score (IS)
Measures both quality (confident classifications) and diversity (spread across classes):
def inception_score(images, splits=10):
model = inception_v3(pretrained=True).eval().to("cuda")
with torch.no_grad():
preds = torch.softmax(model(images.to("cuda")), dim=1).cpu().numpy()
scores = []
chunk_size = len(preds) // splits
for i in range(splits):
chunk = preds[i * chunk_size:(i + 1) * chunk_size]
p_y = chunk.mean(axis=0, keepdims=True)
kl_div = chunk * (np.log(chunk + 1e-10) - np.log(p_y + 1e-10))
scores.append(np.exp(kl_div.sum(axis=1).mean()))
return float(np.mean(scores)), float(np.std(scores))
Training monitoring dashboard
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
class GANTrainingMonitor:
def __init__(self, log_dir="runs/gan"):
self.writer = SummaryWriter(log_dir)
self.fixed_noise = torch.randn(64, 100, device="cuda")
def log_step(self, step, gen, disc, loss_g, loss_d, real_batch):
self.writer.add_scalar("loss/generator", loss_g, step)
self.writer.add_scalar("loss/discriminator", loss_d, step)
# Discriminator accuracy on real vs fake
with torch.no_grad():
real_pred = disc(real_batch).mean().item()
fake = gen(self.fixed_noise[:real_batch.size(0)])
fake_pred = disc(fake).mean().item()
self.writer.add_scalar("disc/real_score", real_pred, step)
self.writer.add_scalar("disc/fake_score", fake_pred, step)
# Generate sample grid periodically
if step % 500 == 0:
with torch.no_grad():
samples = gen(self.fixed_noise)
grid = vutils.make_grid(samples, normalize=True, nrow=8)
self.writer.add_image("generated", grid, step)
# Gradient norms (training health indicator)
g_grad = self._gradient_norm(gen)
d_grad = self._gradient_norm(disc)
self.writer.add_scalar("gradients/generator", g_grad, step)
self.writer.add_scalar("gradients/discriminator", d_grad, step)
@staticmethod
def _gradient_norm(model):
total = 0
for p in model.parameters():
if p.grad is not None:
total += p.grad.data.norm(2).item() ** 2
return total ** 0.5
Distributed training
For high-resolution GANs requiring multi-GPU:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train_distributed(rank, world_size, dataloader):
setup_distributed(rank, world_size)
gen = Generator().to(rank)
disc = Discriminator().to(rank)
gen = DDP(gen, device_ids=[rank])
disc = DDP(disc, device_ids=[rank])
# Training loop same as single-GPU
# DDP handles gradient synchronization automatically
# Only save checkpoints from rank 0
if rank == 0:
torch.save(gen.module.state_dict(), "generator.pth")
Practical training recipes
| Resolution | Architecture | Batch size | Training time (1× A100) |
|---|---|---|---|
| 64×64 | DCGAN + WGAN-GP | 128 | 4–8 hours |
| 256×256 | ProGAN | 16–32 | 2–4 days |
| 512×512 | StyleGAN2 | 8–16 | 1–2 weeks |
| 1024×1024 | StyleGAN2 | 4–8 | 2–4 weeks |
One thing to remember: Stable GAN training requires the right loss function (Wasserstein + gradient penalty for reliability), proper training ratio (critic trains more than generator), progressive resolution growth for high-quality output, and continuous monitoring of both gradient norms and discriminator scores — the training signal quality determines everything.
See Also
- Diffusion Models Stable Diffusion and DALL-E don't 'draw' your images — they unspoil a scrambled mess until a picture emerges. Here's the surprisingly simple idea behind it.
- Python Controlnet Image Control Find out how ControlNet lets you boss around an AI artist by giving it sketches, poses, and outlines to follow.
- Python Image Generation Pipelines Discover how Python chains together multiple steps to turn your ideas into polished AI-generated images, like a factory assembly line for pictures.
- Python Image Inpainting Learn how Python can magically fill in missing parts of a photo, like erasing something and having the picture fix itself.
- Python Lora Fine Tuning Learn how LoRA lets you teach an AI new tricks without replacing its entire brain, using tiny add-on lessons instead.