Variational-methods-and-mean-field-approximation

Variational Methods and Mean Field Approximation

Bayesian Learning (objective)

$$ p(\theta \mid \mathcal D) = \frac{p(\mathcal D \mid \theta) p(\theta)}{p(\mathcal D)} = \frac{\hat p( \theta \mid \mathcal D )}{p(\mathcal D)} $$

with the unnormalized $\hat p( \theta \mid \mathcal D )$

Goal: Find an approximation of the posterior distribution $p(\theta \mid \mathcal D)$.

Variational Bayes

$q_w(\theta)$ is our variational distribution for an approximation of $p(\theta \mid \mathcal D)$. $q_w(\theta)$ is a distribution in a parametrized family. The parameters are the $w$s.

To get a good approximation we use as objective the minimization of the KL-Divergence between $q_w(\theta)$ and $p(\theta \mid \mathcal D) $:

$$\begin{align} \mathcal D_{KL} \left( q_w(\theta) \mid \mid \frac{\hat p( \theta \mid \mathcal D )}{p(\mathcal D)} \right) &= \int_\Theta q_w(\theta) \log \frac{q_w(\theta)}{\frac{\hat p( \theta \mid \mathcal D )}{p(\mathcal D)}} d \theta\\ &= \int_\Theta q_w(\theta) \log \frac{q_w(\theta)}{\hat p( \theta \mid \mathcal D )} d \theta + \int_\Theta q_w(\theta) \log p(\mathcal D ) d \theta \\ &=\int_\Theta q_w(\theta) \log \frac{q_w(\theta)}{\hat p( \theta \mid \mathcal D )} d \theta + \log p(\mathcal D ) \\ &= \mathcal D_{KL} \left( q_w(\theta) \mid \mid \hat p( \theta \mid \mathcal D ) \right) + const. \end{align}$$

$\log p(\mathcal D )$ does not depend on $\theta$. So, we don't need to care about the difficulty of computing the normalizer $p(\mathcal D)$.

For searching for an approximation of the posterior ($p(\theta \mid \mathcal D)$) in the parametrized family of $q_w(\theta)$ (parametrization according to $w$), we minimize $D_{KL} \left( q_w(\theta) \mid \mid \hat p( \theta \mid \mathcal D ) \right)$, i.e.,

$$ w^* = \text{arg}\min_w \mathcal D_{KL} \left( q_w(\theta) \mid \mid \hat p( \theta \mid \mathcal D ) \right) $$

Then the normalized $q_{w^*}(\theta)$ is an approximation of $p(\theta \mid \mathcal D)$, e.g., if we are using a known parametric distribution.

Mean Field Approximation

In the mean field approximation the familiy of distributions factorizes (in blocks or total), e.g.

for variational Bayes: $$ Q = \left\{ q_w \mid q_w(\theta)= \prod_i q_w(\theta_i)\right\} $$

Notation

$\theta = (\theta_1, \theta_2, \dots, \theta_n)$

Optimization by coordinate descent

Loop over (until convergence):

  • optimize w.r.t. $q_1$
  • optimize w.r.t. $q_2$
  • $\dots$

Optimization w.r.t. $q_k$ in detail

$$ \begin{align} & \text{arg}\min_{q_k} \mathcal D_{KL} \left( \prod_i q({\bf \theta}_i) \mid \mid \hat p( {\bf \theta} \mid {\bf X} ) \right) \\ = & \text{arg}\min_{q_k} \left( \int \prod_i q({\bf \theta}_i) \log\frac{\prod_j q({\bf \theta}_j)}{\hat p( {\bf \theta} \mid {\bf X} )} d{\bf \theta}\right)\\ = &\text{arg}\min_{q_k} \left( \int \prod_i q({\bf \theta}_i) \log \left(\prod _j q({\bf \theta}_j) \right) d{\bf \theta} - \int \prod_i q({\bf \theta}_i) \log{\hat p( {\bf \theta} \mid {\bf X} )} d{\bf \theta} \right)\\ = &\text{arg}\min_{q_k} \left( \sum_j \int \prod_i q({\bf \theta}_i) \log{ q({\bf \theta}_j)} d{\bf \theta} - \int \prod_i q({\bf \theta}_i) \log{\hat p( {\bf \theta} \mid {\bf X} )} d{\bf \theta} \right)\\ = &\text{arg}\min_{q_k} \left( \int \prod_i q({\bf \theta}_i) \log{ q({\bf \theta}_k)} d{\bf \theta} + \sum_{j\neq k} \int \prod_i q({\bf \theta}_i) \log{ q({\bf \theta}_j)} d{\bf \theta} - \int \prod_i q({\bf \theta}_i) \log{\hat p( {\bf \theta} \mid {\bf X} )} d{\bf \theta} \right)\\ =& \text{arg}\min_{q_k} \left( A + B + C \right)\\ \end{align} $$

