Neural Style Transfer in Python — Core Concepts

Neural style transfer (NST) decomposes images into content and style representations using a pretrained convolutional neural network, then optimizes a new image to match the content of one source and the style of another. The technique, introduced by Gatys et al. in 2015, remains foundational for understanding how CNNs represent visual information.

How content and style are captured

A pretrained CNN like VGG-19 processes images through progressively deeper layers. Each layer produces feature maps — multi-channel representations of the input. The key insight is that different layers capture different types of information:

Content representation: Deeper layers (like conv4_2) capture high-level structures — object shapes, spatial arrangements, scene composition. The exact pixel values do not matter; two photos of the same scene from slightly different angles produce similar deep features.

Style representation: Style is captured through Gram matrices computed from feature maps at multiple layers. A Gram matrix measures correlations between different feature channels — which textures tend to appear together, which colors co-occur. This captures the “feel” of an image without its spatial structure.

The Gram matrix

For a feature map F with shape (C, H, W) — C channels, each H×W pixels — the Gram matrix G is C×C:

G[i][j] = sum over all spatial positions of F[i] * F[j]

Two images with similar Gram matrices share similar textures and patterns, even if their spatial layouts are completely different. A Gram matrix from Van Gogh’s Starry Night captures swirling textures and blue-yellow correlations without encoding where the stars are.

Implementation in PyTorch

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# Load and preprocess images
def load_image(path, size=512):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ])
    return transform(Image.open(path)).unsqueeze(0).to("cuda")

# Extract features from VGG-19
class StyleTransferModel(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features.eval()
        self.slices = nn.ModuleList([
            vgg[:4],   # conv1_2 (style)
            vgg[4:9],  # conv2_2 (style)
            vgg[9:18], # conv3_4 (style)
            vgg[18:27],# conv4_4 (style + content)
            vgg[27:36],# conv5_4 (style)
        ])
        for param in self.parameters():
            param.requires_grad_(False)
    
    def forward(self, x):
        features = []
        for slice in self.slices:
            x = slice(x)
            features.append(x)
        return features

def gram_matrix(features):
    b, c, h, w = features.shape
    F = features.view(b, c, h * w)
    return torch.bmm(F, F.transpose(1, 2)) / (c * h * w)

The optimization loop

Style transfer optimizes the output image (starting from the content image or noise) to minimize two losses simultaneously:

def style_transfer(
    content_img, style_img,
    steps=300, style_weight=1e6, content_weight=1,
):
    model = StyleTransferModel().to("cuda")
    
    # Initialize output from content image
    output = content_img.clone().requires_grad_(True)
    optimizer = torch.optim.Adam([output], lr=0.01)
    
    content_features = model(content_img)
    style_features = model(style_img)
    style_grams = [gram_matrix(f) for f in style_features]
    
    for step in range(steps):
        output_features = model(output)
        
        # Content loss: MSE at conv4_4
        content_loss = nn.functional.mse_loss(
            output_features[3], content_features[3]
        )
        
        # Style loss: MSE of Gram matrices at all layers
        style_loss = 0
        for of, sg in zip(output_features, style_grams):
            style_loss += nn.functional.mse_loss(gram_matrix(of), sg)
        
        total_loss = content_weight * content_loss + style_weight * style_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    
    return output.detach()

Fast style transfer

The original optimization takes minutes per image. Fast style transfer trains a feed-forward network that learns to apply a specific style in a single forward pass (milliseconds):

# Using a pretrained fast style transfer model
from torchvision.models import vgg19
import torch

class TransformNet(nn.Module):
    """Feed-forward network trained for a single style."""
    def __init__(self):
        super().__init__()
        # Encoder → Residual blocks → Decoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 9, padding=4),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
        )
        self.residual = nn.Sequential(
            *[ResidualBlock(64) for _ in range(5)]
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 3, 9, padding=4),
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.decoder(self.residual(self.encoder(x)))

The tradeoff: optimization-based transfer works with any style image; fast transfer requires training a separate network for each style but runs in real time.

Common misconception

Style transfer does not “paste” artistic textures onto your photo like a sticker. It restructures the image’s visual patterns at multiple scales — from fine brushstroke textures to broad color palettes — using the same feature representations the neural network uses to recognize objects. The output is a genuinely new rendering of the scene, not a filter overlay.

One thing to remember: Neural style transfer works because CNNs naturally separate content (what is depicted) from style (how it is rendered) across their layers, and Gram matrices capture style as texture correlations that can be transferred between any two images.

pythonstyle-transfercomputer-visiondeep-learning

See Also

  • Diffusion Models Stable Diffusion and DALL-E don't 'draw' your images — they unspoil a scrambled mess until a picture emerges. Here's the surprisingly simple idea behind it.
  • Python Controlnet Image Control Find out how ControlNet lets you boss around an AI artist by giving it sketches, poses, and outlines to follow.
  • Python Gan Training Patterns Learn how two neural networks compete like an art forger and a detective to create incredibly realistic fake images.
  • Python Image Generation Pipelines Discover how Python chains together multiple steps to turn your ideas into polished AI-generated images, like a factory assembly line for pictures.
  • Python Image Inpainting Learn how Python can magically fill in missing parts of a photo, like erasing something and having the picture fix itself.