[Deep Bayes] Stochastic Variational Inference and Variational AutoEncoders

해당 내용은 https://deepbayes.ru/ (Deep Bayes 강의)를 듣고 정리한 내용이다.

1. Introduction

이번 포스트에서는 다음과 같이 크게 3가지에 대해 다룰 것이다.

  • 1 ) Probabilistic PCA
  • 2 ) VAE ( Variational Auto Encoder )
  • 3 ) Reparameterization Trick


2 . Probabilistic PCA

Probabilistic PCA가 무엇인지 알기 이전에, 우리는 Continuous latent variable에 대해 먼저 알아볼 것이다. 우선, continous variables들은 다음과 같이 여러 연속적인 분포의 mixture로 표현할 수 있다.

\[p(x_i \mid \theta) = \int p(x_i, z_i \mid \theta)d z_i = \int p(x_i \mid z_i \theta)p(z_i \mid \theta) dz_i\]


그리고 위 식을 활용하여, 우리는 다음과 같은 식을 통해 E-step을 수행했었다.

\[q(z_i) = p(z_i \mid x_i, \theta) = \frac{p(x_i \mid z_i, \theta)p(z_i \mid \theta)}{\int p(x_i \mid z_i, \theta)p(z_i \mid \theta) dz_i}\]


위 식에서의 continuous latent variable은 주로 차원 축소 (dimension reduction) 에서 많이 사용된다. 그 중 대표적인 예가 바로 PCA이다.


(1) PCA

PCA는 가장 기본적인 차원축소 기법으로, 많이들 알고 있을 것이다. 하지만 이러한 PCA도 우리가 이전에 배운 EM algorithm 방법으로 해결할 수 있다.


https://miro.medium.com/max/462/1*QALwLkPJG45E7_DsOccdaw.png

우선, observed variable인 X와 continuous latent variable인 Z에 대한 joint distribution은 다음과 같이 나타낼 수 있다.

\(x \in R^D, z \in R^d\) such that D » D

\[p(X,Z \mid \theta) = \prod_{i=1}^{n}p(x_i \mid z_i, \theta)p(z_i \mid \theta) = \prod_{i=1}^{n}N(x_i \mid V_{z_i} + \mu, \sigma^2 I)N(z_i \mid 0, I)\]

위 식에서 \(\theta\)는 (D x d matrix)인 \(V\)와 , D-차원 벡터 \(\mu\)와 스칼라 \(\sigma\)로 구성된다.

\(\underset{\theta}{argmax}\;\;p(X_{tr}\mid \theta)\)를 찾기 위해 어떻게 해야 할까?

위 식에서 나온 두 부분이 모두 Normal Distribution이기 때문에 conjugate를 사용하여 쉽게 풀 수 있지만, 우리는 여기서 를 찾기 위해 EM알고리즘을 사용할 수 있다.


(2) Advantages of EM PCA

  • 1 ) 시간복잡도가 \(O(nD^2)\)에서 \(O(nDd)\)로 줄어든다
  • 2 ) X 데이터의 missing part를 생성할 수 있다.
  • 3) mixture of PCA등과 같이 보다 일반적인 모델(generalized model)로 확장할 수 있다.


(3) Mixture of PCA

https://www.researchgate.net/profile/Joydeep_Ghosh6/publication/3193212/


기존의 PCA

(continuous latent variable Z 하나만을 사용한) 기존의 PCA의 Joint Distribution은 다음과 같이 표현할 수 있었다.

\[p(X,Z \mid \theta) = \prod_{i=1}^{n}p(x_i \mid z_i, \theta)p(z_i \mid \theta) = \prod_{i=1}^{n}N(x_i \mid V_{z_i} + \mu, \sigma^2 I)N(z_i \mid 0, I)\]


하지만, 위 처럼 단 하나의 PCA를 사용하는 것 대신, 여러 개의 PCA를 서로 섞어서 사용할 수 있다. 이는, \(Z\)와 함께 discrete latent variable \(T\)를 사용함으로써 표현할 수 있다. 그 식은 다음과 같다.


Mixture of PCA

\[p(X,Z,T \mid \theta) = \prod_{i=1}^{n}p(x_i \mid t_i,z_i, \theta)p(z_i \mid \theta)p(t_i \mid \theta) = \prod_{i=1}^{n}N(x_i \mid V_{t_i}z_i + \mu_{t_i}, \sigma_{t_i}^2 I)N(z_i \mid 0, I)\pi_t\]

