( Skip the basic parts + not important contents )

10. Approximate Inference

Important task in probabilistic models = “evaluation of the posterior” (\(P(\mathbf{Z}\mid\mathbf{X})\) )

EM algorithm

  • evaluate the expectation of complete-data log likelihood, w.r.t posterior of latent variable


But in practice….

  • problem 1) dimension of latent space : TOO HIGH
  • problem 2) not analytically tractable
    • continuous ) may not have closed-form
    • discrete ) always possible in principle, but exponentially many hidden states \(\rightarrow\)expensive


Thus, we need “APPROXIMATION”….fall into 2 classes

  • 1) stochastic approx

    • MCMC ( Ch.11 )
    • given infinite computational resource, generate exact result!
    • but computationally demanding \(\rightarrow\) limited to small-scale problems
  • 2) deterministic approx

    • scale well to large applications

    • based on “analytical approximations” to posterior

      ( ex. factorizes with Gaussian )

    • Laplace approximation

      ( based on local Gaussian approximation to a mode of distn )

    • Variational Inference, Expectation Propagation


10-1. Variational Inference

“approximate \(p\) with simple distribution \(q\) “


10-1-1. Factorized distribution

assume that \(q\) factorizes as…

  • \(q(\mathrm{Z})=\prod_{i=1}^{M} q_{i}\left(\mathrm{Z}_{i}\right)\).


Make Lower bound \(L(q)\) largest

\(\begin{aligned} \mathcal{L}(q) &=\int \prod_{i} q_{i}\left\{\ln p(\mathbf{X}, \mathbf{Z})-\sum_{i} \ln q_{i}\right\} \mathrm{d} \mathbf{Z} \\ &=\int q_{j}\left\{\int \ln p(\mathbf{X}, \mathbf{Z}) \prod_{i \neq j} q_{i} \mathrm{~d} \mathbf{Z}_{i}\right\} \mathrm{d} \mathbf{Z}_{j}-\int q_{j} \ln q_{j} \mathrm{~d} \mathbf{Z}_{j}+\text { const } \\ &=\int q_{j} \ln \tilde{p}\left(\mathbf{X}, \mathbf{Z}_{j}\right) \mathrm{d} \mathbf{Z}_{j}-\int q_{j} \ln q_{j} \mathrm{~d} \mathbf{Z}_{j}+\text { const } \end{aligned}\).

  • define new distribution \(\tilde{p}\left(\mathbf{X}, \mathbf{Z}_{j}\right)\)
    • \(\ln \tilde{p}\left(\mathbf{X}, \mathbf{Z}_{j}\right)=\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})]+\operatorname{const.}\).
    • where \(\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})]=\int \ln p(\mathbf{X}, \mathbf{Z}) \prod_{i \neq j} q_{i} \mathrm{~d} \mathbf{Z}_{i}\)


Maximizing lower bound (\(L(q)\) ) = minimizing KL-divergence

\(\rightarrow\) occurs when \(q_{j}\left(\mathbf{Z}_{j}\right)=\tilde{p}\left(\mathbf{X}, \mathbf{Z}_{j}\right) .\)


Thus, optimal solution

  • \[\ln q_{j}^{\star}\left(\mathbf{Z}_{j}\right)=\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})]+\text { const. }\]
    • constant : set by normalizing the distribution
  • \(q_{j}^{\star}\left(\mathbf{Z}_{j}\right)=\frac{\exp \left(\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})]\right)}{\int \exp \left(\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})]\right) \mathrm{d} \mathbf{Z}_{j}}\).


10-1-2. Properties of factorized approximations

one approach of VI = based on “factorized approximation”

  • ex) using a factorized Gaussian


Example

  • \(p(\mathbf{z})=\mathcal{N}\left(\mathbf{z} \mid \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1}\right)\)., where \(\mathrm{z}=\left(z_{1}, z_{2}\right)\)
    • mean and variance : \(\boldsymbol{\mu}=\left(\begin{array}{c} \mu_{1} \\ \mu_{2} \end{array}\right), \quad \boldsymbol{\Lambda}=\left(\begin{array}{ll} \Lambda_{11} & \Lambda_{12} \\ \Lambda_{21} & \Lambda_{22} \end{array}\right)\).
  • \(q(\mathrm{z})=\) \(q_{1}\left(z_{1}\right) q_{2}\left(z_{2}\right)\)


