Q-Learning Implementation — Deep Dive
Production-grade tabular Q-Learning
The basic version works but has subtle issues. Here is a robust implementation with proper decay schedules and evaluation:
import numpy as np
import gymnasium as gym
from collections import defaultdict
class QLearningAgent:
def __init__(
self,
n_actions: int,
alpha: float = 0.1,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay_steps: int = 50_000,
):
self.n_actions = n_actions
self.alpha = alpha
self.gamma = gamma
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay_steps = epsilon_decay_steps
self.q_table: dict[tuple, np.ndarray] = defaultdict(
lambda: np.zeros(n_actions)
)
self.step_count = 0
@property
def epsilon(self) -> float:
frac = min(self.step_count / self.epsilon_decay_steps, 1.0)
return self.epsilon_start + frac * (self.epsilon_end - self.epsilon_start)
def select_action(self, state, greedy: bool = False) -> int:
if not greedy and np.random.random() < self.epsilon:
return np.random.randint(self.n_actions)
return int(np.argmax(self.q_table[state]))
def update(self, state, action, reward, next_state, terminated):
best_next = np.max(self.q_table[next_state])
target = reward + self.gamma * best_next * (1 - terminated)
td_error = target - self.q_table[state][action]
self.q_table[state][action] += self.alpha * td_error
self.step_count += 1
return td_error
Double Q-Learning (tabular)
Standard Q-Learning overestimates values because max is a biased estimator over noisy Q-values. Double Q-Learning maintains two tables and uses one to select and the other to evaluate:
class DoubleQLearningAgent:
def __init__(self, n_actions: int, alpha=0.1, gamma=0.99):
self.n_actions = n_actions
self.alpha = alpha
self.gamma = gamma
self.q1: dict[tuple, np.ndarray] = defaultdict(lambda: np.zeros(n_actions))
self.q2: dict[tuple, np.ndarray] = defaultdict(lambda: np.zeros(n_actions))
def select_action(self, state, epsilon: float) -> int:
if np.random.random() < epsilon:
return np.random.randint(self.n_actions)
combined = self.q1[state] + self.q2[state]
return int(np.argmax(combined))
def update(self, state, action, reward, next_state, terminated):
if np.random.random() < 0.5:
best_a = int(np.argmax(self.q1[next_state]))
target = reward + self.gamma * self.q2[next_state][best_a] * (1 - terminated)
self.q1[state][action] += self.alpha * (target - self.q1[state][action])
else:
best_a = int(np.argmax(self.q2[next_state]))
target = reward + self.gamma * self.q1[next_state][best_a] * (1 - terminated)
self.q2[state][action] += self.alpha * (target - self.q2[state][action])
Deep Q-Network (DQN)
When states are continuous or high-dimensional, replace the table with a neural network Q(s, a; θ) ≈ Q*(s, a).
Core DQN components
- Experience replay buffer — stores transitions and samples mini-batches to break temporal correlation.
- Target network — a frozen copy of the Q-network, updated periodically, to stabilise training targets.
- Epsilon-greedy exploration — same as tabular, but over network outputs.
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
class QNetwork(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions),
)
def forward(self, x):
return self.net(x)
class ReplayBuffer:
def __init__(self, capacity: int = 100_000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size: int):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(states),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(next_states),
torch.FloatTensor(dones),
)
def __len__(self):
return len(self.buffer)
class DQNAgent:
def __init__(
self,
obs_dim: int,
n_actions: int,
lr: float = 1e-3,
gamma: float = 0.99,
buffer_size: int = 100_000,
batch_size: int = 64,
target_update_freq: int = 1000,
):
self.n_actions = n_actions
self.gamma = gamma
self.batch_size = batch_size
self.target_update_freq = target_update_freq
self.q_net = QNetwork(obs_dim, n_actions)
self.target_net = QNetwork(obs_dim, n_actions)
self.target_net.load_state_dict(self.q_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
self.buffer = ReplayBuffer(buffer_size)
self.step_count = 0
def select_action(self, state, epsilon: float) -> int:
if random.random() < epsilon:
return random.randint(0, self.n_actions - 1)
with torch.no_grad():
q_values = self.q_net(torch.FloatTensor(state).unsqueeze(0))
return q_values.argmax(dim=1).item()
def train_step(self) -> float:
if len(self.buffer) < self.batch_size:
return 0.0
states, actions, rewards, next_states, dones = self.buffer.sample(
self.batch_size
)
# Current Q values
q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
# Target Q values (from frozen target network)
with torch.no_grad():
next_q = self.target_net(next_states).max(dim=1).values
targets = rewards + self.gamma * next_q * (1 - dones)
loss = nn.functional.mse_loss(q_values, targets)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)
self.optimizer.step()
# Update target network
self.step_count += 1
if self.step_count % self.target_update_freq == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
return loss.item()
Training loop
env = gym.make("CartPole-v1")
agent = DQNAgent(obs_dim=4, n_actions=2)
epsilon_start, epsilon_end, epsilon_decay = 1.0, 0.01, 50_000
for episode in range(1000):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
step = agent.step_count
eps = max(epsilon_end,
epsilon_start - (epsilon_start - epsilon_end) * step / epsilon_decay)
action = agent.select_action(state, eps)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.buffer.push(state, action, reward, next_state, float(terminated))
loss = agent.train_step()
state = next_state
total_reward += reward
if episode % 50 == 0:
print(f"Episode {episode}, Reward: {total_reward:.0f}, Eps: {eps:.3f}")
DQN improvements
Double DQN
Use the online network to select the action and the target network to evaluate it:
# Replace target computation in train_step:
with torch.no_grad():
best_actions = self.q_net(next_states).argmax(dim=1, keepdim=True)
next_q = self.target_net(next_states).gather(1, best_actions).squeeze(1)
targets = rewards + self.gamma * next_q * (1 - dones)
Dueling DQN
Split the network into value and advantage streams:
class DuelingQNetwork(nn.Module):
def __init__(self, obs_dim, n_actions, hidden=128):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.ReLU(),
)
self.value_stream = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, 1),
)
self.advantage_stream = nn.Sequential(
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, n_actions),
)
def forward(self, x):
features = self.shared(x)
value = self.value_stream(features)
advantage = self.advantage_stream(features)
# Subtract mean advantage for identifiability
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values
Prioritised experience replay
Sample transitions proportional to their TD error so the agent focuses on surprising experiences:
class PrioritisedReplayBuffer:
def __init__(self, capacity, alpha=0.6):
self.capacity = capacity
self.alpha = alpha
self.priorities = np.zeros(capacity, dtype=np.float32)
self.buffer = []
self.pos = 0
def push(self, *transition):
max_prio = self.priorities[:len(self.buffer)].max() if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append(transition)
else:
self.buffer[self.pos] = transition
self.priorities[self.pos] = max_prio
self.pos = (self.pos + 1) % self.capacity
def sample(self, batch_size, beta=0.4):
n = len(self.buffer)
probs = self.priorities[:n] ** self.alpha
probs /= probs.sum()
indices = np.random.choice(n, batch_size, p=probs, replace=False)
weights = (n * probs[indices]) ** (-beta)
weights /= weights.max()
batch = [self.buffer[i] for i in indices]
return batch, indices, torch.FloatTensor(weights)
def update_priorities(self, indices, td_errors):
for idx, td in zip(indices, td_errors):
self.priorities[idx] = abs(td) + 1e-6
Performance comparison
On CartPole-v1 (solved at 475+ avg reward over 100 episodes):
| Method | Episodes to solve | Notes |
|---|---|---|
| Tabular Q (discretised) | ~3,000 | Requires manual state binning |
| Vanilla DQN | ~300 | Stable with target network |
| Double DQN | ~250 | Reduces overestimation |
| Dueling Double DQN | ~200 | Better value estimation |
| + Prioritised replay | ~150 | Focuses on hard transitions |
Debugging checklist
- Reward not increasing? Check epsilon schedule — if it decays too fast, exploration stops prematurely.
- Q-values exploding? Gradient clipping and lower learning rate help. Also check gamma is not too close to 1.0 for environments with unbounded episode lengths.
- Target network sync? Verify the update frequency. Too frequent = instability (targets chase themselves). Too rare = stale targets.
- Replay buffer too small? The agent overfits to recent transitions. Start with 100K capacity.
- NaN in loss? Check for division by zero in priority weights or unnormalised observations.
The one thing to remember: Q-Learning’s journey from a simple table to DQN with double evaluation, dueling heads, and prioritised replay shows how one elegant idea — update Q toward reward plus best future Q — scales from toy problems to Atari-level challenges.
See Also
- Python Environment Wrappers How thin add-on layers let you change what a learning program sees and does without rewriting the game itself
- Python Monte Carlo Tree Search The clever trick behind AlphaGo — how a program explores millions of possible moves by playing quick random games against itself
- Python Multi Agent Reinforcement What happens when multiple programs learn together in the same world — cooperation, competition, and emergent teamwork
- Python Openai Gym Environments Why OpenAI Gym is the playground where robots and programs learn by trial and error — no prior coding knowledge needed
- Python Policy Gradient Methods Instead of scoring every move, what if the program just learned which moves feel right? That is policy gradients