Neural Architecture Search with Python — Deep Dive

Differentiable Architecture Search (DARTS)

DARTS makes architecture search a continuous optimization problem. Instead of discrete “choose operation A or B,” each edge has a weighted mixture of all operations, and the weights are learned through gradient descent alongside the model weights.

DARTS Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

# Available operations in the search space
OPS = {
    "sep_conv_3x3": lambda C: SepConv(C, C, 3, 1),
    "sep_conv_5x5": lambda C: SepConv(C, C, 5, 2),
    "dil_conv_3x3": lambda C: DilConv(C, C, 3, 2),
    "max_pool_3x3": lambda C: nn.MaxPool2d(3, 1, 1),
    "avg_pool_3x3": lambda C: nn.AvgPool2d(3, 1, 1),
    "skip_connect": lambda C: nn.Identity(),
    "none": lambda C: Zero(),
}

class MixedOp(nn.Module):
    """Weighted mixture of all candidate operations."""
    def __init__(self, channels):
        super().__init__()
        self.ops = nn.ModuleList([
            OPS[name](channels) for name in OPS
        ])

    def forward(self, x, weights):
        # Weighted sum of all operations
        return sum(w * op(x) for w, op in zip(weights, self.ops))


class DARTSCell(nn.Module):
    """A searchable cell with learnable architecture parameters."""
    def __init__(self, channels, num_nodes=4):
        super().__init__()
        self.num_nodes = num_nodes

        # Mixed operations for each edge
        self.edges = nn.ModuleDict()
        for i in range(num_nodes):
            for j in range(i + 2):  # Each node receives from all previous nodes + 2 inputs
                self.edges[f"{j}_{i+2}"] = MixedOp(channels)

        # Architecture parameters (learnable)
        self.arch_params = nn.ParameterList([
            nn.Parameter(torch.randn(len(OPS)) * 1e-3)
            for _ in range(len(self.edges))
        ])

    def forward(self, s0, s1):
        states = [s0, s1]

        param_idx = 0
        for i in range(self.num_nodes):
            node_inputs = []
            for j in range(len(states)):
                edge_key = f"{j}_{i+2}"
                if edge_key in self.edges:
                    weights = F.softmax(self.arch_params[param_idx], dim=0)
                    node_inputs.append(
                        self.edges[edge_key](states[j], weights)
                    )
                    param_idx += 1
            states.append(sum(node_inputs))

        # Concatenate intermediate nodes as output
        return torch.cat(states[2:], dim=1)

Bi-Level Optimization

DARTS alternates between optimizing model weights (w) and architecture parameters (α):

class DARTSTrainer:
    def __init__(self, model, train_loader, val_loader):
        self.model = model

        # Separate optimizers for weights and architecture
        self.w_optimizer = torch.optim.SGD(
            model.weight_params(), lr=0.025, momentum=0.9, weight_decay=3e-4
        )
        self.alpha_optimizer = torch.optim.Adam(
            model.arch_params(), lr=3e-4, weight_decay=1e-3
        )

        self.train_iter = iter(train_loader)
        self.val_iter = iter(val_loader)

    def step(self):
        # Step 1: Update architecture params on validation data
        val_inputs, val_labels = next(self.val_iter)
        val_inputs, val_labels = val_inputs.cuda(), val_labels.cuda()

        self.alpha_optimizer.zero_grad()
        val_loss = F.cross_entropy(self.model(val_inputs), val_labels)
        val_loss.backward()
        self.alpha_optimizer.step()

        # Step 2: Update model weights on training data
        train_inputs, train_labels = next(self.train_iter)
        train_inputs, train_labels = train_inputs.cuda(), train_labels.cuda()

        self.w_optimizer.zero_grad()
        train_loss = F.cross_entropy(self.model(train_inputs), train_labels)
        train_loss.backward()
        self.w_optimizer.step()

        return train_loss.item(), val_loss.item()

    def derive_architecture(self):
        """Extract discrete architecture from continuous params."""
        genotype = []
        for node_idx in range(self.model.num_nodes):
            # For each node, keep top-2 edges (strongest connections)
            edge_scores = []
            for j, alpha in enumerate(node_alphas):
                best_op = torch.argmax(F.softmax(alpha, dim=0))
                score = F.softmax(alpha, dim=0)[best_op]
                edge_scores.append((score, j, list(OPS.keys())[best_op]))

            top_2 = sorted(edge_scores, reverse=True)[:2]
            for _, src, op_name in top_2:
                genotype.append((op_name, src, node_idx + 2))

        return genotype

