I've seen a lot of blogs explaining approximate inference algorithms, but haven't found any that actually show the implementation details. This is going to be my attempt to try and explain approximate inference techniques while implementing them in Python from scratch.
Interactive Colab Notebook
Let's consider a toy model to work with. Imagine that I give you a coin and flip it, and it lands heads. I tell you that the bias of the coin is some variable . This means that the probability that your coin lands in heads is . We don't know what is, but given an observation from the coin, we would like to infer a distribution over .
In inference parlance, is known as a latent variable and is known as an observation. For this example, we will assume that all 's are equally likely, i.e., we have a uniform prior over our latents.
In all, we have access to,
We would like to compute a posterior distribution over our latent, given some observation, .
Let's code our toy model in Python. I'm assuming we have access to the NumPy and SciPy packages. Let's first define a way to sample latents from our prior,
def prior():
return np.random.uniform(0, 1)
And our model,
def model(z):
if np.random.random() < z:
return "HEADS"
return "TAILS"
To generate samples from our model we can,
z = prior()
result = model(z)
Before we dive into approximate inference algorithms, let's try to analytically derive a closed-form expression for our desired posterior. I'm going to try to be very explicit about notation for clarity.
Using Bayes rule,
If we graph our posterior, the density looks like this,
Plotting the formula of our analytical posterior.
This makes sense, if the only observation we had about the coin was that it landed heads, we'd want to assign a higher probability mass towards .
Since this was a toy example, computing the exact posterior in closed form was easy, but usually, this is not the case. In particular, if we didn't have access to a closed-form expression for our model, or if was very high dimensional, then marginalizing over all possible (as in the denominator) would be intractable.
We'll now discuss several approximate inference techniques. A lot of these methods are actually sampling techniques: they are ways of generating samples from some distribution that we can't generate samples from. Usually what is assumed is that we know,
Where is some normalizing constant. So we lowkey know what the density should be for sampling some particular , by just evaluating . But generating actual samples is hard.
A quick note about my notation abuse here, I use the variable here not to mean the observation, but just as a generic parameter into .
This is actually very useful for our purposes, since our task in the inference setting is to generate samples from the posterior distribution, and we know what this is proportional to,
Often, we can construct an unnormalized likelihood function more like a scoring function if we don't have access to an underlying model, without worrying if it's an actual density function since the normalizing constants can be written off. This has pretty neat parallels to optimization in general, but that's a story for another day 🙂.
The core idea is to generate samples through our model, and only accept samples that match a particular condition. If you had some latents generating some observation , try a whole bunch of 's from our prior and see which ones match our target observation . Plot distribution over accepted 's.
Here is how to do this in Python,
observation = "HEADS"
iterations = 100000
samples = []
for _ in range(iterations):
z = prior()
result = model(z)
if result == observation:
samples.append(z)
_ = plt.hist(samples, bins=50)
This produces samples from our posterior that match our analytical posterior very closely,
Samples from rejection sampling.
One thing to think about as an exercise to the reader is what happens when our model outputs real-valued observations. In that case, getting result == observation
has a near 0
probability, and we'd get 0 samples. (Hint: how can we use values from our likelihood function?)
An issue with rejection sampling is that if we have a high dimensional latent variable with our posterior being concentrated relatively to few locations across the latent space, we'd have to get incredibly lucky to even generate a few samples. The convergence time hence scales exponentially as the number of dimensions increase.
The core idea here is to construct a Markov chain whose steady-state can approximate our posterior. All states of this Markov chain are in the domain of the target distribution we're trying to sample, in our case we're trying to sample from a conditional distribution over latents.
Let's start with some state , and have some transition function that generates subsequent states. Markov chains will end up having a steady state. When it reaches that steady-state, all subsequent samples we'll pretend are actually samples from our target distribution.
But how do we construct such a Markov chain that ends up being like ? There are many algorithms, including Gibbs sampling and Hamiltonian Monte Carlo. We'll discuss the most widespread and simplest called Metropolis-Hastings.
I'll connect to the conditional stuff later, but as before, let's say our goal is to generate samples from some target distribution for which we only have access to the unnormalized density function, i.e.,
So we only have access to (This is important because, in Bayesian inference, the normalizing constant is intractable.).
Step 1: Work with a known distribution that will serve as our transition proposal distribution, . We want to be an easy distribution. Note that isn't actually the next state, it's just a proposal. Consider an acceptance function, that gives us the probability of this transition. So to simulate a random walk, we start with some initial state , run it through and accept is a random coin flip of probability is heads, if not, we set and continue.
Step 2: Consider the detailed balance condition. Let's say we have some states and and we are considering a transition from state to . This condition tells us when we have converged and is stated by,
Where is the transition probability from to . Substituting what we know,
Note that we actually need the PDF of to compute the term. Therefore, we pick to be very nice to work with. For instance if we pick to be from a normal family, then is the PDF evaluated at for a normal . Here is a hyperparameter we can choose and is also known as the proposal width.
We can now do some algebra and arrive at,
Let,
Then we must set ,
Here is a Python implementation of the above, we start with some unnormalized density function f
and some starting state x0
and then use the proposal distribution and acceptance function to perform a random walk,
def metropolis_hastings(f, x0, iterations=100000):
"""Generate samples from a given unnormalized density function using Metropolis-Hastings.
Args:
f: A function that gives us unnormalized density.
x0: The starting state for the Markov chain.
iterations: Number of iterations to run the random walk.
Returns:
An array of samples.
"""
# Initialize our sample collection list.
samples = [x0]
for _ in range(iterations):
last_sample = samples[-1]
# By default the new_sample will be our last_sample
# unless we decide to move.
new_sample = last_sample
# Proposal density.
g = scipy.stats.norm(loc=last_sample)
# Propose a candidate.
candidate = g.rvs()
# Compute the acceptance probability.
rf = f(candidate) / f(last_sample)
rg = g.pdf(last_sample) / g.pdf(candidate)
acceptance = min(1, rf * rg)
# Flip a coin to decide if we decide to move to
# this new candidate.
if np.random.random() < acceptance:
new_sample = candidate
samples.append(new_sample)
return samples
To actually use this, we need an unnormalized density. Observe that for our conditional case, this is just,
The numerator in Bayes rules is also known as the joint, from the definition of conditional probability. In Python, we can implement this as,
observation = "HEADS"
prior_dist = scipy.stats.uniform(0, 1)
def joint(z):
"""Compute P(X = x | Z = z) P(Z = z) for some fixed global `observation`.
Intuitively, it's computing for some latent `z`, what's the probability we get
our observation.
"""
if not (0 <= z <= 1):
return 0
prior_z = prior_dist.pdf(z)
if observation == "HEADS":
return z * prior_z
if observation == "TAILS":
return (1 - z) * prior_z
Now we can use our Metropolis-Hastings sampler on this joint to sample from our posterior. The first few samples MCMC would give us won't be very accurate, so we will not plot those. This is called burn-in: the time it takes for MCMC to converge to a chain that can produce valid samples.
We can try to do posterior inference using,
samples = metropolis_hastings(joint, 0.5)
_ = plt.hist(samples[1000:], bins=50)
And this yields,
Samples from MCMC sampling.
The MCMC sampler I wrote above is probably the slowest one you can find on the internet. But it is serving a pedagogical purpose. PyMC3 implements a solid one, along with a host of other samplers. You can get fancy with MCMC and probabilistic programming and do a random walk over program execution traces to provide a black-boxy interface. This chapter does a great job explaining how WebPPL implements MCMC.