Policy Gradient Methods — Deep Dive
REINFORCE from scratch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym
import numpy as np
class PolicyNetwork(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions),
)
def forward(self, x: torch.Tensor) -> Categorical:
logits = self.net(x)
return Categorical(logits=logits)
def compute_returns(rewards: list[float], gamma: float = 0.99) -> torch.Tensor:
returns = []
G = 0.0
for r in reversed(rewards):
G = r + gamma * G
returns.insert(0, G)
returns = torch.tensor(returns, dtype=torch.float32)
# Normalise for variance reduction
return (returns - returns.mean()) / (returns.std() + 1e-8)
def train_reinforce(env_name="CartPole-v1", episodes=1000, gamma=0.99, lr=1e-3):
env = gym.make(env_name)
policy = PolicyNetwork(env.observation_space.shape[0], env.action_space.n)
optimizer = optim.Adam(policy.parameters(), lr=lr)
for ep in range(episodes):
state, _ = env.reset()
log_probs, rewards = [], []
done = False
while not done:
dist = policy(torch.FloatTensor(state))
action = dist.sample()
log_probs.append(dist.log_prob(action))
state, reward, terminated, truncated, _ = env.step(action.item())
rewards.append(reward)
done = terminated or truncated
returns = compute_returns(rewards, gamma)
loss = -torch.stack(log_probs) * returns
loss = loss.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if ep % 100 == 0:
print(f"Episode {ep}, Total reward: {sum(rewards):.0f}")
return policy
Why normalising returns matters
Without normalisation, an episode with total return 500 will push gradients 500x harder than one with return 1. Normalising to zero mean and unit variance ensures each batch contributes proportionally, regardless of the reward scale.
Generalised Advantage Estimation (GAE)
GAE (Schulman et al., 2015) balances bias and variance in advantage estimation. It computes a weighted average of n-step advantages:
def compute_gae(
rewards: torch.Tensor,
values: torch.Tensor,
next_values: torch.Tensor,
dones: torch.Tensor,
gamma: float = 0.99,
lam: float = 0.95,
) -> torch.Tensor:
"""Compute GAE advantages."""
T = len(rewards)
advantages = torch.zeros(T)
gae = 0.0
for t in reversed(range(T)):
delta = rewards[t] + gamma * next_values[t] * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages[t] = gae
return advantages
- λ = 0 reduces to one-step TD advantage (low variance, high bias).
- λ = 1 reduces to Monte Carlo returns minus baseline (high variance, low bias).
- λ = 0.95 is the standard sweet spot.
Actor-Critic (A2C) implementation
class ActorCritic(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.ReLU(),
)
self.actor_head = nn.Linear(hidden, n_actions)
self.critic_head = nn.Linear(hidden, 1)
def forward(self, x):
features = self.shared(x)
return Categorical(logits=self.actor_head(features)), self.critic_head(features)
def evaluate(self, states, actions):
features = self.shared(states)
dist = Categorical(logits=self.actor_head(features))
values = self.critic_head(features).squeeze(-1)
log_probs = dist.log_prob(actions)
entropy = dist.entropy()
return log_probs, values, entropy
A2C collects rollouts from multiple parallel environments, computes GAE advantages, and updates both actor and critic in one backward pass:
def a2c_update(model, optimizer, states, actions, returns, advantages,
value_coef=0.5, entropy_coef=0.01):
log_probs, values, entropy = model.evaluate(states, actions)
actor_loss = -(log_probs * advantages.detach()).mean()
critic_loss = nn.functional.mse_loss(values, returns)
entropy_bonus = entropy.mean()
loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy_bonus
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
optimizer.step()
return actor_loss.item(), critic_loss.item(), entropy_bonus.item()
PPO implementation
PPO adds the clipped surrogate objective on top of A2C:
def ppo_update(
model, optimizer, states, actions, old_log_probs, returns, advantages,
clip_eps=0.2, value_coef=0.5, entropy_coef=0.01, epochs=4, batch_size=64,
):
dataset_size = states.shape[0]
for _ in range(epochs):
indices = torch.randperm(dataset_size)
for start in range(0, dataset_size, batch_size):
idx = indices[start : start + batch_size]
b_states = states[idx]
b_actions = actions[idx]
b_old_lp = old_log_probs[idx]
b_returns = returns[idx]
b_adv = advantages[idx]
# Normalise advantages per mini-batch
b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)
log_probs, values, entropy = model.evaluate(b_states, b_actions)
ratio = torch.exp(log_probs - b_old_lp)
# Clipped surrogate
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_adv
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = nn.functional.mse_loss(values, b_returns)
entropy_bonus = entropy.mean()
loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy_bonus
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
optimizer.step()
Key PPO hyperparameters
| Parameter | Role | Typical value |
|---|---|---|
| clip_eps | Max policy change per update | 0.1–0.3 |
| n_steps | Rollout length before update | 2048 |
| epochs | Gradient steps per rollout | 3–10 |
| batch_size | Mini-batch size within epoch | 64–256 |
| entropy_coef | Encourages exploration | 0.0–0.01 |
| gamma | Discount factor | 0.99 |
| gae_lambda | GAE bias-variance trade-off | 0.95 |
| learning_rate | Step size | 3e-4 (with linear decay) |
Continuous action spaces
For continuous control, the policy outputs Gaussian parameters:
class GaussianPolicy(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
)
self.mean_head = nn.Linear(hidden, act_dim)
self.log_std = nn.Parameter(torch.zeros(act_dim)) # learnable
def forward(self, obs):
features = self.net(obs)
mean = self.mean_head(features)
std = self.log_std.exp()
return torch.distributions.Normal(mean, std)
Important: use log_std as a learnable parameter (not a network output) for stability. Initialize near zero so the policy starts moderately stochastic.
Diagnostic metrics
Track these during training to catch problems early:
| Metric | Healthy range | Problem indicator |
|---|---|---|
| Approx KL divergence | 0.01–0.05 | > 0.1 means updates are too large |
| Clip fraction | 0.1–0.3 | Too high means clip_eps is too tight |
| Entropy | Decreasing slowly | Collapsing to zero = premature convergence |
| Explained variance | 0.5–1.0 | < 0 means critic is worse than predicting mean |
| Gradient norm | Stable | Spikes indicate instability |
# Approx KL
with torch.no_grad():
approx_kl = (old_log_probs - log_probs).mean().item()
# Clip fraction
clip_frac = ((ratio - 1).abs() > clip_eps).float().mean().item()
# Explained variance
ev = 1 - (returns - values).var() / (returns.var() + 1e-8)
Common failure modes
- Entropy collapse — the policy becomes deterministic too quickly and gets stuck. Fix: increase
entropy_coefor add exploration noise. - Value function lag — the critic is poorly trained, producing bad advantage estimates. Fix: increase
value_coefor train critic with more epochs. - Reward hacking — the agent exploits reward shaping. Fix: validate with the true objective periodically.
- Advantage normalisation bug — normalising across the entire batch instead of per mini-batch can dilute signal. Always normalise per mini-batch in PPO.
The one thing to remember: Policy gradient methods — from REINFORCE through PPO — all share the same core: ∇ log π(a|s) × advantage. Everything else (baselines, GAE, clipping, entropy bonuses) is variance reduction and stability engineering around that one equation.
See Also
- Python Environment Wrappers How thin add-on layers let you change what a learning program sees and does without rewriting the game itself
- Python Monte Carlo Tree Search The clever trick behind AlphaGo — how a program explores millions of possible moves by playing quick random games against itself
- Python Multi Agent Reinforcement What happens when multiple programs learn together in the same world — cooperation, competition, and emergent teamwork
- Python Openai Gym Environments Why OpenAI Gym is the playground where robots and programs learn by trial and error — no prior coding knowledge needed
- Python Q Learning Implementation How a program builds a cheat sheet of every situation and every action to figure out the best move — no teacher required