PyTorch Custom Datasets — Deep Dive

Anatomy of a Production Dataset

A minimal custom dataset is straightforward, but production datasets handle edge cases: corrupted files, missing labels, variable-length sequences, and multi-modal inputs. Here’s a robust image classification dataset:

import torch
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
import logging

logger = logging.getLogger(__name__)

class ImageClassificationDataset(Dataset):
    def __init__(self, root_dir: str, transform=None):
        self.root = Path(root_dir)
        self.transform = transform
        self.samples: list[tuple[Path, int]] = []
        self.class_to_idx: dict[str, int] = {}

        # Build class mapping from directory structure
        class_dirs = sorted(d for d in self.root.iterdir() if d.is_dir())
        self.class_to_idx = {d.name: i for i, d in enumerate(class_dirs)}

        for class_dir in class_dirs:
            label = self.class_to_idx[class_dir.name]
            for img_path in sorted(class_dir.glob("*")):
                if img_path.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}:
                    self.samples.append((img_path, label))

        logger.info(f"Loaded {len(self.samples)} samples across "
                     f"{len(self.class_to_idx)} classes")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except (OSError, IOError) as e:
            logger.warning(f"Corrupt image {img_path}: {e}, using blank")
            image = Image.new("RGB", (224, 224))

        if self.transform:
            image = self.transform(image)

        return image, label

Key design decisions: lazy image loading (only paths stored in __init__), graceful handling of corrupt files, and sorted directory traversal for deterministic ordering.

Custom Collate Functions

When samples have variable sizes — text of different lengths, images of different resolutions, or nested structures — the default collate function fails. Write a custom one:

from torch.nn.utils.rnn import pad_sequence

def variable_length_collate(batch):
    """Collate sequences of different lengths with padding."""
    sequences, labels = zip(*batch)

    # Pad sequences to the length of the longest in this batch
    padded = pad_sequence(sequences, batch_first=True, padding_value=0)

    # Create attention masks (1 for real tokens, 0 for padding)
    masks = torch.zeros_like(padded)
    for i, seq in enumerate(sequences):
        masks[i, :len(seq)] = 1

    labels = torch.tensor(labels, dtype=torch.long)
    return padded, masks, labels

# Usage
loader = DataLoader(dataset, batch_size=32, collate_fn=variable_length_collate)

This pattern is essential for NLP tasks where padding to a global maximum length wastes compute on short sequences.

Multi-Worker Loading Internals

Setting num_workers > 0 spawns separate processes, each with a copy of the dataset object. This has implications:

Shared memory. Workers communicate with the main process via shared memory queues. Large individual samples (high-res images, long sequences) increase serialization overhead. Profile with:

import time

loader = DataLoader(dataset, batch_size=32, num_workers=4)
start = time.perf_counter()
for batch in loader:
    pass
elapsed = time.perf_counter() - start
print(f"Full epoch load: {elapsed:.1f}s ({len(dataset)/elapsed:.0f} samples/sec)")

Worker initialization. Use worker_init_fn to seed RNGs differently per worker, preventing duplicate augmentations:

def worker_init(worker_id):
    seed = torch.initial_seed() % 2**32
    import numpy as np
    np.random.seed(seed + worker_id)

loader = DataLoader(dataset, num_workers=4, worker_init_fn=worker_init)

File handle limits. Each worker opens its own file handles. With num_workers=8 and a dataset backed by HDF5 or LMDB, you can hit OS file descriptor limits. Monitor with ulimit -n and increase if needed.

IterableDataset for Streaming Data

When data doesn’t support random access — streaming from S3, reading from a Kafka topic, or processing a 500GB text file — use IterableDataset:

from torch.utils.data import IterableDataset, get_worker_info

class StreamingTextDataset(IterableDataset):
    def __init__(self, file_path: str, tokenizer, max_length: int = 512):
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __iter__(self):
        worker_info = get_worker_info()

        with open(self.file_path, "r") as f:
            for i, line in enumerate(f):
                # Shard across workers
                if worker_info is not None:
                    if i % worker_info.num_workers != worker_info.id:
                        continue

                tokens = self.tokenizer.encode(
                    line.strip(),
                    max_length=self.max_length,
                    truncation=True,
                )
                yield torch.tensor(tokens)

The get_worker_info() call is critical for multi-worker loading — without it, every worker yields the entire dataset, multiplying your data by num_workers.

Efficient Storage Backends

For high-throughput training, raw files on disk can bottleneck:

LMDB — Memory-mapped key-value store. Reads are nearly instant because the OS handles caching:

import lmdb
import pickle

class LMDBDataset(Dataset):
    def __init__(self, db_path: str):
        self.env = lmdb.open(db_path, readonly=True, lock=False)
        with self.env.begin() as txn:
            self.length = txn.stat()["entries"]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with self.env.begin() as txn:
            data = txn.get(str(idx).encode())
        return pickle.loads(data)

WebDataset — Tar-based format designed for sequential access and cloud storage. Popular for large-scale distributed training where data lives on S3 or GCS.

Hugging Face datasets — Arrow-backed, memory-mapped. Handles caching, splitting, and streaming with a clean API. Good default for NLP tasks.

Debugging Dataset Issues

Common failure modes and how to catch them:

# 1. Check a single sample loads correctly
sample = dataset[0]
print(type(sample), [s.shape if hasattr(s, 'shape') else s for s in sample])

# 2. Check all samples load (catches corrupt files)
from tqdm import tqdm
errors = []
for i in tqdm(range(len(dataset))):
    try:
        dataset[i]
    except Exception as e:
        errors.append((i, str(e)))
print(f"Errors: {len(errors)}/{len(dataset)}")

# 3. Verify DataLoader produces correct batch shapes
batch = next(iter(DataLoader(dataset, batch_size=4)))
for item in batch:
    if hasattr(item, 'shape'):
        print(item.shape, item.dtype)

Reproducibility

For reproducible training, control all sources of randomness in data loading:

generator = torch.Generator()
generator.manual_seed(42)

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    generator=generator,
    num_workers=4,
    worker_init_fn=worker_init,
)

The generator ensures identical shuffle order across runs. Combined with worker_init_fn and seeded transforms, you get bit-for-bit reproducible data pipelines.

The one thing to remember: Production datasets are 20% data access and 80% engineering — error handling, caching, parallel loading, and reproducibility determine whether training runs complete reliably or crash at 3 AM.

pythonmachine-learningpytorch

See Also

  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
  • Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'