Variational_EM

Variational EM

Recap: EM Algorithm

  • E-Step: $q^{(t+1)}= \text{arg}\min_q\sum_{_i} \mathcal D_{KL}\left( q({\bf z}_i) \mid \mid p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )\right)$
  • M-Step: $\theta^{(t+1)} = \text{arg}\max_\theta \sum_{i} \mathbb E_{q_i^{(t+1)}} \left[ \log {p({\bf x}_i, {\bf z_i} \mid \theta )} \right] $

E-Step

$$q^{(t+1)}= \text{arg}\min_{q\in Q}\sum_{_i} \mathcal D_{KL}\left( q({\bf z}_i) \mid \mid p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )\right)$$

$Q$: variational family

from $$ p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )= \frac{\hat p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )}{ p( {\bf x_i}\mid \theta^{(t)} )} $$

\begin{align} \mathcal D_{KL}\left( q({\bf z}_i) \mid \mid p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )\right) &= \int q({\bf z}_i) \log \frac{q({\bf z}_i)}{p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )} d{\bf z}_i \\ &= \int q({\bf z}_i) \log \frac{q({\bf z}_i)}{\frac{\hat p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )}{ p( {\bf x_i}\mid \theta^{(t)} )}} d{\bf z}_i \\ &= \int q({\bf z}_i) \log \frac{q({\bf z}_i)}{\hat p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )} d{\bf z}_i + \int q({\bf z}_i) \log p({\bf x}_i \mid \theta^{(t)} ) d{\bf z}_i \\ &= \mathcal D_{KL}\left( q({\bf z}_i) \mid \mid \hat p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )\right) + \log p({\bf x}_i\mid \theta) \end{align}

for the optimization w.r.t. $q$ we can neglect the second term, so we have for the E-Step:

$$q^{(t+1)}= \text{arg}\min_{q\in Q}\sum_{_i} \mathcal D_{KL}\left( q({\bf z}_i) \mid \mid \hat p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )\right)$$

In variational EM with mean field approximation we get

  • for the hidden states $\bf Z$ as approximation for the posterior $p({\bf Z}\mid {\bf X})$: $q({\bf Z}) = \prod_i q({\bf z}_i)$
  • for $\theta$ a point estimate, i.e. the maximum likelihood estimate or MAP.