위 식에서 \(\theta\)는 (D x d matrix)인 \(V_k\) , D-차원 벡터 \(\mu_k\), 스칼라 \(\sigma_k\)와 \(p(t_i=k)=\pi_k\)인 확률 \(\pi\)로 구성된다. 기존의 PCA와의 다른 좋은 점은, 이것은 non-linear하게 차원을 축소할 수 있다는 점이다. (위의 그림을 참고)


(4) EM Algorithm for mixture of PCA

위의 mixture of PCA에도 EM Algorithm을 적용하여 풀 수 있다.

E-Step과 M-Step에 대해서 알아보자.


E-step

\[\begin{align*} q(Z,T) & = \prod_{i=1}^{n}q(z_i, t_i) \\ &= \prod_{i=1}^{n}p(z_i,t_i \mid x_i, \theta)\\ &= \prod_{i=1}^{n} \frac{N(x_i \mid V_{t_i}z_i + \mu_{t_i}, \sigma_{t_i}^2 I)N(z_i \mid 0, I)\pi_t}{\sum_{k=1}^{K}\int N(x_i \mid V_{k}z_i + \mu_{k}, \sigma_{k}^2 I)N(z_i \mid 0, I)\pi_k dz_i} \\\end{align*}\]


M-step

\[E_{Z,T}logp(X,Z,T \mid \theta) = \sum_{i=1}^{n}E_{z_i,t_i}(logp(x_i \mid t_i, z_i, \theta) + logp(z_i \mid \theta) + logp(t_i \mid \theta)) \rightarrow \underset{\theta}{max}\]


3. VAE ( Variational Auto Encoder ) (1)

VAE(Variational Encoder)는 생성 모델 (Generative Model)의 대표적인 방법 중 하나이다. 즉, VAE는 확률 분포에 대해 학습을 한 뒤, 이로 부터 새로운 데이터를 생성해내는 방법이다. VAE는 크게 두 개의 구조( Encoder & Decoder) 로 이루어져있다. Encoder는 데이터 X를 입력받아, latent variable Z를 출력한다. Decoder은 이러한 Encoder의 출력인 Z를 받은 뒤, 이를 통해 다양한 이미지를 생성해낸다. 지금까지 probabilistic PCA를 이야기한 이유도, 바로 이 VAE를 설명하기 위해서이다. VAE는 probabilistic PCA의 일반화 버전이라고 생각하면 된다.

https://datascienceschool.net/upfiles/f38b90fa89cb46eba22178edbae07a26.png


Generator

Generator은 factorized된 Gaussian Distribution을 출력하는데, 이 분포의 parameter(\(\mu\) 와 \(\sigma\) )들은 latent variable들의 non-linear function 형태이다. ( Deep Neural Network에 의해 학습됨 ) 이를 통해, 매우 복잡한 고차원의 분포를 단순한 저차원의 분포로 변환할 수 있다.


PCA의 식은 다음과 같았다.

\[p(X,Z \mid \theta) = \prod_{i=1}^{n}p(x_i \mid z_i, \theta)p(z_i \mid \theta) = \prod_{i=1}^{n}N(x_i \mid V_{z_i} + \mu, \sigma^2 I)N(z_i \mid 0, I)\]


여기서 Z에 대해 적분하기는 쉽지 않다 ( intractable )

\[\int p(X,Z\mid \theta)dZ = \prod_{i=1}^{n}\int p(x_i \mid z_i, \theta)p(z_i)dz_i\]


그래서 우리는 이것을 Variational Inference를 통해 optimization문제로 풀 것이다.


Variational Inference 복습

\[q(z_i \mid x_i, \phi) \approx p(z_i \mid x_i, \theta)\]


factorized Gaussian 형태)

\[q(z_i \mid x_i, \phi) = \prod_{j=1}^{d}N(z_{ij}\mid \mu_j(x_i),\sigma_j^{2}(x_i))\]


Variational Inference를 통해 이 문제를 푸려면, 우리는 \(q(Z \mid X, \phi)\) 를 최대화 해야 하고, 이는 결국 \(KL(q(Z \mid X,\phi) \mid \mid p(Z\mid X, \theta))\) 를 최소화해야한다. 이는 결국, ELBO인 다음 식을, \(\phi\)와 \(\theta\)에 대해서 최대화 하는 것이다.

\[L(\phi, \theta) = \int q(Z \mid X, \phi)log \frac{p(X\mid Z, \theta)p(Z)}{q(Z\mid X,\phi)}dZ\]


위의 ELBO를 최대화 하는 것이 우리의 목적이라는 것은, 다르게 말하면 Loss Function으로 대변될 수 있는 아래의 식 (ELBO에 (-)를 붙인 것)을 최소화 하는 것과 같다.

