Image Segmentation in Python — Deep Dive
Production image segmentation is where model accuracy meets latency budgets, memory constraints, and messy real-world data. This guide covers architecture choices, training strategies, and deployment patterns used in teams shipping segmentation models today.
Architecture landscape
U-Net and variants
U-Net introduced the symmetric encoder-decoder with skip connections. The encoder compresses spatial dimensions while increasing channel depth; the decoder mirrors the path back up, concatenating feature maps from corresponding encoder layers. Skip connections preserve fine-grained spatial detail that the bottleneck would otherwise lose.
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=3,
classes=5,
)
Modern variants include U-Net++ (nested skip connections), Attention U-Net (gating mechanisms on skips), and nnU-Net (self-configuring pipeline that auto-tunes preprocessing, architecture, and post-processing for medical datasets).
DeepLabV3+
DeepLab uses atrous (dilated) convolutions to capture multi-scale context without reducing resolution. DeepLabV3+ adds a lightweight decoder on top of the atrous spatial pyramid pooling (ASPP) module, recovering sharp object boundaries.
from torchvision.models.segmentation import deeplabv3_resnet101
model = deeplabv3_resnet101(pretrained=True)
model.classifier[4] = torch.nn.Conv2d(256, num_classes, 1)
DeepLabV3+ with a ResNet-101 backbone reaches ~80 mIoU on PASCAL VOC and runs inference at roughly 8 FPS on a consumer GPU.
Segment Anything Model (SAM)
SAM from Meta AI uses a Vision Transformer (ViT) image encoder, a prompt encoder (points, boxes, or text), and a lightweight mask decoder. It generalizes to novel objects without retraining, making it ideal for interactive annotation and zero-shot segmentation.
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h.pth")
predictor = SamPredictor(sam)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True,
)
SAM’s ViT-H encoder processes a 1024×1024 image in ~150ms on an A100. For edge deployment, MobileSAM and FastSAM distill the encoder down to under 10ms.
Training pipeline
Data preparation
Segmentation labels are pixel masks stored as single-channel PNGs where each pixel value encodes a class ID. Common formats include COCO (polygon JSON), Pascal VOC (color-indexed PNG), and custom binary masks.
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
class SegDataset(Dataset):
def __init__(self, image_paths, mask_paths, transform=None):
self.images = image_paths
self.masks = mask_paths
self.transform = transform
def __getitem__(self, idx):
image = np.array(Image.open(self.images[idx]).convert("RGB"))
mask = np.array(Image.open(self.masks[idx]))
if self.transform:
augmented = self.transform(image=image, mask=mask)
image, mask = augmented["image"], augmented["mask"]
return image, mask.long()
def __len__(self):
return len(self.images)
Loss functions
Cross-entropy loss is the baseline for multi-class segmentation. It treats each pixel as an independent classification problem.
Dice loss directly optimizes the Dice coefficient, handling class imbalance better than cross-entropy. A combined loss (CE + Dice) often outperforms either alone.
def dice_loss(pred, target, smooth=1.0):
pred = torch.softmax(pred, dim=1)
target_onehot = torch.nn.functional.one_hot(target, pred.shape[1])
target_onehot = target_onehot.permute(0, 3, 1, 2).float()
intersection = (pred * target_onehot).sum(dim=(2, 3))
union = pred.sum(dim=(2, 3)) + target_onehot.sum(dim=(2, 3))
dice = (2.0 * intersection + smooth) / (union + smooth)
return 1.0 - dice.mean()
Focal loss down-weights easy pixels to focus training on hard boundaries. Useful when most of the image is background.
Augmentation strategy
Geometric augmentations (flips, rotations, elastic deformation) and photometric ones (brightness, contrast, hue shifts) are standard. Libraries like Albumentations apply identical transforms to both image and mask atomically.
import albumentations as A
transform = A.Compose([
A.RandomResizedCrop(512, 512, scale=(0.5, 1.0)),
A.HorizontalFlip(p=0.5),
A.ElasticTransform(alpha=120, sigma=6, p=0.3),
A.ColorJitter(brightness=0.2, contrast=0.2, p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
A.pytorch.ToTensorV2(),
])
Optimization and deployment
Model quantization
INT8 quantization via TensorRT or ONNX Runtime typically halves inference time with less than 1 mIoU loss. Post-training quantization works well for segmentation because activations are relatively uniform across spatial dimensions.
Tiling for large images
Medical and satellite images can exceed 10,000×10,000 pixels. Process them in overlapping tiles (e.g., 512×512 with 64-pixel overlap), predict each tile, then stitch results using blending in the overlap zone to avoid boundary artifacts.
Serving architecture
A typical production flow:
- Client uploads image to an API gateway.
- Gateway sends the image to a GPU inference server (Triton, TorchServe, or BentoML).
- Server returns a compressed mask (run-length encoded or PNG).
- Client renders the mask as an overlay.
Batch requests improve GPU utilization. Dynamic batching in Triton collects requests over a short window (5–10ms) and processes them together.
Monitoring
Track inference latency (p50, p95, p99), mIoU on a held-out canary set, and prediction distribution drift (are certain classes disappearing from predictions over time?). A drop in mean prediction confidence often signals distribution shift before mIoU degrades.
Tradeoffs
| Approach | Pros | Cons |
|---|---|---|
| Classical (watershed, GrabCut) | No GPU needed, fast, interpretable | Struggles with complex scenes |
| CNN (U-Net, DeepLab) | Strong accuracy, well-understood | Needs labeled data, GPU for training |
| Transformer (SAM, SegFormer) | Zero-shot capability, state-of-art | Large models, higher latency |
| Knowledge distillation | Edge-deployable, fast | Accuracy loss vs. teacher model |
Real-world example
A pathology lab processes ~200 whole-slide images per day, each around 100,000×50,000 pixels. They use a U-Net with an EfficientNet-B4 backbone, tiled at 256×256 with 32-pixel overlap. Training used 800 annotated slides; inference runs on two NVIDIA T4 GPUs behind Triton with dynamic batching. End-to-end processing time per slide dropped from 12 minutes (manual annotation) to 45 seconds, with a Dice score of 0.91 on tumor regions.
The one thing to remember: Choosing a segmentation architecture is less about raw accuracy and more about matching your latency budget, labeling capacity, and deployment target — the best model is the one you can actually ship.
See Also
- Python Adaptive Learning Systems How Python builds learning apps that adjust to each student like a personal tutor who knows exactly what you need next.
- Python Airflow Learn Airflow as a timetable manager that makes sure data tasks run in the right order every day.
- Python Altair Learn Altair through the idea of drawing charts by describing rules, not by hand-placing every visual element.
- Python Automated Grading How Python grades homework and exams automatically, from simple answer keys to understanding written essays.
- Python Batch Vs Stream Processing Batch processing is like doing laundry once a week; stream processing is like a self-cleaning shirt that cleans itself constantly.