Bayesian Inference — Deep Dive

MCMC under the hood

Markov Chain Monte Carlo constructs a Markov chain whose stationary distribution is the target posterior. The chain explores parameter space, spending more time in high-probability regions. After enough steps, the samples approximate the posterior.

The NUTS sampler

PyMC defaults to the No U-Turn Sampler (NUTS), an adaptive variant of Hamiltonian Monte Carlo (HMC). HMC uses gradient information to propose moves that are far from the current position but still likely to be accepted — unlike random-walk Metropolis-Hastings, which proposes small local steps.

NUTS automatically tunes:

  • Step size (ε) — how far each leapfrog step goes
  • Number of leapfrog steps (L) — determined by detecting when the trajectory starts turning back (“U-turn”)
import pymc as pm

with pm.Model() as model:
    mu = pm.Normal('mu', mu=0, sigma=10)
    sigma = pm.HalfCauchy('sigma', beta=5)
    obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=data)
    
    # NUTS with explicit tuning control
    trace = pm.sample(
        draws=4000,
        tune=2000,
        chains=4,
        target_accept=0.9,  # Higher = more conservative steps
        random_seed=42
    )

MCMC diagnostics

Sampling is not guaranteed to converge. Always check:

R-hat (Gelman-Rubin diagnostic)

Compares variance within chains to variance between chains. R-hat should be below 1.01 for all parameters:

import arviz as az

summary = az.summary(trace, var_names=['mu', 'sigma'])
print(summary[['mean', 'sd', 'r_hat', 'ess_bulk', 'ess_tail']])

Effective sample size (ESS)

Autocorrelated samples carry less information than independent ones. ESS estimates the equivalent number of independent samples. Aim for ESS > 400 per chain for reliable posterior estimates.

Trace plots and rank plots

az.plot_trace(trace, var_names=['mu', 'sigma'])
az.plot_rank(trace, var_names=['mu', 'sigma'])  # Better for detecting convergence issues

Rank plots should show uniform histograms across chains. Systematic differences indicate convergence failure.

Divergences

NUTS flags “divergent transitions” when the numerical integrator fails. Divergences indicate the geometry of the posterior is too difficult for the sampler. Fixes:

  1. Increase target_accept (e.g., 0.95 or 0.99)
  2. Reparameterize the model (see below)
  3. Use stronger priors to constrain problematic regions

Reparameterization

Many convergence problems stem from correlated parameters or funnel geometries. The non-centered parameterization is the most important trick:

# Centered (problematic for hierarchical models)
with pm.Model():
    mu = pm.Normal('mu', 0, 10)
    sigma = pm.HalfNormal('sigma', 5)
    theta = pm.Normal('theta', mu=mu, sigma=sigma, shape=10)

# Non-centered (better geometry)
with pm.Model():
    mu = pm.Normal('mu', 0, 10)
    sigma = pm.HalfNormal('sigma', 5)
    theta_raw = pm.Normal('theta_raw', 0, 1, shape=10)
    theta = pm.Deterministic('theta', mu + sigma * theta_raw)

The non-centered version decouples theta from sigma, eliminating the funnel that causes divergences when sigma approaches zero.

Hierarchical models

Hierarchical (multilevel) models share information across groups while allowing group-level variation. Classic example: estimating batting averages across baseball players.

import numpy as np

# Data: hits and at-bats for 18 players
hits = np.array([18, 17, 16, 15, 14, 14, 13, 12, 11, 11, 10, 10, 10, 10, 10, 9, 8, 7])
at_bats = np.array([45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45])

with pm.Model() as hierarchical_model:
    # Hyperpriors (population-level)
    alpha = pm.Gamma('alpha', alpha=1, beta=1)
    beta = pm.Gamma('beta', alpha=1, beta=1)
    
    # Player-level parameters (shrunk toward population)
    theta = pm.Beta('theta', alpha=alpha, beta=beta, shape=len(hits))
    
    # Likelihood
    y = pm.Binomial('y', n=at_bats, p=theta, observed=hits)
    
    trace = pm.sample(2000, tune=1000)

The hierarchical structure “shrinks” extreme estimates toward the group mean — players with few at-bats get pulled more toward the average. This is partial pooling, and it consistently outperforms both ignoring group structure (complete pooling) and treating groups independently (no pooling).

