Discretely Relaxing Continuous Variables for tractable Variational Inference (NeurIPS 2018)
Abstract
Variational Inference with DISCRETE latent variable priors
propose “DIRECT” … advantage?
( DIRECT = DIscrete RElaxation of ConTinuous variables )
-
1) exactly compute ELBO gradients
-
2) training complexity is independent of number of training points
( permitting inference on large datasets )
-
3) fast inference on hardware limited devices
1. Introduction
Hardware restrictions!
solve this problem of efficient Bayesian Inference by considering DISCRETE latent variable models
-
posterior samples will be “quantized” \(\rightarrow\) leading to efficient inference
-
(generally) model with discrete prior : slow ( \(\because\) requiring the use of HIGH variance MC gradient estimates )
\(\rightarrow\) DIRECTrapidly learn the variational distn without the use of any stochastic estimators
Compared with (continuous / discrete lante variable) SVI, much better!
Using discretized prior, can make use of Kronecker matrix algebra, for efficient & exact ELBO computation
Overall summary
- section 2) VI
- section 3) DIRECT
- section 4) limitations of proposed approach
2. Variational Inference Background
ELBO with continuous/discrete prior :
- (continuous) \(\mathrm{ELBO}(\boldsymbol{\theta})=\int q_{\boldsymbol{\theta}}(\mathbf{w})\left(\log \operatorname{Pr}(\mathbf{y} \mid \mathbf{w})+\log \operatorname{Pr}(\mathbf{w})-\log q_{\boldsymbol{\theta}}(\mathbf{w})\right) d \mathbf{w}\)
- (discrete) \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\)
- \(\log \ell=\left\{\log \operatorname{Pr}\left(\mathbf{y} \mid \mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
- \(\log \mathbf{p}=\left\{\log \operatorname{Pr}\left(\mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
- \(\mathbf{q}=\left\{q_{\boldsymbol{\theta}}\left(\mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
- \(\left\{\mathbf{w}_{i}\right\}_{i=1}^{m}=\mathbf{W} \in \mathbb{R}^{b \times m}\).
Computing ELBO is challenging, when \(b\) is large!
ELBO is not explicitly computed…. instead MC estimate of gradient of ELBO w.r.t variational param \(\theta\)
Found that **discretely relaxing continuous latent variable priors can improve training and inference performance when using our proposed DIRECT technique which computes the ELBO( & its gradient ) directly **
Since discrete…
-
reparameterixation trick (X)
-
REINFORCE (O) … but higher variance
\(\rightarrow\) proposed DIRECT trains much faster!
3. DIRECT : Efficient ELBO Computations with Kronecker Matrix Algebra
DIRECT : allows to efficiently & exactly compute ELBO
-
several advantages over existing SVI techniques
-
consider a discrete prior over our latent variables,
whose support set \(\mathbf{W}\) forms a Cartesian tensor product grid
\(\mathbf{W}=\left(\begin{array}{ccccccc} \overline{\mathbf{w}}_{1}^{T} & \otimes & \mathbf{1}_{\bar{m}}^{T} & \otimes & \cdots & \otimes & \mathbf{1}_{\bar{m}}^{T} \\ \mathbf{1}_{\bar{m}}^{T} & \otimes & \overline{\mathbf{w}}_{2}^{T} & \otimes & \cdots & \otimes & \mathbf{1}_{\bar{m}}^{T} \\ \vdots & & \vdots & & \ddots & & \vdots \\ \mathbf{1}_{\bar{m}}^{T} & \otimes & \mathbf{1}_{\bar{m}}^{T} & \otimes & \cdots & \otimes & \overline{\mathbf{w}}_{b}^{T} \end{array}\right),\).
- \(1_{\bar{m}} \in \mathbb{R}^{\bar{m}}\) denotes a vector of ones
- \(\overline{\mathrm{w}}_{i} \in \mathbb{R}^{\bar{m}}\) contains the \(\bar{m}\) discrete values that the \(i\) th latent variable \(w_{i}\) can take
- \[m=\bar{m}^{b}\]
- \(\otimes\) denotes the Kronecker product
number of columns of \(\mathbf{W} \in \mathbb{R}^{b \times \bar{m}^{b}}\) increases exponentially with respect to \(b\) ….intractable for large \(b\)
Can alleviate this, if \(\mathbf{q}\), \(\log \mathbf{p}, \log \ell\), and \(\log \mathbf{q}\) can be written as a **sum of Kronecker product vectors ** (i.e. \(\sum_{i} \otimes_{j=1}^{b} \mathbf{f}_{j}^{(i)}\))
Computation of ELBO : \(\mathcal{O}\left(\bar{m}^{b}\right) \rightarrow \mathcal{O}(b \bar{m})\)
So, how to express \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\) with Kronecker product?
-
(prior) \(\mathbf{p}=\otimes_{i=1}^{b} \mathbf{p}_{i}\), where \(\mathbf{p}_{i}=\left\{\operatorname{Pr}\left(w_{i}=\bar{w}_{i j}\right)\right\}_{j=1}^{\bar{m}} \in(0,1)^{\bar{m}}\)
\(\rightarrow\) this structure for \(\mathbf{p}\) enables \(\log \mathbf{p}\) to be written as a sum of \(b\) Kronecker product vectors.
Rewrite ELBO :
-
(before) \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\)
-
(after) \(\mathrm{ELBO}(\boldsymbol{\theta})=\mathbf{q}^{T} \log \ell+\sum_{i=1}^{b} \mathbf{q}_{i}^{T} \log \mathbf{p}_{i}-\sum_{i=1}^{b} \mathbf{q}_{i}^{T} \log \mathbf{q}_{i}\).
-
\(\mathbf{q_i}\) : valid pdf for the \(i\)th latent variable
( such that \(\mathbf{q}_{i}^{T} \mathbf{1}_{m}=1\) )
-
\(\log \ell\) : depends on the probabilistic model used
-
3-1. Generalized Linear Regression
( focus on popular class of Bayesian GLM )
GLM : \(\mathbf{y}=\mathbf{\Phi} \mathbf{w}+\boldsymbol{\epsilon}\)
- where \(\boldsymbol{\epsilon} \sim \mathcal{N}\left(\mathbf{0}, \sigma^{2} \mathbf{I}\right)\)
- and \(\boldsymbol{\Phi}=\left\{\phi_{j}\left(\mathbf{x}_{i}\right)\right\}_{i, j} \in \mathbb{R}^{n \times b}\)
using above.. ELBO :
\(\begin{array}{r} E L B O(\theta)=-\frac{n}{2} \mathbf{q}_{\sigma}^{T} \log \sigma^{2}-\frac{1}{2}\left(\mathbf{q}_{\sigma}^{T} \sigma^{-2}\right)\left(\mathbf{y}^{T} \mathbf{y}-2 \mathbf{s}^{T}\left(\mathbf{\Phi}^{T} \mathbf{y}\right)+\mathbf{s}^{T} \mathbf{\Phi}^{T} \mathbf{\Phi} \mathbf{s}-\operatorname{diag}\left(\mathbf{\Phi}^{T} \mathbf{\Phi}\right)^{T} \mathbf{s}^{2}+\right. \\ \left.\sum_{j=1}^{b} \mathbf{q}_{j}^{T} \mathbf{h}_{j}\right)+\sum_{i=1}^{b}\left(\mathbf{q}_{i}^{T} \log \mathbf{p}_{i}-\mathbf{q}_{i}^{T} \log \mathbf{q}_{i}\right)+\mathbf{q}_{\sigma}^{T} \log \mathbf{p}_{\sigma}-\mathbf{q}_{\sigma}^{T} \log \mathbf{q}_{\sigma} \end{array}\).
- \(\mathrm{q}_{\sigma}, \mathrm{p}_{\sigma} \in \mathbb{R}^{\bar{m}}\) : factorized variational and prior dist over Gaussian noise variance \(\sigma^{2}\)
- discrete positive values \(\sigma^{2} \in \mathbb{R}^{m}\)
- \(\mathbf{H}=\left\{\overline{\mathbf{w}}_{j}^{2} \sum_{i=1}^{n} \phi_{i j}^{2}\right\}_{j=1}^{b} \in \mathbb{R}^{m \times b}\),
- \(\mathbf{s}=\left\{\mathbf{q}_{j}^{T} \overline{\mathbf{w}}_{j}\right\}_{j=1}^{b} \in \mathbb{R}^{b}\).
Complexity : \(\mathcal{O}\left(b \bar{m}+b^{2}\right) .\)
\(\rightarrow\) independent of the number of training points ( scalability )
Predictive Posterior Computations
(in general) found by sampling from variational distn & run model forward
however, DIRECT uses Kronecker matrix algebra, to efficiently compute these moments!
ex) GLM model
- exact predictive posterior mean : \(\mathbb{E}\left(y_{*}\right)=\sum_{i=1}^{m} q\left(\mathbf{w}_{i}\right) \int y_{*} \operatorname{Pr}\left(y_{*} \mid \mathbf{w}_{i}\right) d y_{*},=\Phi_{*} \mathbf{W} \mathbf{q}=\Phi_{*} \mathbf{s}\)
- \(\mathbf{s}=\left\{\mathbf{q}_{j}^{T} \overline{\mathbf{w}}_{j}\right\}_{j=1}^{b} \in \mathbb{R}^{b},\).
- \(\Phi_{*} \in \mathbb{R}^{1 \times b}\) contains the basis functions evaluated at \(x_{*}\)
- requires just \(\mathcal{O}(b)\) time per test point
3-2. Deep Neural Networks for Regression
Hierarchical model structure for Bayesian DNN for regression
using DIRECT approach!
would like non-linear activation that maintains a compact representations of log-likelihood evaluated at every point
( that is, \(\log l\) to be represented as a sum of as few Kronecker product vectors as possible )
- ex) quadratic activation function ( \(f(x) = x^2\) )
ELBO can be exactly computed in \(\mathcal{O}\left(\ell \bar{m}(b / \ell)^{4 \ell}\right)\) for Bayesian DNN with \(l\) layers
This complexity evidently enables scalable Bayesian Inference
4. Limitations & Extensions
Other models ( except GLMs ) may not admit this structure!