Environment Wrappers — Deep Dive
Writing custom wrappers
ObservationWrapper example: relative coordinates
Convert absolute positions to positions relative to the agent, simplifying the learning problem:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class RelativeObservation(gym.ObservationWrapper):
"""Subtract agent position from all entity positions."""
def __init__(self, env: gym.Env, agent_idx: int = 0):
super().__init__(env)
self.agent_idx = agent_idx
# Observation space changes: positions become relative
low = self.observation_space.low - self.observation_space.high
high = self.observation_space.high - self.observation_space.low
self.observation_space = spaces.Box(low=low, high=high,
dtype=np.float32)
def observation(self, obs: np.ndarray) -> np.ndarray:
agent_pos = obs[self.agent_idx:self.agent_idx + 2]
relative = obs.copy()
# Subtract agent position from all coordinate pairs
for i in range(0, len(obs), 2):
relative[i:i + 2] -= agent_pos
return relative
ActionWrapper example: discretise continuous actions
class DiscretiseAction(gym.ActionWrapper):
"""Map discrete indices to pre-defined continuous actions."""
def __init__(self, env: gym.Env, n_bins: int = 5):
super().__init__(env)
low = env.action_space.low
high = env.action_space.high
self.actions = np.linspace(low, high, n_bins)
self.action_space = spaces.Discrete(n_bins)
def action(self, act: int) -> np.ndarray:
return self.actions[act]
RewardWrapper example: potential-based shaping
class DistanceShaping(gym.RewardWrapper):
"""Add potential-based reward shaping using distance to goal."""
def __init__(self, env: gym.Env, goal: np.ndarray, gamma: float = 0.99):
super().__init__(env)
self.goal = goal
self.gamma = gamma
self._prev_potential = 0.0
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self._prev_potential = -np.linalg.norm(obs[:2] - self.goal)
return obs, info
def reward(self, reward: float) -> float:
# Access current observation through the wrapper chain
obs = self.env.unwrapped._get_obs() if hasattr(
self.env.unwrapped, '_get_obs'
) else np.zeros(2)
current_potential = -np.linalg.norm(obs[:2] - self.goal)
shaping = self.gamma * current_potential - self._prev_potential
self._prev_potential = current_potential
return reward + shaping
Space transformation correctness
The most common wrapper bug is forgetting to update the observation or action space. RL libraries read these spaces to configure networks:
class BadWrapper(gym.ObservationWrapper):
def observation(self, obs):
return obs[:4] # Truncate to first 4 elements
# BUG: observation_space still reports original shape!
class GoodWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = spaces.Box(
low=env.observation_space.low[:4],
high=env.observation_space.high[:4],
dtype=env.observation_space.dtype,
)
def observation(self, obs):
return obs[:4]
Always verify with:
obs, _ = wrapped_env.reset()
assert wrapped_env.observation_space.contains(obs), \
f"Obs shape {obs.shape} not in space {wrapped_env.observation_space}"
The full Wrapper base class
For transformations that affect multiple methods, subclass gym.Wrapper directly:
class CurriculumWrapper(gym.Wrapper):
"""Gradually increase environment difficulty."""
def __init__(self, env: gym.Env, max_difficulty: int = 10):
super().__init__(env)
self.difficulty = 1
self.max_difficulty = max_difficulty
self.episode_count = 0
self.success_count = 0
def reset(self, **kwargs):
self.episode_count += 1
# Pass difficulty to the environment
kwargs.setdefault("options", {})
kwargs["options"]["difficulty"] = self.difficulty
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
if terminated and reward > 0:
self.success_count += 1
# Level up after 80% success rate over 100 episodes
if self.episode_count % 100 == 0 and self.episode_count > 0:
success_rate = self.success_count / 100
if success_rate > 0.8 and self.difficulty < self.max_difficulty:
self.difficulty += 1
print(f"Difficulty increased to {self.difficulty}")
self.success_count = 0
return obs, reward, terminated, truncated, info
def increase_difficulty(self):
"""External call to increase difficulty (e.g., from a callback)."""
self.difficulty = min(self.difficulty + 1, self.max_difficulty)
Vectorised environment wrappers
When using gymnasium.vector.SyncVectorEnv or AsyncVectorEnv, standard wrappers do not work directly because the vectorised env handles batches. Use gymnasium.wrappers.vector wrappers or apply wrappers to individual sub-environments:
import gymnasium as gym
def make_env(idx: int):
def _init():
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.TimeLimit(env, max_episode_steps=500)
return env
return _init
# Each sub-env gets its own wrapper stack
vec_env = gym.vector.SyncVectorEnv([make_env(i) for i in range(4)])
For normalisation across the vector, use VecNormalize from Stable-Baselines3 which maintains shared running statistics:
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv
vec_env = DummyVecEnv([make_env(i) for i in range(4)])
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
Debugging wrapper chains
Inspect the chain
def print_wrapper_chain(env):
"""Print the full wrapper stack."""
current = env
depth = 0
while hasattr(current, 'env'):
print(f"{' ' * depth}{type(current).__name__}")
print(f"{' ' * depth} obs_space: {current.observation_space}")
print(f"{' ' * depth} act_space: {current.action_space}")
current = current.env
depth += 1
print(f"{' ' * depth}{type(current).__name__} (base)")
print_wrapper_chain(wrapped_env)
Compare wrapped vs unwrapped
def debug_step(env, action):
"""Run one step and show transformations at each wrapper layer."""
obs, reward, term, trunc, info = env.step(action)
print(f"Wrapped: obs={obs.shape}, reward={reward:.4f}")
raw_env = env.unwrapped
raw_obs, raw_reward, _, _, _ = raw_env.step(action)
print(f"Unwrapped: obs={raw_obs.shape}, reward={raw_reward:.4f}")
The check_env utility
Gymnasium’s env checker validates the complete wrapped environment:
from gymnasium.utils.env_checker import check_env
check_env(wrapped_env.unwrapped) # Check base
check_env(wrapped_env) # Check with wrappers
Advanced patterns
Conditional wrappers
Apply wrappers only during training:
def make_env(training: bool = True):
env = gym.make("HalfCheetah-v4")
if training:
env = gym.wrappers.NormalizeObservation(env)
env = gym.wrappers.NormalizeReward(env)
env = gym.wrappers.ClipAction(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
return env
Wrapper factories
Encapsulate common stacks:
def atari_preprocessing(env_id: str, frame_stack: int = 4):
env = gym.make(env_id, render_mode=None)
env = gym.wrappers.AtariPreprocessing(
env, noop_max=30, frame_skip=4,
screen_size=84, grayscale_obs=True,
grayscale_newaxis=False,
)
env = gym.wrappers.FrameStack(env, frame_stack)
return env
def continuous_control_stack(env_id: str):
env = gym.make(env_id)
env = gym.wrappers.ClipAction(env)
env = gym.wrappers.NormalizeObservation(env)
env = gym.wrappers.NormalizeReward(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
return env
Stateful wrappers and serialisation
Wrappers with running statistics (like NormalizeObservation) carry state that must be saved and restored alongside the model:
import pickle
# Save
with open("wrapper_state.pkl", "wb") as f:
pickle.dump({
"obs_rms_mean": env.obs_rms.mean,
"obs_rms_var": env.obs_rms.var,
"obs_rms_count": env.obs_rms.count,
}, f)
# Restore
with open("wrapper_state.pkl", "rb") as f:
state = pickle.load(f)
env.obs_rms.mean = state["obs_rms_mean"]
env.obs_rms.var = state["obs_rms_var"]
env.obs_rms.count = state["obs_rms_count"]
Forgetting this step is one of the most common deployment bugs — the model sees unnormalised observations and behaves erratically.
Performance impact
| Wrapper | Overhead per step | Notes |
|---|---|---|
| ClipAction | ~1 μs | Negligible |
| NormalizeObservation | ~5 μs | Running stats update |
| FrameStack(4) | ~10 μs | Memory copy |
| ResizeObservation(84,84) | ~50 μs | OpenCV resize |
| RecordVideo | ~1 ms | Rendering + encoding |
| NormalizeReward | ~5 μs | Running stats update |
For environments where step takes milliseconds (physics sims), wrapper overhead is negligible. For fast environments (simple grids), it can become significant with deep wrapper chains.
The one thing to remember: Wrappers are the glue between environments and algorithms — get the space transformations right, save stateful wrapper state alongside your model, and always verify with observation_space.contains(obs).
See Also
- Python Monte Carlo Tree Search The clever trick behind AlphaGo — how a program explores millions of possible moves by playing quick random games against itself
- Python Multi Agent Reinforcement What happens when multiple programs learn together in the same world — cooperation, competition, and emergent teamwork
- Python Openai Gym Environments Why OpenAI Gym is the playground where robots and programs learn by trial and error — no prior coding knowledge needed
- Python Policy Gradient Methods Instead of scoring every move, what if the program just learned which moves feel right? That is policy gradients
- Python Q Learning Implementation How a program builds a cheat sheet of every situation and every action to figure out the best move — no teacher required