A Hacker's Guide to Approximate Inference

31 January 2021 · 11 mins

I've seen a lot of blogs explaining approximate inference algorithms, but haven't found any that actually show the implementation details. This is going to be my attempt to try and explain approximate inference techniques while implementing them in Python from scratch.

Interactive Colab Notebook

Example Model

Let's consider a toy model to work with. Imagine that I give you a coin and flip it, and it lands heads. I tell you that the bias of the coin is some variable zz. This means that the probability that your coin lands in heads is zz. We don't know what zz is, but given an observation from the coin, we would like to infer a distribution over zz.

In inference parlance, zz is known as a latent variable and xx is known as an observation. For this example, we will assume that all zz's are equally likely, i.e., we have a uniform prior over our latents.

In all, we have access to,

  • A prior over our latent, p(z)U[0,1]p(z) \sim U[0, 1].
  • A model for how observations are generated given a particular latent, P(X=xZ=z)P(X = x | Z = z). This is fairly straightforward, if we know the bias of the coin is zz, then the probability, P(X=headsZ=z)=zP(X = \text{heads} | Z = z) = z and 1z1 - z for the tails case.

We would like to compute a posterior distribution over our latent, given some observation, P(ZX)P(Z | X).

Example Model in Python

Let's code our toy model in Python. I'm assuming we have access to the NumPy and SciPy packages. Let's first define a way to sample latents from our prior,

def prior():
    return np.random.uniform(0, 1)

And our model,

def model(z):
    if np.random.random() < z:
        return "HEADS"
    return "TAILS"

To generate samples from our model we can,

z = prior()
result = model(z)

Analytical Posterior

Before we dive into approximate inference algorithms, let's try to analytically derive a closed-form expression for our desired posterior. I'm going to try to be very explicit about notation for clarity.

Using Bayes rule,

P(Z=zX=heads)P(Z = z | X = \text{heads})
=P(X=headsZ=z)P(Z=z)P(X=heads)=z×101z dz=2z\begin{align} &= \frac{P(X = \text{heads} | Z = z) P(Z = z)}{P(X = \text{heads})} \\ &= \frac{z \times 1}{\int_0^1 z \ dz} \\ &= \boxed{ 2 z } \end{align}

If we graph our posterior, the density looks like this,

Analytical posterior graph

Figure 0:

Plotting the formula of our analytical posterior.

This makes sense, if the only observation we had about the coin was that it landed heads, we'd want to assign a higher probability mass towards z=1z = 1.

Since this was a toy example, computing the exact posterior in closed form was easy, but usually, this is not the case. In particular, if we didn't have access to a closed-form expression for our model, or if zz was very high dimensional, then marginalizing over all possible zz (as in the denominator) would be intractable.

Approximate Inference

We'll now discuss several approximate inference techniques. A lot of these methods are actually sampling techniques: they are ways of generating samples from some distribution pp that we can't generate samples from. Usually what is assumed is that we know,

p(x)=f(x)K\begin{align} p(x) = \frac{f(x)}{K} \end{align}

Where KK is some normalizing constant. So we lowkey know what the density should be for sampling some particular xx, by just evaluating f(x)f(x). But generating actual samples is hard.

A quick note about my notation abuse here, I use the variable xx here not to mean the observation, but just as a generic parameter into f(x)f(x).

This is actually very useful for our purposes, since our task in the inference setting is to generate samples from the posterior distribution, P(ZX=x)P(Z | X = x) and we know what this is proportional to,

P(ZX=x)=P(X=xZ)P(Z)P(X)    P(ZX=x)P(X=xZ)P(Z)\begin{align} P(Z | X = x) &= \frac{P(X = x | Z) P(Z)}{P(X)} \\ \implies P(Z | X = x) &\propto P(X = x | Z) P(Z) \end{align}

Often, we can construct an unnormalized likelihood function P(X=xZ)P(X = x | Z) more like a scoring function if we don't have access to an underlying model, without worrying if it's an actual density function since the normalizing constants can be written off. This has pretty neat parallels to optimization in general, but that's a story for another day 🙂.

Rejection Sampling

The core idea is to generate samples through our model, and only accept samples that match a particular condition. If you had some latents zz generating some observation xx, try a whole bunch of zz's from our prior and see which ones match our target observation xx. Plot distribution over accepted zz's.

Here is how to do this in Python,

observation = "HEADS"
iterations = 100000

samples = []

for _ in range(iterations):
    z = prior()
    result = model(z)
    
    if result == observation:
        samples.append(z)
        
_ = plt.hist(samples, bins=50)

This produces samples from our posterior that match our analytical posterior very closely,

Figure 0:

Samples from rejection sampling.

One thing to think about as an exercise to the reader is what happens when our model outputs real-valued observations. In that case, getting result == observation has a near 0 probability, and we'd get 0 samples. (Hint: how can we use values from our likelihood function?)

An issue with rejection sampling is that if we have a high dimensional latent variable with our posterior being concentrated relatively to few locations across the latent space, we'd have to get incredibly lucky to even generate a few samples. The convergence time hence scales exponentially as the number of dimensions increase.

Markov Chain Monte Carlo (MCMC)

The core idea here is to construct a Markov chain whose steady-state can approximate our posterior. All states of this Markov chain are in the domain of the target distribution we're trying to sample, in our case we're trying to sample from a conditional distribution over latents.

Let's start with some state x0x_0, and have some transition function that generates subsequent states. Markov chains will end up having a steady state. When it reaches that steady-state, all subsequent samples we'll pretend are actually samples from our target distribution.

But how do we construct such a Markov chain that ends up being like P(ZX=x)P(Z | X = x)? There are many algorithms, including Gibbs sampling and Hamiltonian Monte Carlo. We'll discuss the most widespread and simplest called Metropolis-Hastings.

