Environment Wrappers — Deep Dive

Writing custom wrappers

ObservationWrapper example: relative coordinates

Convert absolute positions to positions relative to the agent, simplifying the learning problem:

import gymnasium as gym
from gymnasium import spaces
import numpy as np

class RelativeObservation(gym.ObservationWrapper):
    """Subtract agent position from all entity positions."""

    def __init__(self, env: gym.Env, agent_idx: int = 0):
        super().__init__(env)
        self.agent_idx = agent_idx
        # Observation space changes: positions become relative
        low = self.observation_space.low - self.observation_space.high
        high = self.observation_space.high - self.observation_space.low
        self.observation_space = spaces.Box(low=low, high=high,
                                            dtype=np.float32)

    def observation(self, obs: np.ndarray) -> np.ndarray:
        agent_pos = obs[self.agent_idx:self.agent_idx + 2]
        relative = obs.copy()
        # Subtract agent position from all coordinate pairs
        for i in range(0, len(obs), 2):
            relative[i:i + 2] -= agent_pos
        return relative

ActionWrapper example: discretise continuous actions

class DiscretiseAction(gym.ActionWrapper):
    """Map discrete indices to pre-defined continuous actions."""

    def __init__(self, env: gym.Env, n_bins: int = 5):
        super().__init__(env)
        low = env.action_space.low
        high = env.action_space.high
        self.actions = np.linspace(low, high, n_bins)
        self.action_space = spaces.Discrete(n_bins)

    def action(self, act: int) -> np.ndarray:
        return self.actions[act]

RewardWrapper example: potential-based shaping

class DistanceShaping(gym.RewardWrapper):
    """Add potential-based reward shaping using distance to goal."""

    def __init__(self, env: gym.Env, goal: np.ndarray, gamma: float = 0.99):
        super().__init__(env)
        self.goal = goal
        self.gamma = gamma
        self._prev_potential = 0.0

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._prev_potential = -np.linalg.norm(obs[:2] - self.goal)
        return obs, info

    def reward(self, reward: float) -> float:
        # Access current observation through the wrapper chain
        obs = self.env.unwrapped._get_obs() if hasattr(
            self.env.unwrapped, '_get_obs'
        ) else np.zeros(2)
        current_potential = -np.linalg.norm(obs[:2] - self.goal)
        shaping = self.gamma * current_potential - self._prev_potential
        self._prev_potential = current_potential
        return reward + shaping

Space transformation correctness

The most common wrapper bug is forgetting to update the observation or action space. RL libraries read these spaces to configure networks:

class BadWrapper(gym.ObservationWrapper):
    def observation(self, obs):
        return obs[:4]  # Truncate to first 4 elements
    # BUG: observation_space still reports original shape!

class GoodWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = spaces.Box(
            low=env.observation_space.low[:4],
            high=env.observation_space.high[:4],
            dtype=env.observation_space.dtype,
        )

    def observation(self, obs):
        return obs[:4]

Always verify with:

obs, _ = wrapped_env.reset()
assert wrapped_env.observation_space.contains(obs), \
    f"Obs shape {obs.shape} not in space {wrapped_env.observation_space}"

The full Wrapper base class

For transformations that affect multiple methods, subclass gym.Wrapper directly:

class CurriculumWrapper(gym.Wrapper):
    """Gradually increase environment difficulty."""

    def __init__(self, env: gym.Env, max_difficulty: int = 10):
        super().__init__(env)
        self.difficulty = 1
        self.max_difficulty = max_difficulty
        self.episode_count = 0
        self.success_count = 0

    def reset(self, **kwargs):
        self.episode_count += 1
        # Pass difficulty to the environment
        kwargs.setdefault("options", {})
        kwargs["options"]["difficulty"] = self.difficulty
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        if terminated and reward > 0:
            self.success_count += 1
        # Level up after 80% success rate over 100 episodes
        if self.episode_count % 100 == 0 and self.episode_count > 0:
            success_rate = self.success_count / 100
            if success_rate > 0.8 and self.difficulty < self.max_difficulty:
                self.difficulty += 1
                print(f"Difficulty increased to {self.difficulty}")
            self.success_count = 0
        return obs, reward, terminated, truncated, info

    def increase_difficulty(self):
        """External call to increase difficulty (e.g., from a callback)."""
        self.difficulty = min(self.difficulty + 1, self.max_difficulty)

Vectorised environment wrappers

