Neural Architecture Search with Python — Deep Dive
Differentiable Architecture Search (DARTS)
DARTS makes architecture search a continuous optimization problem. Instead of discrete “choose operation A or B,” each edge has a weighted mixture of all operations, and the weights are learned through gradient descent alongside the model weights.
DARTS Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
# Available operations in the search space
OPS = {
"sep_conv_3x3": lambda C: SepConv(C, C, 3, 1),
"sep_conv_5x5": lambda C: SepConv(C, C, 5, 2),
"dil_conv_3x3": lambda C: DilConv(C, C, 3, 2),
"max_pool_3x3": lambda C: nn.MaxPool2d(3, 1, 1),
"avg_pool_3x3": lambda C: nn.AvgPool2d(3, 1, 1),
"skip_connect": lambda C: nn.Identity(),
"none": lambda C: Zero(),
}
class MixedOp(nn.Module):
"""Weighted mixture of all candidate operations."""
def __init__(self, channels):
super().__init__()
self.ops = nn.ModuleList([
OPS[name](channels) for name in OPS
])
def forward(self, x, weights):
# Weighted sum of all operations
return sum(w * op(x) for w, op in zip(weights, self.ops))
class DARTSCell(nn.Module):
"""A searchable cell with learnable architecture parameters."""
def __init__(self, channels, num_nodes=4):
super().__init__()
self.num_nodes = num_nodes
# Mixed operations for each edge
self.edges = nn.ModuleDict()
for i in range(num_nodes):
for j in range(i + 2): # Each node receives from all previous nodes + 2 inputs
self.edges[f"{j}_{i+2}"] = MixedOp(channels)
# Architecture parameters (learnable)
self.arch_params = nn.ParameterList([
nn.Parameter(torch.randn(len(OPS)) * 1e-3)
for _ in range(len(self.edges))
])
def forward(self, s0, s1):
states = [s0, s1]
param_idx = 0
for i in range(self.num_nodes):
node_inputs = []
for j in range(len(states)):
edge_key = f"{j}_{i+2}"
if edge_key in self.edges:
weights = F.softmax(self.arch_params[param_idx], dim=0)
node_inputs.append(
self.edges[edge_key](states[j], weights)
)
param_idx += 1
states.append(sum(node_inputs))
# Concatenate intermediate nodes as output
return torch.cat(states[2:], dim=1)
Bi-Level Optimization
DARTS alternates between optimizing model weights (w) and architecture parameters (α):
class DARTSTrainer:
def __init__(self, model, train_loader, val_loader):
self.model = model
# Separate optimizers for weights and architecture
self.w_optimizer = torch.optim.SGD(
model.weight_params(), lr=0.025, momentum=0.9, weight_decay=3e-4
)
self.alpha_optimizer = torch.optim.Adam(
model.arch_params(), lr=3e-4, weight_decay=1e-3
)
self.train_iter = iter(train_loader)
self.val_iter = iter(val_loader)
def step(self):
# Step 1: Update architecture params on validation data
val_inputs, val_labels = next(self.val_iter)
val_inputs, val_labels = val_inputs.cuda(), val_labels.cuda()
self.alpha_optimizer.zero_grad()
val_loss = F.cross_entropy(self.model(val_inputs), val_labels)
val_loss.backward()
self.alpha_optimizer.step()
# Step 2: Update model weights on training data
train_inputs, train_labels = next(self.train_iter)
train_inputs, train_labels = train_inputs.cuda(), train_labels.cuda()
self.w_optimizer.zero_grad()
train_loss = F.cross_entropy(self.model(train_inputs), train_labels)
train_loss.backward()
self.w_optimizer.step()
return train_loss.item(), val_loss.item()
def derive_architecture(self):
"""Extract discrete architecture from continuous params."""
genotype = []
for node_idx in range(self.model.num_nodes):
# For each node, keep top-2 edges (strongest connections)
edge_scores = []
for j, alpha in enumerate(node_alphas):
best_op = torch.argmax(F.softmax(alpha, dim=0))
score = F.softmax(alpha, dim=0)[best_op]
edge_scores.append((score, j, list(OPS.keys())[best_op]))
top_2 = sorted(edge_scores, reverse=True)[:2]
for _, src, op_name in top_2:
genotype.append((op_name, src, node_idx + 2))
return genotype
NAS with NNI (Neural Network Intelligence)
Microsoft’s NNI provides a complete NAS framework:
# pip install nni
import nni
from nni.nas.nn.pytorch import ModelSpace, MutableConv2d, MutableLinear
import nni.nas.strategy as strategy
import nni.nas.evaluator.pytorch as evaluator
# Define a search space as a ModelSpace
class SearchableNet(ModelSpace):
def __init__(self):
super().__init__()
# Mutable convolution: kernel size is searchable
self.conv1 = MutableConv2d(
3, 32,
kernel_size=nni.choice("conv1_ks", [3, 5, 7]),
padding="same"
)
# Mutable channel count
mid_channels = nni.choice("mid_channels", [32, 64, 128])
self.conv2 = MutableConv2d(32, mid_channels, 3, padding=1)
self.conv3 = MutableConv2d(mid_channels, 64, 3, padding=1)
# Mutable classifier depth
self.fc = MutableLinear(64, 10)
self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.pool(x).flatten(1)
return self.fc(x)
# Configure evaluator
eval_fn = evaluator.Classification(
train_dataloaders=train_loader,
val_dataloaders=val_loader,
max_epochs=10,
gpus=1
)
# Choose search strategy
search_strategy = strategy.RegularizedEvolution(population_size=50, sample_size=25)
# Alternative strategies:
# search_strategy = strategy.Random()
# search_strategy = strategy.TPE()
# search_strategy = strategy.DARTS()
# Run the search
from nni.nas.experiment import NasExperiment
experiment = NasExperiment(SearchableNet, eval_fn, search_strategy)
experiment.run(max_trial_number=200, port=8080)
# Get best architecture
best = experiment.export_top_models(top_k=1)[0]
print(f"Best architecture: {best}")
Optuna-Based Architecture Search
For simpler search spaces, Optuna provides an accessible approach:
import optuna
import torch
import torch.nn as nn
def create_model(trial):
"""Define searchable architecture with Optuna."""
layers = []
in_features = 784 # MNIST flattened
# Search number of layers (2-5)
n_layers = trial.suggest_int("n_layers", 2, 5)
for i in range(n_layers):
out_features = trial.suggest_int(f"n_units_l{i}", 32, 512, log=True)
layers.append(nn.Linear(in_features, out_features))
# Search activation function
activation = trial.suggest_categorical(
f"activation_l{i}", ["relu", "gelu", "silu"]
)
layers.append({"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}[activation]())
# Search dropout rate
dropout = trial.suggest_float(f"dropout_l{i}", 0.0, 0.5)
if dropout > 0:
layers.append(nn.Dropout(dropout))
in_features = out_features
layers.append(nn.Linear(in_features, 10))
return nn.Sequential(*layers)
def objective(trial):
model = create_model(trial).cuda()
# Search optimizer hyperparameters too
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD", "AdamW"])
optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)
# Train and evaluate
for epoch in range(20):
train_one_epoch(model, optimizer, train_loader)
accuracy = evaluate(model, val_loader)
# Report intermediate value for pruning bad trials early
trial.report(accuracy, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
return accuracy
# Run search
study = optuna.create_study(
direction="maximize",
pruner=optuna.pruners.MedianPruner(n_warmup_steps=5)
)
study.optimize(objective, n_trials=100)
print(f"Best trial: {study.best_trial.params}")
print(f"Best accuracy: {study.best_trial.value:.4f}")
# Visualize search
optuna.visualization.plot_optimization_history(study)
optuna.visualization.plot_param_importances(study)
Evolutionary Architecture Search
import random
import copy
from dataclasses import dataclass
@dataclass
class Architecture:
genotype: dict
fitness: float = 0.0
def random_architecture(search_space):
"""Create a random architecture from the search space."""
return Architecture(genotype={
key: random.choice(options)
for key, options in search_space.items()
})
def mutate(arch, search_space, mutation_rate=0.3):
"""Mutate an architecture by randomly changing some choices."""
new_genotype = copy.deepcopy(arch.genotype)
for key, options in search_space.items():
if random.random() < mutation_rate:
new_genotype[key] = random.choice(options)
return Architecture(genotype=new_genotype)
def crossover(parent1, parent2):
"""Uniform crossover between two architectures."""
child_genotype = {}
for key in parent1.genotype:
child_genotype[key] = random.choice([
parent1.genotype[key],
parent2.genotype[key]
])
return Architecture(genotype=child_genotype)
def evolutionary_search(search_space, evaluate_fn,
population_size=50, generations=100,
tournament_size=5):
"""Regularized evolutionary architecture search."""
# Initialize population
population = [random_architecture(search_space) for _ in range(population_size)]
# Evaluate initial population
for arch in population:
arch.fitness = evaluate_fn(arch.genotype)
history = []
for gen in range(generations):
# Tournament selection
candidates = random.sample(population, tournament_size)
parent = max(candidates, key=lambda a: a.fitness)
# Mutate to create child
child = mutate(parent, search_space)
child.fitness = evaluate_fn(child.genotype)
# Add child, remove oldest
population.append(child)
population.pop(0) # FIFO — regularized evolution
best = max(population, key=lambda a: a.fitness)
history.append(best.fitness)
if gen % 10 == 0:
print(f"Gen {gen}: best = {best.fitness:.4f}, "
f"genotype = {best.genotype}")
return max(population, key=lambda a: a.fitness), history
Hardware-Aware Search with Latency Prediction
class LatencyPredictor:
"""Predict inference latency from architecture description."""
def __init__(self, hardware="jetson_nano"):
self.lookup_table = self._build_lookup_table(hardware)
def _build_lookup_table(self, hardware):
"""Profile each operation on target hardware."""
table = {}
for op_name in OPS:
for channels in [16, 32, 64, 128, 256]:
for spatial in [56, 28, 14, 7]:
# Measure actual latency
op = OPS[op_name](channels)
latency = benchmark_op(op, (1, channels, spatial, spatial))
table[(op_name, channels, spatial)] = latency
return table
def predict(self, architecture):
"""Predict total latency for an architecture."""
total = 0
for layer in architecture:
key = (layer.op_name, layer.channels, layer.spatial)
total += self.lookup_table.get(key, 0)
return total
def hardware_aware_objective(genotype, latency_target_ms=20.0):
"""Multi-objective: accuracy with latency penalty."""
model = build_model(genotype)
accuracy = train_and_evaluate(model)
latency = latency_predictor.predict(genotype)
if latency > latency_target_ms:
# Soft penalty that increases with latency overshoot
penalty = (latency / latency_target_ms) ** 1.5
return accuracy / penalty
return accuracy
Zero-Cost NAS Proxies
Estimate architecture quality without training:
def compute_synflow_score(model, input_shape):
"""SynFlow: training-free architecture scoring.
Measures total synaptic flow through the network.
Higher score correlates with better trainability.
"""
# Set all parameters to positive
for param in model.parameters():
param.data = param.data.abs()
# Forward pass with ones input
model.eval()
inputs = torch.ones(1, *input_shape).cuda()
output = model(inputs)
loss = output.sum()
# Backward pass
loss.backward()
# Score = product of (param * grad) across all layers
score = 0
for param in model.parameters():
if param.grad is not None:
score += (param * param.grad).sum().item()
return score
def rank_architectures_zero_cost(candidates, input_shape):
"""Rank candidate architectures without any training."""
scores = []
for genotype in candidates:
model = build_model(genotype).cuda()
score = compute_synflow_score(model, input_shape)
scores.append((score, genotype))
# Sort by score (higher is better)
scores.sort(reverse=True)
return scores
NAS Cost Comparison
| Method | GPU-Days | Accuracy (CIFAR-10) | Notes |
|---|---|---|---|
| NASNet (RL) | 450 | 97.35% | Original — prohibitively expensive |
| AmoebaNet (Evolution) | 150 | 97.45% | Better than RL, still costly |
| DARTS | 1 | 97.24% | Breakthrough in efficiency |
| ProxylessNAS | 4 | 97.10% | Hardware-aware |
| Zero-Cost NAS | 0.01 | ~96.5% | Seconds, not days |
| Random Search (strong baseline) | 4 | 96.9% | Surprisingly competitive |
The one thing to remember: Practical NAS in Python ranges from Optuna-based hyperparameter search (hours, accessible) to DARTS differentiable search (GPU-days, advanced) to evolutionary strategies with hardware-aware objectives — where the right approach depends on your compute budget, search space complexity, and whether you need to hit specific hardware latency targets.
See Also
- Python Hyperparameter Tuning Learn why adjusting the dials on a computer's learning recipe makes predictions way better.
- Python Knowledge Distillation How a big expert AI teaches a tiny student AI to be almost as smart — like a professor writing a cheat sheet for an exam.
- Python Model Compression Methods All the ways Python developers shrink massive AI models to fit on phones and tiny devices — like packing for a trip with a carry-on bag.
- Python Model Pruning Techniques Why cutting away parts of an AI's brain can make it faster without making it dumber.
- Python Pytorch Quantization How shrinking numbers inside an AI model makes it run faster on phones and cheaper servers without losing much accuracy.