TensorFlow Federated Learning — Deep Dive

TensorFlow Federated Architecture

TensorFlow Federated (TFF) separates federated computations into two layers:

  • Federated Learning API — High-level tools for building federated training and evaluation pipelines
  • Federated Core API — Low-level primitives for defining custom federated algorithms with explicit placement annotations (tff.SERVER, tff.CLIENTS)

All TFF computations are defined as pure functional transformations — no side effects, no mutable state outside the computation graph. This makes them inspectable, testable, and deployable across different execution backends.

Building a Federated Model

Step 1: Define the Model

import tensorflow as tf
import tensorflow_federated as tff

def create_keras_model():
    return tf.keras.Sequential([
        tf.keras.layers.Input(shape=(784,)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=preprocessed_example_dataset.element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

Step 2: Prepare Federated Data

TFF expects data as a list of datasets — one per client:

# Simulate federated data from EMNIST
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

# Each client's data is a tf.data.Dataset
def preprocess(dataset):
    return dataset.map(
        lambda x: (tf.reshape(x['pixels'], [-1]), x['label'])
    ).shuffle(512).batch(32)

# Select a subset of clients for each round
sample_clients = emnist_train.client_ids[:100]
federated_train_data = [
    preprocess(emnist_train.create_dataset(client_id))
    for client_id in sample_clients
]

Step 3: Build the Federated Process

# Standard Federated Averaging
trainer = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)

# Initialize server state
state = trainer.initialize()

# Run training rounds
for round_num in range(50):
    # In production, select different clients each round
    result = trainer.next(state, federated_train_data)
    state = result.state
    metrics = result.metrics

    print(f"Round {round_num}: "
          f"loss={metrics['client_work']['train']['loss']:.4f}, "
          f"accuracy={metrics['client_work']['train']['sparse_categorical_accuracy']:.4f}")

Custom Aggregation Strategies

FedAvg with Momentum (FedAvgM)

trainer = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(
        learning_rate=1.0, momentum=0.9
    )
)

FedProx (Proximal Term)

FedProx adds a regularization term that penalizes local models for drifting too far from the global model, improving convergence with non-IID data:

trainer = tff.learning.algorithms.build_weighted_fed_prox(
    model_fn=model_fn,
    proximal_strength=0.1,  # Higher = more regularization
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0)
)

Custom Aggregation with Federated Core

For full control, use the Federated Core API:

@tff.federated_computation(
    tff.FederatedType(tf.float32, tff.CLIENTS)
)
def trimmed_mean(client_values):
    """Robust aggregation that drops outliers."""
    # Collect all values at server
    all_values = tff.federated_collect(client_values)

    @tff.tf_computation(tff.SequenceType(tf.float32))
    def compute_trimmed_mean(values):
        tensor = values.reduce(
            tf.constant([], dtype=tf.float32),
            lambda acc, x: tf.concat([acc, [x]], axis=0)
        )
        sorted_vals = tf.sort(tensor)
        n = tf.shape(sorted_vals)[0]
        trim = tf.cast(tf.cast(n, tf.float32) * 0.1, tf.int32)
        return tf.reduce_mean(sorted_vals[trim:n-trim])

    return tff.federated_map(compute_trimmed_mean, all_values)

Differential Privacy Integration

User-Level Differential Privacy

# Clip individual client contributions and add Gaussian noise
dp_query = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
    noise_multiplier=0.1,
    clients_per_round=100,
    clip=1.0  # L2 norm clip for each client's update
)

trainer = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
    model_aggregator=dp_query
)

Privacy Budget Accounting

Track the accumulated privacy cost (epsilon, delta) across training rounds:

from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib

epsilon, optimal_order = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
    n=total_training_examples,
    batch_size=clients_per_round,
    noise_multiplier=0.1,
    epochs=num_rounds,
    delta=1e-5
)

print(f"Privacy guarantee: (ε={epsilon:.2f}, δ=1e-5)")

A typical target is ε < 10 for meaningful privacy. Lower noise multipliers train faster but provide weaker guarantees. Google’s Gboard achieves ε ≈ 8-10 with thousands of rounds.

Secure Aggregation

Secure aggregation ensures the server only sees the sum of client updates:

# Use secure aggregation factory
secure_agg = tff.aggregators.SecureSumFactory(
    upper_bound_threshold=2.0,
    lower_bound_threshold=-2.0
)

trainer = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
    model_aggregator=secure_agg
)