When using gymnasium.vector.SyncVectorEnv or AsyncVectorEnv, standard wrappers do not work directly because the vectorised env handles batches. Use gymnasium.wrappers.vector wrappers or apply wrappers to individual sub-environments:

import gymnasium as gym

def make_env(idx: int):
    def _init():
        env = gym.make("CartPole-v1")
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.TimeLimit(env, max_episode_steps=500)
        return env
    return _init

# Each sub-env gets its own wrapper stack
vec_env = gym.vector.SyncVectorEnv([make_env(i) for i in range(4)])

For normalisation across the vector, use VecNormalize from Stable-Baselines3 which maintains shared running statistics:

from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv

vec_env = DummyVecEnv([make_env(i) for i in range(4)])
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)

Debugging wrapper chains

Inspect the chain

def print_wrapper_chain(env):
    """Print the full wrapper stack."""
    current = env
    depth = 0
    while hasattr(current, 'env'):
        print(f"{'  ' * depth}{type(current).__name__}")
        print(f"{'  ' * depth}  obs_space: {current.observation_space}")
        print(f"{'  ' * depth}  act_space: {current.action_space}")
        current = current.env
        depth += 1
    print(f"{'  ' * depth}{type(current).__name__} (base)")

print_wrapper_chain(wrapped_env)

Compare wrapped vs unwrapped

def debug_step(env, action):
    """Run one step and show transformations at each wrapper layer."""
    obs, reward, term, trunc, info = env.step(action)
    print(f"Wrapped: obs={obs.shape}, reward={reward:.4f}")

    raw_env = env.unwrapped
    raw_obs, raw_reward, _, _, _ = raw_env.step(action)
    print(f"Unwrapped: obs={raw_obs.shape}, reward={raw_reward:.4f}")

The check_env utility

Gymnasium’s env checker validates the complete wrapped environment:

from gymnasium.utils.env_checker import check_env
check_env(wrapped_env.unwrapped)  # Check base
check_env(wrapped_env)            # Check with wrappers

Advanced patterns

Conditional wrappers

Apply wrappers only during training:

def make_env(training: bool = True):
    env = gym.make("HalfCheetah-v4")
    if training:
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.NormalizeReward(env)
        env = gym.wrappers.ClipAction(env)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    return env

Wrapper factories

Encapsulate common stacks:

def atari_preprocessing(env_id: str, frame_stack: int = 4):
    env = gym.make(env_id, render_mode=None)
    env = gym.wrappers.AtariPreprocessing(
        env, noop_max=30, frame_skip=4,
        screen_size=84, grayscale_obs=True,
        grayscale_newaxis=False,
    )
    env = gym.wrappers.FrameStack(env, frame_stack)
    return env

def continuous_control_stack(env_id: str):
    env = gym.make(env_id)
    env = gym.wrappers.ClipAction(env)
    env = gym.wrappers.NormalizeObservation(env)
    env = gym.wrappers.NormalizeReward(env)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    return env

Stateful wrappers and serialisation

Wrappers with running statistics (like NormalizeObservation) carry state that must be saved and restored alongside the model:

import pickle

# Save
with open("wrapper_state.pkl", "wb") as f:
    pickle.dump({
        "obs_rms_mean": env.obs_rms.mean,
        "obs_rms_var": env.obs_rms.var,
        "obs_rms_count": env.obs_rms.count,
    }, f)

# Restore
with open("wrapper_state.pkl", "rb") as f:
    state = pickle.load(f)
    env.obs_rms.mean = state["obs_rms_mean"]
    env.obs_rms.var = state["obs_rms_var"]
    env.obs_rms.count = state["obs_rms_count"]

Forgetting this step is one of the most common deployment bugs — the model sees unnormalised observations and behaves erratically.

Performance impact

WrapperOverhead per stepNotes
ClipAction~1 μsNegligible
NormalizeObservation~5 μsRunning stats update
FrameStack(4)~10 μsMemory copy
ResizeObservation(84,84)~50 μsOpenCV resize
RecordVideo~1 msRendering + encoding
NormalizeReward~5 μsRunning stats update

For environments where step takes milliseconds (physics sims), wrapper overhead is negligible. For fast environments (simple grids), it can become significant with deep wrapper chains.

The one thing to remember: Wrappers are the glue between environments and algorithms — get the space transformations right, save stateful wrapper state alongside your model, and always verify with observation_space.contains(obs).

pythonreinforcement-learningaigymnasium

See Also