TensorFlow Federated Learning — Deep Dive
TensorFlow Federated Architecture
TensorFlow Federated (TFF) separates federated computations into two layers:
- Federated Learning API — High-level tools for building federated training and evaluation pipelines
- Federated Core API — Low-level primitives for defining custom federated algorithms with explicit placement annotations (
tff.SERVER,tff.CLIENTS)
All TFF computations are defined as pure functional transformations — no side effects, no mutable state outside the computation graph. This makes them inspectable, testable, and deployable across different execution backends.
Building a Federated Model
Step 1: Define the Model
import tensorflow as tf
import tensorflow_federated as tff
def create_keras_model():
return tf.keras.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
def model_fn():
keras_model = create_keras_model()
return tff.learning.models.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
Step 2: Prepare Federated Data
TFF expects data as a list of datasets — one per client:
# Simulate federated data from EMNIST
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
# Each client's data is a tf.data.Dataset
def preprocess(dataset):
return dataset.map(
lambda x: (tf.reshape(x['pixels'], [-1]), x['label'])
).shuffle(512).batch(32)
# Select a subset of clients for each round
sample_clients = emnist_train.client_ids[:100]
federated_train_data = [
preprocess(emnist_train.create_dataset(client_id))
for client_id in sample_clients
]
Step 3: Build the Federated Process
# Standard Federated Averaging
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)
# Initialize server state
state = trainer.initialize()
# Run training rounds
for round_num in range(50):
# In production, select different clients each round
result = trainer.next(state, federated_train_data)
state = result.state
metrics = result.metrics
print(f"Round {round_num}: "
f"loss={metrics['client_work']['train']['loss']:.4f}, "
f"accuracy={metrics['client_work']['train']['sparse_categorical_accuracy']:.4f}")
Custom Aggregation Strategies
FedAvg with Momentum (FedAvgM)
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(
learning_rate=1.0, momentum=0.9
)
)
FedProx (Proximal Term)
FedProx adds a regularization term that penalizes local models for drifting too far from the global model, improving convergence with non-IID data:
trainer = tff.learning.algorithms.build_weighted_fed_prox(
model_fn=model_fn,
proximal_strength=0.1, # Higher = more regularization
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0)
)
Custom Aggregation with Federated Core
For full control, use the Federated Core API:
@tff.federated_computation(
tff.FederatedType(tf.float32, tff.CLIENTS)
)
def trimmed_mean(client_values):
"""Robust aggregation that drops outliers."""
# Collect all values at server
all_values = tff.federated_collect(client_values)
@tff.tf_computation(tff.SequenceType(tf.float32))
def compute_trimmed_mean(values):
tensor = values.reduce(
tf.constant([], dtype=tf.float32),
lambda acc, x: tf.concat([acc, [x]], axis=0)
)
sorted_vals = tf.sort(tensor)
n = tf.shape(sorted_vals)[0]
trim = tf.cast(tf.cast(n, tf.float32) * 0.1, tf.int32)
return tf.reduce_mean(sorted_vals[trim:n-trim])
return tff.federated_map(compute_trimmed_mean, all_values)
Differential Privacy Integration
User-Level Differential Privacy
# Clip individual client contributions and add Gaussian noise
dp_query = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
noise_multiplier=0.1,
clients_per_round=100,
clip=1.0 # L2 norm clip for each client's update
)
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
model_aggregator=dp_query
)
Privacy Budget Accounting
Track the accumulated privacy cost (epsilon, delta) across training rounds:
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
epsilon, optimal_order = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
n=total_training_examples,
batch_size=clients_per_round,
noise_multiplier=0.1,
epochs=num_rounds,
delta=1e-5
)
print(f"Privacy guarantee: (ε={epsilon:.2f}, δ=1e-5)")
A typical target is ε < 10 for meaningful privacy. Lower noise multipliers train faster but provide weaker guarantees. Google’s Gboard achieves ε ≈ 8-10 with thousands of rounds.
Secure Aggregation
Secure aggregation ensures the server only sees the sum of client updates:
# Use secure aggregation factory
secure_agg = tff.aggregators.SecureSumFactory(
upper_bound_threshold=2.0,
lower_bound_threshold=-2.0
)
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
model_aggregator=secure_agg
)
Communication Efficiency
Compression
# Lossy compression of model updates
compressed_agg = tff.aggregators.EncodedSumFactory(
tff.aggregators.encoders.as_gather_encoder(
sparsity_encoding_fn,
tf.TensorSpec(shape=[128, 10])
)
)
Client-Side Training Optimization
# Train for fewer local epochs on devices with less data
def client_optimizer_fn():
return tf.keras.optimizers.SGD(
learning_rate=0.02,
# Adaptive based on local dataset size
)
# Use learning rate warmup on new rounds
client_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=0.01,
decay_steps=10,
decay_rate=0.96
)
Simulation at Scale
Multi-GPU Simulation
# Distribute simulation across GPUs
tff.backends.native.set_sync_local_cpp_execution_context()
# Or use the remote execution backend for multi-machine simulation
tff.backends.native.set_remote_python_execution_context(
channels=[
grpc.insecure_channel(f'worker-{i}:8000')
for i in range(num_workers)
]
)
Realistic Client Simulation
Model real-world conditions in simulation:
import random
def simulate_round(state, all_client_data, clients_per_round=100):
# Simulate client selection with availability constraints
available = [c for c in all_client_data if random.random() < 0.3]
selected = random.sample(available, min(clients_per_round, len(available)))
# Simulate stragglers by dropping slow clients
completed = [
c for c in selected
if random.random() < 0.95 # 5% dropout rate
]
result = trainer.next(state, completed)
return result.state, result.metrics
Cross-Silo Federated Learning
Cross-silo FL (hospitals, banks) differs from cross-device (phones):
| Aspect | Cross-Device | Cross-Silo |
|---|---|---|
| Participants | Millions of phones | 2-100 organizations |
| Availability | Intermittent | Always on |
| Data per participant | Small (KBs-MBs) | Large (GBs-TBs) |
| Communication | Bandwidth-limited | Datacenter links |
| Trust model | Untrusted clients | Semi-trusted |
Cross-Silo Implementation
# Each silo preprocesses large local datasets
def silo_dataset(hospital_data_path):
dataset = tf.data.TFRecordDataset(hospital_data_path)
return dataset.map(parse_medical_image).batch(64)
# Fewer participants, more local training per round
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.Adam(1e-4),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
)
# Run more local epochs in cross-silo (data is abundant)
for round_num in range(200):
result = trainer.next(state, silo_datasets)
state = result.state
NVIDIA’s FLARE framework builds on similar principles for healthcare — hospitals collaboratively train radiology models without sharing patient images. Studies show federated models achieving 95-98% of centralized accuracy in medical imaging tasks.
Personalization Strategies
Local Fine-Tuning
After global training, each client fine-tunes the model on local data:
# Global model from federated training
global_model = extract_model_from_state(final_state)
# Per-client personalization
def personalize(client_data, num_epochs=5):
local_model = tf.keras.models.clone_model(global_model)
local_model.set_weights(global_model.get_weights())
local_model.compile(
optimizer=tf.keras.optimizers.SGD(1e-3),
loss='sparse_categorical_crossentropy'
)
local_model.fit(client_data, epochs=num_epochs)
return local_model
Split Personalization
Train shared feature extraction layers federally, keep classification layers local:
def model_fn_with_personalization():
model = create_model()
# Mark last layer as non-federated
return tff.learning.models.from_keras_model(
model,
input_spec=spec,
loss=loss,
metrics=[accuracy],
# Only aggregate non-personalization layers
)
Production Monitoring
Track these metrics across federated rounds:
| Metric | What to Monitor |
|---|---|
| Participation rate | % of selected clients that complete training |
| Round duration | Wall-clock time per round |
| Model accuracy (eval) | Performance on held-out centralized test set |
| Privacy budget (ε) | Cumulative privacy spend |
| Communication cost | Bytes sent/received per round |
| Client diversity | Distribution of participating device types/regions |
Alert on:
- Participation rate dropping below 70% (infrastructure issue)
- Accuracy not improving for 20+ rounds (convergence problem)
- Privacy budget approaching limit (stop training or reduce noise)
The one thing to remember: TFF provides the building blocks — FedAvg, differential privacy, secure aggregation, compression — but production federated learning requires careful tuning of client selection, communication efficiency, and privacy budgets specific to your deployment scenario.
See Also
- Activation Functions Why neural networks need these tiny mathematical functions — and how ReLU's simplicity accidentally made deep learning possible.
- Ai Agents Architecture How AI systems go from answering questions to actually doing things — the design patterns that turn language models into autonomous agents that browse, code, and plan.
- Ai Agents ChatGPT answers questions. AI agents actually do things — browse the web, write code, send emails, and keep going until the job is done. Here's the difference.
- Ai Ethics Why building AI fairly is harder than it sounds — bias, accountability, privacy, and who gets to decide what AI is allowed to do.
- Ai Hallucinations ChatGPT sometimes makes up facts with total confidence. Here's the weird reason why — and why it's not as simple as 'the AI lied.'