Communication Efficiency

Compression

# Lossy compression of model updates
compressed_agg = tff.aggregators.EncodedSumFactory(
    tff.aggregators.encoders.as_gather_encoder(
        sparsity_encoding_fn,
        tf.TensorSpec(shape=[128, 10])
    )
)

Client-Side Training Optimization

# Train for fewer local epochs on devices with less data
def client_optimizer_fn():
    return tf.keras.optimizers.SGD(
        learning_rate=0.02,
        # Adaptive based on local dataset size
    )

# Use learning rate warmup on new rounds
client_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.01,
    decay_steps=10,
    decay_rate=0.96
)

Simulation at Scale

Multi-GPU Simulation

# Distribute simulation across GPUs
tff.backends.native.set_sync_local_cpp_execution_context()

# Or use the remote execution backend for multi-machine simulation
tff.backends.native.set_remote_python_execution_context(
    channels=[
        grpc.insecure_channel(f'worker-{i}:8000')
        for i in range(num_workers)
    ]
)

Realistic Client Simulation

Model real-world conditions in simulation:

import random

def simulate_round(state, all_client_data, clients_per_round=100):
    # Simulate client selection with availability constraints
    available = [c for c in all_client_data if random.random() < 0.3]
    selected = random.sample(available, min(clients_per_round, len(available)))

    # Simulate stragglers by dropping slow clients
    completed = [
        c for c in selected
        if random.random() < 0.95  # 5% dropout rate
    ]

    result = trainer.next(state, completed)
    return result.state, result.metrics

Cross-Silo Federated Learning

Cross-silo FL (hospitals, banks) differs from cross-device (phones):

AspectCross-DeviceCross-Silo
ParticipantsMillions of phones2-100 organizations
AvailabilityIntermittentAlways on
Data per participantSmall (KBs-MBs)Large (GBs-TBs)
CommunicationBandwidth-limitedDatacenter links
Trust modelUntrusted clientsSemi-trusted

Cross-Silo Implementation

# Each silo preprocesses large local datasets
def silo_dataset(hospital_data_path):
    dataset = tf.data.TFRecordDataset(hospital_data_path)
    return dataset.map(parse_medical_image).batch(64)

# Fewer participants, more local training per round
trainer = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.Adam(1e-4),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
)

# Run more local epochs in cross-silo (data is abundant)
for round_num in range(200):
    result = trainer.next(state, silo_datasets)
    state = result.state

NVIDIA’s FLARE framework builds on similar principles for healthcare — hospitals collaboratively train radiology models without sharing patient images. Studies show federated models achieving 95-98% of centralized accuracy in medical imaging tasks.

Personalization Strategies

Local Fine-Tuning

After global training, each client fine-tunes the model on local data:

# Global model from federated training
global_model = extract_model_from_state(final_state)

# Per-client personalization
def personalize(client_data, num_epochs=5):
    local_model = tf.keras.models.clone_model(global_model)
    local_model.set_weights(global_model.get_weights())
    local_model.compile(
        optimizer=tf.keras.optimizers.SGD(1e-3),
        loss='sparse_categorical_crossentropy'
    )
    local_model.fit(client_data, epochs=num_epochs)
    return local_model

Split Personalization

Train shared feature extraction layers federally, keep classification layers local:

def model_fn_with_personalization():
    model = create_model()
    # Mark last layer as non-federated
    return tff.learning.models.from_keras_model(
        model,
        input_spec=spec,
        loss=loss,
        metrics=[accuracy],
        # Only aggregate non-personalization layers
    )

Production Monitoring

Track these metrics across federated rounds:

MetricWhat to Monitor
Participation rate% of selected clients that complete training
Round durationWall-clock time per round
Model accuracy (eval)Performance on held-out centralized test set
Privacy budget (ε)Cumulative privacy spend
Communication costBytes sent/received per round
Client diversityDistribution of participating device types/regions

Alert on:

  • Participation rate dropping below 70% (infrastructure issue)
  • Accuracy not improving for 20+ rounds (convergence problem)
  • Privacy budget approaching limit (stop training or reduce noise)

The one thing to remember: TFF provides the building blocks — FedAvg, differential privacy, secure aggregation, compression — but production federated learning requires careful tuning of client selection, communication efficiency, and privacy budgets specific to your deployment scenario.

pythonmachine-learningtensorflowprivacy

See Also

  • Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
  • Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
  • Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
  • Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
  • Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'