TensorFlow Keras API — Deep Dive

Architecture Under the Hood

Keras sits atop TensorFlow’s computation graph engine. When you call model.fit(), Keras compiles your model into a tf.function-traced graph, applies XLA optimizations when available, and dispatches operations to CPU, GPU, or TPU. Understanding this pipeline helps you write models that are both elegant and fast.

Layer Lifecycle

Every Keras layer goes through a precise lifecycle:

  1. __init__ — Store configuration (units, activation, etc.), but do not create weights yet. The input shape is unknown at this point.
  2. build(input_shape) — Called the first time the layer sees data. Create weight variables using self.add_weight(). This lazy initialization lets the same layer class work with any compatible input shape.
  3. call(inputs, training=None) — The forward computation. The training flag controls behavior differences like dropout (active during training, inactive during inference).
  4. get_config() — Return a dictionary of the layer’s configuration for serialization.
import tensorflow as tf

class ScaledDense(tf.keras.layers.Layer):
    def __init__(self, units, scale=1.0, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.scale = scale

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="glorot_uniform",
            trainable=True,
            name="kernel"
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer="zeros",
            trainable=True,
            name="bias"
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) * self.scale + self.b

    def get_config(self):
        config = super().get_config()
        config.update({"units": self.units, "scale": self.scale})
        return config

The add_weight method registers variables with TensorFlow’s variable tracking, enabling automatic gradient computation and checkpoint saving.

Graph Tracing and tf.function

When you call model.compile() and then model.fit(), Keras wraps the training step in a tf.function. The first call traces the Python code into a TensorFlow graph, which is then executed natively in C++ for all subsequent calls. This yields significant speedups — often 2-5x over eager execution.

Pitfalls to watch for:

  • Python side effects inside call() — Printing, appending to lists, or modifying Python state will only execute during tracing, not on subsequent calls.
  • Data-dependent control flow — Use tf.cond and tf.while_loop instead of Python if/for when branching depends on tensor values.
  • Input signature changes — Each new input shape or dtype triggers a retrace. Pad inputs to consistent shapes to avoid excessive retracing.

The Three APIs in Depth

Sequential: Implementation Details

Sequential is a subclass of Model that maintains an ordered list of layers. When you call model(x), it iterates through self.layers and applies each one. The constraint is strict: single input tensor, single output tensor, no branching.

Under the hood, Sequential builds a Functional model graph the first time it receives input. This means Sequential models get the same graph optimizations as Functional models — they are not slower.

Functional API: Symbolic Tensors

The Functional API works with symbolic tensors — placeholder objects that record the computation graph without executing it:

inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(256, activation="relu")(inputs)
x = tf.keras.layers.Dropout(0.3)(x)
x = tf.keras.layers.Dense(128, activation="relu")(x)
outputs = tf.keras.layers.Dense(10, activation="softmax")(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

Each layer call returns a new symbolic tensor that knows its parent. tf.keras.Model walks this graph to determine the full computation path, verify compatibility, and enable features like model.summary() and layer-by-layer inspection.

Multi-input/multi-output example:

text_input = tf.keras.Input(shape=(100,), name="text")
image_input = tf.keras.Input(shape=(64, 64, 3), name="image")

text_features = tf.keras.layers.Embedding(10000, 64)(text_input)
text_features = tf.keras.layers.GlobalAveragePooling1D()(text_features)

image_features = tf.keras.layers.Conv2D(32, 3, activation="relu")(image_input)
image_features = tf.keras.layers.GlobalAveragePooling2D()(image_features)

merged = tf.keras.layers.Concatenate()([text_features, image_features])
category_output = tf.keras.layers.Dense(5, activation="softmax", name="category")(merged)
priority_output = tf.keras.layers.Dense(1, activation="sigmoid", name="priority")(merged)

model = tf.keras.Model(
    inputs=[text_input, image_input],
    outputs=[category_output, priority_output]
)

Model Subclassing: Dynamic Graphs

Subclassed models define forward logic imperatively in call(). This enables architectures that cannot be expressed as static graphs:

class DynamicDepthModel(tf.keras.Model):
    def __init__(self, max_layers=5, units=64):
        super().__init__()
        self.blocks = [
            tf.keras.layers.Dense(units, activation="relu")
            for _ in range(max_layers)
        ]
        self.output_layer = tf.keras.layers.Dense(10, activation="softmax")
        self.gate = tf.keras.layers.Dense(1, activation="sigmoid")

    def call(self, inputs, training=None):
        x = inputs
        for block in self.blocks:
            x = block(x)
            confidence = self.gate(x)
            if tf.reduce_mean(confidence) > 0.9:
                break  # Early exit when confident
        return self.output_layer(x)

Tradeoff: Subclassed models lose some Functional API features — model.summary() cannot show the full architecture without running data through it, and serialization requires explicit get_config / from_config methods.

Custom Training Loops

The compile/fit workflow covers most scenarios, but custom training loops give full control:

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

@tf.function
def train_step(model, x_batch, y_batch):
    with tf.GradientTape() as tape:
        predictions = model(x_batch, training=True)
        loss = loss_fn(y_batch, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

for epoch in range(10):
    for x_batch, y_batch in train_dataset:
        loss = train_step(model, x_batch, y_batch)

This pattern is essential when you need:

  • Gradient accumulation across multiple batches (simulating larger batch sizes on limited GPU memory)
  • Gradient clipping before applying updates
  • Multiple loss functions with custom weighting schedules
  • Adversarial training (GANs) where generator and discriminator update alternately

Serialization and Deployment

Keras supports multiple serialization formats:

FormatUse CaseIncludes Optimizer State
SavedModel (model.save())Production deployment, TF ServingYes
HDF5 (.h5)Quick experiments, legacy codeYes
Weights only (.weights.h5)Architecture defined in codeNo
ONNX (via tf2onnx)Cross-framework deploymentNo

SavedModel is the recommended format for production. It serializes the computation graph, weights, and optimizer state into a directory structure that TF Serving, TF Lite, and TensorFlow.js can all consume directly.

# Save
model.save("models/my_classifier")

# Load — reconstructs architecture, weights, and optimizer
loaded = tf.keras.models.load_model("models/my_classifier")

For custom objects (custom layers, losses, metrics), register them with @tf.keras.utils.register_keras_serializable() so they deserialize correctly.

Performance Optimization Patterns

Mixed Precision Training

Enable float16 computation on supported GPUs for ~2x speedup with minimal accuracy impact:

tf.keras.mixed_precision.set_global_policy("mixed_float16")

Keras automatically keeps loss scaling in float32 to avoid underflow while running matrix multiplications in float16.

Distribution Strategies

Keras integrates with tf.distribute for multi-GPU and multi-node training:

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = build_model()
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")

model.fit(train_dataset, epochs=10)

MirroredStrategy replicates the model on each GPU, splits batches, and synchronizes gradients automatically. For multi-machine setups, MultiWorkerMirroredStrategy handles network communication.

Profiling with TensorBoard

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="./logs",
    profile_batch="10,20"  # Profile batches 10-20
)
model.fit(train_data, epochs=5, callbacks=[tensorboard_callback])

The TensorBoard profiler shows GPU utilization, operation timing, memory allocation, and data pipeline bottlenecks — critical for identifying whether your model is compute-bound or I/O-bound.

Real-World Architecture Patterns

Residual Connections (Functional API)

inputs = tf.keras.Input(shape=(256,))
x = tf.keras.layers.Dense(256, activation="relu")(inputs)
x = tf.keras.layers.Dense(256)(x)
x = tf.keras.layers.Add()([x, inputs])  # Skip connection
x = tf.keras.layers.Activation("relu")(x)
outputs = tf.keras.layers.Dense(10, activation="softmax")(x)

Feature Extraction with Pre-trained Models

base = tf.keras.applications.EfficientNetV2B0(
    weights="imagenet",
    include_top=False,
    input_shape=(224, 224, 3)
)
base.trainable = False  # Freeze pre-trained weights

model = tf.keras.Sequential([
    base,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(5, activation="softmax")
])

This pattern — using a frozen pre-trained backbone with a custom head — is the standard approach for transfer learning when you have limited training data. Uber, Airbnb, and Spotify all use variants of this in production image and audio classification systems.

The one thing to remember: Keras provides three APIs at increasing complexity — Sequential, Functional, and Subclassing — each backed by the same graph engine, so you start simple and only add complexity when your architecture demands it.

pythonmachine-learningtensorflowkeras

See Also