Discretely Relaxing Continuous Variables for tractable Variational Inference (NeurIPS 2018)


Variational Inference with DISCRETE latent variable priors

propose “DIRECT” … advantage?

( DIRECT = DIscrete RElaxation of ConTinuous variables )

  • 1) exactly compute ELBO gradients

  • 2) training complexity is independent of number of training points

    ( permitting inference on large datasets )

  • 3) fast inference on hardware limited devices

1. Introduction

Hardware restrictions!

solve this problem of efficient Bayesian Inference by considering DISCRETE latent variable models

  • posterior samples will be “quantized” \(\rightarrow\) leading to efficient inference

  • (generally) model with discrete prior : slow ( \(\because\) requiring the use of HIGH variance MC gradient estimates )

    \(\rightarrow\) DIRECTrapidly learn the variational distn without the use of any stochastic estimators

Compared with (continuous / discrete lante variable) SVI, much better!

Using discretized prior, can make use of Kronecker matrix algebra, for efficient & exact ELBO computation

Overall summary

  • section 2) VI
  • section 3) DIRECT
  • section 4) limitations of proposed approach

2. Variational Inference Background

ELBO with continuous/discrete prior :

  • (continuous) \(\mathrm{ELBO}(\boldsymbol{\theta})=\int q_{\boldsymbol{\theta}}(\mathbf{w})\left(\log \operatorname{Pr}(\mathbf{y} \mid \mathbf{w})+\log \operatorname{Pr}(\mathbf{w})-\log q_{\boldsymbol{\theta}}(\mathbf{w})\right) d \mathbf{w}\)
  • (discrete) \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\)
    • \(\log \ell=\left\{\log \operatorname{Pr}\left(\mathbf{y} \mid \mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
    • \(\log \mathbf{p}=\left\{\log \operatorname{Pr}\left(\mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
    • \(\mathbf{q}=\left\{q_{\boldsymbol{\theta}}\left(\mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
    • \(\left\{\mathbf{w}_{i}\right\}_{i=1}^{m}=\mathbf{W} \in \mathbb{R}^{b \times m}\).

Computing ELBO is challenging, when \(b\) is large!

ELBO is not explicitly computed…. instead MC estimate of gradient of ELBO w.r.t variational param \(\theta\)

Found that **discretely relaxing continuous latent variable priors can improve training and inference performance when using our proposed DIRECT technique which computes the ELBO( & its gradient ) directly **

Since discrete…

  • reparameterixation trick (X)

  • REINFORCE (O) … but higher variance

    \(\rightarrow\) proposed DIRECT trains much faster!

3. DIRECT : Efficient ELBO Computations with Kronecker Matrix Algebra

DIRECT : allows to efficiently & exactly compute ELBO

  • several advantages over existing SVI techniques

  • consider a discrete prior over our latent variables,

    whose support set \(\mathbf{W}\) forms a Cartesian tensor product grid

\(\mathbf{W}=\left(\begin{array}{ccccccc} \overline{\mathbf{w}}_{1}^{T} & \otimes & \mathbf{1}_{\bar{m}}^{T} & \otimes & \cdots & \otimes & \mathbf{1}_{\bar{m}}^{T} \\ \mathbf{1}_{\bar{m}}^{T} & \otimes & \overline{\mathbf{w}}_{2}^{T} & \otimes & \cdots & \otimes & \mathbf{1}_{\bar{m}}^{T} \\ \vdots & & \vdots & & \ddots & & \vdots \\ \mathbf{1}_{\bar{m}}^{T} & \otimes & \mathbf{1}_{\bar{m}}^{T} & \otimes & \cdots & \otimes & \overline{\mathbf{w}}_{b}^{T} \end{array}\right),\).

  • \(1_{\bar{m}} \in \mathbb{R}^{\bar{m}}\) denotes a vector of ones
  • \(\overline{\mathrm{w}}_{i} \in \mathbb{R}^{\bar{m}}\) contains the \(\bar{m}\) discrete values that the \(i\) th latent variable \(w_{i}\) can take
  • \[m=\bar{m}^{b}\]
  • \(\otimes\) denotes the Kronecker product

number of columns of \(\mathbf{W} \in \mathbb{R}^{b \times \bar{m}^{b}}\) increases exponentially with respect to \(b\) ….intractable for large \(b\)

Can alleviate this, if \(\mathbf{q}\), \(\log \mathbf{p}, \log \ell\), and \(\log \mathbf{q}\) can be written as a **sum of Kronecker product vectors ** (i.e. \(\sum_{i} \otimes_{j=1}^{b} \mathbf{f}_{j}^{(i)}\))

Computation of ELBO : \(\mathcal{O}\left(\bar{m}^{b}\right) \rightarrow \mathcal{O}(b \bar{m})\)

So, how to express \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\) with Kronecker product?

  • (prior) \(\mathbf{p}=\otimes_{i=1}^{b} \mathbf{p}_{i}\), where \(\mathbf{p}_{i}=\left\{\operatorname{Pr}\left(w_{i}=\bar{w}_{i j}\right)\right\}_{j=1}^{\bar{m}} \in(0,1)^{\bar{m}}\)

    \(\rightarrow\) this structure for \(\mathbf{p}\) enables \(\log \mathbf{p}\) to be written as a sum of \(b\) Kronecker product vectors.

Rewrite ELBO :

  • (before) \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\)

  • (after) \(\mathrm{ELBO}(\boldsymbol{\theta})=\mathbf{q}^{T} \log \ell+\sum_{i=1}^{b} \mathbf{q}_{i}^{T} \log \mathbf{p}_{i}-\sum_{i=1}^{b} \mathbf{q}_{i}^{T} \log \mathbf{q}_{i}\).

    • \(\mathbf{q_i}\) : valid pdf for the \(i\)th latent variable

      ( such that \(\mathbf{q}_{i}^{T} \mathbf{1}_{m}=1\) )

    • \(\log \ell\) : depends on the probabilistic model used

3-1. Generalized Linear Regression

( focus on popular class of Bayesian GLM )

GLM : \(\mathbf{y}=\mathbf{\Phi} \mathbf{w}+\boldsymbol{\epsilon}\)

  • where \(\boldsymbol{\epsilon} \sim \mathcal{N}\left(\mathbf{0}, \sigma^{2} \mathbf{I}\right)\)
  • and \(\boldsymbol{\Phi}=\left\{\phi_{j}\left(\mathbf{x}_{i}\right)\right\}_{i, j} \in \mathbb{R}^{n \times b}\)

using above.. ELBO :

\(\begin{array}{r} E L B O(\theta)=-\frac{n}{2} \mathbf{q}_{\sigma}^{T} \log \sigma^{2}-\frac{1}{2}\left(\mathbf{q}_{\sigma}^{T} \sigma^{-2}\right)\left(\mathbf{y}^{T} \mathbf{y}-2 \mathbf{s}^{T}\left(\mathbf{\Phi}^{T} \mathbf{y}\right)+\mathbf{s}^{T} \mathbf{\Phi}^{T} \mathbf{\Phi} \mathbf{s}-\operatorname{diag}\left(\mathbf{\Phi}^{T} \mathbf{\Phi}\right)^{T} \mathbf{s}^{2}+\right. \\ \left.\sum_{j=1}^{b} \mathbf{q}_{j}^{T} \mathbf{h}_{j}\right)+\sum_{i=1}^{b}\left(\mathbf{q}_{i}^{T} \log \mathbf{p}_{i}-\mathbf{q}_{i}^{T} \log \mathbf{q}_{i}\right)+\mathbf{q}_{\sigma}^{T} \log \mathbf{p}_{\sigma}-\mathbf{q}_{\sigma}^{T} \log \mathbf{q}_{\sigma} \end{array}\).

  • \(\mathrm{q}_{\sigma}, \mathrm{p}_{\sigma} \in \mathbb{R}^{\bar{m}}\) : factorized variational and prior dist over Gaussian noise variance \(\sigma^{2}\)
  • discrete positive values \(\sigma^{2} \in \mathbb{R}^{m}\)
  • \(\mathbf{H}=\left\{\overline{\mathbf{w}}_{j}^{2} \sum_{i=1}^{n} \phi_{i j}^{2}\right\}_{j=1}^{b} \in \mathbb{R}^{m \times b}\),
  • \(\mathbf{s}=\left\{\mathbf{q}_{j}^{T} \overline{\mathbf{w}}_{j}\right\}_{j=1}^{b} \in \mathbb{R}^{b}\).

Complexity : \(\mathcal{O}\left(b \bar{m}+b^{2}\right) .\)

\(\rightarrow\) independent of the number of training points ( scalability )

Predictive Posterior Computations

(in general) found by sampling from variational distn & run model forward

however, DIRECT uses Kronecker matrix algebra, to efficiently compute these moments!

ex) GLM model

  • exact predictive posterior mean : \(\mathbb{E}\left(y_{*}\right)=\sum_{i=1}^{m} q\left(\mathbf{w}_{i}\right) \int y_{*} \operatorname{Pr}\left(y_{*} \mid \mathbf{w}_{i}\right) d y_{*},=\Phi_{*} \mathbf{W} \mathbf{q}=\Phi_{*} \mathbf{s}\)
    • \(\mathbf{s}=\left\{\mathbf{q}_{j}^{T} \overline{\mathbf{w}}_{j}\right\}_{j=1}^{b} \in \mathbb{R}^{b},\).
    • \(\Phi_{*} \in \mathbb{R}^{1 \times b}\) contains the basis functions evaluated at \(x_{*}\)
  • requires just \(\mathcal{O}(b)\) time per test point

3-2. Deep Neural Networks for Regression

Hierarchical model structure for Bayesian DNN for regression

using DIRECT approach!

would like non-linear activation that maintains a compact representations of log-likelihood evaluated at every point

( that is, \(\log l\) to be represented as a sum of as few Kronecker product vectors as possible )

  • ex) quadratic activation function ( \(f(x) = x^2\) )

ELBO can be exactly computed in \(\mathcal{O}\left(\ell \bar{m}(b / \ell)^{4 \ell}\right)\) for Bayesian DNN with \(l\) layers

This complexity evidently enables scalable Bayesian Inference

4. Limitations & Extensions

Other models ( except GLMs ) may not admit this structure!

