TensorFlow Custom Layers — Deep Dive
The Full Layer API
The tf.keras.layers.Layer base class provides a rich API beyond the basic build/call/get_config trio. Mastering these additional methods and properties lets you build layers that behave correctly in every context — eager mode, graph mode, mixed precision, distributed training, and serialization.
Weight Management in Detail
add_weight accepts several parameters that affect training and serialization:
import tensorflow as tf
class AttentionGate(tf.keras.layers.Layer):
def __init__(self, units, use_bias=True, **kwargs):
super().__init__(**kwargs)
self.units = units
self.use_bias = use_bias
def build(self, input_shape):
self.W_gate = self.add_weight(
name="gate_kernel",
shape=(input_shape[-1], self.units),
initializer="glorot_uniform",
regularizer=tf.keras.regularizers.l2(1e-4),
constraint=tf.keras.constraints.MaxNorm(3.0),
trainable=True
)
if self.use_bias:
self.b_gate = self.add_weight(
name="gate_bias",
shape=(self.units,),
initializer="zeros",
trainable=True
)
# Non-trainable state: exponential moving average of gate activations
self.ema_activation = self.add_weight(
name="ema_activation",
shape=(self.units,),
initializer="zeros",
trainable=False
)
def call(self, inputs, training=None):
gate = tf.matmul(inputs, self.W_gate)
if self.use_bias:
gate = gate + self.b_gate
gate = tf.nn.sigmoid(gate)
if training:
batch_mean = tf.reduce_mean(gate, axis=0)
self.ema_activation.assign(
0.99 * self.ema_activation + 0.01 * batch_mean
)
return inputs[:, :self.units] * gate
def get_config(self):
config = super().get_config()
config.update({
"units": self.units,
"use_bias": self.use_bias
})
return config
Key parameters for add_weight:
| Parameter | Purpose |
|---|---|
initializer | How weights start (glorot, he_normal, orthogonal, etc.) |
regularizer | Adds penalty to loss (L1, L2, or custom) |
constraint | Clips weights after each update (MaxNorm, NonNeg, etc.) |
trainable | Whether gradients flow to this variable |
Input Validation with compute_output_shape
Override compute_output_shape(input_shape) to let Keras verify tensor shapes at graph-construction time:
def compute_output_shape(self, input_shape):
return input_shape[0], self.units
This is critical for the Functional API — without it, Keras cannot infer output shapes when building the model graph, and model.summary() will show None for your layer’s output.
Graph-Mode Compatibility
Custom layers must work in both eager mode (debugging) and graph mode (production). The main rules:
Avoid Python State Mutations in call()
# BAD — only executes during tracing
def call(self, inputs):
self.call_count += 1 # Python int, not tracked by tf.function
return inputs * self.scale
# GOOD — uses TensorFlow variable
def call(self, inputs):
self.call_count.assign_add(1) # tf.Variable, tracked
return inputs * self.scale
Use TensorFlow Ops for Control Flow
# BAD — Python if evaluates once during tracing
def call(self, inputs, training=None):
if training:
inputs = tf.nn.dropout(inputs, rate=0.5)
return inputs
# GOOD — handles both eager and graph mode
def call(self, inputs, training=None):
if training:
inputs = tf.nn.dropout(inputs, rate=0.5)
return inputs
# Note: this works because Keras handles the training flag specially.
# For OTHER tensor-dependent conditions, use tf.cond:
def call(self, inputs):
return tf.cond(
tf.reduce_mean(inputs) > 0,
lambda: inputs * self.pos_scale,
lambda: inputs * self.neg_scale
)
Masking Support
If your layer should respect sequence masks (common in NLP), override compute_mask:
class MaskedDense(tf.keras.layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.supports_masking = True
def compute_mask(self, inputs, mask=None):
return mask # Pass mask through unchanged
def call(self, inputs, mask=None):
output = tf.matmul(inputs, self.kernel) + self.bias
if mask is not None:
output = output * tf.cast(mask[:, :, tf.newaxis], dtype=output.dtype)
return output
Mixed Precision Considerations
When using tf.keras.mixed_precision, computations may run in float16. Custom layers need to handle this:
class PrecisionAwareLayer(tf.keras.layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.kernel = self.add_weight(
"kernel",
shape=(input_shape[-1], self.units),
# Let the policy handle dtype
dtype=self.compute_dtype
)
def call(self, inputs):
# Cast inputs to compute dtype
inputs = tf.cast(inputs, self.compute_dtype)
output = tf.matmul(inputs, self.kernel)
# Cast back to layer's result dtype (float32 for loss stability)
return tf.cast(output, self.dtype)
The compute_dtype and dtype properties are set by the mixed precision policy. Always use them instead of hard-coding float32.
Building Composite Layers (Blocks)
Production architectures are built from reusable blocks:
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout_rate
def build(self, input_shape):
self.attention = tf.keras.layers.MultiHeadAttention(
num_heads=self.num_heads, key_dim=self.embed_dim
)
self.ffn = tf.keras.Sequential([
tf.keras.layers.Dense(self.ff_dim, activation="gelu"),
tf.keras.layers.Dense(self.embed_dim),
])
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(self.dropout_rate)
self.dropout2 = tf.keras.layers.Dropout(self.dropout_rate)
def call(self, inputs, training=None):
attn_output = self.attention(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
x = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(x)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(x + ffn_output)
def get_config(self):
config = super().get_config()
config.update({
"embed_dim": self.embed_dim,
"num_heads": self.num_heads,
"ff_dim": self.ff_dim,
"dropout_rate": self.dropout_rate,
})
return config
Sub-layers created in build() (or __init__) are automatically tracked — their weights appear in layer.trainable_variables and are included in checkpoints.
Testing Custom Layers
A production-grade test suite for a custom layer should verify:
import numpy as np
def test_layer_output_shape():
layer = AttentionGate(units=32)
x = tf.random.normal((4, 64))
output = layer(x)
assert output.shape == (4, 32)
def test_layer_training_flag():
layer = AttentionGate(units=32)
x = tf.random.normal((4, 64))
out_train = layer(x, training=True)
out_infer = layer(x, training=False)
# EMA should update only during training
assert not np.allclose(out_train.numpy(), out_infer.numpy()) or True
def test_layer_serialization():
layer = AttentionGate(units=32)
layer.build(input_shape=(None, 64))
config = layer.get_config()
restored = AttentionGate.from_config(config)
assert restored.units == 32
def test_layer_in_model_save_load(tmp_path):
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(64,)),
AttentionGate(units=32)
])
model.compile(optimizer="adam", loss="mse")
model.save(tmp_path / "test_model")
loaded = tf.keras.models.load_model(tmp_path / "test_model")
assert loaded.layers[1].units == 32
Run these tests under both eager mode and tf.function to catch tracing issues.
Performance Profiling
Custom layers can introduce bottlenecks. Profile with:
# Time a single forward pass
x = tf.random.normal((32, 256))
layer = MyCustomLayer(128)
# Warm up
_ = layer(x)
# Profile
import timeit
time = timeit.timeit(lambda: layer(x), number=1000)
print(f"{time:.3f}s for 1000 calls")
Common performance pitfalls in custom layers:
- Excessive Python ops in
call()— move computation into TensorFlow ops - Small tensor operations that cannot utilize GPU parallelism — batch them
- Unnecessary
tf.py_functioncalls — these break graph compilation - Creating tensors inside
call()— create constants inbuild()instead
Real-World Custom Layer Patterns
Squeeze-and-Excitation blocks (used in EfficientNet) learn channel-wise attention:
class SEBlock(tf.keras.layers.Layer):
def __init__(self, reduction=16, **kwargs):
super().__init__(**kwargs)
self.reduction = reduction
def build(self, input_shape):
channels = input_shape[-1]
self.squeeze = tf.keras.layers.GlobalAveragePooling2D()
self.fc1 = tf.keras.layers.Dense(channels // self.reduction, activation="relu")
self.fc2 = tf.keras.layers.Dense(channels, activation="sigmoid")
def call(self, inputs):
se = self.squeeze(inputs)
se = self.fc1(se)
se = self.fc2(se)
return inputs * se[:, tf.newaxis, tf.newaxis, :]
This pattern — global pooling → bottleneck → sigmoid gating — appears in architectures that won ImageNet competitions and is now standard in mobile vision models deployed by Google and Apple.
The one thing to remember: Custom layers must implement build/call/get_config correctly for graph mode, mixed precision, and serialization — but once they do, they are first-class citizens in every Keras workflow.
See Also
- Python Pytorch Lightning Training How PyTorch Lightning removes the boring parts of training AI models so researchers can focus on ideas instead of boilerplate.
- Python Tensorflow Data Pipelines How TensorFlow feeds data to your model without wasting time — explained like a restaurant kitchen that never stops cooking.
- Python Tensorflow Keras Api Why Keras is TensorFlow's friendly front door — and how it turns complex math into simple building blocks anyone can stack together.
- Python Tensorflow Model Optimization Why making a trained model smaller and faster matters — explained like packing a suitcase for a trip.
- Python Tensorflow Tensorboard How TensorBoard lets you watch your model learn in real time — explained like a fitness tracker for your AI.