NAS with NNI (Neural Network Intelligence)

Microsoft’s NNI provides a complete NAS framework:

# pip install nni

import nni
from nni.nas.nn.pytorch import ModelSpace, MutableConv2d, MutableLinear
import nni.nas.strategy as strategy
import nni.nas.evaluator.pytorch as evaluator

# Define a search space as a ModelSpace
class SearchableNet(ModelSpace):
    def __init__(self):
        super().__init__()

        # Mutable convolution: kernel size is searchable
        self.conv1 = MutableConv2d(
            3, 32,
            kernel_size=nni.choice("conv1_ks", [3, 5, 7]),
            padding="same"
        )

        # Mutable channel count
        mid_channels = nni.choice("mid_channels", [32, 64, 128])
        self.conv2 = MutableConv2d(32, mid_channels, 3, padding=1)
        self.conv3 = MutableConv2d(mid_channels, 64, 3, padding=1)

        # Mutable classifier depth
        self.fc = MutableLinear(64, 10)

        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)


# Configure evaluator
eval_fn = evaluator.Classification(
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    max_epochs=10,
    gpus=1
)

# Choose search strategy
search_strategy = strategy.RegularizedEvolution(population_size=50, sample_size=25)

# Alternative strategies:
# search_strategy = strategy.Random()
# search_strategy = strategy.TPE()
# search_strategy = strategy.DARTS()

# Run the search
from nni.nas.experiment import NasExperiment
experiment = NasExperiment(SearchableNet, eval_fn, search_strategy)
experiment.run(max_trial_number=200, port=8080)

# Get best architecture
best = experiment.export_top_models(top_k=1)[0]
print(f"Best architecture: {best}")

For simpler search spaces, Optuna provides an accessible approach:

import optuna
import torch
import torch.nn as nn

