Neural Style Transfer in Python — Deep Dive
Production neural style transfer requires moving beyond the original Gatys optimization to faster, more flexible architectures. This guide covers Adaptive Instance Normalization (AdaIN), multi-style networks, video consistency, perceptual loss design, and deployment patterns for real-time stylization.
Adaptive Instance Normalization (AdaIN)
AdaIN enables arbitrary style transfer in a single forward pass — no per-style training, no optimization loop. The core operation normalizes content features to match style statistics:
import torch
import torch.nn as nn
def adaptive_instance_normalization(content_feat, style_feat):
"""Transfer mean and variance from style to content features."""
c_mean = content_feat.mean(dim=[2, 3], keepdim=True)
c_std = content_feat.std(dim=[2, 3], keepdim=True) + 1e-5
s_mean = style_feat.mean(dim=[2, 3], keepdim=True)
s_std = style_feat.std(dim=[2, 3], keepdim=True) + 1e-5
normalized = (content_feat - c_mean) / c_std
return normalized * s_std + s_mean
The insight: instance normalization statistics (mean and variance per channel) capture style information. By replacing content statistics with style statistics, you transfer the “feel” of one image to another.
Complete AdaIN network
from torchvision.models import vgg19
class AdaINStyleTransfer(nn.Module):
def __init__(self):
super().__init__()
# Encoder: VGG-19 up to relu4_1 (frozen)
vgg = vgg19(pretrained=True).features[:21]
self.encoder = nn.Sequential(*list(vgg.children()))
for param in self.encoder.parameters():
param.requires_grad_(False)
# Decoder: mirrors encoder architecture
self.decoder = nn.Sequential(
nn.Conv2d(512, 256, 3, padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 128, 3, padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 3, 3, padding=1),
)
def forward(self, content, style, alpha=1.0):
content_feat = self.encoder(content)
style_feat = self.encoder(style)
transferred = adaptive_instance_normalization(content_feat, style_feat)
# Alpha blending for style intensity control
transferred = alpha * transferred + (1 - alpha) * content_feat
return self.decoder(transferred)
Training the decoder
def train_adain_decoder(model, content_loader, style_loader, epochs=20):
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-4)
mse_loss = nn.MSELoss()
for epoch in range(epochs):
for content_batch, style_batch in zip(content_loader, style_loader):
content = content_batch.to("cuda")
style = style_batch.to("cuda")
output = model(content, style)
# Content loss: feature similarity
output_feat = model.encoder(output)
content_feat = model.encoder(content)
content_loss = mse_loss(output_feat,
adaptive_instance_normalization(content_feat, model.encoder(style))
)
# Style loss: Gram matrix similarity at multiple layers
style_loss = compute_multi_layer_style_loss(
model.encoder, output, style
)
loss = content_loss + 10.0 * style_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
Perceptual loss functions
Beyond Gram matrices
Johnson et al. proposed perceptual losses that compare high-level features directly, producing sharper results:
class PerceptualLoss(nn.Module):
def __init__(self, layers=None):
super().__init__()
vgg = vgg19(pretrained=True).features.eval()
# Default layers for perceptual comparison
self.layer_indices = layers or [3, 8, 17, 26, 35]
self.slices = nn.ModuleList()
prev = 0
for idx in self.layer_indices:
self.slices.append(nn.Sequential(*list(vgg.children())[prev:idx]))
prev = idx
for param in self.parameters():
param.requires_grad_(False)
def forward(self, input, target):
loss = 0
x, y = input, target
for slice in self.slices:
x = slice(x)
y = slice(y)
loss += nn.functional.mse_loss(x, y)
return loss
class CombinedStyleLoss(nn.Module):
def __init__(self):
super().__init__()
self.perceptual = PerceptualLoss()
def forward(self, output, content, style):
content_loss = self.perceptual(output, content)
# Style loss with Gram matrices
style_loss = 0
for layer in range(len(self.perceptual.slices)):
of = self.extract_layer(output, layer)
sf = self.extract_layer(style, layer)
style_loss += nn.functional.mse_loss(
self.gram(of), self.gram(sf)
)
return content_loss + 1e5 * style_loss
@staticmethod
def gram(x):
b, c, h, w = x.shape
f = x.view(b, c, -1)
return torch.bmm(f, f.transpose(1, 2)) / (c * h * w)
Video style transfer
Temporal consistency
Naively applying style transfer frame-by-frame produces flickering. Temporal consistency requires additional constraints:
import cv2
class VideoStyleTransfer:
def __init__(self, model, device="cuda"):
self.model = model.to(device)
self.device = device
def compute_optical_flow(self, prev_frame, curr_frame):
"""Dense optical flow for temporal warping."""
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
flow = cv2.calcOpticalFlowFarneback(
prev_gray, curr_gray,
None, 0.5, 3, 15, 3, 5, 1.2, 0
)
return flow
def warp_image(self, image, flow):
"""Warp previous stylized frame using optical flow."""
h, w = flow.shape[:2]
remap_x = np.float32(np.arange(w))
remap_y = np.float32(np.arange(h))
remap_x, remap_y = np.meshgrid(remap_x, remap_y)
remap_x += flow[:, :, 0]
remap_y += flow[:, :, 1]
warped = cv2.remap(image, remap_x, remap_y, cv2.INTER_LINEAR)
return warped
def stylize_video(
self, input_path: str, style_image,
output_path: str, temporal_weight: float = 0.7
):
cap = cv2.VideoCapture(input_path)
fps = cap.get(cv2.CAP_PROP_FPS)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
writer = cv2.VideoWriter(
output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)
)
prev_frame = None
prev_stylized = None
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Stylize current frame
stylized = self.stylize_frame(frame, style_image)
if prev_frame is not None and prev_stylized is not None:
# Warp previous stylized frame to current position
flow = self.compute_optical_flow(prev_frame, frame)
warped_prev = self.warp_image(prev_stylized, flow)
# Blend current stylization with warped previous
stylized = cv2.addWeighted(
stylized, 1 - temporal_weight,
warped_prev, temporal_weight,
0,
)
writer.write(stylized)
prev_frame = frame
prev_stylized = stylized
cap.release()
writer.release()
Multi-style interpolation
Blend multiple styles with controllable weights:
class MultiStyleTransfer:
def __init__(self, model):
self.model = model
def interpolate_styles(
self, content, styles: list, weights: list
):
"""Blend multiple styles with weighted combination."""
assert len(styles) == len(weights)
assert abs(sum(weights) - 1.0) < 1e-6
content_feat = self.model.encoder(content)
# Compute weighted style statistics
combined_mean = 0
combined_std = 0
for style, weight in zip(styles, weights):
style_feat = self.model.encoder(style)
s_mean = style_feat.mean(dim=[2, 3], keepdim=True)
s_std = style_feat.std(dim=[2, 3], keepdim=True) + 1e-5
combined_mean += weight * s_mean
combined_std += weight * s_std
# Normalize content and apply combined style statistics
c_mean = content_feat.mean(dim=[2, 3], keepdim=True)
c_std = content_feat.std(dim=[2, 3], keepdim=True) + 1e-5
normalized = (content_feat - c_mean) / c_std
transferred = normalized * combined_std + combined_mean
return self.model.decoder(transferred)
Production serving
Real-time API with model caching
from fastapi import FastAPI, UploadFile, File
from io import BytesIO
import base64
app = FastAPI()
# Load model once at startup
model = AdaINStyleTransfer()
model.load_state_dict(torch.load("adain_decoder.pth"))
model = model.to("cuda").eval()
# Pre-encode popular styles
STYLE_CACHE = {}
POPULAR_STYLES = {
"starry_night": "styles/starry_night.jpg",
"great_wave": "styles/great_wave.jpg",
"mosaic": "styles/mosaic.jpg",
"candy": "styles/candy.jpg",
}
for name, path in POPULAR_STYLES.items():
style_img = load_and_preprocess(path)
with torch.no_grad():
STYLE_CACHE[name] = model.encoder(style_img.to("cuda"))
@app.post("/stylize")
async def stylize(
content: UploadFile = File(...),
style_name: str = "starry_night",
alpha: float = 1.0,
):
content_img = load_upload(await content.read())
content_feat = model.encoder(content_img)
style_feat = STYLE_CACHE[style_name]
transferred = adaptive_instance_normalization(content_feat, style_feat)
transferred = alpha * transferred + (1 - alpha) * content_feat
output = model.decoder(transferred)
return {"image": tensor_to_base64(output)}
Performance benchmarks
| Method | Resolution | Time (GPU) | Quality |
|---|---|---|---|
| Gatys optimization (300 steps) | 512×512 | 45s | Best |
| Fast style transfer (per-style) | 512×512 | 15ms | Good |
| AdaIN (arbitrary style) | 512×512 | 25ms | Good |
| AdaIN (arbitrary style) | 1024×1024 | 85ms | Good |
For real-time applications (video, interactive), AdaIN at 512×512 achieves 40+ FPS on modern GPUs.
Style-content tradeoff tuning
The balance between preserving content and applying style is the central creative decision:
def generate_tradeoff_gallery(
model, content, style,
alphas=[0.0, 0.25, 0.5, 0.75, 1.0],
):
"""Generate a gallery showing style intensity progression."""
results = []
for alpha in alphas:
with torch.no_grad():
output = model(content, style, alpha=alpha)
results.append((alpha, output))
return results
Low alpha preserves photographic fidelity; high alpha produces strong artistic transformation. Most users prefer alpha between 0.6 and 0.8 for a visible style effect that does not obscure content.
One thing to remember: AdaIN made arbitrary style transfer practical by reducing it to a statistics transfer operation — matching mean and variance of content features to style features — enabling real-time stylization of any content with any style in a single forward pass.
See Also
- Diffusion Models Stable Diffusion and DALL-E don't 'draw' your images — they unspoil a scrambled mess until a picture emerges. Here's the surprisingly simple idea behind it.
- Python Controlnet Image Control Find out how ControlNet lets you boss around an AI artist by giving it sketches, poses, and outlines to follow.
- Python Gan Training Patterns Learn how two neural networks compete like an art forger and a detective to create incredibly realistic fake images.
- Python Image Generation Pipelines Discover how Python chains together multiple steps to turn your ideas into polished AI-generated images, like a factory assembly line for pictures.
- Python Image Inpainting Learn how Python can magically fill in missing parts of a photo, like erasing something and having the picture fix itself.