I'll connect to the conditional stuff later, but as before, let's say our goal is to generate samples from some target distribution p(x)p(x) for which we only have access to the unnormalized density function, i.e.,

p(x)=f(x)K\begin{align} p(x) = \frac{f(x)}{K} \end{align}

So we only have access to f(x)f(x) (This is important because, in Bayesian inference, the normalizing constant is intractable.).

Step 1: Work with a known distribution g(x)g(x) that will serve as our transition proposal distribution, xt+1=g(xt)x_{t+1} = g(x_t). We want g(x)g(x) to be an easy distribution. Note that xt+1x_{t+1} isn't actually the next state, it's just a proposal. Consider an acceptance function, A(xtxt+1)A(x_t \to x_{t+1}) that gives us the probability of this transition. So to simulate a random walk, we start with some initial state x0x_0, run it through x1=g(x0)x_1 = g(x_0) and accept is a random coin flip of probability A(x0x1)A(x_0 \to x_1) is heads, if not, we set x1=x0x_1 = x_0 and continue.

Step 2: Consider the detailed balance condition. Let's say we have some states aa and bb and we are considering a transition from state aa to bb. This condition tells us when we have converged and is stated by,

p(a)T(ab)=p(b)T(ba)\begin{align*} p(a) T(a \to b) = p(b) T(b \to a) \end{align*}

Where TT is the transition probability from aa to bb. Substituting what we know,

f(a)g(ba)A(ab)=f(b)g(ab)A(ba)\begin{align*} f(a) g(b | a) A(a \to b) = f(b) g(a | b) A(b \to a) \end{align*}

Note that we actually need the PDF of gg to compute the g(ba)g(b | a) term. Therefore, we pick gg to be very nice to work with. For instance if we pick gg to be from a normal family, then g(ba)g(b | a) is the PDF evaluated at bb for a normal N(a,σ2)\mathcal{N}(a, \sigma^2). Here σ2\sigma^2 is a hyperparameter we can choose and is also known as the proposal width.

We can now do some algebra and arrive at,

A(ab)A(ba)=f(b)f(a)×g(ab)g(ba)\begin{align*} \frac{A(a \to b)}{A(b \to a)} = \frac{f(b)}{f(a)} \times \frac{g(a | b)}{g(b | a)} \end{align*}

Let,

rf=f(b)f(a)rg=g(ab)g(ba)\begin{align*} r_f &= \frac{f(b)}{f(a)} \\ r_g &= \frac{g(a | b)}{g(b | a)} \end{align*}

Then we must set A(ab)A(a \to b),

A(ab)=min(1,rf×rg)\begin{align*} A(a \to b) = \min(1, r_f \times r_g) \end{align*}

Here is a Python implementation of the above, we start with some unnormalized density function f and some starting state x0 and then use the proposal distribution and acceptance function to perform a random walk,

def metropolis_hastings(f, x0, iterations=100000):
    """Generate samples from a given unnormalized density function using Metropolis-Hastings.

    Args:
        f:  A function that gives us unnormalized density.
        x0: The starting state for the Markov chain.
        iterations: Number of iterations to run the random walk.

    Returns:
        An array of samples.
    """
    
    # Initialize our sample collection list.
    samples = [x0]
    
    for _ in range(iterations):
        last_sample = samples[-1]

        # By default the new_sample will be our last_sample
        # unless we decide to move.
        new_sample = last_sample

        # Proposal density.
        g = scipy.stats.norm(loc=last_sample)
        
        # Propose a candidate.
        candidate = g.rvs()

        # Compute the acceptance probability.
        rf = f(candidate) / f(last_sample)
        rg = g.pdf(last_sample) / g.pdf(candidate)
        acceptance = min(1, rf * rg)
        
        # Flip a coin to decide if we decide to move to
        # this new candidate.
        if np.random.random() < acceptance:
            new_sample = candidate

        samples.append(new_sample)
        
    return samples

To actually use this, we need an unnormalized density. Observe that for our conditional case, this is just,

f(z)=P(Z=zX=x)=P(X=xZ=z)P(Z=z)\begin{align*} f(z) &= P(Z = z | X = x) \\ &= P(X = x | Z = z) P(Z = z) \end{align*}

The numerator in Bayes rules is also known as the joint, from the definition of conditional probability. In Python, we can implement this as,

observation = "HEADS"
prior_dist = scipy.stats.uniform(0, 1)

def joint(z):
    """Compute P(X = x | Z = z) P(Z = z) for some fixed global `observation`.

    Intuitively, it's computing for some latent `z`, what's the probability we get
    our observation.
    """

    if not (0 <= z <= 1):
        return 0

    prior_z = prior_dist.pdf(z)

    if observation == "HEADS":
        return z * prior_z

    if observation == "TAILS":
        return (1 - z) * prior_z

Now we can use our Metropolis-Hastings sampler on this joint to sample from our posterior. The first few samples MCMC would give us won't be very accurate, so we will not plot those. This is called burn-in: the time it takes for MCMC to converge to a chain that can produce valid samples.

We can try to do posterior inference using,

samples = metropolis_hastings(joint, 0.5)
_ = plt.hist(samples[1000:], bins=50)

And this yields,

Figure 0:

Samples from MCMC sampling.

The MCMC sampler I wrote above is probably the slowest one you can find on the internet. But it is serving a pedagogical purpose. PyMC3 implements a solid one, along with a host of other samplers. You can get fancy with MCMC and probabilistic programming and do a random walk over program execution traces to provide a black-boxy interface. This chapter does a great job explaining how WebPPL implements MCMC.

Shreyas Kapur's Blog