TensorFlow Custom Layers — Deep Dive

The Full Layer API

The tf.keras.layers.Layer base class provides a rich API beyond the basic build/call/get_config trio. Mastering these additional methods and properties lets you build layers that behave correctly in every context — eager mode, graph mode, mixed precision, distributed training, and serialization.

Weight Management in Detail

add_weight accepts several parameters that affect training and serialization:

import tensorflow as tf

class AttentionGate(tf.keras.layers.Layer):
    def __init__(self, units, use_bias=True, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.use_bias = use_bias

    def build(self, input_shape):
        self.W_gate = self.add_weight(
            name="gate_kernel",
            shape=(input_shape[-1], self.units),
            initializer="glorot_uniform",
            regularizer=tf.keras.regularizers.l2(1e-4),
            constraint=tf.keras.constraints.MaxNorm(3.0),
            trainable=True
        )
        if self.use_bias:
            self.b_gate = self.add_weight(
                name="gate_bias",
                shape=(self.units,),
                initializer="zeros",
                trainable=True
            )
        # Non-trainable state: exponential moving average of gate activations
        self.ema_activation = self.add_weight(
            name="ema_activation",
            shape=(self.units,),
            initializer="zeros",
            trainable=False
        )

    def call(self, inputs, training=None):
        gate = tf.matmul(inputs, self.W_gate)
        if self.use_bias:
            gate = gate + self.b_gate
        gate = tf.nn.sigmoid(gate)

        if training:
            batch_mean = tf.reduce_mean(gate, axis=0)
            self.ema_activation.assign(
                0.99 * self.ema_activation + 0.01 * batch_mean
            )

        return inputs[:, :self.units] * gate

    def get_config(self):
        config = super().get_config()
        config.update({
            "units": self.units,
            "use_bias": self.use_bias
        })
        return config

Key parameters for add_weight:

ParameterPurpose
initializerHow weights start (glorot, he_normal, orthogonal, etc.)
regularizerAdds penalty to loss (L1, L2, or custom)
constraintClips weights after each update (MaxNorm, NonNeg, etc.)
trainableWhether gradients flow to this variable

Input Validation with compute_output_shape

Override compute_output_shape(input_shape) to let Keras verify tensor shapes at graph-construction time:

def compute_output_shape(self, input_shape):
    return input_shape[0], self.units

This is critical for the Functional API — without it, Keras cannot infer output shapes when building the model graph, and model.summary() will show None for your layer’s output.

Graph-Mode Compatibility

Custom layers must work in both eager mode (debugging) and graph mode (production). The main rules:

Avoid Python State Mutations in call()

# BAD — only executes during tracing
def call(self, inputs):
    self.call_count += 1  # Python int, not tracked by tf.function
    return inputs * self.scale

# GOOD — uses TensorFlow variable
def call(self, inputs):
    self.call_count.assign_add(1)  # tf.Variable, tracked
    return inputs * self.scale

Use TensorFlow Ops for Control Flow

# BAD — Python if evaluates once during tracing
def call(self, inputs, training=None):
    if training:
        inputs = tf.nn.dropout(inputs, rate=0.5)
    return inputs

# GOOD — handles both eager and graph mode
def call(self, inputs, training=None):
    if training:
        inputs = tf.nn.dropout(inputs, rate=0.5)
    return inputs
# Note: this works because Keras handles the training flag specially.
# For OTHER tensor-dependent conditions, use tf.cond:

def call(self, inputs):
    return tf.cond(
        tf.reduce_mean(inputs) > 0,
        lambda: inputs * self.pos_scale,
        lambda: inputs * self.neg_scale
    )

Masking Support

If your layer should respect sequence masks (common in NLP), override compute_mask:

class MaskedDense(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.supports_masking = True

    def compute_mask(self, inputs, mask=None):
        return mask  # Pass mask through unchanged

    def call(self, inputs, mask=None):
        output = tf.matmul(inputs, self.kernel) + self.bias
        if mask is not None:
            output = output * tf.cast(mask[:, :, tf.newaxis], dtype=output.dtype)
        return output

Mixed Precision Considerations

When using tf.keras.mixed_precision, computations may run in float16. Custom layers need to handle this:

class PrecisionAwareLayer(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.kernel = self.add_weight(
            "kernel",
            shape=(input_shape[-1], self.units),
            # Let the policy handle dtype
            dtype=self.compute_dtype
        )

    def call(self, inputs):
        # Cast inputs to compute dtype
        inputs = tf.cast(inputs, self.compute_dtype)
        output = tf.matmul(inputs, self.kernel)
        # Cast back to layer's result dtype (float32 for loss stability)
        return tf.cast(output, self.dtype)

The compute_dtype and dtype properties are set by the mixed precision policy. Always use them instead of hard-coding float32.

Building Composite Layers (Blocks)

Production architectures are built from reusable blocks:

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate

    def build(self, input_shape):
        self.attention = tf.keras.layers.MultiHeadAttention(
            num_heads=self.num_heads, key_dim=self.embed_dim
        )
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(self.ff_dim, activation="gelu"),
            tf.keras.layers.Dense(self.embed_dim),
        ])
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(self.dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(self.dropout_rate)

    def call(self, inputs, training=None):
        attn_output = self.attention(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        x = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(x)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(x + ffn_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "ff_dim": self.ff_dim,
            "dropout_rate": self.dropout_rate,
        })
        return config

Sub-layers created in build() (or __init__) are automatically tracked — their weights appear in layer.trainable_variables and are included in checkpoints.

Testing Custom Layers

A production-grade test suite for a custom layer should verify:

import numpy as np

def test_layer_output_shape():
    layer = AttentionGate(units=32)
    x = tf.random.normal((4, 64))
    output = layer(x)
    assert output.shape == (4, 32)

def test_layer_training_flag():
    layer = AttentionGate(units=32)
    x = tf.random.normal((4, 64))
    out_train = layer(x, training=True)
    out_infer = layer(x, training=False)
    # EMA should update only during training
    assert not np.allclose(out_train.numpy(), out_infer.numpy()) or True

def test_layer_serialization():
    layer = AttentionGate(units=32)
    layer.build(input_shape=(None, 64))
    config = layer.get_config()
    restored = AttentionGate.from_config(config)
    assert restored.units == 32

def test_layer_in_model_save_load(tmp_path):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(64,)),
        AttentionGate(units=32)
    ])
    model.compile(optimizer="adam", loss="mse")
    model.save(tmp_path / "test_model")
    loaded = tf.keras.models.load_model(tmp_path / "test_model")
    assert loaded.layers[1].units == 32

