Variational Inference in 10 Lines

Leggo my ELBO

  1. One day you come home and realize it stinks like rotten fruit.
  2. You decide the smell right now is 60% of your maximum smell tolerance (which, upon reaching, you, of course, puke). So the smell intensity \(x\) right now is \(0.6\).
  3. Not in much of a hurry, you try to come up with a set of possible causes that might be causing this smell incident.
  4. You isolate three factors \(z_1\), \(z_2\), and \(z_3\) – that apple on your desk, your roommate's banana, and the strawberry in the fridge you forgot to give your niece when he visited last week.
  5. Together, those three components constitute a state vector \(z = [z_1, z_2, z_3]\). Each dimension of \(z\) denotes the rottage progress of each fruit, in the range \([0, 1]\).
  6. A state of \([0.2, 0.9, 0.1]\), for example, describes a universe where the apple is 20% rotten, the banana is 90% rotten, and the strawberry is 10% rotten.
  7. Ultimately you want to model \(P(z \vert x)\), which means \(P(causes \vert smell\ intensity)\). Maybe next time, you'll walk into your room, embrace the smell, wash your hands (because Covid), and use \(P(z \vert x)\) to easily identify the causes of the smell (\(z\)) given the stink intensity (\(x\)). And maybe do something about it, too.
  8. Being a good Bayesian statistician, you know that \(P(z \vert x) = P(x \vert z)P(z)/P(x)\). So all you need to figure out \(P(z \vert x)\) are \(P(x \vert z)\), \(P(z)\), and \(P(x)\). Seems a bit Goldbergian to go from calculating a single variable to three, eh? But please don't interrupt.
  9. \(P(x \vert z)\) is whatever you say it is. Because you need some assumptions to start your analysis from. Some people call this... [In THUNDEROUS VOICE] THE MODEL.
  10. \(P(z)\) is also whatever you decide it is. Because who out there has the right to force upon you with which format to jot down how rotten your fruits are? Maybe you want \(z_1\) to be either ROTTEN (\(-99999\)) or NOT ROTTEN (\(42\)). Maybe you are more sensible and use the scheme we described in #5.
  11. The most unfortunate thing about all this is that the last piece of the puzzle, \(P(x)\), is not calculable. You were so close... and yet so far. \(P(x)\), by definition, is \(\iiint_z P(x \vert z) ,dz_1,dz_2,dz_3\). Unfortunately, we as humankind are not smart enough to calculate \(\iiint_z whatever ,dz_1,dz_2,dz_3\) when \(z\) is highly multi-dimensional.
  12. So you either use variational inference or don't ever buy more than a handful of fruits unless you want to go back to that stinky life, helpless and unknowledgeable as to which berries are causing all the stench.
  13. At this point, you've abandoned all hope of calculating \(\iiint_z whatever ,dz_1,dz_2,dz_3\) and just assume \(P(z \vert x)\) to be Gaussian. We call this faux \(P(z \vert x)\), \(Q(z)\).
  14. Not to be defeatist about all this, but since we're assuming a faux distribution \(Q(z)\), there's bound to be some errors in our calculation of \(P(z \vert x)\). That error can be expressed as the distance from \(Q(z)\) to \(P(z \vert x)\). And we use Kullback-Leibler Divergence to measure the distance between two probability distributions: \(D_{KL}[Q(z) | P(z \vert x)]\).
  15. (Why \(Q\) to \(P\) and not \(P\) to \(Q\)? That is a discussion for the next installment of this series, "Forward and Reverse KLD, explained in 3 lines or less". For now, it suffices to say \(Q\) to \(P\) is like YOLO and \(P\) to \(Q\) is like anti-YOLO. And we want YOLO.)
  16. \(D_{KL}[Q(z) | P(z \vert x)]\) is equal to \(E_q[logQ(z) - logP(x, z)] + logP(x)\) (derivation). If you notice, \(D_{KL}[Q(z) | P(z \vert x)]\) is always positive, and \(log P(X)\) is not something we have control over. If you could control \(x\) – the smell– why are you calculating all this in the first place? So, in order to minimize \(D_{KL}[Q(z) | P(z \vert x)]\), we have to maximize \(-E_q[logQ(z) - logP(x, z)]\), which we dub... [IN THUNDEROUS VOICE] THE ELBO.
  17. So you see, if some parameter configuration of \(P(z \vert x)\) maximizes the ELBO, that \(P(z \vert x)\) was the \(P(z \vert x)\) we were after all along.
  18. That's the beauty of variational inference. You abandoned calculating \(P(z \vert x)\) but you still arrived at \(P(z \vert x)\).

That was 18 lines. Sue me.