Importance Weighted Hierarchical Variational Inference ( NeurIPS 2019 )
Abstract
Expressivity of the variational family?
\(\rightarrow\) limited by the requirement of having a tractable density function.
Introduce New family of variational upper bounds on a log marginal density, in the case of “hierarchical models”
- ex) Hierarchical Variational Models, Semi-Implicit Variational Inference, Doubly Semi-Implicit Variational Inference
1. Introduction
Efficiency & Accuracy of VI depend on how close an approximate posterior is to the true posterior!
This paper considers “hierarchical variational models”
-
where \(q(z \mid x)\) is represented as “mixture of tractable distn \(q(z \mid \psi, x)\)“
over some tractable mixing distribution \(q( \psi \mid x) : q(z \mid x) = \int q(z \mid \psi,x) q(\psi \mid x) d \psi\).
-
such variational models contain semi-implicit models
2. Background
hierarchical model : \(p_{\theta}(x)=\int p_{\theta}(x \mid z) p_{\theta}(z) d z\)
- 2 tasks : inference and learning
Variational Inference :
-
\(\log p_{\theta}(x) \geq \log p_{\theta}(x)-D_{K L}\left(q_{\phi}(z \mid x) \mid \mid p_{\theta}(z \mid x)\right)=\underset{q_{\phi}(z \mid x)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)}\).
requires analytically tractable densities for both \(q_{\phi}(z \mid x)\) and \(p_{\theta}(x, z)\)
-
variational bias can be reduced by tightening the bound
-
IWAE bound :
- \(\log p_{\theta}(x) \geq \underset{q_{\phi}\left(z_{1: M} \mid x\right)}{\mathbb{E}} \log \frac{1}{M} \sum_{m=1}^{M} \frac{p_{\theta}\left(x, z_{m}\right)}{q_{\phi}\left(z_{m} \mid x\right)}\).
-
price of this increased tightness is higher computation complexity that mostly stems from increased number of evaluations of high-dim decoder \(p_{\theta}(x \mid z)\)
Hierarchical Variational Model (HVM)
-
\(q_{\phi}(z \mid x)=\int q_{\phi}(z \mid x, \psi) q_{\phi}(\psi \mid x) d \psi\).
where \(\psi\) are auxiliary latent variables
-
Tightness is controlled by auxiliary variational distribution \(\tau_{\eta}(\psi \mid x, z)\)
\(\log p_{\theta}(x) \geq \underset{q_{\phi}(z \mid x)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \geq \underset{q_{\phi}(z, \psi \mid x)}{\mathbb{E}}\left[\log p_{\theta}(x, z)-\log \frac{q_{\phi}(z, \psi \mid x)}{\tau_{\eta}(\psi \mid x, z)}\right]\).
But this brings another auxiliary variational bias ( = gap 2 )
- (gap 1) : marginal density & ELBO (1)
- (gap 2) ELBO (1) & ELBO (2)
Semi-implicit models
-
hierarchical models \(q_{\phi}(z \mid x)=\) \(\int q_{\phi}(z \mid \psi, x) q_{\phi}(\psi \mid x) d \psi\) ….with
- implicit but reparameterizable \(q_{\phi}(\psi \mid x)\)
- explicit \(q_{\phi}(z \mid \psi, x)\)
-
SIVI bound
\(\log p_{\theta}(x) \geq \underset{q_{\phi}\left(z, \psi_{0} \mid x\right) q_{\phi}\left(\psi_{1: K} \mid x\right)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{\frac{1}{K+1} \sum_{k=0}^{K} q_{\phi}\left(z \mid \psi_{k}, x\right)}\).
( gets tighter as the number of samples \(K\) increases )
-
SIVI can be generalized to use multiple samples \(z\) similar to the IWAE bound
\(\log p(x) \geq \mathbb{E}\left[\log \frac{1}{M} \sum_{m=1}^{M} \frac{p_{\theta}\left(x, z_{m}\right)}{\frac{1}{K+1} \sum_{k=0}^{K} q_{\phi}\left(z_{m} \mid x, \psi_{m, k}\right)}\right]\).
2.1 SIVI Insights
SIVI’s points of weaknesses & how to generalize the method
SIVI bounds
-
(1) use samples from \(q_{\phi}\left(\psi_{1: K} \mid x\right)\) to describe \(z\)
\(\rightarrow\) in higher dimensions, it would take many samples
-
(2) many such semi-implicit models can be equivalently reformulated as mixture of 2 explicit distns
-
due to reparameterizability!
-
(before) \(q_{\phi}(\psi \mid x)\)
(after) \(\psi=g_{\phi}(\varepsilon \mid x)\) for some \(\varepsilon \sim q(\varepsilon)\)
-
3. Importance Weighted Hierarchical Variational Inference
3-1. Tractable Lower Bounds on Log marginal likelihood with a hierarchical proposal
IWHVI (Importance Weighted Hierarchical Variational Inference)
\(\log p_{\theta}(x) \geq \underset{q_{\phi}(z \mid x)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{q_{\phi}(z \mid x)} \geq \underset{q_{\phi}\left(z, \psi_{0} \mid x\right) \tau_{\eta}\left(\psi_{1: K} \mid z, x\right)}{\mathbb{E}} \log \frac{p_{\theta}(x, z)}{\frac{1}{K+1} \sum_{k=0}^{K} \frac{q_{\phi}\left(z, \psi_{k} \mid x\right)}{\tau_{\eta}\left(\psi_{k} \mid z, x\right)}}\).
-
introduces an additional auxiliary variational distn \(\tau_{\eta}(\psi \mid x, z)\)
( optimal distribution : \(\tau(\psi \mid z, x)=q(\psi \mid z, x)\) )
4. Multisample Extension
4-1. Multisample Bound and Complexity
Generalize the bound!
Doubly Importance Weighted Hierarchical Variational Inference
\(\log p_{\theta}(x) \geq \mathbb{E}\left[\log \frac{1}{M} \sum_{m=1}^{M} \frac{p_{\theta}\left(x, z_{m}\right)}{\frac{1}{K+1} \sum_{k=0}^{K} \frac{q_{\phi}\left(z_{m}, \psi_{m, k} \mid x\right)}{\tau_{\eta}\left(\psi_{m, k} \mid z_{m}, x\right)}}\right]\).
- Sample \(\psi_{m, 0} \sim q_{\phi}(\psi \mid x)\) for \(1 \leq m \leq M\)
- Sample \(z_{m} \sim q_{\phi}\left(z \mid x_{n}, \psi_{m, 0}\right)\) for \(1 \leq m \leq M\)
- Sample \(\psi_{m, k} \sim \tau_{\eta}\left(\psi \mid z_{m}, x\right) \text { for } 1 \leq m \leq M \text { and } 1 \leq k \leq K\).
5. Conclusion
present a multisample variational upper bound on the log marginal density
allows us to give tight tractable lower bounds on the intractable ELBO in the case of HVM \(q_{\phi}(z \mid x)\)
Then, combine our bound with multisample IWAE bound, which led to a tighter lower bound of the log marginal likelihood.