Run these tests under both eager mode and tf.function to catch tracing issues.

Performance Profiling

Custom layers can introduce bottlenecks. Profile with:

# Time a single forward pass
x = tf.random.normal((32, 256))
layer = MyCustomLayer(128)

# Warm up
_ = layer(x)

# Profile
import timeit
time = timeit.timeit(lambda: layer(x), number=1000)
print(f"{time:.3f}s for 1000 calls")

Common performance pitfalls in custom layers:

  • Excessive Python ops in call() — move computation into TensorFlow ops
  • Small tensor operations that cannot utilize GPU parallelism — batch them
  • Unnecessary tf.py_function calls — these break graph compilation
  • Creating tensors inside call() — create constants in build() instead

Real-World Custom Layer Patterns

Squeeze-and-Excitation blocks (used in EfficientNet) learn channel-wise attention:

class SEBlock(tf.keras.layers.Layer):
    def __init__(self, reduction=16, **kwargs):
        super().__init__(**kwargs)
        self.reduction = reduction

    def build(self, input_shape):
        channels = input_shape[-1]
        self.squeeze = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(channels // self.reduction, activation="relu")
        self.fc2 = tf.keras.layers.Dense(channels, activation="sigmoid")

    def call(self, inputs):
        se = self.squeeze(inputs)
        se = self.fc1(se)
        se = self.fc2(se)
        return inputs * se[:, tf.newaxis, tf.newaxis, :]

This pattern — global pooling → bottleneck → sigmoid gating — appears in architectures that won ImageNet competitions and is now standard in mobile vision models deployed by Google and Apple.

The one thing to remember: Custom layers must implement build/call/get_config correctly for graph mode, mixed precision, and serialization — but once they do, they are first-class citizens in every Keras workflow.

pythonmachine-learningtensorflowdeep-learning

See Also