For convenience, let's simplify the tree terms ($A,B$ and $C$) separately:


$$\begin{align} A &= \int \prod_i^m q({\bf \theta}_i) \log{ q({\bf \theta}_k)} d{\bf \theta} \\ &= \int q({\bf \theta}_k) \log{ q({\bf \theta}_k)} \left(\prod_{i\neq k} q({\bf \theta}_i)\right) d{\bf \theta} \\ &= \int q({\bf \theta}_k) \log{ q({\bf \theta}_k)} \left( \int \prod_{i\neq k} q({\bf \theta}_i)d{\bf \theta}_{\neq k}\right) d{\bf \theta}_k \\ &= \int q({\bf \theta}_k) \log{ q({\bf \theta}_k)} d{\bf \theta}_k \\ \end{align}$$

Note that we used that the $q$'s are probability densities, i.e., it holds $\int q({\bf \theta}_l)d{\bf \theta}_l = 1$


$$\begin{align} B &= \sum_{j\neq k} \int \prod_i q({\bf \theta}_i) \log{ q({\bf \theta}_j)} d{\bf \theta} \\ &= \sum_{j\neq k} \int q({\bf \theta}_k) \prod_{i\neq k} q({\bf \theta}_i) \log{ q({\bf \theta}_j)} d{\bf \theta}_{\neq k} d{\bf \theta}_k \\ &= \sum_{j\neq k} \int \left( \int q({\bf \theta}_k) d{\bf \theta}_k \right) \prod_{i\neq k} q({\bf \theta}_i) \log{ q({\bf \theta}_j)} d{\bf \theta}_{\neq k} \\ &= \sum_{j\neq k} \int \prod_{i\neq k} q({\bf \theta}_i) \log{ q({\bf \theta}_j)} d{\bf \theta}_{\neq k} \end{align}$$

$B$ is independent of $q({\bf \theta}_k)$ and just a constant w.r.t. our minimization objectiv.


$$\begin{align} C &= - \int \prod_i q({\bf \theta}_i) \log{\hat p( {\bf \theta} \mid {\bf X} )} d{\bf \theta} \\ &= - \int q({\bf \theta}_k) \left(\int \prod_{i\neq k} q({\bf \theta}_i) \log{\hat p( {\bf \theta} \mid {\bf X} )} d{\bf \theta}_{\neq k} \right) d{\bf \theta}_k \\ &= - \int q({\bf \theta}_k) \mathbb E_{q_{-k}} \left[ \log{\hat p( {\bf \theta} \mid {\bf X} )} \right]d{\bf \theta}_k \end{align}$$

$\mathbb E_{q_{-k}} \left[ \log{\hat p( {\bf \theta} \mid {\bf X} )} \right]$ is the expectation of $\hat p$ with respect to all factors of $q({\bf \theta}_j)$ but not w.r.t. $q({\bf \theta}_{k})$.

We can write $\mathbb E_{q_{-k}} \left[ \log{\hat p( {\bf \theta} \mid {\bf X} )} \right]$ as a function of ${\bf \theta}_k$ $$ h({\bf \theta}_k) = \mathbb E_{q_{-k}} \left[ \log{\hat p( {\bf \theta} \mid {\bf X} )} \right] $$ because we have integrated out all other components ${\bf \theta}_j$ except ${\bf \theta}_k$.

With $h({\bf \theta}_k)$ we can define a new distribution $$ t({\bf \theta}_k) = \frac{\exp (h({\bf \theta}_k))}{\int \exp (h({\bf \theta}_k)) d{\bf \theta}_k} $$

resp. $$ h({\bf \theta}_k) = \log t({\bf \theta}_k) + \log \left( {\int \exp (h({\bf \theta}_k)) d{\bf \theta}_k} \right) = \log t({\bf \theta}_k) - const. $$

So we have $$\begin{align} C &= - \int q({\bf \theta}_k) \mathbb E_{q_{-k}} \left[ \log{\hat p( {\bf \theta} \mid {\bf X} )} \right]d{\bf \theta}_k \\ &= - \int q({\bf \theta}_k) h({\bf \theta}_k) d{\bf \theta}_k \\ &= - \int q({\bf \theta}_k) (\log t({\bf \theta}_k) - const. ) d{\bf \theta}_k \\ &= - \int q({\bf \theta}_k) \log t({\bf \theta}_k) d{\bf \theta}_k - const. \\ \end{align}$$


So, we have for the optimization objective (neglecting the constants):