def create_model(trial):
    """Define searchable architecture with Optuna."""
    layers = []
    in_features = 784  # MNIST flattened

    # Search number of layers (2-5)
    n_layers = trial.suggest_int("n_layers", 2, 5)

    for i in range(n_layers):
        out_features = trial.suggest_int(f"n_units_l{i}", 32, 512, log=True)
        layers.append(nn.Linear(in_features, out_features))

        # Search activation function
        activation = trial.suggest_categorical(
            f"activation_l{i}", ["relu", "gelu", "silu"]
        )
        layers.append({"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}[activation]())

        # Search dropout rate
        dropout = trial.suggest_float(f"dropout_l{i}", 0.0, 0.5)
        if dropout > 0:
            layers.append(nn.Dropout(dropout))

        in_features = out_features

    layers.append(nn.Linear(in_features, 10))
    return nn.Sequential(*layers)


def objective(trial):
    model = create_model(trial).cuda()

    # Search optimizer hyperparameters too
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD", "AdamW"])
    optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)

    # Train and evaluate
    for epoch in range(20):
        train_one_epoch(model, optimizer, train_loader)

        accuracy = evaluate(model, val_loader)

        # Report intermediate value for pruning bad trials early
        trial.report(accuracy, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

    return accuracy


# Run search
study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.MedianPruner(n_warmup_steps=5)
)
study.optimize(objective, n_trials=100)

print(f"Best trial: {study.best_trial.params}")
print(f"Best accuracy: {study.best_trial.value:.4f}")

# Visualize search
optuna.visualization.plot_optimization_history(study)
optuna.visualization.plot_param_importances(study)
import random
import copy
from dataclasses import dataclass

@dataclass
class Architecture:
    genotype: dict
    fitness: float = 0.0

def random_architecture(search_space):
    """Create a random architecture from the search space."""
    return Architecture(genotype={
        key: random.choice(options)
        for key, options in search_space.items()
    })

def mutate(arch, search_space, mutation_rate=0.3):
    """Mutate an architecture by randomly changing some choices."""
    new_genotype = copy.deepcopy(arch.genotype)
    for key, options in search_space.items():
        if random.random() < mutation_rate:
            new_genotype[key] = random.choice(options)
    return Architecture(genotype=new_genotype)

def crossover(parent1, parent2):
    """Uniform crossover between two architectures."""
    child_genotype = {}
    for key in parent1.genotype:
        child_genotype[key] = random.choice([
            parent1.genotype[key],
            parent2.genotype[key]
        ])
    return Architecture(genotype=child_genotype)

def evolutionary_search(search_space, evaluate_fn,
                        population_size=50, generations=100,
                        tournament_size=5):
    """Regularized evolutionary architecture search."""

    # Initialize population
    population = [random_architecture(search_space) for _ in range(population_size)]

    # Evaluate initial population
    for arch in population:
        arch.fitness = evaluate_fn(arch.genotype)

    history = []

    for gen in range(generations):
        # Tournament selection
        candidates = random.sample(population, tournament_size)
        parent = max(candidates, key=lambda a: a.fitness)

        # Mutate to create child
        child = mutate(parent, search_space)
        child.fitness = evaluate_fn(child.genotype)

        # Add child, remove oldest
        population.append(child)
        population.pop(0)  # FIFO — regularized evolution

        best = max(population, key=lambda a: a.fitness)
        history.append(best.fitness)

        if gen % 10 == 0:
            print(f"Gen {gen}: best = {best.fitness:.4f}, "
                  f"genotype = {best.genotype}")

    return max(population, key=lambda a: a.fitness), history

Hardware-Aware Search with Latency Prediction

class LatencyPredictor:
    """Predict inference latency from architecture description."""

    def __init__(self, hardware="jetson_nano"):
        self.lookup_table = self._build_lookup_table(hardware)

    def _build_lookup_table(self, hardware):
        """Profile each operation on target hardware."""
        table = {}
        for op_name in OPS:
            for channels in [16, 32, 64, 128, 256]:
                for spatial in [56, 28, 14, 7]:
                    # Measure actual latency
                    op = OPS[op_name](channels)
                    latency = benchmark_op(op, (1, channels, spatial, spatial))
                    table[(op_name, channels, spatial)] = latency
        return table

    def predict(self, architecture):
        """Predict total latency for an architecture."""
        total = 0
        for layer in architecture:
            key = (layer.op_name, layer.channels, layer.spatial)
            total += self.lookup_table.get(key, 0)
        return total


def hardware_aware_objective(genotype, latency_target_ms=20.0):
    """Multi-objective: accuracy with latency penalty."""
    model = build_model(genotype)
    accuracy = train_and_evaluate(model)
    latency = latency_predictor.predict(genotype)

    if latency > latency_target_ms:
        # Soft penalty that increases with latency overshoot
        penalty = (latency / latency_target_ms) ** 1.5
        return accuracy / penalty
    return accuracy

Zero-Cost NAS Proxies

Estimate architecture quality without training:

def compute_synflow_score(model, input_shape):
    """SynFlow: training-free architecture scoring.
    
    Measures total synaptic flow through the network.
    Higher score correlates with better trainability.
    """
    # Set all parameters to positive
    for param in model.parameters():
        param.data = param.data.abs()

    # Forward pass with ones input
    model.eval()
    inputs = torch.ones(1, *input_shape).cuda()
    output = model(inputs)
    loss = output.sum()

    # Backward pass
    loss.backward()

    # Score = product of (param * grad) across all layers
    score = 0
    for param in model.parameters():
        if param.grad is not None:
            score += (param * param.grad).sum().item()

    return score


def rank_architectures_zero_cost(candidates, input_shape):
    """Rank candidate architectures without any training."""
    scores = []
    for genotype in candidates:
        model = build_model(genotype).cuda()
        score = compute_synflow_score(model, input_shape)
        scores.append((score, genotype))

    # Sort by score (higher is better)
    scores.sort(reverse=True)
    return scores

NAS Cost Comparison

MethodGPU-DaysAccuracy (CIFAR-10)Notes
NASNet (RL)45097.35%Original — prohibitively expensive
AmoebaNet (Evolution)15097.45%Better than RL, still costly
DARTS197.24%Breakthrough in efficiency
ProxylessNAS497.10%Hardware-aware
Zero-Cost NAS0.01~96.5%Seconds, not days
Random Search (strong baseline)496.9%Surprisingly competitive

The one thing to remember: Practical NAS in Python ranges from Optuna-based hyperparameter search (hours, accessible) to DARTS differentiable search (GPU-days, advanced) to evolutionary strategies with hardware-aware objectives — where the right approach depends on your compute budget, search space complexity, and whether you need to hit specific hardware latency targets.

pythonmachine-learningmodel-optimization

See Also