Stable Diffusion API in Python — Deep Dive
Running Stable Diffusion in production means going beyond single-image generation scripts. This guide covers memory optimization, batching strategies, model management, and building robust APIs around diffusion models.
Memory optimization strategies
Half-precision and attention slicing
A full Stable Diffusion v1.5 checkpoint uses roughly 4 GB in float32. Switching to float16 halves that immediately. For GPUs with limited VRAM, attention slicing processes the attention computation in chunks:
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
safety_checker=None, # saves ~1 GB VRAM
)
pipe = pipe.to("cuda")
pipe.enable_attention_slicing() # reduces peak VRAM by ~30%
xFormers memory-efficient attention
The xformers library provides fused attention kernels that reduce memory usage and increase speed by 20–40%:
pipe.enable_xformers_memory_efficient_attention()
This single line can mean the difference between a model fitting on an 8 GB GPU or not.
Sequential CPU offloading
For extremely constrained environments, offload model components to CPU when not in use:
pipe.enable_sequential_cpu_offload()
Each component (text encoder, U-Net, VAE) loads to GPU only during its forward pass. Latency increases by roughly 3x, but VRAM usage drops to under 3 GB.
Model CPU offloading (balanced approach)
A middle ground keeps the full model on CPU but moves entire components to GPU as needed:
pipe.enable_model_cpu_offload()
Less latency overhead than sequential offloading, with VRAM usage around 4 GB for SD 1.5.
Batch generation
Generating multiple images per prompt amortizes the text encoding cost:
def generate_batch(pipe, prompt, count=4, seed=42):
generators = [
torch.Generator("cuda").manual_seed(seed + i)
for i in range(count)
]
images = pipe(
[prompt] * count,
generator=generators,
num_inference_steps=25,
).images
return images
On an A100, generating 4 images takes roughly 1.5x the time of generating 1 — significant throughput gains.
SDXL and model variants
Stable Diffusion XL uses a two-stage pipeline with a base model and optional refiner:
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
base = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
)
base = base.to("cuda")
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
torch_dtype=torch.float16,
variant="fp16",
)
refiner = refiner.to("cuda")
# Generate with base, then refine
base_image = base(
"professional photograph of alpine meadow",
num_inference_steps=40,
denoising_end=0.8,
output_type="latent",
).images
refined = refiner(
"professional photograph of alpine meadow",
image=base_image,
num_inference_steps=40,
denoising_start=0.8,
).images[0]
The base handles composition and structure; the refiner adds fine details and textures.
Building an API service
A production API wraps the pipeline with request queuing, health checks, and error handling:
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel, Field
import asyncio
import uuid
app = FastAPI()
request_queue: asyncio.Queue = asyncio.Queue(maxsize=100)
results: dict = {}
class GenerationRequest(BaseModel):
prompt: str
negative_prompt: str = ""
steps: int = Field(default=25, ge=10, le=50)
guidance_scale: float = Field(default=7.5, ge=1, le=20)
width: int = Field(default=512, ge=256, le=1024)
height: int = Field(default=512, ge=256, le=1024)
@app.post("/generate")
async def generate(req: GenerationRequest):
job_id = str(uuid.uuid4())
await request_queue.put((job_id, req))
return {"job_id": job_id, "status": "queued"}
@app.get("/result/{job_id}")
async def get_result(job_id: str):
if job_id not in results:
return {"status": "processing"}
return {"status": "complete", "image_url": results[job_id]}
GPU worker loop
A dedicated worker processes the queue sequentially, avoiding GPU contention:
async def gpu_worker():
pipe = load_pipeline() # initialized once
while True:
job_id, req = await request_queue.get()
try:
image = pipe(
req.prompt,
negative_prompt=req.negative_prompt,
num_inference_steps=req.steps,
guidance_scale=req.guidance_scale,
width=req.width,
height=req.height,
).images[0]
path = f"/outputs/{job_id}.png"
image.save(path)
results[job_id] = path
except Exception as e:
results[job_id] = f"error: {str(e)}"
Model management in multi-model environments
When serving multiple checkpoints (realistic, anime, architectural), swap models efficiently:
from diffusers import StableDiffusionPipeline
import gc
class ModelManager:
def __init__(self):
self.current_model = None
self.pipe = None
def load(self, model_id: str):
if model_id == self.current_model:
return self.pipe
# Clear previous model
if self.pipe is not None:
del self.pipe
gc.collect()
torch.cuda.empty_cache()
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
).to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.current_model = model_id
return self.pipe
Compilation with torch.compile
PyTorch 2.0+ can compile the U-Net for 20–40% speedup on supported hardware:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# First call triggers compilation (slow), subsequent calls are faster
image = pipe("warmup prompt").images[0]
Compilation adds 60–120 seconds on first run but pays for itself in production where the same pipeline handles thousands of requests.
Safety and content filtering
The default safety_checker catches some NSFW content but is far from comprehensive. Production systems typically add multiple layers:
# Custom post-generation filter
from transformers import pipeline
nsfw_classifier = pipeline(
"image-classification",
model="Falconsai/nsfw_image_detection"
)
def is_safe(image) -> bool:
result = nsfw_classifier(image)
return all(r["score"] < 0.8 for r in result if r["label"] == "nsfw")
Tradeoffs to consider
| Decision | Pro | Con |
|---|---|---|
| Float16 vs Float32 | Half the VRAM, faster | Rare precision artifacts |
| Safety checker on | Content filtering | +1 GB VRAM, 10% slower |
| xFormers attention | Faster, less VRAM | Extra dependency, build complexity |
| torch.compile | 20–40% faster after warmup | 2 min compilation, no dynamic shapes |
| CPU offloading | Fits on small GPUs | 2–3x slower generation |
One thing to remember: Production Stable Diffusion in Python is about managing GPU memory, batching intelligently, and wrapping the pipeline in infrastructure that handles queuing, model swapping, and safety filtering — the generation call itself is the easy part.
See Also
- Diffusion Models Stable Diffusion and DALL-E don't 'draw' your images — they unspoil a scrambled mess until a picture emerges. Here's the surprisingly simple idea behind it.
- Python Controlnet Image Control Find out how ControlNet lets you boss around an AI artist by giving it sketches, poses, and outlines to follow.
- Python Gan Training Patterns Learn how two neural networks compete like an art forger and a detective to create incredibly realistic fake images.
- Python Image Generation Pipelines Discover how Python chains together multiple steps to turn your ideas into polished AI-generated images, like a factory assembly line for pictures.
- Python Image Inpainting Learn how Python can magically fill in missing parts of a photo, like erasing something and having the picture fix itself.