Tl;dr Approximate Inference methods made easy
“MCMC vs VI” is no longer a discussion about your favourite Roman numeral. If you share my trepidation for model performance in the face of data sparsity, or you simply suffer from anxiety uncertainty, you might be tempted into the Bayesian world. Years later at the precipice of your career (and mental health degeneracy) you over-engineer probabilistic models so intractable that would stress Lord Bayes himself into stomach ulcers. The solution? Approximate inference, the true antihero to model simplification.
A closed-form solution to a machine learning model is one that can be written down on a sheet of paper using a finite number of standard mathematical operations. For example, linear models have closed-form solutions IF the design covariance matrix is invertible, otherwise we obtain a solution using iterative optimisation.
Bayesian models do not typically have exact closed-form solutions for their posterior distributions. One thing that typically helps is choosing simple models, Gaussian likelihood functions and conjugate priors. A prior distribution is said to be conjugate to a likelihood function if the resulting posterior belongs to the same distribution family as the prior.
Bayesian linear regression is a model that typically assumes Gaussian priors over both the regression coefficients and the likelihood function. When we update the prior with the observed data (using Bayes’ theorem), the resulting posterior distribution for the regression coefficients will also follow a normal distribution. This can be written down analytically and sampled using standard methods in Python.
Conjugacy, however, does not always guarantee tractability. High-dimensional parameter spaces, hierarchical structures, non-Gaussian likelihoods with non-linear prior interactions can give rise to intractable integrals for the normalisation constant (which involve over the entire parameter space). This actually becomes prohibitive when want to build, say, a Multivariate Gaussian Linear Regression model with many predictors, or when we want to model count data using a Poisson likelihood and control for overfitting using on using a Laplace (non-conjugate) prior. Thankfully, a solution as old as the first computers comes to the rescue: Markov Chain Monte Carlo (MCMC).
Markov Chain Monte Carlo
Markov Chain Monte Carlo can be described with enough mathematical jargon to send one fleeing back to first-derivative optimisers, so I’ll skip the stomach ulcers and give an intuitive overview instead.
Given a probabilistic model parameterised by latent continuous random variables z, and observed values x, we can write down the known form for its probability density function P(z | x). If P(z | x) is intractable, we want to to generate an empirical distribution of samples based on a Markov chain that approximates the probability distribution.
This empirical distribution can then be used in place of the analytical solution to estimate posterior means, variances, quantiles, and other probabilistic summaries of the model parameters. The most important question: who is Markov and why are we talking about his chain?
In MCMC, a Markov chain is simply a sequence of samples where each sample is “memoryless”, i.e. the probability of transitioning to the next sample depends only on the current sample and not on the previous history. This helps us reach a “stationary” distribution over samples, i.e. when we run the chain long enough, the probability P(0 < z < 1 | x)_{n} at iteration n and P(0 < z < 1 | x)_{m} at iteration m should be equal.
What’s amazing is that the distribution over samples from our Markov chain provides asymptotic exactness; MCMC converges to the true posterior distribution in the limit of infinite samples. How is this implemented in practice?
The Metropolis-Hastings (MH) Algorithm
Metropolis-Hastings (MH) is a specific type of (MCMC) algorithm ubiquitously used in approximate inference. The idea is to build a chain of samples with a proposal distribution that selects “the next” sample based only on the “the current” sample (remember the Markov principle).
Proposed samples with higher probabilities in our posterior are accepted into the chain more frequently and those with lower probabilities are rejected more often (don’t make it into the chain). How do we capture the “tails” of our posterior if we’re busy focusing on high probability regions?
This is where Metropolis-Hastings acceptance/rejection mechanism really shines:
- For any given proposed sample, we define acceptance probability = min(1, α), where α as the ratio of target and proposal distributions at proposed and current samples.
- Next comes the heart of the algorithm’s exploration-exploitation mechanism: We generate a random number n in the domain [0,1]. If n ≤ α, accept the new sample in the chain; if n > α, keep the current sample and don’t extend the chain.
- The best bit? Samples with acceptance probabilities close to 1 are more likely to move the chain towards higher probability regions. Those with low acceptance probabilities can be accepted when 0 ≤ n ≤ α, exploring lower probability areas and avoiding local modes.
What diagnostics do we run to check that MH successfully converged to a stationary empirical approximation to the posterior?
Trace Plots and Chain Mixing
A “trace plot” allows us to inspect the chain by plotting accepted sample values for each iteration.
What we’re looking for is low autocorrelation between successive samples, and full exploration of the sample space characterised by high variance across moving windows of the trace. Chain 1 is an example of an ideal trace. Chain 2 initially has high autocorrelation and low variance but converges to stationarity after iteration t ~1500. We discard the head segment t < 1500 (so-called burn-in samples) since they’re unlikely to be part of our target distribution.
What about Chain 3? This trace demonstrates poor chain mixing, moving slowly across the parameter space between different regions of the distribution. One problem could be that we just haven’t let the algorithm run long enough; but Model complexity increases with multimodality, high dimensionality and correlated parameters the asymptotic exactness guarantee of MCMC doesn’t come with a tqdm, you could be waiting for quite a while. In these cases, we present the next best thing: variational inference.
Variatonal Inference
While MCMC offers asymptotic exactness around high dimensional distributions, it can be computationally intensive and impractical for complex distributions. Often we’re just interested in a rough approximation to the posterior that scales well for deployment.
Variational Inference (VI) frames the problem of approximating the posterior as an optimisation problem. Starting with a synthetic posterior Q(z | λ) built from families of simpler distributions (known as the variational family), we optimise over parameters λ that minimise the distance between the variational family and the true posterior P(z | x). This sounds cool but how do we choose the variational family? What even is a distance between distributions?
KL-Divergence
The choice of Q(z | λ) depends on the degree of flexibility required (increasing with the complexity of P(z | x)), but common choices are exponential or Gaussian distributions. To compare the “closeness” of P and Q, we employ a similarity measure such as the Kullback-Leibler (KL) divergence:
Although the KL divergence is asymmetric (DKL(Q||P) =/= DKL(P||Q)), it helps to quantify the difference between Q and P.
“But we don’t have a closed form solution for P(z|x)!” I hear you exclaiming correctly. That’s why we compute something called the Evidence Lower Bound (ELBO) instead: ELBO(λ) = log( P(x) ) −DKL(Q(z | λ)||P(z | x)). This can be rearranged as ELBO(λ) = E[log P(x, z)]−E[log Q(z∣λ)] helps us avoid that pesky intractable marginal integral.
Thus, maximizing ELBO is equivalent to minimizing the KL divergence, serves as an objective function we optimise using standard methods like co-ordinate ascent. Once λ are optimised, the approximating distribution Q(z | λ) serves as a surrogate for the true posterior. This approximation can then be used for downstream tasks like prediction, data imputation, or model interpretation.