Only retain those terms that depend on \(z_1\)

( all other terms are absorbed to normalizing constant )

\(\begin{aligned} \ln q_{1}^{\star}\left(z_{1}\right) &=\mathbb{E}_{z_{2}}[\ln p(\mathbf{z})]+\text { const } \\ &=\mathbb{E}_{z_{2}}\left[-\frac{1}{2}\left(z_{1}-\mu_{1}\right)^{2} \Lambda_{11}-\left(z_{1}-\mu_{1}\right) \Lambda_{12}\left(z_{2}-\mu_{2}\right)\right]+\mathrm{const} \\ &=-\frac{1}{2} z_{1}^{2} \Lambda_{11}+z_{1} \mu_{1} \Lambda_{11}-z_{1} \Lambda_{12}\left(\mathbb{E}\left[z_{2}\right]-\mu_{2}\right)+\mathrm{const.} \end{aligned}\).


Therefore, \(q^{\star}\left(z_{1}\right)=\mathcal{N}\left(z_{1} \mid m_{1}, \Lambda_{11}^{-1}\right)\)

  • \(m_{1}=\mu_{1}-\Lambda_{11}^{-1} \Lambda_{12}\left(\mathbb{E}\left[z_{2}\right]-\mu_{2}\right)\).


As the same way, \(q_{2}^{\star}\left(z_{2}\right)=\mathcal{N}\left(z_{2} \mid m_{2}, \Lambda_{22}^{-1}\right)\)

  • \(m_{2}=\mu_{2}-\Lambda_{22}^{-1} \Lambda_{21}\left(\mathbb{E}\left[z_{1}\right]-\mu_{1}\right)\).


\(q^{\star}\left(z_{1}\right)\) needs \(q^{\star}\left(z_{2}\right)\), and vise versa!

\(\rightarrow\) solution : non-singular … \(\mathbb{E}\left[z_{1}\right]=\mu_{1}\) and \(\mathbb{E}\left[z_{2}\right]=\mu_{2}\)


Reverse KL-Divergence (\(\mathrm{KL}(p \| q)\))

  • used in alternative approximate inference, called “EXPECTATION PROPAGATION”

\(\mathrm{KL}(p \| q)=-\int p(\mathbf{Z})\left[\sum_{i=1}^{M} \ln q_{i}\left(\mathbf{Z}_{i}\right)\right] \mathrm{d} \mathbf{Z}+\text { const }\).


Thus, \(q_{j}^{\star}\left(\mathbf{Z}_{j}\right)=\int p(\mathbf{Z}) \prod_{i \neq j} \mathrm{~d} \mathbf{Z}_{i}=p\left(\mathbf{Z}_{j}\right)\).

\(\rightarrow\) optimal solution is just given by the marginal distn of \(P(Z)\) .. closed form!

( do not require iteration )


Comparison

figure2

(a) : \(KL(p\mid \mid q)\) ( reverse KL )

(b) : \(KL(q\mid \mid p)\) (1)

(c) : \(KL(q\mid \mid p)\) (2)

  • minimizing (a) : minimized by \(q(Z)\) that are nonzero, where \(p(Z)\) is non-zero
    • will average across all of the modes
    • lead to poor predictive distn
  • minimizing (b,c) : minimized by \(q(Z)\) that avoid where \(p(Z)\) is small
    • will tent to find one of these multi-modes


both (a) and (b,c) are members of “alpha family”

\(\mathrm{D}_{\alpha}(p \| q)=\frac{4}{1-\alpha^{2}}\left(1-\int p(x)^{(1+\alpha) / 2} q(x)^{(1-\alpha) / 2} \mathrm{~d} x\right)\).

  • when \(\alpha=0\) : symmetric divergence

    called “Helinger distance”

    \(\mathrm{D}_{\mathrm{H}}(p \| q)=\int\left(p(x)^{1 / 2}-q(x)^{1 / 2}\right) \mathrm{d} x\).