$$ \begin{align} & \text{arg}\min_{q_k} \mathcal D_{KL} \left( \prod_i q({\bf \theta}_i) \mid \mid \hat p( {\bf \theta} \mid {\bf X} ) \right) \\ = & \text{arg}\min_{q_k} \left(\int q({\bf \theta}_k) \log{ q({\bf \theta}_k)} d{\bf \theta}_k - \int q({\bf \theta}_k) \log t({\bf \theta}_k) d{\bf \theta}_k \right)\\ = & \text{arg}\min_{q_k} \int q({\bf \theta}_k) \frac{\log{ q({\bf \theta}_k)}}{\log t({\bf \theta}_k)} d{\bf \theta}_k \\ = & \text{arg}\min_{q_k} \mathcal D_{KL} \left( q({\bf \theta}_k) \mid \mid t({\bf \theta}_k) \right) \end{align} $$

The KL-divergence is minimal if $q({\bf \theta}_k) = t({\bf \theta}_k)$ or

$$ \log q({\bf \theta}_k) = \log t({\bf \theta}_k) = h({\bf \theta}_k) + const. $$

i.e. if we set

$$ \log q({\bf \theta}_k) = \mathbb E_{q_{-k}} \left[ \log{\hat p( {\bf \theta} \mid {\bf X} )} \right] + const. $$

So, in the optimization loop (coordinate) gradient descent, $q({\bf \theta}_k)$ is computed by taking the expectation w.r.t. the current approximation of all "other" $q({\bf \theta}_{-k})$. The technical term mean field comes from this fact. The variational distribution of ${\bf \theta}_k$ is computed by taking the mean (expectation) with respect to the distributions $q({\bf \theta}_j)$ of all other variables ${\bf \theta}_j$ with $j \neq k$.

Variational EM

In the Bayesian approch the parameters are handled as (hidden) random variables. So, there is no clear distinction between parameters and variables. Assume, we have a model in which we want to predict for some parameters point estimates (MLE) and for some others distributions (Bayes).

  • $\theta$ are here our "Bayes parameters" (or other kinds of hidden variables)
  • $\tau$ are here the "MLE parameters" (point estimates)

We can write:

$p_\tau({\bf X}, {\bf \theta}) = p({\bf X}, {\bf \theta} \mid \tau)$

Recap

The marginal log-likelihood is greater-equal than the variational lower bound:

$$ \log p({\bf X} \mid \tau) \geq \mathcal L(q, \tau) $$

Implicit maximization of the marginal log-likelihood by maximization of the lower bound w.r.t. $q$ and $\tau$

$$ \max_{q, \tau} \mathcal L(q, \tau) = \max_{q, \tau} \mathbb E_{q({\bf \theta})} \left[\log \frac{p({\bf X}, {\bf \theta} \mid \tau )}{q({\bf \theta})}\right] $$

The EM-Algorithms has two steps:

  • The maximization of the lower bound w.r.t. $q$

    • E-Step: $q^{(t+1)}= \text{arg}\min_q \mathcal D_{KL}\left( q({\bf \theta}) \mid \mid p({\bf \theta} \mid {\bf X}, \tau^{(t)} )\right)$
  • The maximization of the lower bound w.r.t. $\tau$

    • M-Step: $\tau^{(t+1)} = \text{arg}\max_\tau \mathbb E_{q^{(t+1)}} \left[ \log {p({\bf X}, {\bf \theta} \mid \tau )} \right] $

Note that the E-Step corresponds to our assumption of minimizing the KL-divergence for variational inference for latent variable models. So, we can put things together for variational EM. This is necessary if we can not simply set ,e.g., $q({\bf \theta}_i) = p({\bf \theta}_i\mid {\bf X}, \tau)$ as we have it done here.

In variational EM with mean field approximation we take

  • for the hidden variables $\bf {\theta}$ an approximation for the posterior $p({\bf \theta}\mid {\bf X})$: $q({\bf \theta}) = \prod_i q({\bf \theta}_i)$
  • for $\theta$ a point estimate, i.e. the maximum likelihood estimate or MAP.

In the E-Step we compute $\mathcal D_{KL}$ with mean field approximation (in a loop over coordinates/blocks), i.e. for optimizing

$$q^{(t+1)}= \text{arg}\min_{q\in Q} \mathcal D_{KL}\left( q({\bf \theta}) \mid \mid \hat p({\bf \theta} \mid {\bf X}, \tau^{(t)} )\right)$$

we have an (inner) loop where we set

$$ \log q({\bf \theta}_i) = \mathbb E_{q_{-i}} \left[ \log{\hat p( {\bf \theta} \mid {\bf X} )} \right] + const. $$

until we reached convergence. Then we go to the M-Step (outer loop).

Literature