\[\begin{align*} Loss(\phi, \theta) &= - \int q(Z \mid X, \phi)log \frac{p(X\mid Z, \theta)p(Z)}{q(Z\mid X,\phi)}dZ\\ &= - \int q(Z \mid X, \phi)log p(X\mid Z, \theta)dZ + KL(q(Z\mid X, \phi) \mid \mid p(Z)) \\ &= - \int q(Z \mid X, \phi)log p(X\mid Z, \theta)dZ - H(q(Z \mid X, \phi))+ H(q(Z\mid X, \phi),p(Z)) \end{align*}\\\]

위 식의 우변에 생성된 세 부분은, 각각 다음을 의미한다.
( https://ratsgo.github.io/generative%20model/2018/01/27/VAE/ 참고 )

  • PART 1 ) \(- \int q(Z \mid X, \phi)log p(X\mid Z, \theta)dZ\)
    • Reconstruction Error
    • \(X\) 에 대한 복원 오차


  • Part 2 ) \(- H(q(Z \mid X, \phi))\)
    • Posterior Entropy
    • Posterior에서 샘플링 된 Z는 최대한 다양해야


  • Part 3) \(H(q(Z\mid X, \phi),p(Z))\)
    • Cross Entropy
    • Posterior & Prior의 정보량은 유사해야


위의 ELBO 식에서, \(\phi\)에 대해서 maximize하는 것이 E-step에 속하고, \(\theta\)에 대해서 maximize하는 것은 M-step에 속한다. 이에 대해서 보다 자세히 알아보겠다.


4. VAE ( Variatonal Auto Encoder ) (2)

1) Stochastic Optimization

\[L(\phi, \theta) = \int q(Z \mid X, \phi)log \frac{p(X\mid Z, \theta)p(Z)}{q(Z\mid X,\phi)}dZ\]

위 식을 optimize하는 것은 여전히 어려운 문제이다.

첫 째, training data가 매우 커서 iteration이 매우 expensive할 수 있다.

둘 째, 위 식이 여전히 다루기 어렵다 (intractable).

따라서, 우리는 이를 mini-batch와 Monte-Carlo Estimation문제를 사용하여 풀 것이다.


2) Optimization w.r.t \(\theta\)

\(\theta\)에 대해서 최적화 하는 것은 어렵지 않다. 우선, 다음과 같이 mini-batch를 뽑은 뒤 Monte-carlo Estimation을 할 수 있다.

\[\begin{align*} \bigtriangledown _\theta L(\phi, \theta) &= \bigtriangledown _\theta \sum_{i=1}^{n}\int q(z_i \mid x_i, \phi)log\frac{p(x_i \mid z_i, \theta)p(z_i)}{q(z_i \mid x_i, \phi)}dz_i \\ &= \sum_{i=1}^{n}\int q(z_i \mid x_i, \phi)\bigtriangledown _\theta log\frac{p(x_i \mid z_i, \theta)p(z_i)}{q(z_i \mid x_i, \phi)}dz_i \\ &\approx n\int q(z_i \mid x_i, \phi)\bigtriangledown _\theta log\frac{p(x_i \mid z_i, \theta)p(z_i)}{q(z_i \mid x_i, \phi)}dz_i \\ &\approx n \bigtriangledown _\theta log p(x_i \mid z^{*}_i, \theta) \\ \end{align*}\]

( 위 식에서 \(z^{*}_i \sim q(z_i \mid x_i, \phi)\) 이다 )


3) Optimization w.r.t \(\phi\)

\(\phi\)의 경우에는 약간 다르다. 그 이유는, density function 그 자체가 , 우리가 미분하고자 하는 \(\phi\)와 관련이 있기 때문이다. 그래서 위의 2) 처럼 gradient를 \(\int\) 안에 집어 넣을 수가 없다.

\[\bigtriangledown_x \int p(y \mid x)h(x,y)dy \neq \int p(y\mid x) \bigtriangledown_x h(x,y)dy\]


[추가 1] Log-derivative Trick

하지만, 우리는 log-derivative trick을 사용하여 위 문제를 해결할 수 있다.

\[\frac{\partial}{\partial x} \int p(y \mid x)h(x,y)dy = \int p(y\mid x)\frac{\partial}{\partial x}h(x,y)dy + \int p(y\mid x)h(x,y)\frac{\partial}{\partial x} logp(y\mid x)dy\]


위 식의 우변에 등장하는 \(\int p(y\mid x)\)는 Monte Carlo Estimation을 통해 풀 수 있다.

그러면, 위 식은 다음과 같이 근사할 수 있다.

