PyTorch Custom Datasets — Deep Dive
Anatomy of a Production Dataset
A minimal custom dataset is straightforward, but production datasets handle edge cases: corrupted files, missing labels, variable-length sequences, and multi-modal inputs. Here’s a robust image classification dataset:
import torch
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
import logging
logger = logging.getLogger(__name__)
class ImageClassificationDataset(Dataset):
def __init__(self, root_dir: str, transform=None):
self.root = Path(root_dir)
self.transform = transform
self.samples: list[tuple[Path, int]] = []
self.class_to_idx: dict[str, int] = {}
# Build class mapping from directory structure
class_dirs = sorted(d for d in self.root.iterdir() if d.is_dir())
self.class_to_idx = {d.name: i for i, d in enumerate(class_dirs)}
for class_dir in class_dirs:
label = self.class_to_idx[class_dir.name]
for img_path in sorted(class_dir.glob("*")):
if img_path.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}:
self.samples.append((img_path, label))
logger.info(f"Loaded {len(self.samples)} samples across "
f"{len(self.class_to_idx)} classes")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
img_path, label = self.samples[idx]
try:
image = Image.open(img_path).convert("RGB")
except (OSError, IOError) as e:
logger.warning(f"Corrupt image {img_path}: {e}, using blank")
image = Image.new("RGB", (224, 224))
if self.transform:
image = self.transform(image)
return image, label
Key design decisions: lazy image loading (only paths stored in __init__), graceful handling of corrupt files, and sorted directory traversal for deterministic ordering.
Custom Collate Functions
When samples have variable sizes — text of different lengths, images of different resolutions, or nested structures — the default collate function fails. Write a custom one:
from torch.nn.utils.rnn import pad_sequence
def variable_length_collate(batch):
"""Collate sequences of different lengths with padding."""
sequences, labels = zip(*batch)
# Pad sequences to the length of the longest in this batch
padded = pad_sequence(sequences, batch_first=True, padding_value=0)
# Create attention masks (1 for real tokens, 0 for padding)
masks = torch.zeros_like(padded)
for i, seq in enumerate(sequences):
masks[i, :len(seq)] = 1
labels = torch.tensor(labels, dtype=torch.long)
return padded, masks, labels
# Usage
loader = DataLoader(dataset, batch_size=32, collate_fn=variable_length_collate)
This pattern is essential for NLP tasks where padding to a global maximum length wastes compute on short sequences.
Multi-Worker Loading Internals
Setting num_workers > 0 spawns separate processes, each with a copy of the dataset object. This has implications:
Shared memory. Workers communicate with the main process via shared memory queues. Large individual samples (high-res images, long sequences) increase serialization overhead. Profile with:
import time
loader = DataLoader(dataset, batch_size=32, num_workers=4)
start = time.perf_counter()
for batch in loader:
pass
elapsed = time.perf_counter() - start
print(f"Full epoch load: {elapsed:.1f}s ({len(dataset)/elapsed:.0f} samples/sec)")
Worker initialization. Use worker_init_fn to seed RNGs differently per worker, preventing duplicate augmentations:
def worker_init(worker_id):
seed = torch.initial_seed() % 2**32
import numpy as np
np.random.seed(seed + worker_id)
loader = DataLoader(dataset, num_workers=4, worker_init_fn=worker_init)
File handle limits. Each worker opens its own file handles. With num_workers=8 and a dataset backed by HDF5 or LMDB, you can hit OS file descriptor limits. Monitor with ulimit -n and increase if needed.
IterableDataset for Streaming Data
When data doesn’t support random access — streaming from S3, reading from a Kafka topic, or processing a 500GB text file — use IterableDataset:
from torch.utils.data import IterableDataset, get_worker_info
class StreamingTextDataset(IterableDataset):
def __init__(self, file_path: str, tokenizer, max_length: int = 512):
self.file_path = file_path
self.tokenizer = tokenizer
self.max_length = max_length
def __iter__(self):
worker_info = get_worker_info()
with open(self.file_path, "r") as f:
for i, line in enumerate(f):
# Shard across workers
if worker_info is not None:
if i % worker_info.num_workers != worker_info.id:
continue
tokens = self.tokenizer.encode(
line.strip(),
max_length=self.max_length,
truncation=True,
)
yield torch.tensor(tokens)
The get_worker_info() call is critical for multi-worker loading — without it, every worker yields the entire dataset, multiplying your data by num_workers.
Efficient Storage Backends
For high-throughput training, raw files on disk can bottleneck:
LMDB — Memory-mapped key-value store. Reads are nearly instant because the OS handles caching:
import lmdb
import pickle
class LMDBDataset(Dataset):
def __init__(self, db_path: str):
self.env = lmdb.open(db_path, readonly=True, lock=False)
with self.env.begin() as txn:
self.length = txn.stat()["entries"]
def __len__(self):
return self.length
def __getitem__(self, idx):
with self.env.begin() as txn:
data = txn.get(str(idx).encode())
return pickle.loads(data)
WebDataset — Tar-based format designed for sequential access and cloud storage. Popular for large-scale distributed training where data lives on S3 or GCS.
Hugging Face datasets — Arrow-backed, memory-mapped. Handles caching, splitting, and streaming with a clean API. Good default for NLP tasks.
Debugging Dataset Issues
Common failure modes and how to catch them:
# 1. Check a single sample loads correctly
sample = dataset[0]
print(type(sample), [s.shape if hasattr(s, 'shape') else s for s in sample])
# 2. Check all samples load (catches corrupt files)
from tqdm import tqdm
errors = []
for i in tqdm(range(len(dataset))):
try:
dataset[i]
except Exception as e:
errors.append((i, str(e)))
print(f"Errors: {len(errors)}/{len(dataset)}")
# 3. Verify DataLoader produces correct batch shapes
batch = next(iter(DataLoader(dataset, batch_size=4)))
for item in batch:
if hasattr(item, 'shape'):
print(item.shape, item.dtype)
Reproducibility
For reproducible training, control all sources of randomness in data loading:
generator = torch.Generator()
generator.manual_seed(42)
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
generator=generator,
num_workers=4,
worker_init_fn=worker_init,
)
The generator ensures identical shuffle order across runs. Combined with worker_init_fn and seeded transforms, you get bit-for-bit reproducible data pipelines.
The one thing to remember: Production datasets are 20% data access and 80% engineering — error handling, caching, parallel loading, and reproducibility determine whether training runs complete reliably or crash at 3 AM.
See Also
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
- Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'