Uncertainty-Aware Attention for Reliable Interpretation and Prediction
Contents
- Abstract
- Introduction
- Approach
- Stochastic Attention with input-adaptive Gaussian Noise
- Variational Inference
0. Abstract
Attention mechanism은 relevant feature에 집중할 수 있게끔 하는 엄청난 기능을 가지지만….
\(\rightarrow\) may be UNRELIABLE
이를 극복하기 위해….Input-dependent uncertainty라는 개념을 소개함
( = generates attention for each feature with varying degrees of noise based on the given input )
\(\rightarrow\) learn larger variance on instances, with high uncertainty
propse UA ( Uncertainty-aware Attention ) mechanism using VI
1. Introduction
Background
-
high reliability를 얻는 것은 매우 중요해! 특히 safety 관련해서!
-
Attention : find most relevant features for each input instance
( + allows easy interpretation )
-
Attention의 한계점 : unreliable
\(\rightarrow\) need a model that knows its own limitation!
( = 예측/판단을 내려도 될 정도로 safe한지를 모델 지 스스로 잘 알아야! )
Proposal
allow attention to output uncertainty on each input
-
더 나아가서, leverage them when making final predictions
-
구체적으로, attention weight를 Gaussian으로 모델링 ( input dependent noise O )
Contribution
- 1) Variational Attention 모델을 제안함
- 2) UA알고리즘이 accurate calibration of model uncertainty를 만듬
- 3) 6개의 real world 데이터에 테스스
2. Approach
STOCHASTIC attention ( 최초 제안은 X )
- \(\mathbf{v}(\mathbf{x}) \in \mathbb{R}^{r \times i}\) : concatenation of \(i\) intermediate features
- each column of which \(\mathbf{v}_{j}(\mathbf{x})\) is a length \(r\) vector
- 이 \(\mathbf{v}(\mathbf{x})\) 로부터 random variables \(\left\{\mathbf{a}_{j}\right\}_{j=1}^{i}\) 가 conditionally 생성된다
- \(\mathbf{c}(\mathbf{x})=\sum_{j=1}^{i} \mathbf{a}_{j} \odot \mathbf{v}_{j}(\mathbf{x})\) : context vector ( \(\mathbf{c} \in \mathbb{R}^{r}\) )
- \(\hat{\mathbf{y}}=f(\mathbf{c}(\mathbf{x}))\) : final output
attention은 deterministic / stochastic 할 수 있다
-
ex) stochastic attention : \(\mathbf{a}_j\) 는 Bernoulli distribution에서 생성됨.
\(\rightarrow\) maximize ELBO
-
stochastic attention > deterministic counterpart, on image annotation task.
2-1. Stochastic Attention with input-adaptive Gaussian Noise
위의 Stochastic attention의 2가지 한계점
( stochastic attention을 directly하게 Bernoulli/Multinomial에서 뽑는다는 점에서 )
-
[ 한계점 1 ] Bernoulli의 variance는 allocation probability \(\mu\)와 DEPENDENT
- \(\mathbf{a} \sim \text{Bernoulli}(\mu)\).
- allocation probability \(\mu\) : attention 연결 할지/말지
- Bernoulli의 variance : \(\sigma^{2}=\mu(1-\mu)\)
- 따라서, \(\mathbf{a}\)의 variance는 낮기 어렵다 ( if \(\mu\)가 0.5 부근 )
-
[ 해결책 1 ]
disentangle the attention strength a from the attention uncertainty
\(\rightarrow\) so that the uncertainty could vary even with the same attention strength
-
[ 한계점 2 ] Vanilla stochastic attention models the noise independently of the input
- (구 방식) input과 무관하게 noise를 모델링함
-
[ 해결책 2 ] 위의 두 한계점들을 극복하기 위해…
“input과 관련있게 noise를 모델링함 ( \(\sigma(x)\) )”
-
(구) \(p(\boldsymbol{\omega})=\mathcal{N}\left(\mathbf{0}, \tau^{-1} \mathbf{I}\right), \quad p_{\theta}(\mathbf{z} \mid \mathbf{x}, \boldsymbol{\omega})=\mathcal{N}\left(\boldsymbol{\mu}(\mathbf{x}, \boldsymbol{\omega} ; \theta), \operatorname{diag}\left(\boldsymbol{\sigma}^{2}\right)\right)\)
-
(신) \(p(\boldsymbol{\omega})=\mathcal{N}\left(\mathbf{0}, \tau^{-1} \mathbf{I}\right), \quad p_{\theta}(\mathbf{z} \mid \mathbf{x}, \boldsymbol{\omega})=\mathcal{N}\left(\boldsymbol{\mu}(\mathbf{x}, \boldsymbol{\omega} ; \theta), \operatorname{diag}\left(\boldsymbol{\sigma}^{2}(\mathbf{x}, \boldsymbol{\omega} ; \theta)\right)\right)\)
( 위 두 식에서, \(\mathbf{z}\)는 attention score before squashing… 즉 \(\mathrm{a}=\pi(\mathrm{z})\) )
-
empirically shown that the quality of uncertainty improves
-
2-2. Variational Inference
( 위 2-1.에서 세운 방법을, VI를 통해 푼다 )
\(\mathbf{Z}\) : set of latent variables \(\left\{\mathbf{z}^{(n)}\right\}_{n=1}^{N}\) that stands for attention weight before squashing.
Posterior \(p(\mathbf{Z}, \boldsymbol{\omega} \mid \mathcal{D})\) is usually computationally intractable !
\(\rightarrow\) use VI ( Variational Inference )
Variational Distribution :
-
\(q(\mathbf{Z}, \boldsymbol{\omega} \mid \mathcal{D})=q_{\mathbf{M}}(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y}) q(\mathbf{Z} \mid \mathbf{X}, \mathbf{Y}, \boldsymbol{\omega})\).
- 1번째 term) MC Dropout 사용 ( variational parameter \(\mathbf{M}\) )
- 2번째 term) 그냥 set \(q(\mathbf{Z} \mid \mathbf{X}, \mathbf{Y}, \boldsymbol{\omega})=p_{\theta}(\mathbf{Z} \mid \mathbf{X}, \boldsymbol{\omega})\)
이를 정리하면, 아래의 ELBO를 maximize하는 것과 동일하다.
- \(\begin{aligned} \log p(\mathbf{Y} \mid \mathbf{X}) \geq & \mathbb{E}_{\boldsymbol{\omega} \sim q_{\mathbf{M}}(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y}), \mathbf{Z} \sim p_{\theta}(\mathbf{Z} \mid \mathbf{X}, \boldsymbol{\omega})}[\log p(\mathbf{Y} \mid \mathbf{X}, \mathbf{Z}, \boldsymbol{\omega})] \\ &-\mathrm{KL}\left[q_{\mathbf{M}}(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y}) \| p(\boldsymbol{\omega})\right]-\operatorname{KL}\left[q(\mathbf{Z} \mid \mathbf{X}, \mathbf{Y}, \boldsymbol{\omega}) \| p_{\theta}(\mathbf{Z} \mid \mathbf{X}, \boldsymbol{\omega})\right] \end{aligned}\).
최종 maximize할 대상 : \(\mathcal{L}(\theta, \mathbf{M} ; \mathbf{X}, \mathbf{Y})=\sum \log p_{\theta}\left(\mathbf{y}^{(n)} \mid \tilde{\mathbf{z}}^{(n)}, \mathbf{x}^{(n)}\right)-\lambda\|\mathbf{M}\|^{2}\).
- step 1) sample random weights with dropout masks \(\widetilde{\omega} \sim q_{\mathrm{M}}(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y})\)
- step 2) sample \(\mathbf{z}\) such that \(\tilde{\mathbf{z}}=q(\mathbf{x}, \tilde{\varepsilon}, \widetilde{\omega}), \tilde{\varepsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\)
Testing new input \(\mathbf{x}^{*}\) : MC-sampling 사용해서..
- \(p\left(\mathbf{y}^{*} \mid \mathbf{x}^{*}\right)=\iint p\left(\mathbf{y}^{*} \mid \mathbf{x}^{*}, \mathbf{z}\right) p\left(\mathbf{z} \mid \mathbf{x}^{*}, \boldsymbol{\omega}\right) p(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y}) \mathrm{d} \boldsymbol{\omega} \mathrm{d} \mathbf{z} \approx \frac{1}{S} \sum_{s=1}^{S} p\left(\mathbf{y}^{*} \mid \mathbf{x}^{*}, \tilde{\mathbf{z}}^{(s)}\right)\)/
Uncertainty Calibration
ECE (Expected Calibration Error)가 보다 나음을 확인!
( = expected gap w.r.t the distribution of model confidence )
\(\mathrm{ECE}=\mathbb{E}_{\text {confidence }}[\mid p(\) correct \(\mid\) confidence \()-\) confidence \(\mid]\)