Gumbel-Softmax Trick
1. Reparameterization Trick
세 줄 요약 :
-
stochastic term을, 두 개의 부분 (1) stochastic & (2) deterministic으로 나누는 과정
-
나누는 이유? Back Propagation을 하기 위해
( stochastic한 부분에 대해서는 back-propagation을 할 수 없다. 따라서 deterministic한 부분을 만들어주고, 이에 대해 back-prop을 실시한다. )
-
ex) \(x \sim N(\mu_{\phi}, \sigma^2_{\phi})\)
- stochastic 부분 : \(\epsilon \sim N(0,1)\)
- 따라서, \(x = \mu_{\phi} + \sigma^2_{\phi}\cdot \epsilon\)
Categorical Variable을 reparameterization하는 대표적인 두 가지 방법으로
- (1) Gumbel-Max Trick 과
- (2) Gumbel-softmax Trick이 있다.
2. Gumbel-Max Trick
다음과 같이 Categorical 분포를 따르는 \(z\)가 있다고 해보자.
\[z \sim \text{Categorical}(\pi_1,...,\pi_k)\]이때, 우리는 \(z\)를 다음과 같은 방법으로 샘플링할 수 있다.
\(z = \underset{k}{\text{argmin}} \frac{\xi_k}{\pi_k}\), where \(\xi_k \sim \text{Exp}(1)\)
( 혹은, \(z = \underset{k}{\text{argmax}} \frac{\pi_k}{\xi_k}\))
( \(\pi_k\)가 클 수록, 즉 높은 확률일 수록 더 샘플링 될 확률이 높아지는 꼴이다.)
위 식에 log를 씌우면, 다음과 같이 정리될 수 있다.
\[z = \underset{k}{\text{argmax}} [log\pi_k - log\xi_k]\]( 여기서 \(-log\xi_k\)가 Gubmel(0,1) 분포를 따르기 때문에 해당 방법의 이름이 Gumbel-Max trick이다 )
우리는 위 방법을 통해 stochastic 부분과 deterministic 부분으로 바꾸는 reparameterization을 했다. 하지만, 이는 여전히 argmin/argmax의 특성 상 해당 경계를 제외하고 \(\pi\)에 대한 미분이 불가능(0이 된다)하다.
따라서 우리는 이를 continous하게 만들어 줄 필요가 있고, 그래서 등장한 것이 Gumbel-Softmax Trick이다.
3. Gumbel-Softmax Trick
생각보다 간단하다. Gumbel-max Trick에서 argmax를 softmax로만 바꿔주면 된다.
우선 softmax함수에 대한 식을 정리하면 다음과 같다.
\[\text{softmax}_{\tau}(x)_j=\frac{exp(x_j / \tau)}{\sum_{k=1}^{K}exp(x_k / \tau)}\]여기서 \(\tau\)는 위 softmax의 ‘sharpness’를 결정한다
- \(\tau=0 \rightarrow\) 결국 softmax도 argmax랑 같다.
- \(\tau = \infty \rightarrow\) uniform distribution이 된다
지금까지 \(z\)는 one-hot vector로, discrete했다. 우리는 이를 다음과 같이 softmax를 사용하여 continuous하게 변형할 것이고, 이를 \(\tilde{z}\) 로 나타낼 것이다.
\[\tilde{z}(\gamma, \pi) := \text{softmax}_{\tau}(log\pi_1 + \gamma_1, .... , log\pi_k + \gamma_k)\]-
\[\gamma_k \sim \text{Gumbel}(0,1)\]
- \[\gamma_k = -log(-logu_k)\]
- \[u_k \sim \text{Uniform}(0,1)\]
\(\tau\)값을 어떻게 설정하냐에 따라 sample 및 expectation은 다음과 같이 변한다. ( 여기서 \(\tau\)를 temperature라고 부른다 )
SUMMARY
-
expectation : \(\underset{q_{\phi}(z)}{E}f(z) \approx\underset{q_{\phi}(\tilde{z})}{E}f(\tilde{z}) = \underset{p(\gamma)}{E}f(\tilde{z}(\gamma, \phi))\)
-
gradient : \(\frac{\partial}{\partial \phi}f(\tilde{z}(\gamma, \pi(\phi)))\)