Neural Variational Inference and Learning in Belief Networks ( 2014 )
( 아래의 포스트를 읽기전에, DBN, RBM에 대한 내용에 대해 이해가 선행되어야 한다. 이에 대한 설명은 서울대학교 컴퓨터공학부 장병탁교수님의 자료를 참고하면, 쉽게 이해할 수 있을 것이다. download here : Download )
Abstract
이 논문은 variational posterior부터 efficient exact sampling을 하기 위해, feedforward network를 사용한 non-iterative approximate inference 방법론을 제안한다. 여기서 “모델”과 “inference network”는 ELBO를 최대화하는 방향으로 jointly trained된다. ( 모델과 inference network를 헷갈리면 안된다. 일반적으로, 모델은 \(p(X \mid Z)\), inference network는 \(p(Z \mid X)\) ! )
또한, inference network gradient의 variance가 커지는 문제를 해결하기 위한 variance reduction technique도 제안한다.
1. Introduction
DBN(Deep Belief Network), DBM(Deep Blotzmann)는 powerful 한 latent variable 모델로서 큰 데이터에 적용가능하긴 하지만, efficient한 learning을 하는 것은 아직까지 풀어야 하는 숙제이다. 이처럼, 복잡하면서도 (complex), 큰 데이터에 적용가능한 (scalable) 알고리즘을 만드는 것은 매우 중요한 포인트이다. 이를 충족시키기 위한 방법으로, MCMC나 Variational method 등이 제안되었다. 이들에 대해 간단히 요약하자면, 아래와 같다.
MCMC (sampling)
-
suffer from slow mixing ( too computationally expensive)
\(\rightarrow\) difficult to scale to large dataset
Variational method
- optimization- based
- tend to be more efficient than MCMC
- but, highly expressive models are intractable
- more model-dependent than MCMC
이 논문은, sampling기반의 방법과 variational method의 장점을 결합한 새로운 접근법을 시도한다. 핵심은, variational posterior로 부터 efficient exact sampling을 하기 위해 feedforward neural network를 사용한다는 점이다.
이 inference network는 모델과 함께 ELBO를 maximize하는 방향으로 train 된다. 뿐만 아니라, 이때의 gradient의 variance가 커지는 것을 막기 위해 general variance reduction 테크닉을 제안한다. 이는 RL에서 제안하는 REINFORCE 알고리즘과 매우 닮아있다.
이 알고리즘은 inference를 수행하기 위한 모델로서 stochastic feedforward network를 사용한다는 점에 있어서, 그 알고리즘의 이름은 NVIL (Neural Variational Inference and Learning) 이라고 부른다.
2. NVIL (Neural Variational Inference and Learning)
2-1. Variational Object
( latent variable를 \(z\)로 표현하는 경우가 많으나, RBM,DBN 등에서는 이를 hidden layer \(h\)로 표현하는 경우가 많다 )
우리는 latent variable model인 \(P_{\theta}(x,h)\)를 \(\theta\)에 대해서 학습시키는데에 관심이 있다. 하지만, 대부분의 경우 이는 intractable하고, 따라서 우리는 VI방법을 사용하여 ELBO를 최대화한다.
notation은 다음과 같이 정리하겠다.
-
training set \(\mathcal{D}\), consisting \(x_{1}, \ldots, x_{D},\)
- approximating distribution : \(Q_{\phi}(h \mid x)\)
- (intractable) exact posterior : \(P_{\theta}(h \mid x)\)
ELBO를 정리하면 아래와 같다.
\(\begin{aligned} \log P_{\theta}(x) &=\log \sum_{h} P_{\theta}(x, h) \\ & \geq \sum_{h} Q_{\phi}(h \mid x) \log \frac{P_{\theta}(x, h)}{Q_{\phi}(h \mid x)} \\ &=E_{Q}\left[\log P_{\theta}(x, h)-\log Q_{\phi}(h \mid x)\right] \\ &=\mathcal{L}(x, \theta, \phi) \end{aligned}\).
\(\mathcal{L}(x, \theta, \phi)=\log P_{\theta}(x)-K L\left(Q_{\phi}(h \mid x), P_{\theta}(h \mid x)\right)\).
\[\mathcal{L}(\mathcal{D}, \theta, \phi)=\sum_{i} \mathcal{L}\left(x_{i}, \theta, \phi\right)\]기존의 VI 방법과의 차이점은, “local variational parameter”를 사용하지 않는다는 점이다. 여기서 \(x\)를 approximating distribution \(Q_{\phi}(h \mid x)\)로 mapping하는 것을 inference network라고 부른다.
이 inference network에게는 아래의 두 가지 사항이 요구된다.
- efficient to evaluate
- efficient to sample from
이 network로부터 생성된 sample들은 이후에 gradient estimate를 계산하는데에 사용된다.
large dataset에 적용가능하기 위해, 기존의 방법과 마찬가지로 minibatch에 기반한 stochastic optimization을 사용한다.
2-2. Parameter gradients
위의 ELBO를, model parameter \(\theta\)와 inference network parameter \(\phi\)에 대해 미분한 값은 아래와 같다.
- \(\nabla_{\theta} \mathcal{L}(x)=E_{Q}\left[\nabla_{\theta} \log P_{\theta}(x, h)\right]\).
- \(\nabla_{\phi} \mathcal{L}(x)=E_{Q}[\left(\log P_{\theta}(x, h)-\log Q_{\phi}(h \mid x)\right)]\).
위 식을 MC integration을 사용해서 다시 정리하면, 아래와 같다.
( 이는 unbiased estimator이고, ELBO에 대한 stochastic maximization을 수행하기 위해 사용됨. )
- \(\nabla_{\theta} \mathcal{L}(x) \approx \frac{1}{n} \sum_{i=1}^{n} \nabla_{\theta} \log P_{\theta}\left(x, h^{(i)}\right)\).
- \(\nabla_{\phi} \mathcal{L}(x) \approx \frac{1}{n} \sum_{i=1}^{n}\left(\log P_{\theta}\left(x, h^{(i)}\right)-\log Q_{\phi}\left(h^{(i)} \mid x\right)\right) \times \nabla_{\phi} \log Q_{\phi}\left(h^{(i)} \mid x\right)\).
알다시피, convergence 속도는 estimator의 variance에 따라 크게 의존한다. 하지만, 위 식에서 \(\nabla_{\phi} \mathcal{L}(x)\)의 variance는 그 값이 매우 클수 있고, 이는 결국 속도 저하를 가져올 수 있다. 따라서, 2-3에서는 이 variance를 줄이기 위한 방안을 제안한다.
2-3. Variance Reduction techniques
variance를 줄이는 방법은 생각보다 간단하다.
( 뒤에서 말할 이 방법은 “model-independent” technique이라 general하게 적용가능하다! )
2-3-1. Centering the Learning Signal
ELBO를 \(\phi\)에 대해서 미분한 식을 다시한번 적어보자.
\(\nabla_{\phi} \mathcal{L}(x) \approx \frac{1}{n} \sum_{i=1}^{n}\left(\log P_{\theta}\left(x, h^{(i)}\right)-\log Q_{\phi}\left(h^{(i)} \mid x\right)\right) \times \nabla_{\phi} \log Q_{\phi}\left(h^{(i)} \mid x\right)\).
여기서, \(l_{\phi}(x, h)=\log P_{\theta}(x, h)-\log Q_{\phi}(h \mid x)\). 를 우리는 “learning signal“이라고 부른다. ( 이 learning signal (=\(l_{\phi}(x, h)\) )의 정도에 따라 update되는 정도가 다름을, 식을 통해 직관적으로 알 수 있다. )
조금 이상한 점이 있지 않는가? 우리는 \(Q_{\phi}(h \mid x)\)가 \(P_{\theta}(h \mid x)\)를 근사하게끔 만들고 싶다. 하지만 위의 learning signal을 보면, \(P_{\theta}(h \mid x)\)가 아니라 \(P_{\theta}(x, h)\)가 사용됨을 알 수 있다. 하지만, 이는 문제가 되지 않는다. 그 이유는, 이제 밑에서 설명할 것이다.
우리는 learning signal에, \(h\)에 depend하지 않는 특정 term \(c\)를 빼도, 그 expectation값에는 변화가 없음을 알 수 있다.
\(\begin{aligned} &E_{Q}\left[\left(l_{\phi}(x, h)-c\right) \nabla_{\phi} \log Q_{\phi}(h \mid x)\right] \\ &=E_{Q}\left[l_{\phi}(x, h) \nabla_{\phi} \log Q_{\phi}(h \mid x)\right]-c E_{Q}\left[\nabla_{\phi} \log Q_{\phi}(h \mid x)\right]\\ &=E_{Q}\left[l_{\phi}(x, h) \nabla_{\phi} \log Q_{\phi}(h \mid x)\right] \end{aligned}\)……….. 식 (a)
( \(\because E_{Q}\left[\nabla_{\phi} \log Q_{\phi}(h \mid x)\right]=E_{Q}\left[\frac{\nabla_{\phi} Q_{\phi}(h \mid x)}{Q_{\phi}(h \mid x)}\right]=\nabla_{\phi} E_{Q}[1]=0\) )
그리고, \(\log P_{\theta}(x, h)=\log P_{\theta}(h \mid x)+\log P_{\theta}(x)\)로 decompose되는 것을 알 수 있다. 하지만 여기서 두번째 term인 \(\log P_{\theta}(x)\)는 \(h\)에 depend하지 않는 term이다. 따라서 이 값을 더하거나/빼주어도 그 expectation값에 차이가 없다. 이 덕분에, 우리는 intractable한 \(P_{\theta}(h \mid x)\) term을 계산할 필요가 없어진다!
하지만 우리는 tractability를 얻는 대신, variance가 높아지는 문제점이 발생한다.하지만 다행히도, 우리는 식 (a)에서 \(c\)를 잘 설정함으로써 variance를 줄일 수 있다. 가장 간단한 방법은, \(c\)를 하나의 parameter로써 취급하고 이를 학습하는 것이다.
하지만 우리에게 남아있는 \(\log P_{\theta}(x)\) term때문에, \(c\)는 observation \(x\)에 따른 learning signal의 차이를 잘 capture해내기 어려울 것이다. 따라서, 우리는 추가적으로 observation-dependent term인 \(C_{\psi}(x)\)를 뺴주면 이 문제를 해결할 수 있다.
variance를 줄이기 위해 도입한 두 개의 term, \(c\)와 \(C_{\psi}(x)\)를 우리는 “baseline”이라고 부른다. ( baseline은 RL(Reinforcement Learning)에서 등장하는 개념으로, 이를 공부했다면 이와 같이 명명한 이유를 이해할 수 있을 것이다 )
최종적으로, 우리는 이 centered learning signal인 \(E_{Q}\left[\left(l_{\phi}(x, h)-C_{\psi}(x)-c\right)^{2}\right]\)가 최소화 되도록 학습시킨다.
2-3-2. Variance Normalization
위의 2-3-1을 통해서, 우리는 learning signal을 centering했다. 하지만, 이것의 average magnitude가 큰 폭으로 변동할 수 있기 때문에, 이를 normalization해줄 필요가 있다. 이 Normalization를 통해 signal은 approximately unit variance를 가지고, 안정적인 학습을 가능하게 한다.
( 여기서, variance normalization은, estimate의 standard deviation이 1보다 큰 경우에 한해서 적용한다 )
2-3-3. Local Learning Signals
우리는 지금까지 model이나 inference network의 structure에 대해서 아무런 가정을 하지 않았다. 하지만, conditional independence property를 통해서, 우리는 inference network를 더 simple하고 less noisy한 local learning signal을 사용할 수 있다 ( global learning signal을 사용하는 것 보다 )
model과 inference network를 아래와 같이 factorize할 수 있다.
- \(P_{\theta}(x, h)=P_{\theta}\left(x \mid h^{1}\right) \prod_{i=1}^{n-1} P_{\theta}\left(h^{i} \mid h^{i+1}\right) P_{\theta}\left(h^{n}\right)\).
-
\(Q_{\phi}(h \mid x)=Q_{\phi^{1}}\left(h^{1} \mid x\right) \prod_{i=1}^{n-1} Q_{\phi^{i+1}}\left(h^{i+1} \mid h^{i}\right)\).
- notation
- \(h^{i}\) : latent variables in the \(i^{\text{th}}\) layer
- \(\phi^{i}\) : parameters of the variational distribution in the \(i^{\text{th}}\) layer
- \(h^{i:j}\) : latent variables in layers \(i\) through \(j\)
layer \(i\)에서의 variational distribution의 parameter를 학습하기 위해, 아래의 gradient를 계산해야 한다.
( 이를 Law of Iterated Expectation을 사용하여 추가로 정리할 수 있다.)
\(\nabla_{\phi_{i}} \mathcal{L}(x)=E_{Q(h \mid x)}\left[l_{\phi}(x, h) \nabla_{\phi_{i}} \log Q_{\phi_{i}}\left(h^{i} \mid h^{i-1}\right)\right]\).
\(\begin{array}{l} \nabla_{\phi_{i}} \mathcal{L}(x)=E_{Q\left(h^{1: t-1} \mid x\right)}[\left.\left.E_{Q\left(h^{\imath: n} \mid h^{i-1}\right)}\left[l_{\phi}(x, h) \nabla_{\phi_{i}} \log Q_{\phi_{i}}\left(h^{i} \mid h^{i-1}\right)\right] \mid h^{i-1}\right]\right] \end{array}\).
따라서, layer \(i\)에서의 local learning signal은 아래와 같다
\(l_{\phi}^{i}(x, h)=\log P_{\theta}\left(h^{i-1: n}\right)-\log Q_{\phi}\left(h^{i: n} \mid h^{i-1}\right)\).
( first hidden layer에 대한 signal은, 위의 \(h^0\)대신에 \(x\)를 넣으면 된다. )