Structured Stochastic Variational Inference ( 2014 )
Abstract
SVI(Stochastic Variational Inference)는 large dataset의 경우에 적합한 posterior를 근사하는 알고리즘이다. 주로 Mean Field 가정을 통해, parameter들 간의 independence를 가정하여 factorization을 한다.
이 논문은, 이러한 mean field approximation 가정을 완화하여, parameter들 간의 dependency를 고려한 알고리즘을 제안한다. 이를 SSVI (Structured Stochastic Variational Inference)라고 한다.
1. Introduction
VI와 MCMC는 approximation inference algorithm의 대표적인 두 가지 방법이다. (각각의 장단점에 대한 설명은 이전 포스트들을 참고하길 바란다)
그 중에서, 이 논문은 VI를 사용하여, multimodal posterior를 가진 고차원의 베이지안 모델 (ex. mixture model, topic model, factor model…)의 문제를 풀고자 한다.
앞서 말했듯, Mean-field는 parameter들 간의 independency를 가정하여 factorized distribution으로 표현한다. 이러한 제약은, tractability 를 가져오는 대신에, approximation이 아주 잘 근사하지는 못한다는 점과, local optimum에 빠질 수 있다는 두가지 문제를 가지고 있다.
이를 해결하고자, parameter들 간의 dependency를 부여하여 위 가정을 완화시키는 방법이 있는데, 이것을 “structured” mean-field approximation이라고 한다. 이는 intractable한 ELBO를 만들지만, 최근에 나온 stochastic optimization등의 방법을 통해 이 또한 풀 수 있다.
Stochastic Variational Inference (SVI)는 MFVI에 stochastic optimization을 적용한 방법이다. SVI는 hierarchical model에 있는 unobserved variable을, (1) global parameter \(\beta\)와 **(2) local hidden variable \(z_1,...z_N\) **으로 쪼갠다. 그런 뒤, \(q(z, \beta)\)와 \(p(z,\beta \mid y)\)사이ㅣ의 KL Divergence를 minimize하는 방향으로 ( 혹은 ELBO를 maximize하는 방향으로 ) 문제를 풀어나간다.
방금 설명한 SVI도, “mean field approximation”에 의존한다. 즉, random variable들 사이의 dependency를 고려하지 못한다. 따라서 이 논문에서는, SVI framework의 generalization인 SSVI (Structued SVI)를 제안하여, global과 local variable들 사이의 dependency를 부여한다.
- Mean Field approximation : \(\left(\prod_{k} q\left(\beta_{k}\right)\right) \prod_{n, m} q\left(z_{n, m}\right)\)
- Structured Mean Field approximation : \(q(z, \beta)=\left(\prod_{k} q\left(\beta_{k}\right)\right) \prod_{n} q\left(z_{n} \mid \beta\right)\)
2. SSVI
이 파트에서는 2가지 SSVI 알고리즘을 제안할 것이다. 그 이전에, (2-1)SSVI가 적용될 수 있는 모델과, 그것이 사용하는 (2-2) variational distribution에 대해서 review할 것이다.
2-1. Model Assumption.
총 N개의 dataset이 있다고 가정하자. \(y_{1:N}\)
이 경우, probability model은 다음과 같이 factorize될 수 있다.
\(p(y, z, \beta)=p(\beta) \prod_{n} p\left(y_{n}, z_{n} \mid \beta\right)\).
- \(\beta\) : global parameter ( 모든 observation에 대해서 share된다)
- \(z_{1:N}\) : \(\beta\)가 주어졌을 때, conditionally independent하다
우리의 모델에 다음과 같은 가정을 할 것이다. ( conditionally conjugate model )
(1) prior : \(p(\beta)\)
-
tractable exponential family
-
\(p(\beta)=h(\beta) \exp \{\eta \cdot t(\beta)-A(\eta)\}\).
(2) likelihood : \(p\left(y_{n}, z_{n} \mid \beta\right)\)
- \(p\left(y_{n}, z_{n} \mid \beta\right)=\exp \{t(\beta)\left.\eta_{n}\left(y_{n}, z_{n}\right)+g_{n}\left(y_{n}, z_{n}\right)\right\}\).
(3) posterior : \(p(\beta \mid y, z)\)
- \[\begin{array}{c} p(\beta \mid y, z)=h(\beta) \exp \left\{\left(\eta+\sum_{n} \eta_{n}\left(y_{n}, z_{n}\right)\right) \cdot t(\beta)-\right. \left.A\left(\eta+\sum_{n} \eta_{n}\left(y_{n}, z_{n}\right)\right)\right\} \end{array}\]
위의 (1)~(3)의 가정은, Hoffman(2013)이 제안했던 모델보다 더 weak한 가정이다.
이 논문은 \(p(z_n \mid y_n, \beta)\)의 tractability에 대한 가정을 요하지 않는다. 따라서 이 모델은 SVI frameowkr에 fit하는 그 어떠한 모델에도 적용이 가능하다 ( ex. Mixture Model, LDA, HMM, Kalman filter 등 )
2-2. Approximating Distribution
우리의 목표는, true intractable posterior in \(p(z,\beta \mid y)\)를, approximating distribution \(q(z, \beta)\)를 사용해서 잘 근사하는 것이다. 그러기 위해, 이 둘 간의 KL divergence를 minimize한다.
가장 간단한 방법은 mean-field approximation으로, \(q(z,\beta)\)를 아래와 같이 factorize하는 것이다.
\(q(z,\beta)=q(\beta)\prod_n \prod_m q(z_{n,m})\).
간단하다는 장점이 있지만, 다음과 같은 2가지의 단점이 있다.
- 1) less able to closely approximate the true posterior
- 2) may introduce additional local minima into the KL divergence.
Structured mean field는, 위 가정을 완화시켜서 dependence를 부여한다.
Minmno et al (2012)는 LDA를 위해 SSVI를 도입했다. 이에 따르면, \(q(z,\beta)\)는 다음과 같이 factorize된다.
\(q(z, \beta)=\left(\prod_{k} q\left(\beta_{k}\right)\right) \prod_{n} q\left(z_{n}\right)\).
( 여기서 최적의 \(q(z_n)\)은 intractable할 수 있으나, MCMC를 통해 sampling될 수 있다. )
이 논문에서는 \(\beta\)와 \(z\)사이의 dependence를 부여하는 SSVI를 제안한다.
여기서는, \(q(z,\beta)\)를 다음과 같이 factorize한다.
\(q(z, \beta)=\left(\prod_{k} q\left(\beta_{k}\right)\right) \prod_{n} q\left(z_{n} \mid \beta\right)\).
Assumption
-
\(q(\beta)\)이 prior \(p(\beta)\)와 동일한 exponential family를 따른다
( 즉, \(q(\beta)=h(\beta) \exp \{\lambda \cdot t(\beta)-A(\lambda)\}\) )
-
\(q(z_n \mid \beta) = q(z_n \mid \gamma_n(\beta))\) (\(q\)로 하여금 rich dependency를 가지게끔, \(\gamma_n(\cdot)\)을 도입 함)
2-3. The Structured Variational Objective
우리의 목표는 \(q(\beta,z)\)를 찾는 것이다. 그러기 위해 구해야 하는 KL-Divergence와 ELBO는 아래와 같다.
\(\begin{aligned} \mathrm{KL}\left(q_{z, \beta} \| p_{z, \beta \mid y}\right)=&-\mathbb{E}_{q}[\log p(y, z, \beta)]+ \mathbb{E}_{q}[\log q(z, \beta)]+\log p(y) \end{aligned}\).
\(\begin{aligned} \mathcal{L} \equiv & \mathbb{E}_{q}[\log p(y, z, \beta)]-\mathbb{E}_{q}[\log q(z, \beta)] \\ &=\mathbb{E}_{q}\left[\log \frac{p(\beta)}{q(\beta)}\right]+\sum_{n} \mathbb{E}_{q}\left[\log \frac{p\left(y_{n}, z_{n} \mid \beta\right)}{q\left(z_{n} \mid \beta\right)}\right] \\ &=\int_{\beta} q(\beta)\left(\log \frac{p(\beta)}{q(\beta)}+\sum_{n} \int_{z_{n}} q\left(z_{n} \mid \beta\right) \log \frac{p\left(y_{n}, z_{n} \mid \beta\right)}{q\left(z_{n} \mid \beta\right)} d z_{n}\right) d \beta\\ &\leq \log p(y) \end{aligned}\).
위 식에서, \(\int_{z_{n}} q\left(z_{n} \mid \beta\right) \log \frac{p\left(y_{n}, z_{n} \mid \beta\right)}{q\left(z_{n} \mid \beta\right)} d z_{n}\) 는 그 자체로 \(n\)번째 observation group의 marginal probability에 대한 Lower Bound이다.
\[\begin{array}{l} \int_{z_{n}} q\left(z_{n} \mid \beta\right) \log \frac{p\left(y_{n}, z_{n} \mid \beta\right)}{q\left(z_{n} \mid \beta\right)} d z_{n} =-\operatorname{KL}\left(q_{z_{n} \mid \beta}|| p_{z_{n} \mid y_{n}, \beta}\right)+\log p\left(y_{n} \mid \beta\right) \leq \log p\left(y_{n} \mid \beta\right) \end{array}\]따라서, 우리는 \(q\left(z_{n} \mid \beta\right)\)와 \(p(z_n \mid y_n, \beta)\)의 KL divergence를 최소화함으로써, \(q\left(z_{n} \mid \beta\right)\)의 global ELBO를 최대화할 수 있다.
그리고 앞서 도입한 \(\gamma_n(\cdot)\)를, 위의 역할을 하게 끔 정의할 것이다.
\(\nabla_{\gamma_n}\int_{z_{n}} q\left(z_{n} \mid \beta\right) \log \frac{p\left(y_{n}, z_{n} \mid \beta\right)}{q\left(z_{n} \mid \beta\right)} d z_{n}=0\).
2-4. Algorithm
ELBO인 \(\int_{\beta} q(\beta)\left(\log \frac{p(\beta)}{q(\beta)}+\sum_{n} \int_{z_{n}} q\left(z_{n} \mid \beta\right) \log \frac{p\left(y_{n}, z_{n} \mid \beta\right)}{q\left(z_{n} \mid \beta\right)} d z_{n}\right) d \beta\)를 maximize하는 알고리즘은 아래와 같다.
매 iteration마다…
step 1) \(\beta \sim q(\beta)\)
step 2) local ELBO를 maximize하도록 하는 \(\gamma_n(\beta)\)들을 각각 계산한다
-
local ELBO : \(\mathbb{E}_{q}\left[\log p\left(y_{n}, z_{n} \mid \beta\right)-\log q\left(z_{n} \mid \beta\right) \mid \beta\right]\)
( 혹은 $$-\operatorname{KL}\left(q_{z_{n} \mid \beta} p_{z_{n} \mid y_{n}, \beta}\right)+\log p\left(y_{n} \mid \beta\right)$$ )
step 3) unbiased estimate \(\hat{\eta}_{n}\) of \(\mathbb{E}_{q}\left[\eta_{n}\left(y_{n}, z\right) \mid \beta\right]\) 를 계산한다
step 4) \(\lambda\)를 (step size \(\rho\)) update한다.
\(\lambda^{\prime}=(1-\rho) \lambda+\rho\left(\eta+V(\beta, \lambda) \sum_{n} \hat{\eta}_{n}\right)\).
< 뒷 부분은 나중에 >