Stable-Baselines3 — Deep Dive

Architecture overview

SB3 is built on three layers:

  1. Algorithms — PPO, SAC, TD3, etc. Each inherits from BaseAlgorithm (on-policy) or OffPolicyAlgorithm and implements train() and collect_rollouts() / _sample_action().
  2. Policiesneural networks that map observations to actions. Policies inherit from BasePolicy and expose _predict().
  3. BuffersRolloutBuffer for on-policy, ReplayBuffer for off-policy, plus DictRolloutBuffer and DictReplayBuffer for dictionary observations.

Understanding this stack lets you override exactly the right layer for custom behaviour.

Custom feature extractors

When observations are complex (images plus scalar state), write a custom feature extractor:

import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from gymnasium import spaces

class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        n_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        # Compute flat size by forward pass
        with torch.no_grad():
            sample = torch.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample).shape[1]
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

model = PPO(
    "CnnPolicy", env,
    policy_kwargs=dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=256),
    ),
)

For multi-input (Dict observation spaces), subclass CombinedExtractor and process each key independently before concatenating.

Hyperparameter tuning with Optuna

SB3’s RL Zoo project bundles Optuna-based tuning. You can also do it manually:

import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

def objective(trial):
    lr = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    n_steps = trial.suggest_categorical("n_steps", [256, 512, 1024, 2048])
    gamma = trial.suggest_float("gamma", 0.9, 0.9999, log=True)
    ent_coef = trial.suggest_float("ent_coef", 1e-8, 0.1, log=True)

    model = PPO(
        "MlpPolicy", "LunarLander-v3",
        learning_rate=lr, n_steps=n_steps,
        gamma=gamma, ent_coef=ent_coef,
        verbose=0,
    )
    model.learn(total_timesteps=50_000)
    mean_reward, _ = evaluate_policy(model, model.get_env(), n_eval_episodes=20)
    return mean_reward

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)

Key tuning parameters per algorithm:

AlgorithmCritical hyperparameters
PPOlearning_rate, n_steps, batch_size, n_epochs, ent_coef, clip_range, gamma
SAClearning_rate, buffer_size, batch_size, tau, gamma, ent_coef (auto or fixed)
DQNlearning_rate, buffer_size, exploration_fraction, target_update_interval

Replay buffer strategies

Off-policy algorithms store transitions in a replay buffer. The default ReplayBuffer uses uniform sampling. For better sample efficiency:

  • HER (Hindsight Experience Replay) — relabels failed episodes with achieved goals. Critical for sparse-reward robotics.
  • Prioritised Experience Replay — not built into SB3 core but available through third-party extensions. Samples transitions proportional to their TD error.

Buffer sizing matters. Too small and you lose diversity; too large and you waste RAM:

model = SAC("MlpPolicy", env, buffer_size=1_000_000)  # default
# For image observations, reduce to avoid OOM:
model = SAC("CnnPolicy", env, buffer_size=100_000,
            optimize_memory_usage=True)  # stores obs only once

Multi-environment training patterns

Separate train and eval environments

Always evaluate on a separate environment to avoid contaminating training statistics:

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize

train_env = make_vec_env("HalfCheetah-v4", n_envs=8)
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True)

eval_env = make_vec_env("HalfCheetah-v4", n_envs=1)
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, training=False)
# Sync normalisation stats
eval_env.obs_rms = train_env.obs_rms

Normalisation pitfalls

VecNormalize tracks running statistics. If you save a model without saving the normalisation stats, the loaded model will see raw (unnormalised) observations and perform terribly. Always save and load together:

model.save("ppo_cheetah")
train_env.save("vec_normalize.pkl")

# Loading
model = PPO.load("ppo_cheetah")
env = VecNormalize.load("vec_normalize.pkl", make_vec_env("HalfCheetah-v4", n_envs=1))

Custom callbacks for advanced control

Beyond built-in callbacks, you can subclass BaseCallback:

from stable_baselines3.common.callbacks import BaseCallback

class CurriculumCallback(BaseCallback):
    """Increase environment difficulty based on success rate."""
    def __init__(self, threshold: float = 0.8, verbose=0):
        super().__init__(verbose)
        self.threshold = threshold

    def _on_step(self) -> bool:
        if self.n_calls % 10_000 == 0:
            # Evaluate and adjust difficulty
            mean_reward = self._evaluate()
            if mean_reward > self.threshold:
                self.training_env.env_method("increase_difficulty")
        return True  # False would stop training

    def _evaluate(self):
        # Custom evaluation logic
        return 0.0

Exporting for production

For inference outside Python or at low latency:

  1. ONNX export — convert the policy network to ONNX for C++, Rust, or JavaScript runtimes:

    import torch
    obs = torch.randn(1, *model.observation_space.shape)
    torch.onnx.export(model.policy, obs, "policy.onnx")
  2. TorchScripttorch.jit.trace the policy for use in C++ libtorch:

    traced = torch.jit.trace(model.policy, obs)
    traced.save("policy.pt")
  3. Quantisation — apply torch.quantization for edge deployment. Test thoroughly since RL policies can be sensitive to precision loss.

Debugging training failures

When reward flatlines, check in this order:

  1. Reward signal — is it informative? Plot raw rewards per episode.
  2. Observation normalisation — unnormalised or badly scaled observations are the most common silent killer.
  3. Hyperparameters — use the RL Zoo’s known-good defaults as a starting point.
  4. Environment bugs — run check_env and verify that repeated resets with the same seed produce identical trajectories.
  5. Gradient norms — log them via TensorBoard; exploding gradients indicate learning rate is too high.

Performance benchmarks

Approximate training times on a single GPU (RTX 3080) with SB3 defaults:

TaskAlgorithm1M stepsFinal reward
CartPole-v1PPO~2 min500 (solved)
LunarLander-v3PPO~8 min~250
HalfCheetah-v4SAC~45 min~8000
Humanoid-v4PPO~4 hours~5000

These are ballpark numbers. Vectorised environments and tuned hyperparameters can cut times significantly.

The one thing to remember: SB3’s real power is the clean separation of algorithm, policy, and buffer — understand that stack and you can customise anything from feature extraction to production deployment.

pythonreinforcement-learningaideep-learning

See Also