Using Large Ensembles of Control Variates for Variational Inference (NeurIPS 2018)
VI : used with “stochastic optimization”
- “gradient’s variance” is important ( high \(\rightarrow\) poor convergence )
- use **large number of control variates **!
- present a Bayesian risk minimization framework
1. Introduction
VI = “approximate probabilistic inference”
nowadays, address a wider range of range of problems, by adopting black box view, based on
only evaluating the value or gradient of the target distribution
( optimized via stochastic gradient descent (SGD) )
use CVs ( control variates ) to reduce variance of gradient estimates
Proposes “how to use many CVs”
- 1) propose a systematic view of “how to generate many existing CVs”
- 2) propose an algorithm to use “multiple CVs simultaneously”
2. Preliminaries
2-1. ELBO
VI = transform inference problem into “optimization”, by…
decomposing the “marginal likelihood” as below :
\(\log p(x)=\underbrace{\underset{Z \sim q_{w}(Z)}{\mathbb{E}}\left[\log \frac{p(Z, x)}{q_{w}(Z)}\right]}_{\operatorname{ELBO}(w)}+\underbrace{\operatorname{KL}\left(q_{w}(Z) \mid \mid p(Z \mid x)\right)}_{\mathrm{KL}-\text { divergence }}\).
(target) maximize ELBO
target’s gradient :
\(g(w)=\nabla_{w} \operatorname{ELBO}(w)=\nabla_{w} \underset{Z \sim q_{w}(Z)}{\mathbb{E}}\left[\log p(Z, x)-\log q_{w}(Z)\right]\).
2-2. Control Variates
CV = “r.v with expectation 0”
ex) let \(C\) be CV
let \(Y=X+aC\) then…
- same mean ) \(E[Y] = E[X + aC] = E[X]\) has same ex
- smaller variance ) \(\operatorname{Var}(Y)=\operatorname{Var}(X)\left(1-\operatorname{Corr}(X, C)^{2}\right)\).
- optimal : \(a=\operatorname{Cov}(X, C) / \operatorname{Var}(C)\).
- good CV for \(X\) = \(C\) that is highly correlated with \(X\)
3. Systematic Generation of Control Variates
recipe for creating control variates!
split ELBO gradient into 4 terms :
$$g(w)=\underbrace{\nabla_{w} \underset{q_{w}}{\mathbb{E}} \log p(x \mid Z)}{g{1}(w): \text { Data term }}+\underbrace{\nabla_{w} \underset{q_{w}}{\mathbb{E}} \log p(Z)}{g{2}(w): \text { Prior term }}-\underbrace{\left.\nabla_{w} \underset{q_{w}}{\mathbb{E}} \log q_{v}(Z)\right | {v=w}}{g_{3}(w): \text { Variational term }}-\underbrace{\left.\nabla_{w} \underset{q_{v}}{\mathbb{E}} \log q_{w}(Z)\right | {v=w}}{g_{4}(w): \text { Score term }}$$. |
- [ term 1~3 ] influence of \(w\) on the expectation of some function “independent of \(w\)“
- section 3-1 & 3-2
- [ term 4 ] score term ( function “inside” the expectation depends on \(w\) )
- section 3-3
3-1. CV from pairs of estimators
technique for deriving CVs = take the difference between a pair of unbiased estimators of a general term \(t(w)\)( = \(g_1, g_2, g_3\), or combination of them ) ….. have expectation zero
general term : \(\begin{aligned} &t(w)=\nabla_{w} \underset{q_{w}(Z)}{\mathbb{E}}[f(Z)]\quad \text { or } \quad \left.\left(\nabla_{w} \underset{q_{w}(Z)}{\mathbb{E}}\left[f_{v}(Z)\right]\right)\right|_{v=w} \end{aligned}\)
example) 5 types of estimator for \(t(w)\)
3-2. CV from approximations
in 3-1… use the fact that “difference of 2 unbiased estimators has expectation zero”
in 3-2… use the fact that “if general term \(t(w)\) is replaced with APPROXIMATION, difference between 2 estimators of the general term still produces a valid CV”
Randomness in the estimators above : due to….
- 1) “distributional sampling error”
- expectations over \(q_w\) are approximated by sampling
- 2) “data subsampling error”
- by drawing a mini-batch
1) Correcting for distributional sampling
- goal : approximate \(f\) with \(\tilde{f}\) to make \(\mathbb{E}[\tilde{f}(Z)]\) easier to estimate
- ex) Paisley et al : approximate the data term with “Taylor approximation in \(z\)“
2) Correcting for data subsampling
random subsets of data
ex) Wang et al : propose to approximate \(f_d(z)\) with “Taylor expansion in \(x\)“
thus leading to an approximate data term, \(\tilde{g}_{1}(z)=\nabla_{w} \mathbb{E}_{q_{w}} \mathbb{E}_{D} \tilde{f}_{D}(z)\)
3-3. CV from the score term \(g_4\)
score term = always zero! ( \(g_4(w)=0\) )….thus no need to be estimated
naive control variate : \(T_4 = \nabla_w \log q_w(Z)\)
4. Combining multiple CVs
to use CV, need to define…
- a base gradient estimator \(h(w) \in R^D\)
- set of CVs \(\{c_1,...,c_L\}\) where \(c_i \in R^D\)
Multiply each \(c_i\) with weight \(a_i\)… to get the estimator :
\(\hat{g}(w)=h(w)+\sum_{i=1}^{L} a_{i} c_{i}(w)\).
\[\hat{g}(w)=h(w)+C(w) a\]
- \[a \in \mathbb{R}^{L}\]
- \(C \in \mathbb{R}^{D \times L}\) as the matrix with \(c_{i}\) as the i-th column
Goal : find \(a\) such that final gradient has low variance!
[Lemma 4.1]
Let \(h(w) \in \mathbb{R}^{D}\) be a random variable, \(C(w) \in \mathbb{R}^{L \times D}\) a matrix of random variables such that each element has mean zero. For \(a \in \mathbb{R}^{L}\), define \(\hat{g}(w)=h(w)+C(w)\) a.
\(\rightarrow\) The value of a that minimizes \(\mathbb{E}\mid \mid\hat{g}(w)\mid \mid^{2}\) for a given \(w\) is
$a^{*}(w)=-\underset{p(C, h \mid w)}{\mathbb{E}}\left[C^{T} C\right]^{-1} \mathbb{E}\left[C^{T} h\right]$.
- requires the \(\mathbb{E}\left[C^{T} C\right]\) and \(\mathbb{E}\left[C^{T} h\right]\) ( usually not available in closed form )
- solution ) give some observed gradients \(h_1,...,h_M\) & CVs \(C_1,...,C_M\) to estimate \(a^{*}\)
4-1. Bayesian Regularization
“risk minimization” perspective
[Loss Function] for selecting the vector of weights \(a\), when true parameter is \(\theta\) :
- loss function : \(L(a, \theta)=\underset{C, h \mid \theta}{\mathbb{E}}\mid \mid h+C a\mid \mid^{2}\)
- decision rule : \(\alpha\left(C_{1}, h_{1}, \ldots, C_{M}, h_{M}\right)\)
- which takes as input as “mini-batch” of \(M\) evaluations of \(h\) and \(C\)
Then, for a pre-specified probabilistic model \(p(C, h, \theta)\)….
- Bayesian regret : \(\text { BayesRegret }(\alpha)=\underset{\theta}{\mathbb{E}} \underset{C_{1}, h_{1}, \ldots, C_{M}, h_{M} \mid \theta}{\mathbb{E}}\left[L\left(\alpha\left(C_{1}, h_{1}, \ldots, C_{M}, h_{M}\right), \theta\right)\right]\).
5. Conclusion
Focus on how to obtain low variance gradients given a fixed set of control variates!
Combination algorithm to use multiple control variates