Improving Explorability in Variational Inference with Annealed Variational Objectives (NeurIPS 2018)

Abstract

Demonstrate the drawbacks of biasing the true posterior to be unimodal

Introduce Annealed Variational Objectives (AVO) into the training of hierarchical variational methods


1. Introduction

Variational approximations tend not to propagate uncertainty welk…….

\(\rightarrow\) can result in bias in statistics of certain features of the unobserved variable


This bias is caused by the “variational family that is used”

  • ex) factorial Gaussian
    • can be alleviated by more rich family of distn
    • but challenging for optimization


Variational Free Energy \(F\) :

\(\mathcal{F}(q)=\mathbb{E}_{q}[\log q(z)-\log f(z)]=\mathcal{D}_{\mathrm{KL}}(q \mid \mid f)\).

  • due to KL-term … \(q\) gets penalized for allocating probability mass where \(f\) is low

  • biased towards being excessively confident

    \(\rightarrow\) inhibit the variational approximation from escaping poor local minima

    • especially in “multi-modality”

      ( might get stuck fitting only a subset of modes )


2 annealing techniques:

(1) Alpha-annealing

  • energy tempering

  • \(\mathbb{E}_{q}[\log q(z)-\alpha \log f(z)]\).

  • when \(\alpha\) is small … energy landscape is smoothed out

    ( as goes to zero everywhere , when \(\alpha \rightarrow 0\) )

  • not be ideal in practice…


(2) Beta-annealing

  • deterministic warm-up

  • applied to improve training of generative model

  • of greater importance, in case of “hierarchical models”

  • \(\mathbb{E}_{q}[\beta(\log q(z)-\log p(z))-\log p(x \mid z)]\).

    • where \(\beta\) is annealed from 0 to 1

    • 1st term : over-regularizes the model.. by making \(q\) be like prior \(p\)

  • in conflict with energy tempering…

    ( in that it allows the approximate posterior to EXPLORE many modes ! )


First, review few techniques in VI & MCMC to tackle expressiveness problem

and optimization problem in inference


2. Related work

recent works in VI :

  • focused on “reducing representational bias”, especially in the setting of amortized VI ( known as VAE )


Importance-Weighted Auto Encoder (IWAE)

  • uses several samples for evaluating the loss to reduce variational gap
  • can be expensive, where decoder is much more complex


Jackknife Variational Inference

  • reduce the bias


3. Background

Latent Variable Model

  • joint prob : \(p_{\theta}(x,z) = p(x \mid z)p(z)\)

  • direct maximization of marginal likelihood is impossible

    ( \(\because\) \(\log p(x)=\log \int p(x, z) \mathrm{d} z\) )

  • thus, training is usually conducted by maximizing ……

    Expected Complete Data Log Likelihood ( ECLL ) over an auxiliary distn \(q\) = \(\max _{\theta} \mathbb{E}_{q(z)}\left[\log p_{\theta}(x, z)\right]\)


Exact inference = possible : \(q(z)=p(z \mid x)\)

Exact inference = impossible : approximate true posterior ( usually by MCMC or VI)

  • this approximation induces “bias”

    • \(\mathbb{E}_{q}\left[\log p_{\theta}(x, z)\right]=\log p_{\theta}(x)+\mathbb{E}_{q}\left[\log p_{\theta}(z \mid x)\right]\).

    • Maximizing ECLL = increases marginal likelihood of data,

      \(+\) while “biasing the true posterior” to be like auxiliary distribution


Due to zero-forcing property of KL…. \(q\) tends to be unimodal & more concentrated


This paper emphasizes that…

  • highly flexible parametric form of \(q\) can potentially alleviate this problem
  • but does not address the issue of finding the optimal parameters


Key points for approximate inference in practice

  • 1) true posterior is likely to be multi-modal

  • 2) biasing the posterior to be unimodal inhibits the model from learning true generative process of data

  • 3)

    • Beta annealing = facilitates point 1

      ( by lowering the penalty of the prior contrastive term )

    • Alpha annealing = encourages point 2 ( exploration )

      ( by lowering the penalty of cross-entropy term )


3-1. Assumption of Variational Family

recent work : more expressive parametric form

  • ex) Hierarchical Variational Inference (HVI)