10-1-3. Ex) Univariate Gaussian

Factorized Variational approximation using Gaussian

Goal : infer posterior for

  • mean \(\mu\)
  • precision \(\tau\)


Likelihood : \(p(\mathcal{D} \mid \mu, \tau)=\left(\frac{\tau}{2 \pi}\right)^{N / 2} \exp \left\{-\frac{\tau}{2} \sum_{n=1}^{N}\left(x_{n}-\mu\right)^{2}\right\}\).


Conjugate prior ( Gaussian-Gamma conjuagte prior distn )

  • mean (=Normal) : \(p(\mu \mid \tau) =\mathcal{N}\left(\mu \mid \mu_{0},\left(\lambda_{0} \tau\right)^{-1}\right)\)
  • precision (=Gamma) : \(p(\tau) =\operatorname{Gam}\left(\tau \mid a_{0}, b_{0}\right)\)


We will consider a factorized variational approximation, as

\(q(\mu, \tau)=q_{\mu}(\mu) q_{\tau}(\tau)\).


( as we have found in \(\ln q_{j}^{\star}\left(\mathbf{Z}_{j}\right)=\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})]+\text { const. }\)…. )

Mean (\(\mu\))

\(\begin{aligned} \ln q_{\mu}^{\star}(\mu) &=\mathbb{E}_{\tau}[\ln p(\mathcal{D} \mid \mu, \tau)+\ln p(\mu \mid \tau)]+\text { const } \\ &=-\frac{\mathbb{E}[\tau]}{2}\left\{\lambda_{0}\left(\mu-\mu_{0}\right)^{2}+\sum_{n=1}^{N}\left(x_{n}-\mu\right)^{2}\right\}+\text { const. } \end{aligned}\).


Thus \(q_{\mu}(\mu)\) follows \(\mathcal{N}\left(\mu \mid \mu_{N}, \lambda_{N}^{-1}\right)\)

  • mean ) \(\mu_{N} =\frac{\lambda_{0} \mu_{0}+N \bar{x}}{\lambda_{0}+N}\)
  • precision ) \(\lambda_{N} =\left(\lambda_{0}+N\right) \mathbb{E}[\tau]\)


for \(N \rightarrow \infty\)

  • same as MLE result! (ignores prior, considers only data)


Precision (\(\tau\))

\(\begin{aligned} \ln q_{\tau}^{\star}(\tau)=& \mathbb{E}_{\mu}[\ln p(\mathcal{D} \mid \mu, \tau)+\ln p(\mu \mid \tau)]+\ln p(\tau)+\text { const } \\ =&\left(a_{0}-1\right) \ln \tau-b_{0} \tau+\frac{N}{2} \ln \tau \\ &-\frac{\tau}{2} \mathbb{E}_{\mu}\left[\sum_{n=1}^{N}\left(x_{n}-\mu\right)^{2}+\lambda_{0}\left(\mu-\mu_{0}\right)^{2}\right]+\text { const } \end{aligned}\).


Thus \(q_{\tau}(\tau)\) follows \(\operatorname{Gam}\left(\tau \mid a_{N}, b_{N}\right)\)

  • \[a_{N}=a_{0}+\frac{N}{2}\]
  • \[b_{N}=b_{0}+\frac{1}{2} \mathbb{E}_{\mu}\left[\sum_{n=1}^{N}\left(x_{n}-\mu\right)^{2}+\lambda_{0}\left(\mu-\mu_{0}\right)^{2}\right]\]


for \(N \rightarrow \infty\)

  • same as MLE result! (ignores prior, considers only data)


Summary

  • (step 1) make initial guess for \(\mathbb{E}[\tau]\)
  • (step 2) re-compute \(q_{\mu}(\mu)\)
  • (step 3) with revised distn, calculate \(\mathbb{E}[\mu]\) and \(\mathbb{E}\left[\mu^{2}\right]\)
  • (step 4) re-compute \(q_{\tau}(\tau)\)

Repeat step1~step4 until convergence!


10-2. Illustration : Variational Mixture of Gaussian

\(p(\mathbf{Z} \mid \pi)=\prod_{n=1}^{N} \prod_{k=1}^{K} \pi_{k}^{z_{n k}}\).

\(p(\mathbf{X} \mid \mathbf{Z}, \boldsymbol{\mu}, \boldsymbol{\Lambda})=\prod_{n=1}^{N} \prod_{k=1}^{K} \mathcal{N}\left(\mathbf{x}_{n} \mid \boldsymbol{\mu}_{k}, \boldsymbol{\Lambda}_{k}^{-1}\right)^{z_{n k}}\).


priors for \(\mu, \Lambda\) and \(\pi\)

  • \(\pi\) : \(p(\pi)=\operatorname{Dir}\left(\pi \mid \alpha_{0}\right)=C\left(\alpha_{0}\right) \prod_{k=1}^{K} \pi_{k}^{\alpha_{0}-1}\).

    • \(C\left(\alpha_{0}\right)\) is a normalizing constant

    • \(\alpha_0\) : effective prior number of observations

      ( if \(\alpha_0\) is small \(\rightarrow\) more effect on data )

  • \(\mu\), \(\Lambda\) :

    \(\begin{aligned} p(\mu, \Lambda) &=p(\mu \mid \Lambda) p(\Lambda) \\ &=\prod_{k=1}^{K} \mathcal{N}\left(\mu_{k} \mid \mathbf{m}_{0},\left(\beta_{0} \Lambda_{k}\right)^{-1}\right) \mathcal{W}\left(\Lambda_{k} \mid \mathbf{W}_{0}, \nu_{0}\right) \end{aligned}\).

    • Gaussian-Wishart prior

      ( represents conjugate prior, when both mean and precision are unknown )


10-2-1. Variational Distribution

joint pdf : \(p(\mathbf{X}, \mathbf{Z}, \boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})=p(\mathbf{X} \mid \mathbf{Z}, \boldsymbol{\mu}, \boldsymbol{\Lambda}) p(\mathbf{Z} \mid \pi) p(\boldsymbol{\pi}) p(\boldsymbol{\mu} \mid \boldsymbol{\Lambda}) p(\boldsymbol{\Lambda})\)

variational distn : \(q(\mathbf{Z}, \boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})=q(\mathbf{Z}) q(\boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})\)


Sequential update can be easily done!

( as we have found in \(\ln q_{j}^{\star}\left(\mathbf{Z}_{j}\right)=\mathbb{E}_{i \neq j}[\ln p(\mathbf{X}, \mathbf{Z})]+\text { const. }\)…. )

  • update \(q(\mathbf{Z})\)
  • update \(q(\boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})\)


update \(q(\mathbf{Z})\)

\(\ln q^{\star}(\mathbf{Z})=\mathbb{E}_{\pi, \mu, \Lambda}[\ln p(\mathbf{X}, \mathbf{Z}, \boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})]+\text { const. }\).

\(\ln q^{\star}(\mathbf{Z})=\mathbb{E}_{\pi}[\ln p(\mathbf{Z} \mid \pi)]+\mathbb{E}_{\mu, \Lambda}[\ln p(\mathbf{X} \mid \mathbf{Z}, \boldsymbol{\mu}, \boldsymbol{\Lambda})]+\text { const. }\) ( by decomposing \(p(\mathbf{X}, \mathbf{Z}, \boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda}))\)

\(\ln q^{\star}(\mathbf{Z})=\sum_{n=1}^{N} \sum_{k=1}^{K} z_{n k} \ln \rho_{n k}+\operatorname{const}\).

  • where \(\begin{aligned} \ln \rho_{n k}=& \mathbb{E}\left[\ln \pi_{k}\right]+\frac{1}{2} \mathbb{E}\left[\ln \left|\boldsymbol{\Lambda}_{k}\right|\right]-\frac{D}{2} \ln (2 \pi) -\frac{1}{2} \mathbb{E}_{\boldsymbol{\mu}_{k}, \boldsymbol{\Lambda}_{k}}\left[\left(\mathbf{x}_{n}-\boldsymbol{\mu}_{k}\right)^{\mathrm{T}} \boldsymbol{\Lambda}_{k}\left(\mathbf{x}_{n}-\boldsymbol{\mu}_{k}\right)\right] \end{aligned}\)

