Stable-Baselines3 — Deep Dive
Architecture overview
SB3 is built on three layers:
- Algorithms — PPO, SAC, TD3, etc. Each inherits from
BaseAlgorithm(on-policy) orOffPolicyAlgorithmand implementstrain()andcollect_rollouts()/_sample_action(). - Policies — neural networks that map observations to actions. Policies inherit from
BasePolicyand expose_predict(). - Buffers —
RolloutBufferfor on-policy,ReplayBufferfor off-policy, plusDictRolloutBufferandDictReplayBufferfor dictionary observations.
Understanding this stack lets you override exactly the right layer for custom behaviour.
Custom feature extractors
When observations are complex (images plus scalar state), write a custom feature extractor:
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from gymnasium import spaces
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
n_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_channels, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(),
)
# Compute flat size by forward pass
with torch.no_grad():
sample = torch.as_tensor(observation_space.sample()[None]).float()
n_flatten = self.cnn(sample).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.linear(self.cnn(observations))
model = PPO(
"CnnPolicy", env,
policy_kwargs=dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=256),
),
)
For multi-input (Dict observation spaces), subclass CombinedExtractor and process each key independently before concatenating.
Hyperparameter tuning with Optuna
SB3’s RL Zoo project bundles Optuna-based tuning. You can also do it manually:
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
def objective(trial):
lr = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
n_steps = trial.suggest_categorical("n_steps", [256, 512, 1024, 2048])
gamma = trial.suggest_float("gamma", 0.9, 0.9999, log=True)
ent_coef = trial.suggest_float("ent_coef", 1e-8, 0.1, log=True)
model = PPO(
"MlpPolicy", "LunarLander-v3",
learning_rate=lr, n_steps=n_steps,
gamma=gamma, ent_coef=ent_coef,
verbose=0,
)
model.learn(total_timesteps=50_000)
mean_reward, _ = evaluate_policy(model, model.get_env(), n_eval_episodes=20)
return mean_reward
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)
Key tuning parameters per algorithm:
| Algorithm | Critical hyperparameters |
|---|---|
| PPO | learning_rate, n_steps, batch_size, n_epochs, ent_coef, clip_range, gamma |
| SAC | learning_rate, buffer_size, batch_size, tau, gamma, ent_coef (auto or fixed) |
| DQN | learning_rate, buffer_size, exploration_fraction, target_update_interval |
Replay buffer strategies
Off-policy algorithms store transitions in a replay buffer. The default ReplayBuffer uses uniform sampling. For better sample efficiency:
- HER (Hindsight Experience Replay) — relabels failed episodes with achieved goals. Critical for sparse-reward robotics.
- Prioritised Experience Replay — not built into SB3 core but available through third-party extensions. Samples transitions proportional to their TD error.
Buffer sizing matters. Too small and you lose diversity; too large and you waste RAM:
model = SAC("MlpPolicy", env, buffer_size=1_000_000) # default
# For image observations, reduce to avoid OOM:
model = SAC("CnnPolicy", env, buffer_size=100_000,
optimize_memory_usage=True) # stores obs only once
Multi-environment training patterns
Separate train and eval environments
Always evaluate on a separate environment to avoid contaminating training statistics:
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
train_env = make_vec_env("HalfCheetah-v4", n_envs=8)
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True)
eval_env = make_vec_env("HalfCheetah-v4", n_envs=1)
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, training=False)
# Sync normalisation stats
eval_env.obs_rms = train_env.obs_rms
Normalisation pitfalls
VecNormalize tracks running statistics. If you save a model without saving the normalisation stats, the loaded model will see raw (unnormalised) observations and perform terribly. Always save and load together:
model.save("ppo_cheetah")
train_env.save("vec_normalize.pkl")
# Loading
model = PPO.load("ppo_cheetah")
env = VecNormalize.load("vec_normalize.pkl", make_vec_env("HalfCheetah-v4", n_envs=1))
Custom callbacks for advanced control
Beyond built-in callbacks, you can subclass BaseCallback:
from stable_baselines3.common.callbacks import BaseCallback
class CurriculumCallback(BaseCallback):
"""Increase environment difficulty based on success rate."""
def __init__(self, threshold: float = 0.8, verbose=0):
super().__init__(verbose)
self.threshold = threshold
def _on_step(self) -> bool:
if self.n_calls % 10_000 == 0:
# Evaluate and adjust difficulty
mean_reward = self._evaluate()
if mean_reward > self.threshold:
self.training_env.env_method("increase_difficulty")
return True # False would stop training
def _evaluate(self):
# Custom evaluation logic
return 0.0
Exporting for production
For inference outside Python or at low latency:
-
ONNX export — convert the policy network to ONNX for C++, Rust, or JavaScript runtimes:
import torch obs = torch.randn(1, *model.observation_space.shape) torch.onnx.export(model.policy, obs, "policy.onnx") -
TorchScript —
torch.jit.tracethe policy for use in C++ libtorch:traced = torch.jit.trace(model.policy, obs) traced.save("policy.pt") -
Quantisation — apply
torch.quantizationfor edge deployment. Test thoroughly since RL policies can be sensitive to precision loss.
Debugging training failures
When reward flatlines, check in this order:
- Reward signal — is it informative? Plot raw rewards per episode.
- Observation normalisation — unnormalised or badly scaled observations are the most common silent killer.
- Hyperparameters — use the RL Zoo’s known-good defaults as a starting point.
- Environment bugs — run
check_envand verify that repeated resets with the same seed produce identical trajectories. - Gradient norms — log them via TensorBoard; exploding gradients indicate learning rate is too high.
Performance benchmarks
Approximate training times on a single GPU (RTX 3080) with SB3 defaults:
| Task | Algorithm | 1M steps | Final reward |
|---|---|---|---|
| CartPole-v1 | PPO | ~2 min | 500 (solved) |
| LunarLander-v3 | PPO | ~8 min | ~250 |
| HalfCheetah-v4 | SAC | ~45 min | ~8000 |
| Humanoid-v4 | PPO | ~4 hours | ~5000 |
These are ballpark numbers. Vectorised environments and tuned hyperparameters can cut times significantly.
The one thing to remember: SB3’s real power is the clean separation of algorithm, policy, and buffer — understand that stack and you can customise anything from feature extraction to production deployment.
See Also
- Python Environment Wrappers How thin add-on layers let you change what a learning program sees and does without rewriting the game itself
- Python Monte Carlo Tree Search The clever trick behind AlphaGo — how a program explores millions of possible moves by playing quick random games against itself
- Python Multi Agent Reinforcement What happens when multiple programs learn together in the same world — cooperation, competition, and emergent teamwork
- Python Openai Gym Environments Why OpenAI Gym is the playground where robots and programs learn by trial and error — no prior coding knowledge needed
- Python Policy Gradient Methods Instead of scoring every move, what if the program just learned which moves feel right? That is policy gradients