Hierarchical Variational Inference (HVI)

  • generic family of methods that subsume “discrete mixture proposals”, “auxiliary variable methods”, “normalizing flows”

  • use latent variable model as approximate posterior

    \[q\left(z_{T}\right)=\int q\left(z_{T}, z_{t<T}\right) \mathrm{d} z_{t<T}\]
    • thus, entropy term ( = \(-\mathbb{E}_{q\left(z_{T}\right)}\left[\log q\left(z_{T}\right)\right]\) ) is intractable… needs to be approximated!

      ( can lower bound this, by using “REVERSE NETWORK” , \(r\left(z_{t<T} \mid z_{T}\right)\) )

      \(\begin{array}{l} \begin{aligned} -\mathbb{E}_{q\left(z_{T}\right)}\left[\log q\left(z_{T}\right)\right] & \geq-\mathbb{E}_{q\left(z_{T}\right)}\left[\log q\left(z_{T}\right)+\mathcal{D}_{\mathrm{KL}}\left(q\left(z_{t<T} \mid z_{T}\right) \mid \mid r\left(z_{t<T} \mid z_{T}\right)\right)\right] \\ &=-\mathbb{E}_{q\left(z_{T}, z_{t<T}\right)}\left[\log q\left(z_{T} \mid z_{t<T}\right) q\left(z_{t<T}\right)-\log r\left(z_{t<T} \mid z_{T}\right)\right] . \end{aligned} \end{array}\).

  • Variational Lower bound :

    \(\mathcal{L}(x) \doteq \mathbb{E}_{q\left(z_{T}, z_{t<T}\right)}\left[\log \frac{p\left(x, z_{T}\right) r\left(z_{t<T} \mid z_{T}\right)}{q\left(z_{T} \mid z_{t<T}\right) q\left(z_{t<T}\right)}\right]\).

  • \(q\left(z_{T}\right)\) : can be seen as an infinite mixture model!

where \(q\left(z_{T} \mid z_{t<T}\right)\) is deterministic and invertible, can choose \(r\) to be its **inverse **function.

  • KL term would vanish

  • Entropy can be computed recursively via the change-of-variable formula:

    \(q\left(z_{t}\right)=q\left(z_{t-1}\right)\mid \frac{\partial z_{t}}{\partial z_{t-1}}\mid ^{-1}\) .


3-2. Loss Function tempering : annealed importance sampling

purpose of alpha-annealing : let variational distn more “EXPLORATORY”

Annealed Importance Sampling (AIS)

  • MCMC method that encapsulates the concept of alpha-annealing

  • extended state space with \(z_0,z_1,..,z_T\)

    ( \(z_0\) : sampled from simple distn )

  • transition operators : \(q_t(z_t \mid z_{t-1})\)

  • Importance Weight :\(w_{j}=\frac{\tilde{f}_{1}\left(z_{1}\right)}{\tilde{f}_{0}\left(z_{1}\right)} \frac{\tilde{f}_{2}\left(z_{2}\right)}{\tilde{f}_{1}\left(z_{2}\right)} \ldots \frac{\tilde{f}_{T}\left(z_{T}\right)}{\tilde{f}_{T-1}\left(z_{T}\right)}\).

  • downside : requires constructing a LONG sequence of transitions


4. Annealed Variational Objectives

inspired by AIS & alpha-annealing

propose to integrate ENERGY TEMPERING into the optimization objective

  • ( like AIS ) consider extended state space with r.v. \(z_0,z_1,..,z_T\)
  • ( unlike AIS ) propose to LEARN the parametric transition operators!


This paper defines …

  • \[q_{t}\left(z_{t}, z_{t-1}\right)=q_{t-1}\left(z_{t-1}\right) q_{t}\left(z_{t} \mid z_{t-1}\right)\]
  • \[q_{t}\left(z_{t}\right)=\int_{z_{0}, \ldots, z_{t-1}} q_{t}\left(z_{t} \mid z_{t-1}\right) \prod_{t^{\prime}=0}^{t-1} q_{t^{\prime}}\left(z_{t^{\prime}} \mid z_{t^{\prime}-1}\right) \mathrm{d} z_{t^{\prime}}\]
  • set \(q_{0}\left(z_{0}\right)=f_{0}\left(z_{t}\right)^{3}\)


Consider maximizing the following objective, AVO ( = Annealed Variational Objectives )

  • \(\max _{q_{t}\left(z_{t} \mid z_{t-1}\right), r_{t}\left(z_{t-1} \mid z_{t}\right)} \mathbb{E}_{q_{t}\left(z_{t} \mid z_{t-1}\right) q_{t-1}\left(z_{t-1}\right)}\left[\log \frac{\tilde{f}_{t}\left(z_{t}\right) r_{t}\left(z_{t-1} \mid z_{t}\right)}{q_{t}\left(z_{t} \mid z_{t-1}\right) q_{t-1}\left(z_{t-1}\right)}\right]\).


Goal of each forward transition :

  • “to stochastically transform the samples drawn from previous step into the intermediate target distn assgined to it”


Loss Calibrated AVO

AVO depends on the optimality of each transition operator!

\(\rightarrow\) when used for amortized VI, each update will not necessarily improve the marginal to be a better approximate posterior


Thus, use loss calibrated version of AVO

\(\max _{q_{t}\left(z_{t} \mid z_{t-1}\right), r_{t}\left(z_{t-1} \mid z_{t}\right)} a \mathbb{E}_{q_{t}\left(z_{t} \mid z_{t-1}\right) q_{t-1}\left(z_{t-1}\right)}\left[\log \frac{f_{t}\left(z_{t}\right) r_{t}\left(z_{t-1} \mid z_{t}\right)}{q\left(z_{t} \mid z_{t-1}\right) q\left(z_{t-1}\right)}\right]+(1-a) \mathcal{L}(x)\).


5. Conclusion

Density that can be represented is

  • not only limited by the family of approximate distn

  • but also “optimization process”

    \(\rightarrow\) resolve this by incorporating annealed objectives into the training of hierarchical variational methods

Categories:

Updated: