Policy Gradient Methods — Deep Dive

REINFORCE from scratch

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym
import numpy as np

class PolicyNetwork(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )

    def forward(self, x: torch.Tensor) -> Categorical:
        logits = self.net(x)
        return Categorical(logits=logits)


def compute_returns(rewards: list[float], gamma: float = 0.99) -> torch.Tensor:
    returns = []
    G = 0.0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32)
    # Normalise for variance reduction
    return (returns - returns.mean()) / (returns.std() + 1e-8)


def train_reinforce(env_name="CartPole-v1", episodes=1000, gamma=0.99, lr=1e-3):
    env = gym.make(env_name)
    policy = PolicyNetwork(env.observation_space.shape[0], env.action_space.n)
    optimizer = optim.Adam(policy.parameters(), lr=lr)

    for ep in range(episodes):
        state, _ = env.reset()
        log_probs, rewards = [], []
        done = False

        while not done:
            dist = policy(torch.FloatTensor(state))
            action = dist.sample()
            log_probs.append(dist.log_prob(action))
            state, reward, terminated, truncated, _ = env.step(action.item())
            rewards.append(reward)
            done = terminated or truncated

        returns = compute_returns(rewards, gamma)
        loss = -torch.stack(log_probs) * returns
        loss = loss.sum()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if ep % 100 == 0:
            print(f"Episode {ep}, Total reward: {sum(rewards):.0f}")

    return policy

Why normalising returns matters

Without normalisation, an episode with total return 500 will push gradients 500x harder than one with return 1. Normalising to zero mean and unit variance ensures each batch contributes proportionally, regardless of the reward scale.

Generalised Advantage Estimation (GAE)

GAE (Schulman et al., 2015) balances bias and variance in advantage estimation. It computes a weighted average of n-step advantages:

def compute_gae(
    rewards: torch.Tensor,
    values: torch.Tensor,
    next_values: torch.Tensor,
    dones: torch.Tensor,
    gamma: float = 0.99,
    lam: float = 0.95,
) -> torch.Tensor:
    """Compute GAE advantages."""
    T = len(rewards)
    advantages = torch.zeros(T)
    gae = 0.0
    for t in reversed(range(T)):
        delta = rewards[t] + gamma * next_values[t] * (1 - dones[t]) - values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages[t] = gae
    return advantages
  • λ = 0 reduces to one-step TD advantage (low variance, high bias).
  • λ = 1 reduces to Monte Carlo returns minus baseline (high variance, low bias).
  • λ = 0.95 is the standard sweet spot.

Actor-Critic (A2C) implementation

class ActorCritic(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
        )
        self.actor_head = nn.Linear(hidden, n_actions)
        self.critic_head = nn.Linear(hidden, 1)

    def forward(self, x):
        features = self.shared(x)
        return Categorical(logits=self.actor_head(features)), self.critic_head(features)

    def evaluate(self, states, actions):
        features = self.shared(states)
        dist = Categorical(logits=self.actor_head(features))
        values = self.critic_head(features).squeeze(-1)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        return log_probs, values, entropy

A2C collects rollouts from multiple parallel environments, computes GAE advantages, and updates both actor and critic in one backward pass:

def a2c_update(model, optimizer, states, actions, returns, advantages,
               value_coef=0.5, entropy_coef=0.01):
    log_probs, values, entropy = model.evaluate(states, actions)

    actor_loss = -(log_probs * advantages.detach()).mean()
    critic_loss = nn.functional.mse_loss(values, returns)
    entropy_bonus = entropy.mean()

    loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy_bonus

    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
    optimizer.step()

    return actor_loss.item(), critic_loss.item(), entropy_bonus.item()

PPO implementation

PPO adds the clipped surrogate objective on top of A2C:

def ppo_update(
    model, optimizer, states, actions, old_log_probs, returns, advantages,
    clip_eps=0.2, value_coef=0.5, entropy_coef=0.01, epochs=4, batch_size=64,
):
    dataset_size = states.shape[0]

    for _ in range(epochs):
        indices = torch.randperm(dataset_size)
        for start in range(0, dataset_size, batch_size):
            idx = indices[start : start + batch_size]
            b_states = states[idx]
            b_actions = actions[idx]
            b_old_lp = old_log_probs[idx]
            b_returns = returns[idx]
            b_adv = advantages[idx]
            # Normalise advantages per mini-batch
            b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)

            log_probs, values, entropy = model.evaluate(b_states, b_actions)
            ratio = torch.exp(log_probs - b_old_lp)

            # Clipped surrogate
            surr1 = ratio * b_adv
            surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_adv
            actor_loss = -torch.min(surr1, surr2).mean()

            critic_loss = nn.functional.mse_loss(values, b_returns)
            entropy_bonus = entropy.mean()

            loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy_bonus

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()

Key PPO hyperparameters

ParameterRoleTypical value
clip_epsMax policy change per update0.1–0.3
n_stepsRollout length before update2048
epochsGradient steps per rollout3–10
batch_sizeMini-batch size within epoch64–256
entropy_coefEncourages exploration0.0–0.01
gammaDiscount factor0.99
gae_lambdaGAE bias-variance trade-off0.95
learning_rateStep size3e-4 (with linear decay)

Continuous action spaces

For continuous control, the policy outputs Gaussian parameters:

class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
        )
        self.mean_head = nn.Linear(hidden, act_dim)
        self.log_std = nn.Parameter(torch.zeros(act_dim))  # learnable

    def forward(self, obs):
        features = self.net(obs)
        mean = self.mean_head(features)
        std = self.log_std.exp()
        return torch.distributions.Normal(mean, std)

Important: use log_std as a learnable parameter (not a network output) for stability. Initialize near zero so the policy starts moderately stochastic.

Diagnostic metrics

Track these during training to catch problems early:

MetricHealthy rangeProblem indicator
Approx KL divergence0.01–0.05> 0.1 means updates are too large
Clip fraction0.1–0.3Too high means clip_eps is too tight
EntropyDecreasing slowlyCollapsing to zero = premature convergence
Explained variance0.5–1.0< 0 means critic is worse than predicting mean
Gradient normStableSpikes indicate instability
# Approx KL
with torch.no_grad():
    approx_kl = (old_log_probs - log_probs).mean().item()
# Clip fraction
clip_frac = ((ratio - 1).abs() > clip_eps).float().mean().item()
# Explained variance
ev = 1 - (returns - values).var() / (returns.var() + 1e-8)

Common failure modes

  1. Entropy collapse — the policy becomes deterministic too quickly and gets stuck. Fix: increase entropy_coef or add exploration noise.
  2. Value function lag — the critic is poorly trained, producing bad advantage estimates. Fix: increase value_coef or train critic with more epochs.
  3. Reward hacking — the agent exploits reward shaping. Fix: validate with the true objective periodically.
  4. Advantage normalisation bug — normalising across the entire batch instead of per mini-batch can dilute signal. Always normalise per mini-batch in PPO.

The one thing to remember: Policy gradient methods — from REINFORCE through PPO — all share the same core: ∇ log π(a|s) × advantage. Everything else (baselines, GAE, clipping, entropy bonuses) is variance reduction and stability engineering around that one equation.

pythonreinforcement-learningaipolicy-gradients

See Also