[Deep Bayes] 02. Variational Inference

해당 내용은 https://deepbayes.ru/ (Deep Bayes 강의)를 듣고 정리한 내용이다.

1. Introduction

이번 포스트에서는, 다음과 같이 세 가지에 대해 다룰 것이다.

(1) Variational Inference란?

(2) Variational lower bound derivation

(3) Variational mean field approximation


2. Variational Inference

변분 추론이라 불리는 Variational Inference에 대해 많이들 들어 봤을 것이다. 이는 데이터 x가 주어졌을 때 모수 \(\theta\)를 추정 ( = \(p(\theta \mid x)\) )하는 방법 중 하나이다.

\(p(\theta \mid x)\)를 추정하는 방법에는 크게 두 가지가 있다.

  • 1 ) Variational Inference ( Optimization 방법으로 문제 해결 )
  • 2 ) MCMC ( Sampling 방법으로 문제 해결 )


[ Variational Inference ]

우선 Variational Inference의 핵심 아이디어는 \(p(\theta \mid x)\)에 근사하는 \(q(\theta)\)를 찾는 것이다.

특징

  • Biased
  • Faster & More Scalable


이에 대비한 MCMC의 특징은 다음과 같다.

[ MCMC (Markov Chain Monte Carlo) ]

unnormalized된 \(p(\theta \mid x)\)로부터 sampling하여 \(\theta\)를 추정

  • Unbiased
  • Need a lot of samples


Solving Problem using Variational Inference

우리는 이 두 방법 중, Variational Inference에 따라 모수를 추정할 것이다.

이 방법의 핵심은 아까 말했다 시피, \(p(\theta \mid x)\)에 근사하는 \(q(\theta)\)를 추정하는 것에 있는데, 다음과 같은 criterion function을 사용하여 문제를 푼다.

\[F(q) := KL(q(\theta) \mid \mid p(\theta \mid x)) \rightarrow \underset{q(\theta)\in Q}{min}\]

KL Divergence에 대해서는 예전의 포스트를 참고하길 바란다. (https://seunghan96.github.io/stat/13.-vi-Variational_Inference_Intro(1)/) 그래도 간단히 설명하자면, KL Divergence는 두 분포 사이의 거리를 측정하는 지표이다. (정확히 말하면 거리(distance)의 조건인 ‘symmetry’를 충족하지는 못한다. 하지만 두 분포가 얼마나 차이나는지를 보여주는 지표이므로 거리라고 표현하겠다 )

그러면 우리는 \(log\;p(x)\)를 다음과 같이 표현할 수 있다.

\[log\; p(x) = L(q(\theta)) + KL(q(\theta) \mid \mid p(\theta \mid x))\]


위 식의 우변의 왼쪽 부분인 \(L(q(\theta))\)를 ELBO (Evidence Lower Bound)라고 부르고, 오른 쪽 부분인 \(KL(q(\theta) \mid\mid p(\theta \mid x))\)가 우리가 우리가 최소화해야하는 KL-Divergence이다.

위 식을 보면, 좌변인 \(log\; p(x)\)는 \(q\)에 의존하지 않는 반면, 우변의 두 부분은 모두 \(q\)에 대한 함수이다. 따라서, 우리가 풀어야 하는 KL-Divergence 최소화는 결국 ELBO를 최대화하는 것과 같다.


ELBO인 \(L(q(\theta))\) 는 다음과 같이 표현할 수 있다 ( 이 부분에 대한 증명도 위의 포스트 링크를 참고하길 바란다 )

\[L(q( \theta) ) = \int q(\theta)log\frac{p(x,\theta)}{q(\theta)} = E_{q(\theta)}log\;p(x\mid \theta) - KL(q(\theta) \mid \mid p(\theta))\]


따라서 결국 우리가 풀어야하는 문제는 위의 \(L(q( \theta) ) = \int q(\theta)log\frac{p(x,\theta)}{q(\theta)}\)를 최대화 하는 것이고,

여기서 우리는 이제 \(q\)에 대한 가정을 할 수 있다. 어떠한 가정을 하냐에 따라 풀 수 있는 방법이 또 나뉜다. ( Mean Field ApproximationParametric Approximation)


3. Mean Field Approximation

개념은 간단하다. \(q(\theta)\)를 여러 distribution으로 factorize하여 다음과 같이 표현하는 것을 의미한다.

\[q(\theta) = \prod_{j=1}^{m}q_j(\theta_j)\]


이 식에는 \(\theta_1,...\theta_m\)이 서로 독립(independent)이라는 가정이 필요하다. 이 가정을 적용하여 위 2번의 ELBO(Evidence Lower Bound)를 정리하면 다음과 같이 표현할 수 있다.

\[\begin{align*} L(q(\theta)) &= E_{q{(\theta)}} \;logp(x,\theta)- E_{q(\theta)}logq(\theta)\\ &=E_{q{(\theta)}} \;logp(x,\theta)- \sum_{k=1}^{m}E_{q_k(\theta_k)}logq_k(\theta_k)\\ &=E_{q_j{(\theta_j)}} [E_{q_{i\neq j}}logp(x,\theta)]- E_{q_j{(\theta_j)}} logq_j(\theta_j) + Const\\ &=\{r_j(\theta_j) = \frac{1}{Z_j}exp(E_{q_{i\neq j}}logp(x,\theta))\} \\ &=E_{q_j(\theta_j)}log\frac{r_j(\theta_j)}{q_j(\theta_j)} + Const \\ &=-KL(q_j(\theta_j)\mid \mid r_j(\theta_j)) + Const \\ \end{align*}\]


따라서, Mean Field Approximation을 적용했을 때 우리는 ELBO를 최대화 하는 것을 다음과 같은 KL-Divergence “ \(KL(q_j(\theta_j)\mid \mid r_j(\theta_j))\)를 최소화” 하는 문제로 바꿔서 풀 수 있다.


따라서 위의 KL-Divergence를 최소화하기 위해, 우리는 \(q_j(\theta_j)\)를 다음과 같이 설정하면 된다.

\[q_j(\theta_j) = r_j(\theta_j) = \frac{1}{Z_j} exp(E_{q_{i\neq j}}logp(x,\theta))\]


4. Mean Field Variational Inference

위에서 얻어낸 updating equation을 적용하여, Mean Field Assumption을 적용한 Variational Inference ( Mean Field Variational Inference )의 알고리즘을 정리하면 다음과 같다.

Initialize \(q(\theta) = \prod_{j=1}^{m}q_j(\theta_j)\)

Iteration :

  • Update each factor \(q_1 ... q_m\) :

    ​ \(q_j(\theta_j) = \frac{1}{Z_j} exp(E_{q_{i\neq j}}logp(x,\theta))\)

  • Compute ELBO \(L(q(\theta))\)

    Repeat until convergence of ELBO