Monte Carlo Tree Search — Deep Dive

Complete MCTS implementation

import math
import random
from typing import Optional

class MCTSNode:
    __slots__ = ("state", "parent", "action", "children",
                 "visits", "total_value", "untried_actions")

    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action        # action that led here
        self.children: list["MCTSNode"] = []
        self.visits = 0
        self.total_value = 0.0
        self.untried_actions = state.legal_actions()

    @property
    def is_fully_expanded(self) -> bool:
        return len(self.untried_actions) == 0

    @property
    def is_terminal(self) -> bool:
        return self.state.is_terminal()

    def ucb1(self, c: float = 1.41) -> float:
        if self.visits == 0:
            return float("inf")
        exploitation = self.total_value / self.visits
        exploration = c * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration

    def best_child(self, c: float = 1.41) -> "MCTSNode":
        return max(self.children, key=lambda n: n.ucb1(c))

    def expand(self) -> "MCTSNode":
        action = self.untried_actions.pop()
        next_state = self.state.apply(action)
        child = MCTSNode(next_state, parent=self, action=action)
        self.children.append(child)
        return child

    def backpropagate(self, value: float):
        node: Optional[MCTSNode] = self
        while node is not None:
            node.visits += 1
            node.total_value += value
            value = 1.0 - value  # flip for opponent
            node = node.parent


def mcts_search(root_state, iterations: int = 1000, c: float = 1.41):
    root = MCTSNode(root_state)

    for _ in range(iterations):
        # 1. Selection
        node = root
        while not node.is_terminal and node.is_fully_expanded:
            node = node.best_child(c)

        # 2. Expansion
        if not node.is_terminal and not node.is_fully_expanded:
            node = node.expand()

        # 3. Simulation (random rollout)
        state = node.state
        while not state.is_terminal():
            action = random.choice(state.legal_actions())
            state = state.apply(action)
        value = state.result()  # 1.0 win, 0.5 draw, 0.0 loss for current player

        # 4. Backpropagation
        node.backpropagate(value)

    # Choose most-visited child
    return max(root.children, key=lambda n: n.visits).action

The game state interface

The MCTS code above expects a state object with:

class GameState:
    def legal_actions(self) -> list: ...
    def apply(self, action) -> "GameState": ...
    def is_terminal(self) -> bool: ...
    def result(self) -> float: ...   # from perspective of player who just moved
    def current_player(self) -> int: ...

Here is a Tic-Tac-Toe implementation:

import numpy as np

class TicTacToe:
    def __init__(self, board=None, player=1):
        self.board = board if board is not None else np.zeros(9, dtype=int)
        self.player = player

    def legal_actions(self):
        return list(np.where(self.board == 0)[0])

    def apply(self, action):
        new_board = self.board.copy()
        new_board[action] = self.player
        return TicTacToe(new_board, 3 - self.player)

    def is_terminal(self):
        return self._winner() is not None or len(self.legal_actions()) == 0

    def result(self):
        w = self._winner()
        if w is None:
            return 0.5
        # Result from perspective of player who just moved (3 - self.player)
        return 1.0 if w == 3 - self.player else 0.0

    def _winner(self):
        b = self.board.reshape(3, 3)
        lines = (
            list(b) + list(b.T) +
            [b.diagonal(), np.fliplr(b).diagonal()]
        )
        for line in lines:
            if np.all(line == 1): return 1
            if np.all(line == 2): return 2
        return None

AlphaZero-style MCTS with neural network

Replace random rollouts with a neural network that returns (policy_prior, value):

def mcts_search_nn(root_state, network, iterations=800, c_puct=1.5):
    root = MCTSNodeNN(root_state)
    # Evaluate root
    policy, value = network.evaluate(root_state)
    root.expand_with_prior(policy)

    for _ in range(iterations):
        node = root

        # Selection with PUCT
        while node.is_expanded and not node.is_terminal:
            node = node.best_child_puct(c_puct)

        # Evaluate with network instead of rollout
        if not node.is_terminal:
            policy, value = network.evaluate(node.state)
            node.expand_with_prior(policy)
        else:
            value = node.state.result()

        # Backpropagation
        node.backpropagate(value)

    # Temperature-controlled action selection
    visits = np.array([c.visits for c in root.children])
    actions = [c.action for c in root.children]
    probs = visits / visits.sum()
    return actions, probs

