Discretely Relaxing Continuous Variables for tractable Variational Inference (NeurIPS 2018)
Variational Inference with DISCRETE latent variable priors
propose “DIRECT” … advantage?
( DIRECT = DIscrete RElaxation of ConTinuous variables )
1) exactly compute ELBO gradients
2) training complexity is independent of number of training points
( permitting inference on large datasets )
3) fast inference on hardware limited devices
1. Introduction
Hardware restrictions!
solve this problem of efficient Bayesian Inference by considering DISCRETE latent variable models
posterior samples will be “quantized” \(\rightarrow\) leading to efficient inference
(generally) model with discrete prior : slow ( \(\because\) requiring the use of HIGH variance MC gradient estimates )
\(\rightarrow\) DIRECTrapidly learn the variational distn without the use of any stochastic estimators
Compared with (continuous / discrete lante variable) SVI, much better!
Using discretized prior, can make use of Kronecker matrix algebra, for efficient & exact ELBO computation
Overall summary
- section 2) VI
- section 3) DIRECT
- section 4) limitations of proposed approach
2. Variational Inference Background
ELBO with continuous/discrete prior :
- (continuous) \(\mathrm{ELBO}(\boldsymbol{\theta})=\int q_{\boldsymbol{\theta}}(\mathbf{w})\left(\log \operatorname{Pr}(\mathbf{y} \mid \mathbf{w})+\log \operatorname{Pr}(\mathbf{w})-\log q_{\boldsymbol{\theta}}(\mathbf{w})\right) d \mathbf{w}\)
- (discrete) \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\)
- \(\log \ell=\left\{\log \operatorname{Pr}\left(\mathbf{y} \mid \mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
- \(\log \mathbf{p}=\left\{\log \operatorname{Pr}\left(\mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
- \(\mathbf{q}=\left\{q_{\boldsymbol{\theta}}\left(\mathbf{w}_{i}\right)\right\}_{i=1}^{m}\).
- \(\left\{\mathbf{w}_{i}\right\}_{i=1}^{m}=\mathbf{W} \in \mathbb{R}^{b \times m}\).
Computing ELBO is challenging, when \(b\) is large!
ELBO is not explicitly computed…. instead MC estimate of gradient of ELBO w.r.t variational param \(\theta\)
Found that **discretely relaxing continuous latent variable priors can improve training and inference performance when using our proposed DIRECT technique which computes the ELBO( & its gradient ) directly **
Since discrete…
reparameterixation trick (X)
REINFORCE (O) … but higher variance
\(\rightarrow\) proposed DIRECT trains much faster!
3. DIRECT : Efficient ELBO Computations with Kronecker Matrix Algebra
DIRECT : allows to efficiently & exactly compute ELBO
several advantages over existing SVI techniques
consider a discrete prior over our latent variables,
whose support set \(\mathbf{W}\) forms a Cartesian tensor product grid
\(\mathbf{W}=\left(\begin{array}{ccccccc} \overline{\mathbf{w}}_{1}^{T} & \otimes & \mathbf{1}_{\bar{m}}^{T} & \otimes & \cdots & \otimes & \mathbf{1}_{\bar{m}}^{T} \\ \mathbf{1}_{\bar{m}}^{T} & \otimes & \overline{\mathbf{w}}_{2}^{T} & \otimes & \cdots & \otimes & \mathbf{1}_{\bar{m}}^{T} \\ \vdots & & \vdots & & \ddots & & \vdots \\ \mathbf{1}_{\bar{m}}^{T} & \otimes & \mathbf{1}_{\bar{m}}^{T} & \otimes & \cdots & \otimes & \overline{\mathbf{w}}_{b}^{T} \end{array}\right),\).
- \(1_{\bar{m}} \in \mathbb{R}^{\bar{m}}\) denotes a vector of ones
- \(\overline{\mathrm{w}}_{i} \in \mathbb{R}^{\bar{m}}\) contains the \(\bar{m}\) discrete values that the \(i\) th latent variable \(w_{i}\) can take
- \[m=\bar{m}^{b}\]
- \(\otimes\) denotes the Kronecker product
number of columns of \(\mathbf{W} \in \mathbb{R}^{b \times \bar{m}^{b}}\) increases exponentially with respect to \(b\) ….intractable for large \(b\)
Can alleviate this, if \(\mathbf{q}\), \(\log \mathbf{p}, \log \ell\), and \(\log \mathbf{q}\) can be written as a **sum of Kronecker product vectors ** (i.e. \(\sum_{i} \otimes_{j=1}^{b} \mathbf{f}_{j}^{(i)}\))
Computation of ELBO : \(\mathcal{O}\left(\bar{m}^{b}\right) \rightarrow \mathcal{O}(b \bar{m})\)
So, how to express \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\) with Kronecker product?
(prior) \(\mathbf{p}=\otimes_{i=1}^{b} \mathbf{p}_{i}\), where \(\mathbf{p}_{i}=\left\{\operatorname{Pr}\left(w_{i}=\bar{w}_{i j}\right)\right\}_{j=1}^{\bar{m}} \in(0,1)^{\bar{m}}\)
\(\rightarrow\) this structure for \(\mathbf{p}\) enables \(\log \mathbf{p}\) to be written as a sum of \(b\) Kronecker product vectors.
Rewrite ELBO :
(before) \(\mathrm{ELBO}(\theta)=\mathrm{q}^{T}(\log \ell+\log \mathrm{p}-\log \mathrm{q})\)
(after) \(\mathrm{ELBO}(\boldsymbol{\theta})=\mathbf{q}^{T} \log \ell+\sum_{i=1}^{b} \mathbf{q}_{i}^{T} \log \mathbf{p}_{i}-\sum_{i=1}^{b} \mathbf{q}_{i}^{T} \log \mathbf{q}_{i}\).
\(\mathbf{q_i}\) : valid pdf for the \(i\)th latent variable
( such that \(\mathbf{q}_{i}^{T} \mathbf{1}_{m}=1\) )
\(\log \ell\) : depends on the probabilistic model used
3-1. Generalized Linear Regression
( focus on popular class of Bayesian GLM )
GLM : \(\mathbf{y}=\mathbf{\Phi} \mathbf{w}+\boldsymbol{\epsilon}\)
- where \(\boldsymbol{\epsilon} \sim \mathcal{N}\left(\mathbf{0}, \sigma^{2} \mathbf{I}\right)\)
- and \(\boldsymbol{\Phi}=\left\{\phi_{j}\left(\mathbf{x}_{i}\right)\right\}_{i, j} \in \mathbb{R}^{n \times b}\)
using above.. ELBO :
\(\begin{array}{r} E L B O(\theta)=-\frac{n}{2} \mathbf{q}_{\sigma}^{T} \log \sigma^{2}-\frac{1}{2}\left(\mathbf{q}_{\sigma}^{T} \sigma^{-2}\right)\left(\mathbf{y}^{T} \mathbf{y}-2 \mathbf{s}^{T}\left(\mathbf{\Phi}^{T} \mathbf{y}\right)+\mathbf{s}^{T} \mathbf{\Phi}^{T} \mathbf{\Phi} \mathbf{s}-\operatorname{diag}\left(\mathbf{\Phi}^{T} \mathbf{\Phi}\right)^{T} \mathbf{s}^{2}+\right. \\ \left.\sum_{j=1}^{b} \mathbf{q}_{j}^{T} \mathbf{h}_{j}\right)+\sum_{i=1}^{b}\left(\mathbf{q}_{i}^{T} \log \mathbf{p}_{i}-\mathbf{q}_{i}^{T} \log \mathbf{q}_{i}\right)+\mathbf{q}_{\sigma}^{T} \log \mathbf{p}_{\sigma}-\mathbf{q}_{\sigma}^{T} \log \mathbf{q}_{\sigma} \end{array}\).
- \(\mathrm{q}_{\sigma}, \mathrm{p}_{\sigma} \in \mathbb{R}^{\bar{m}}\) : factorized variational and prior dist over Gaussian noise variance \(\sigma^{2}\)
- discrete positive values \(\sigma^{2} \in \mathbb{R}^{m}\)
- \(\mathbf{H}=\left\{\overline{\mathbf{w}}_{j}^{2} \sum_{i=1}^{n} \phi_{i j}^{2}\right\}_{j=1}^{b} \in \mathbb{R}^{m \times b}\),
- \(\mathbf{s}=\left\{\mathbf{q}_{j}^{T} \overline{\mathbf{w}}_{j}\right\}_{j=1}^{b} \in \mathbb{R}^{b}\).
Complexity : \(\mathcal{O}\left(b \bar{m}+b^{2}\right) .\)
\(\rightarrow\) independent of the number of training points ( scalability )
Predictive Posterior Computations
(in general) found by sampling from variational distn & run model forward
however, DIRECT uses Kronecker matrix algebra, to efficiently compute these moments!
ex) GLM model
- exact predictive posterior mean : \(\mathbb{E}\left(y_{*}\right)=\sum_{i=1}^{m} q\left(\mathbf{w}_{i}\right) \int y_{*} \operatorname{Pr}\left(y_{*} \mid \mathbf{w}_{i}\right) d y_{*},=\Phi_{*} \mathbf{W} \mathbf{q}=\Phi_{*} \mathbf{s}\)
- \(\mathbf{s}=\left\{\mathbf{q}_{j}^{T} \overline{\mathbf{w}}_{j}\right\}_{j=1}^{b} \in \mathbb{R}^{b},\).
- \(\Phi_{*} \in \mathbb{R}^{1 \times b}\) contains the basis functions evaluated at \(x_{*}\)
- requires just \(\mathcal{O}(b)\) time per test point
3-2. Deep Neural Networks for Regression
Hierarchical model structure for Bayesian DNN for regression
using DIRECT approach!
would like non-linear activation that maintains a compact representations of log-likelihood evaluated at every point
( that is, \(\log l\) to be represented as a sum of as few Kronecker product vectors as possible )
- ex) quadratic activation function ( \(f(x) = x^2\) )
ELBO can be exactly computed in \(\mathcal{O}\left(\ell \bar{m}(b / \ell)^{4 \ell}\right)\) for Bayesian DNN with \(l\) layers
This complexity evidently enables scalable Bayesian Inference
4. Limitations & Extensions
Other models ( except GLMs ) may not admit this structure!