Q-Learning Implementation — Deep Dive

Production-grade tabular Q-Learning

The basic version works but has subtle issues. Here is a robust implementation with proper decay schedules and evaluation:

import numpy as np
import gymnasium as gym
from collections import defaultdict

class QLearningAgent:
    def __init__(
        self,
        n_actions: int,
        alpha: float = 0.1,
        gamma: float = 0.99,
        epsilon_start: float = 1.0,
        epsilon_end: float = 0.01,
        epsilon_decay_steps: int = 50_000,
    ):
        self.n_actions = n_actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay_steps = epsilon_decay_steps
        self.q_table: dict[tuple, np.ndarray] = defaultdict(
            lambda: np.zeros(n_actions)
        )
        self.step_count = 0

    @property
    def epsilon(self) -> float:
        frac = min(self.step_count / self.epsilon_decay_steps, 1.0)
        return self.epsilon_start + frac * (self.epsilon_end - self.epsilon_start)

    def select_action(self, state, greedy: bool = False) -> int:
        if not greedy and np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        return int(np.argmax(self.q_table[state]))

    def update(self, state, action, reward, next_state, terminated):
        best_next = np.max(self.q_table[next_state])
        target = reward + self.gamma * best_next * (1 - terminated)
        td_error = target - self.q_table[state][action]
        self.q_table[state][action] += self.alpha * td_error
        self.step_count += 1
        return td_error

Double Q-Learning (tabular)

Standard Q-Learning overestimates values because max is a biased estimator over noisy Q-values. Double Q-Learning maintains two tables and uses one to select and the other to evaluate:

class DoubleQLearningAgent:
    def __init__(self, n_actions: int, alpha=0.1, gamma=0.99):
        self.n_actions = n_actions
        self.alpha = alpha
        self.gamma = gamma
        self.q1: dict[tuple, np.ndarray] = defaultdict(lambda: np.zeros(n_actions))
        self.q2: dict[tuple, np.ndarray] = defaultdict(lambda: np.zeros(n_actions))

    def select_action(self, state, epsilon: float) -> int:
        if np.random.random() < epsilon:
            return np.random.randint(self.n_actions)
        combined = self.q1[state] + self.q2[state]
        return int(np.argmax(combined))

    def update(self, state, action, reward, next_state, terminated):
        if np.random.random() < 0.5:
            best_a = int(np.argmax(self.q1[next_state]))
            target = reward + self.gamma * self.q2[next_state][best_a] * (1 - terminated)
            self.q1[state][action] += self.alpha * (target - self.q1[state][action])
        else:
            best_a = int(np.argmax(self.q2[next_state]))
            target = reward + self.gamma * self.q1[next_state][best_a] * (1 - terminated)
            self.q2[state][action] += self.alpha * (target - self.q2[state][action])

Deep Q-Network (DQN)

When states are continuous or high-dimensional, replace the table with a neural network Q(s, a; θ) ≈ Q*(s, a).

Core DQN components

  1. Experience replay buffer — stores transitions and samples mini-batches to break temporal correlation.
  2. Target network — a frozen copy of the Q-network, updated periodically, to stabilise training targets.
  3. Epsilon-greedy exploration — same as tabular, but over network outputs.
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )

    def forward(self, x):
        return self.net(x)


class ReplayBuffer:
    def __init__(self, capacity: int = 100_000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(states),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(next_states),
            torch.FloatTensor(dones),
        )

    def __len__(self):
        return len(self.buffer)


class DQNAgent:
    def __init__(
        self,
        obs_dim: int,
        n_actions: int,
        lr: float = 1e-3,
        gamma: float = 0.99,
        buffer_size: int = 100_000,
        batch_size: int = 64,
        target_update_freq: int = 1000,
    ):
        self.n_actions = n_actions
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq

        self.q_net = QNetwork(obs_dim, n_actions)
        self.target_net = QNetwork(obs_dim, n_actions)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.buffer = ReplayBuffer(buffer_size)
        self.step_count = 0

    def select_action(self, state, epsilon: float) -> int:
        if random.random() < epsilon:
            return random.randint(0, self.n_actions - 1)
        with torch.no_grad():
            q_values = self.q_net(torch.FloatTensor(state).unsqueeze(0))
            return q_values.argmax(dim=1).item()

    def train_step(self) -> float:
        if len(self.buffer) < self.batch_size:
            return 0.0

        states, actions, rewards, next_states, dones = self.buffer.sample(
            self.batch_size
        )

        # Current Q values
        q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        # Target Q values (from frozen target network)
        with torch.no_grad():
            next_q = self.target_net(next_states).max(dim=1).values
            targets = rewards + self.gamma * next_q * (1 - dones)

        loss = nn.functional.mse_loss(q_values, targets)
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)
        self.optimizer.step()

        # Update target network
        self.step_count += 1
        if self.step_count % self.target_update_freq == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())

        return loss.item()