The key change is PUCT (Polynomial Upper Confidence Trees):

def puct_score(self, c_puct: float) -> float:
    q_value = self.total_value / max(self.visits, 1)
    u_value = c_puct * self.prior * math.sqrt(self.parent.visits) / (1 + self.visits)
    return q_value + u_value

The prior comes from the policy network, biasing search toward moves the network thinks are promising before any rollouts.

Performance optimisation

Tree reuse

After making a move, reuse the corresponding subtree instead of building from scratch:

def advance_root(root, chosen_action):
    for child in root.children:
        if child.action == chosen_action:
            child.parent = None  # detach from old tree
            return child
    # Action was not in tree; create fresh root
    return MCTSNode(root.state.apply(chosen_action))

Virtual loss for parallel MCTS

When running multiple search threads, add a “virtual loss” to nodes being evaluated to discourage other threads from selecting the same path:

def select_with_virtual_loss(node, virtual_loss=3):
    node.visits += virtual_loss
    node.total_value -= virtual_loss  # pessimistic
    # ... after evaluation, remove virtual loss and add real result
    node.visits -= virtual_loss
    node.total_value += virtual_loss + real_value

This is how AlphaZero scales to thousands of GPU evaluations in parallel.

Transposition table

In games with move-order independence (same position reachable via different move sequences), a transposition table maps position hashes to tree nodes, merging statistics:

class TranspositionMCTS:
    def __init__(self):
        self.table: dict[int, MCTSNode] = {}

    def get_or_create(self, state, parent, action):
        key = hash(state)
        if key in self.table:
            node = self.table[key]
            node.parent = parent  # update parent for backprop
            return node
        node = MCTSNode(state, parent, action)
        self.table[key] = node
        return node

Batch neural network evaluation

Collecting multiple leaf nodes and evaluating them in one GPU batch dramatically improves throughput:

def batch_mcts(root_state, network, iterations=800, batch_size=8):
    root = MCTSNodeNN(root_state)
    remaining = iterations

    while remaining > 0:
        leaves = []
        for _ in range(min(batch_size, remaining)):
            leaf = select_leaf(root)
            leaves.append(leaf)
            apply_virtual_loss(leaf)

        # Batch evaluate
        states = [l.state for l in leaves]
        policies, values = network.batch_evaluate(states)

        for leaf, policy, value in zip(leaves, policies, values):
            remove_virtual_loss(leaf)
            leaf.expand_with_prior(policy)
            leaf.backpropagate(value)

        remaining -= len(leaves)
    return root

MCTS for non-game domains

MCTS extends beyond board games:

  • Program synthesis — treat code generation as a tree of token choices.
  • Chemical molecule design — each node is a molecular fragment; actions add atoms or bonds.
  • Combinatorial optimisation — job scheduling, routing, packing.
  • Language model decoding — tree search over token sequences with a value model scoring partial generations.

The state interface remains the same: define legal_actions, apply, is_terminal, and result (or a heuristic evaluation).

Tuning MCTS

ParameterEffectGuideline
Iteration countMore = stronger play, slowerScale to time budget
C (UCB1) / c_puctHigher = more explorationStart at √2, tune by win rate
Rollout depth limitCap simulation lengthUse for slow rollouts
TemperatureControls action randomness from visit countsHigh early (diverse training data), zero at test time
Dirichlet noiseAdds noise to root priors for explorationα = 0.3 for chess, 0.03 for Go

Testing your MCTS

  1. Solve a trivial game — verify MCTS finds the optimal move in Tic-Tac-Toe with enough iterations.
  2. Convergence test — as iterations increase, the chosen action should stabilise.
  3. Symmetry check — symmetric positions should produce symmetric visit distributions (within noise).
  4. Profiling — the bottleneck is usually the simulation phase (or network inference for NN-MCTS). Optimise there first.

The one thing to remember: MCTS turns any game or decision problem into a statistical sampling problem — define your state interface, run enough iterations, and the tree converges to strong play.

pythonreinforcement-learningaisearch-algorithms

See Also