\[\frac{\partial}{\partial x} \int p(y \mid x)h(x,y)dy = \frac{\partial}{\partial x}h(x,y_0) + h(x,y_0)\frac{\partial}{\partial x} logp(y_0\mid x)\]


[추가 2] Reparameterization Trick

다음과 같은 복잡한 expectation값을 미분하는 경우를 생각해보자.

\[\frac{\partial}{\partial x} \int p(y \mid x)h(x,y)dy\]


우리는 여기서 다음과 같은 \(g(\cdot)\) 함수를 도입하여 다음과 같이 표현할 수 있다.

\[\int p(y\mid x)h(x,y)dy = \int r(\epsilon)h(x,g(\epsilon,x))d\epsilon\]


이렇게 바꿔서 표현할 경우, 미분하기는 더 쉽다.

\[\begin{align*} \frac{\partial}{\partial x} \int p(y\mid x)h(x,y)dy &= \frac{\partial}{\partial x} \int r(\epsilon)h(x,g(\epsilon,x))d\epsilon \\ &\approx \frac{d}{dx} h(x,g(x,\hat{\epsilon}))\\ &= \frac{\partial}{\partial x}h(x,g(x,\hat{\epsilon})) + \frac{\partial}{\partial g}h(x,g(x,\hat{\epsilon}))\frac{\partial}{\partial x}g(x,\hat{\epsilon}) \end{align*}\]

( 여기서 \(\hat{\epsilon} \sim r(\epsilon)\) 이다 )


위의 [추가 1]Log-derivative Trick[추가 2]Reparameterization Trick 을 적용하여 ELBO 식을 풀어보자. 우선, ELBO 식은 다음과 같다.

\[\begin{align*} L(\phi, \theta) &= \int q(Z \mid X, \phi)log \frac{p(X\mid Z, \theta)p(Z)}{q(Z\mid X,\phi)}dZ\\ &= \int q(Z \mid X, \phi)log p(X\mid Z, \theta)dZ - KL(q(Z\mid X, \phi) \mid \mid p(Z)) \end{align*}\]


위 식에서, KL-Divergence 부분은 \(\phi\)에 대해서 쉽게 미분하여 계산할 수 있다. 하지만 그 앞부분인 \(\int q(Z \mid X, \phi)log p(X\mid Z, \theta)dZ\) 은 그렇지 않다. 그래서 Log-Derivative TrickReparameterization Trick을 사용하여 쉽게 풀 것이다. 그 식은 다음과 같다.

\[\begin{align*} \frac{\partial}{\partial \phi}\int q(Z \mid X, \phi)log p(X\mid Z, \theta)dZ &\approx n \frac{\partial}{\partial \phi}\int q(z_i \mid x_i, \phi)log p(x_i \mid z_i, \theta)dZ \\ &= n \frac{\partial}{\partial \phi}\int r(\epsilon) log p(x_i \mid g(\epsilon, x_i, \phi)z_i, \theta)d \epsilon \\ &\approx n \frac{\partial}{\partial \phi}\int log p(x_i \mid g(\hat{\epsilon}, x_i, \phi)z_i, \theta)d \epsilon \\ \end{align*}\]

( 여기서 \(\hat{\epsilon} \sim r(\epsilon)\) 이다 )


5. VAE ( Variatonal Auto Encoder ) (3)

Summary

VAE의 알고리즘은 다음과 같이 정리할 수 있다.

  • Input : Training data \(X\) 와, latent space의 차원 \(d\)

  • 랜덤하게 \(i \sim U\{1,...N\}\) 에서 샘플링하고, ELBO의 stochastic gradient을 계산한 뒤…

    • 1 ) \(\theta\)에 대해 미분

      • \[stoch.grad_{\theta}L(\phi, \theta) = n \frac{\partial}{\partial \theta}log p(x_i \mid z^{*}_{i}, \theta)\]

        ( \(z^{*}_i \sim q(z_i \mid x_i, \phi)\) 이다 )

    • 2) \(\phi\)에 대해 미분

      • \[stoch.grad_{\phi}L(\phi, \theta) = n \frac{\partial}{\partial \phi}log p(x_i \mid g(\hat{\epsilon},x_i,\phi), \theta) - \frac{\partial}{\partial \phi} KL(q(z_i \mid x_i, \phi) \mid \mid p(z_i))\]

        ( \(\hat{\epsilon} \sim r(\hat{\epsilon})\) 이다)

    위 과정을 반복한다


Summary 2

다음의 한 장의 사진으로 위의 전체 프로세스를 요약할 수 있다.


https://i.stack.imgur.com/GgS2y.png