Training loop

env = gym.make("CartPole-v1")
agent = DQNAgent(obs_dim=4, n_actions=2)

epsilon_start, epsilon_end, epsilon_decay = 1.0, 0.01, 50_000

for episode in range(1000):
    state, _ = env.reset()
    total_reward = 0
    done = False

    while not done:
        step = agent.step_count
        eps = max(epsilon_end,
                  epsilon_start - (epsilon_start - epsilon_end) * step / epsilon_decay)
        action = agent.select_action(state, eps)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        agent.buffer.push(state, action, reward, next_state, float(terminated))
        loss = agent.train_step()
        state = next_state
        total_reward += reward

    if episode % 50 == 0:
        print(f"Episode {episode}, Reward: {total_reward:.0f}, Eps: {eps:.3f}")

DQN improvements

Double DQN

Use the online network to select the action and the target network to evaluate it:

# Replace target computation in train_step:
with torch.no_grad():
    best_actions = self.q_net(next_states).argmax(dim=1, keepdim=True)
    next_q = self.target_net(next_states).gather(1, best_actions).squeeze(1)
    targets = rewards + self.gamma * next_q * (1 - dones)

Dueling DQN

Split the network into value and advantage streams:

class DuelingQNetwork(nn.Module):
    def __init__(self, obs_dim, n_actions, hidden=128):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
        )
        self.value_stream = nn.Sequential(
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 1),
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )

    def forward(self, x):
        features = self.shared(x)
        value = self.value_stream(features)
        advantage = self.advantage_stream(features)
        # Subtract mean advantage for identifiability
        q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
        return q_values

Prioritised experience replay

Sample transitions proportional to their TD error so the agent focuses on surprising experiences:

class PrioritisedReplayBuffer:
    def __init__(self, capacity, alpha=0.6):
        self.capacity = capacity
        self.alpha = alpha
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.buffer = []
        self.pos = 0

    def push(self, *transition):
        max_prio = self.priorities[:len(self.buffer)].max() if self.buffer else 1.0
        if len(self.buffer) < self.capacity:
            self.buffer.append(transition)
        else:
            self.buffer[self.pos] = transition
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        n = len(self.buffer)
        probs = self.priorities[:n] ** self.alpha
        probs /= probs.sum()
        indices = np.random.choice(n, batch_size, p=probs, replace=False)
        weights = (n * probs[indices]) ** (-beta)
        weights /= weights.max()
        batch = [self.buffer[i] for i in indices]
        return batch, indices, torch.FloatTensor(weights)

    def update_priorities(self, indices, td_errors):
        for idx, td in zip(indices, td_errors):
            self.priorities[idx] = abs(td) + 1e-6

Performance comparison

On CartPole-v1 (solved at 475+ avg reward over 100 episodes):

MethodEpisodes to solveNotes
Tabular Q (discretised)~3,000Requires manual state binning
Vanilla DQN~300Stable with target network
Double DQN~250Reduces overestimation
Dueling Double DQN~200Better value estimation
+ Prioritised replay~150Focuses on hard transitions

Debugging checklist

  1. Reward not increasing? Check epsilon schedule — if it decays too fast, exploration stops prematurely.
  2. Q-values exploding? Gradient clipping and lower learning rate help. Also check gamma is not too close to 1.0 for environments with unbounded episode lengths.
  3. Target network sync? Verify the update frequency. Too frequent = instability (targets chase themselves). Too rare = stale targets.
  4. Replay buffer too small? The agent overfits to recent transitions. Start with 100K capacity.
  5. NaN in loss? Check for division by zero in priority weights or unnormalised observations.

The one thing to remember: Q-Learning’s journey from a simple table to DQN with double evaluation, dueling heads, and prioritised replay shows how one elegant idea — update Q toward reward plus best future Q — scales from toy problems to Atari-level challenges.

pythonreinforcement-learningaiq-learning

See Also