Variational inference

When MCMC is too slow (large datasets, complex models, real-time requirements), variational inference (VI) approximates the posterior with a simpler distribution by minimizing the KL divergence:

with pm.Model() as model:
    mu = pm.Normal('mu', 0, 10)
    sigma = pm.HalfNormal('sigma', 5)
    obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=data)
    
    # ADVI: Automatic Differentiation Variational Inference
    approx = pm.fit(method='advi', n=30000)
    trace_vi = approx.sample(5000)

VI trades accuracy for speed. The approximation is typically overconfident (underestimates uncertainty). Use it for:

  • Initial exploration before MCMC
  • Models with very large datasets
  • Real-time inference in production

NumPyro for GPU-accelerated inference

NumPyro (backed by JAX) runs MCMC on GPUs, enabling 10–100× speedup for large models:

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp
import jax

def model(data):
    mu = numpyro.sample('mu', dist.Normal(0, 10))
    sigma = numpyro.sample('sigma', dist.HalfNormal(5))
    with numpyro.plate('data', len(data)):
        numpyro.sample('obs', dist.Normal(mu, sigma), obs=data)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=4)
mcmc.run(jax.random.PRNGKey(0), data=jnp.array(data))
mcmc.print_summary()

NumPyro compiles the model to XLA, enabling automatic vectorization and GPU execution. For models with thousands of parameters (e.g., Bayesian neural networks), this is often the only practical option.

Model comparison

WAIC and LOO-CV

Bayesian model comparison uses information criteria that account for model complexity:

# Compare two models
with model_1:
    trace_1 = pm.sample(2000)
    
with model_2:
    trace_2 = pm.sample(2000)

compare = az.compare({
    'model_1': trace_1,
    'model_2': trace_2
}, ic='loo')  # Leave-one-out cross-validation
print(compare)

LOO-CV (via Pareto-smoothed importance sampling) is generally preferred over WAIC. Lower values indicate better predictive performance.

Bayes factors

For nested models, Bayes factors quantify evidence for one model over another. A Bayes factor > 10 is “strong evidence.” However, Bayes factors are sensitive to prior choices and computationally expensive for complex models. Information criteria are usually more practical.

Posterior predictive checks

After fitting, verify the model generates data that resembles the observed data:

with model:
    ppc = pm.sample_posterior_predictive(trace, random_seed=42)

az.plot_ppc(az.from_pymc3(trace=trace, posterior_predictive=ppc))

If simulated data looks nothing like real data, the model is misspecified — regardless of how well the sampler converged.

Production deployment pattern

import pickle

# Fit offline
with model:
    trace = pm.sample(4000, tune=2000)

# Save posterior samples
with open('posterior.pkl', 'wb') as f:
    pickle.dump(trace.posterior.to_dict(), f)

# In production: load and predict
with open('posterior.pkl', 'rb') as f:
    posterior = pickle.load(f)

# Posterior predictive: sample from posterior parameters
mu_samples = posterior['mu']  
predictions = np.random.normal(mu_samples, posterior['sigma'])
prediction_mean = predictions.mean()
prediction_interval = np.percentile(predictions, [2.5, 97.5])

This separates expensive inference from cheap prediction, with full uncertainty propagation.

One thing to remember: Bayesian inference’s real value is not the point estimate — it is the full posterior distribution, which tells you not just “what” but “how sure,” and that uncertainty propagates honestly through every downstream decision.

pythonmathstatisticsprobabilitypymc

See Also

  • Python Convolution Operations The sliding-window trick that lets computers sharpen photos, recognize faces, and hear words in noisy audio.
  • Python Fourier Transforms How breaking any sound, image, or signal into simple waves reveals hidden patterns invisible to the naked eye.
  • Python Genetic Algorithms How computers borrow evolution's playbook — survival of the fittest, mutation, and reproduction — to solve problems too complicated for brute force.
  • Python Linear Algebra Numpy Why solving puzzles with rows and columns of numbers is the secret engine behind search engines, video games, and AI.
  • Python Markov Chains Why the next thing that happens often depends only on what is happening right now — and how that one rule generates text, predicts weather, and powers board games.