Importance Weighted Hierarchical Variational Inference ( NeurIPS 2019 )


Expressivity of the variational family?

\(\rightarrow\) limited by the requirement of having a tractable density function.

Introduce New family of variational upper bounds on a log marginal density, in the case of “hierarchical models”

  • ex) Hierarchical Variational Models, Semi-Implicit Variational Inference, Doubly Semi-Implicit Variational Inference

1. Introduction

Efficiency & Accuracy of VI depend on how close an approximate posterior is to the true posterior!

This paper considers “hierarchical variational models”

  • where \(q(z \mid x)\) is represented as “mixture of tractable distn \(q(z \mid \psi, x)\)“

    over some tractable mixing distribution \(q( \psi \mid x) : q(z \mid x) = \int q(z \mid \psi,x) q(\psi \mid x) d \psi\).

  • such variational models contain semi-implicit models

2. Background

hierarchical model : \(p_{\theta}(x)=\int p_{\theta}(x \mid z) p_{\theta}(z) d z\)

  • 2 tasks : inference and learning

Variational Inference :

  • \(\log p_{\theta}(x) \geq \log p_{\theta}(x)-D_{K L}\left(q_{\phi}(z \mid x) \mid \mid p_{\theta}(z \mid x)\right)=\underset{q_{\phi}(z \mid x)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)}\).

    requires analytically tractable densities for both \(q_{\phi}(z \mid x)\) and \(p_{\theta}(x, z)\)

  • variational bias can be reduced by tightening the bound

  • IWAE bound :

    • \(\log p_{\theta}(x) \geq \underset{q_{\phi}\left(z_{1: M} \mid x\right)}{\mathbb{E}} \log \frac{1}{M} \sum_{m=1}^{M} \frac{p_{\theta}\left(x, z_{m}\right)}{q_{\phi}\left(z_{m} \mid x\right)}\).
  • price of this increased tightness is higher computation complexity that mostly stems from increased number of evaluations of high-dim decoder \(p_{\theta}(x \mid z)\)

Hierarchical Variational Model (HVM)

  • \(q_{\phi}(z \mid x)=\int q_{\phi}(z \mid x, \psi) q_{\phi}(\psi \mid x) d \psi\).

    where \(\psi\) are auxiliary latent variables

  • Tightness is controlled by auxiliary variational distribution \(\tau_{\eta}(\psi \mid x, z)\)

    \(\log p_{\theta}(x) \geq \underset{q_{\phi}(z \mid x)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \geq \underset{q_{\phi}(z, \psi \mid x)}{\mathbb{E}}\left[\log p_{\theta}(x, z)-\log \frac{q_{\phi}(z, \psi \mid x)}{\tau_{\eta}(\psi \mid x, z)}\right]\).

But this brings another auxiliary variational bias ( = gap 2 )

  • (gap 1) : marginal density & ELBO (1)
  • (gap 2) ELBO (1) & ELBO (2)

Semi-implicit models

  • hierarchical models \(q_{\phi}(z \mid x)=\) \(\int q_{\phi}(z \mid \psi, x) q_{\phi}(\psi \mid x) d \psi\) ….with

    • implicit but reparameterizable \(q_{\phi}(\psi \mid x)\)
    • explicit \(q_{\phi}(z \mid \psi, x)\)
  • SIVI bound

    \(\log p_{\theta}(x) \geq \underset{q_{\phi}\left(z, \psi_{0} \mid x\right) q_{\phi}\left(\psi_{1: K} \mid x\right)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{\frac{1}{K+1} \sum_{k=0}^{K} q_{\phi}\left(z \mid \psi_{k}, x\right)}\).

    ( gets tighter as the number of samples \(K\) increases )

  • SIVI can be generalized to use multiple samples \(z\) similar to the IWAE bound

    \(\log p(x) \geq \mathbb{E}\left[\log \frac{1}{M} \sum_{m=1}^{M} \frac{p_{\theta}\left(x, z_{m}\right)}{\frac{1}{K+1} \sum_{k=0}^{K} q_{\phi}\left(z_{m} \mid x, \psi_{m, k}\right)}\right]\).

2.1 SIVI Insights

SIVI’s points of weaknesses & how to generalize the method

SIVI bounds

  • (1) use samples from \(q_{\phi}\left(\psi_{1: K} \mid x\right)\) to describe \(z\)

    \(\rightarrow\) in higher dimensions, it would take many samples

  • (2) many such semi-implicit models can be equivalently reformulated as mixture of 2 explicit distns

    • due to reparameterizability!

    • (before) \(q_{\phi}(\psi \mid x)\)

      (after) \(\psi=g_{\phi}(\varepsilon \mid x)\) for some \(\varepsilon \sim q(\varepsilon)\)

3. Importance Weighted Hierarchical Variational Inference


3-1. Tractable Lower Bounds on Log marginal likelihood with a hierarchical proposal

IWHVI (Importance Weighted Hierarchical Variational Inference)

\(\log p_{\theta}(x) \geq \underset{q_{\phi}(z \mid x)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \geq \underset{q_{\phi}\left(z, \psi_{0} \mid x\right) \tau_{\eta}\left(\psi_{1: K} \mid z, x\right)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{\frac{1}{K+1} \sum_{k=0}^{K} \frac{q_{\phi}\left(z, \psi_{k} \mid x\right)}{\tau_{\eta}\left(\psi_{k} \mid z, x\right)}}\).

  • introduces an additional auxiliary variational distn \(\tau_{\eta}(\psi \mid x, z)\)

    ( optimal distribution : \(\tau(\psi \mid z, x)=q(\psi \mid z, x)\) )

4. Multisample Extension

4-1. Multisample Bound and Complexity

Generalize the bound!

Doubly Importance Weighted Hierarchical Variational Inference

\(\log p_{\theta}(x) \geq \mathbb{E}\left[\log \frac{1}{M} \sum_{m=1}^{M} \frac{p_{\theta}\left(x, z_{m}\right)}{\frac{1}{K+1} \sum_{k=0}^{K} \frac{q_{\phi}\left(z_{m}, \psi_{m, k} \mid x\right)}{\tau_{\eta}\left(\psi_{m, k} \mid z_{m}, x\right)}}\right]\).

  1. Sample \(\psi_{m, 0} \sim q_{\phi}(\psi \mid x)\) for \(1 \leq m \leq M\)
  2. Sample \(z_{m} \sim q_{\phi}\left(z \mid x_{n}, \psi_{m, 0}\right)\) for \(1 \leq m \leq M\)
  3. Sample \(\psi_{m, k} \sim \tau_{\eta}\left(\psi \mid z_{m}, x\right) \text { for } 1 \leq m \leq M \text { and } 1 \leq k \leq K\).

5. Conclusion

present a multisample variational upper bound on the log marginal density

allows us to give tight tractable lower bounds on the intractable ELBO in the case of HVM \(q_{\phi}(z \mid x)\)

Then, combine our bound with multisample IWAE bound, which led to a tighter lower bound of the log marginal likelihood.