\(q^{\star}(\mathbf{Z}) \propto \prod_{n=1}^{N} \prod_{k=1}^{K} \rho_{n k}^{z_{n k}}\) ( by taking exponential )

\(q^{\star}(\mathbf{Z})=\prod_{n=1}^{N} \prod_{k=1}^{K} r_{n k}^{z_{n k}}\) ( by normalizing )

  • where \(r_{n k}=\frac{\rho_{n k}}{\sum_{j=1}^{K} \rho_{n j}}\)


optimal solution for the factor \(q(Z)\) takes the same functional form as the prior $$p(Z \pi)$$

( since we have used conjugate prior! )

Result : \(\mathbb{E}\left[z_{n k}\right]=r_{n k}\)


update \(q(\boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})\)

\(\begin{array}{l} \ln q^{\star}(\pi, \mu, \Lambda)=\ln p(\pi)+\sum_{k=1}^{K} \ln p\left(\mu_{k}, \Lambda_{k}\right)+\mathbb{E}_{\mathbf{Z}}[\ln p(\mathbf{Z} \mid \pi)] +\sum_{k=1}^{K} \sum_{n=1}^{N} \mathbb{E}\left[z_{n k}\right] \ln \mathcal{N}\left(\mathbf{x}_{n} \mid \mu_{k}, \Lambda_{k}^{-1}\right)+\text {const. } \end{array}\).

  • because \(q(\boldsymbol{\pi}, \boldsymbol{\mu}, \boldsymbol{\Lambda})=q(\boldsymbol{\pi}) \prod_{k=1}^{K} q\left(\boldsymbol{\mu}_{k}, \boldsymbol{\Lambda}_{k}\right)\)


\(q(\boldsymbol{\pi})\) : Dirichlet

  • \(\ln q^{*}(\pi)=\left(\alpha_{0}-1\right) \sum_{k=1}^{K} \ln \pi_{k}+\sum_{k=1}^{K} \sum_{n=1}^{N} r_{n k} \ln \pi_{k}+\text { const }\).
  • \(q^{\star}(\pi)=\operatorname{Dir}(\pi \mid \alpha)\).
    • where \(\alpha_{k}=\alpha_{0}+N_{k}\)


\(q\left(\boldsymbol{\mu}_{k}, \boldsymbol{\Lambda}_{k}\right)\) : Gaussian-Wishart

  • \(q^{\star}\left(\boldsymbol{\mu}_{k}, \boldsymbol{\Lambda}_{k}\right)=\mathcal{N}\left(\boldsymbol{\mu}_{k} \mid \mathbf{m}_{k},\left(\beta_{k} \boldsymbol{\Lambda}_{k}\right)^{-1}\right) \mathcal{W}\left(\boldsymbol{\Lambda}_{k} \mid \mathbf{W}_{k}, \nu_{k}\right)\).

  • where

    \(\begin{aligned} \beta_{k} &=\beta_{0}+N_{k} \\ \mathrm{~m}_{k} &=\frac{1}{\beta_{k}}\left(\beta_{0} \mathrm{~m}_{0}+N_{k} \overline{\mathrm{x}}_{k}\right) \\ \mathrm{W}_{k}^{-1} &=\mathrm{W}_{0}^{-1}+N_{k} \mathrm{~S}_{k}+\frac{\beta_{0} N_{k}}{\beta_{0}+N_{k}}\left(\overline{\mathrm{x}}_{k}-\mathrm{m}_{0}\right)\left(\overline{\mathrm{x}}_{k}-\mathrm{m}_{0}\right)^{\mathrm{T}} \\ \nu_{k} &=\nu_{0}+N_{k} \end{aligned}\).


These updating equations are analogous to “M-step”