Multi-Agent Reinforcement Learning — Deep Dive
Formal framing
A multi-agent problem is typically modelled as a Decentralised Partially Observable Markov Decision Process (Dec-POMDP):
- A set of agents N = {1, …, n}
- Global state space S
- Per-agent observation function O_i(s) — each agent sees only a partial view
- Per-agent action space A_i
- Transition function T(s’ | s, a_1, …, a_n)
- Reward function R(s, a_1, …, a_n) — can be shared or per-agent
The key difference from a single-agent MDP is that the transition and reward depend on the joint action of all agents, but each agent decides based only on its own partial observation.
CTDE algorithm deep dive: MAPPO
Multi-Agent PPO extends PPO by sharing a centralised value function during training:
import torch
import torch.nn as nn
class MAPPOCritic(nn.Module):
"""Centralised critic that sees global state."""
def __init__(self, global_state_dim: int, hidden: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(global_state_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, 1),
)
def forward(self, global_state: torch.Tensor) -> torch.Tensor:
return self.net(global_state)
class MAPPOActor(nn.Module):
"""Decentralised actor that sees only local observation."""
def __init__(self, obs_dim: int, act_dim: int, hidden: int = 128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, act_dim),
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.net(obs) # logits for Categorical distribution
Training loop sketch:
- All agents interact with the environment in parallel, collecting trajectories.
- The centralised critic computes advantages using the global state (concatenation of all observations, or a privileged state vector).
- Each actor updates its policy using PPO’s clipped objective with the centralised advantages.
- The critic is updated to minimise value prediction error on the global state.
Parameter sharing
When agents are homogeneous (same role), share actor parameters across all agents and distinguish them with a one-hot agent ID appended to the observation. This dramatically reduces the parameter count and speeds convergence.
Value decomposition: QMIX
QMIX decomposes the joint action-value function Q_tot into per-agent utilities Q_i while enforcing monotonicity — the global Q increases when any individual Q increases:
class QMIXMixer(nn.Module):
def __init__(self, n_agents: int, state_dim: int, embed_dim: int = 32):
super().__init__()
self.n_agents = n_agents
# Hypernetworks generate mixing weights from global state
self.hyper_w1 = nn.Sequential(
nn.Linear(state_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, n_agents * embed_dim),
)
self.hyper_w2 = nn.Sequential(
nn.Linear(state_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim),
)
self.hyper_b1 = nn.Linear(state_dim, embed_dim)
self.hyper_b2 = nn.Sequential(
nn.Linear(state_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, 1),
)
def forward(self, agent_qs: torch.Tensor, state: torch.Tensor):
# agent_qs: (batch, n_agents)
batch_size = agent_qs.size(0)
agent_qs = agent_qs.unsqueeze(1) # (batch, 1, n_agents)
# First layer — monotonicity via abs()
w1 = torch.abs(self.hyper_w1(state)).view(batch_size, self.n_agents, -1)
b1 = self.hyper_b1(state).unsqueeze(1)
hidden = torch.relu(torch.bmm(agent_qs, w1) + b1)
# Second layer
w2 = torch.abs(self.hyper_w2(state)).view(batch_size, -1, 1)
b2 = self.hyper_b2(state).unsqueeze(1)
q_tot = torch.bmm(hidden, w2) + b2
return q_tot.squeeze(-1).squeeze(-1)
The torch.abs() on the hyper-network weights enforces non-negative mixing, guaranteeing monotonicity. This enables decentralised argmax: each agent can greedily pick its best action without coordinating.
PettingZoo integration
PettingZoo standardises multi-agent environments. Here is a complete parallel-API training loop:
from pettingzoo.mpe import simple_spread_v3
import numpy as np
env = simple_spread_v3.parallel_env(max_cycles=25, continuous_actions=False)
observations, infos = env.reset(seed=42)
for episode in range(1000):
observations, infos = env.reset()
done = False
while not done:
actions = {
agent: env.action_space(agent).sample()
for agent in env.agents
}
observations, rewards, terminations, truncations, infos = env.step(actions)
done = all(terminations.values()) or all(truncations.values())
# Store transitions per agent for training
For CTDE, extract a global state by concatenating all observations:
global_state = np.concatenate([observations[a] for a in sorted(observations)])
RLlib multi-agent configuration
from ray.rllib.algorithms.ppo import PPOConfig
config = (
PPOConfig()
.environment("simple_spread_v3")
.multi_agent(
policies={"shared_policy": (None, obs_space, act_space, {})},
policy_mapping_fn=lambda agent_id, *args, **kwargs: "shared_policy",
)
.training(
train_batch_size=4000,
sgd_minibatch_size=256,
)
)
algo = config.build()
for _ in range(100):
result = algo.train()
print(f"Episode reward mean: {result['env_runners/episode_reward_mean']:.2f}")
Handling partial observability
When agents cannot see the full state, they need memory. Options:
- Frame stacking — concatenate the last K observations (cheap, limited).
- Recurrent policies — use GRU or LSTM in the actor and critic. SB3-Contrib’s
RecurrentPPOsupports this; for MARL, EPyMARL has recurrent QMIX and MAPPO. - Transformer-based — attention over the observation history. More expressive but heavier.
Implementation consideration: recurrent policies require storing hidden states in the replay buffer and resetting them at episode boundaries.
Emergent behaviour and self-play
In competitive settings, agents can train against copies of themselves (self-play). The policy improves by exploiting its own weaknesses. Techniques:
- Naive self-play — always train against the latest version. Can lead to cycles.
- Fictitious self-play — train against a uniform mix of all past versions. More stable.
- Population-based training — maintain a population of policies that train against each other. Used by OpenAI Five and AlphaStar.
Self-play can produce surprisingly creative strategies because agents co-evolve — each breakthrough by one side forces the other to adapt.
Scaling considerations
| Challenge | Solution |
|---|---|
| N agents × M environments = N*M policy forwards | Parameter sharing; batch all agents through one network |
| Communication overhead in distributed training | Compress messages; use local communication (agents only talk to neighbours) |
| Replay buffer size for N agents | Shared buffer with agent IDs; prioritised sampling per agent |
| Credit assignment | QMIX / COMA / Shapley value attribution |
| Evaluation | Track per-agent and team metrics separately; use Elo for competitive |
Debugging MARL training
- Start with 2 agents before scaling to N. Most bugs surface with 2.
- Verify the environment — ensure rewards, observations, and done flags are correct for every agent.
- Plot per-agent reward curves — if one agent’s reward climbs while another’s drops, check for reward hacking or asymmetric information.
- Test with a scripted opponent — before training two learning agents against each other, test each against a fixed strategy to verify the learning pipeline works.
- Monitor KL divergence of each agent’s policy between updates. Large jumps indicate instability.
The one thing to remember: The CTDE paradigm — centralised critic for training, decentralised actors for execution — is the practical foundation for all modern multi-agent reinforcement learning in Python.
See Also
- Python Environment Wrappers How thin add-on layers let you change what a learning program sees and does without rewriting the game itself
- Python Monte Carlo Tree Search The clever trick behind AlphaGo — how a program explores millions of possible moves by playing quick random games against itself
- Python Openai Gym Environments Why OpenAI Gym is the playground where robots and programs learn by trial and error — no prior coding knowledge needed
- Python Policy Gradient Methods Instead of scoring every move, what if the program just learned which moves feel right? That is policy gradients
- Python Q Learning Implementation How a program builds a cheat sheet of every situation and every action to figure out the best move — no teacher required