Multi-Agent Reinforcement Learning — Deep Dive

Formal framing

A multi-agent problem is typically modelled as a Decentralised Partially Observable Markov Decision Process (Dec-POMDP):

  • A set of agents N = {1, …, n}
  • Global state space S
  • Per-agent observation function O_i(s) — each agent sees only a partial view
  • Per-agent action space A_i
  • Transition function T(s’ | s, a_1, …, a_n)
  • Reward function R(s, a_1, …, a_n) — can be shared or per-agent

The key difference from a single-agent MDP is that the transition and reward depend on the joint action of all agents, but each agent decides based only on its own partial observation.

CTDE algorithm deep dive: MAPPO

Multi-Agent PPO extends PPO by sharing a centralised value function during training:

import torch
import torch.nn as nn

class MAPPOCritic(nn.Module):
    """Centralised critic that sees global state."""
    def __init__(self, global_state_dim: int, hidden: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(global_state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, global_state: torch.Tensor) -> torch.Tensor:
        return self.net(global_state)


class MAPPOActor(nn.Module):
    """Decentralised actor that sees only local observation."""
    def __init__(self, obs_dim: int, act_dim: int, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, act_dim),
        )

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs)  # logits for Categorical distribution

Training loop sketch:

  1. All agents interact with the environment in parallel, collecting trajectories.
  2. The centralised critic computes advantages using the global state (concatenation of all observations, or a privileged state vector).
  3. Each actor updates its policy using PPO’s clipped objective with the centralised advantages.
  4. The critic is updated to minimise value prediction error on the global state.

Parameter sharing

When agents are homogeneous (same role), share actor parameters across all agents and distinguish them with a one-hot agent ID appended to the observation. This dramatically reduces the parameter count and speeds convergence.

Value decomposition: QMIX

QMIX decomposes the joint action-value function Q_tot into per-agent utilities Q_i while enforcing monotonicity — the global Q increases when any individual Q increases:

class QMIXMixer(nn.Module):
    def __init__(self, n_agents: int, state_dim: int, embed_dim: int = 32):
        super().__init__()
        self.n_agents = n_agents
        # Hypernetworks generate mixing weights from global state
        self.hyper_w1 = nn.Sequential(
            nn.Linear(state_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, n_agents * embed_dim),
        )
        self.hyper_w2 = nn.Sequential(
            nn.Linear(state_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
        )
        self.hyper_b1 = nn.Linear(state_dim, embed_dim)
        self.hyper_b2 = nn.Sequential(
            nn.Linear(state_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, 1),
        )

    def forward(self, agent_qs: torch.Tensor, state: torch.Tensor):
        # agent_qs: (batch, n_agents)
        batch_size = agent_qs.size(0)
        agent_qs = agent_qs.unsqueeze(1)  # (batch, 1, n_agents)

        # First layer — monotonicity via abs()
        w1 = torch.abs(self.hyper_w1(state)).view(batch_size, self.n_agents, -1)
        b1 = self.hyper_b1(state).unsqueeze(1)
        hidden = torch.relu(torch.bmm(agent_qs, w1) + b1)

        # Second layer
        w2 = torch.abs(self.hyper_w2(state)).view(batch_size, -1, 1)
        b2 = self.hyper_b2(state).unsqueeze(1)
        q_tot = torch.bmm(hidden, w2) + b2
        return q_tot.squeeze(-1).squeeze(-1)

The torch.abs() on the hyper-network weights enforces non-negative mixing, guaranteeing monotonicity. This enables decentralised argmax: each agent can greedily pick its best action without coordinating.

PettingZoo integration

PettingZoo standardises multi-agent environments. Here is a complete parallel-API training loop:

from pettingzoo.mpe import simple_spread_v3
import numpy as np

env = simple_spread_v3.parallel_env(max_cycles=25, continuous_actions=False)
observations, infos = env.reset(seed=42)

for episode in range(1000):
    observations, infos = env.reset()
    done = False
    while not done:
        actions = {
            agent: env.action_space(agent).sample()
            for agent in env.agents
        }
        observations, rewards, terminations, truncations, infos = env.step(actions)
        done = all(terminations.values()) or all(truncations.values())
        # Store transitions per agent for training

For CTDE, extract a global state by concatenating all observations:

global_state = np.concatenate([observations[a] for a in sorted(observations)])

RLlib multi-agent configuration

from ray.rllib.algorithms.ppo import PPOConfig

config = (
    PPOConfig()
    .environment("simple_spread_v3")
    .multi_agent(
        policies={"shared_policy": (None, obs_space, act_space, {})},
        policy_mapping_fn=lambda agent_id, *args, **kwargs: "shared_policy",
    )
    .training(
        train_batch_size=4000,
        sgd_minibatch_size=256,
    )
)
algo = config.build()
for _ in range(100):
    result = algo.train()
    print(f"Episode reward mean: {result['env_runners/episode_reward_mean']:.2f}")

Handling partial observability

When agents cannot see the full state, they need memory. Options:

  1. Frame stacking — concatenate the last K observations (cheap, limited).
  2. Recurrent policies — use GRU or LSTM in the actor and critic. SB3-Contrib’s RecurrentPPO supports this; for MARL, EPyMARL has recurrent QMIX and MAPPO.
  3. Transformer-based — attention over the observation history. More expressive but heavier.

Implementation consideration: recurrent policies require storing hidden states in the replay buffer and resetting them at episode boundaries.

Emergent behaviour and self-play

In competitive settings, agents can train against copies of themselves (self-play). The policy improves by exploiting its own weaknesses. Techniques:

  • Naive self-play — always train against the latest version. Can lead to cycles.
  • Fictitious self-play — train against a uniform mix of all past versions. More stable.
  • Population-based training — maintain a population of policies that train against each other. Used by OpenAI Five and AlphaStar.

Self-play can produce surprisingly creative strategies because agents co-evolve — each breakthrough by one side forces the other to adapt.

Scaling considerations

ChallengeSolution
N agents × M environments = N*M policy forwardsParameter sharing; batch all agents through one network
Communication overhead in distributed trainingCompress messages; use local communication (agents only talk to neighbours)
Replay buffer size for N agentsShared buffer with agent IDs; prioritised sampling per agent
Credit assignmentQMIX / COMA / Shapley value attribution
EvaluationTrack per-agent and team metrics separately; use Elo for competitive

Debugging MARL training

  1. Start with 2 agents before scaling to N. Most bugs surface with 2.
  2. Verify the environment — ensure rewards, observations, and done flags are correct for every agent.
  3. Plot per-agent reward curves — if one agent’s reward climbs while another’s drops, check for reward hacking or asymmetric information.
  4. Test with a scripted opponent — before training two learning agents against each other, test each against a fixed strategy to verify the learning pipeline works.
  5. Monitor KL divergence of each agent’s policy between updates. Large jumps indicate instability.

The one thing to remember: The CTDE paradigm — centralised critic for training, decentralised actors for execution — is the practical foundation for all modern multi-agent reinforcement learning in Python.

pythonreinforcement-learningaimulti-agent

See Also