Beyond temperature scaling : Obtaining well-calibrated multiclass probabilities with Dirichlet calibration
Contents
- Abstract
- Evaluation of Calibration & Temperature Scaling
- Multiclass-calibrated ( = calibrated )
- Classwise-calibrated
- Confidence-calibrated
- Dirichlet Calibration
- Dirichlet calibration map family
- Interpretability
- Relationship to other families
- Fitting and ODIR regularization
0. Abstract
대부분의 multi-class classifier들은 UNcalibrated \(\rightarrow\) OVERCONFIDENCE!
NN에서 자주 사용되는 방법으로, “temperature scaling”은 calibration을 improve시킨다!
이 논문에서는, MULTI-class calibration method를 제안함
- (Dirichlet distribution에서 나온) 그 어떠한 model class에 적용 가능
-
Binary-class의 경우인 beta-calibration의 general 버전
- 아래의 지표들을 통해 성능 비교
- confidence ECE, classwise ECE, log loss, Brier score
1. Evaluation of Calibration & Temperature Scaling
Notation 소개
- probabilistic classifier : \(\hat{\mathbf{p}}: \mathscr{X} \rightarrow \Delta_{k}\)
- output : class probabilities of \(k\) classes
- \(\Delta_{k}=\left\{\left(q_{1}, \ldots, q_{k}\right) \in[0,1]^{k} \mid \sum_{i=1}^{k} q_{i}=1\right\}\).
- output : \(\hat{\mathbf{p}}(\mathbf{x})=\left(\hat{p}_{1}(\mathbf{x}), \ldots, \hat{p}_{k}(\mathbf{x})\right)\)
1-1) Multiclass-calibrated ( = calibrated )
\(P(Y=i \mid \hat{\mathbf{p}}(X)=\mathbf{q})=q_{i} \quad \text { for } i=1, \ldots, k\).
-
\(\mathbf{q}=\left(q_{1}, \ldots, q_{k}\right) \in \Delta_{k}\).
-
(key point) class 고려 X
1-2) Classwise-calibrated
\(P\left(Y=i \mid \hat{p}_{i}(X)=q_{i}\right)=q_{i}\).
- (key point) class 고려 O
1-3) Confidence-calibrated
\(P(Y=\operatorname{argmax}(\hat{\mathbf{p}}(X)) \mid \max (\hat{\mathbf{p}}(X))=c)=c\).
- for any \(c \in[0,1]\)
For practical evaluation, 위의 1-1) ~ 1-3)의 개념이 완화될 필요가 있다!
\(\rightarrow\) equal-width binning ( + viz using reliability diagram )
Temperature Scaling
\(S_{i}(x ; T)=\frac{\exp \left(f_{i}(x) / T\right)}{\sum_{j=1}^{N} \exp \left(f_{j}(x) / T\right)}\).
-
\(T\) : temperature scaling parameter
( training 과정 중에는 \(T=1\) )
-
softmax score : \(S_{\hat{y}}(x ; T)=\max _{i} S_{i}(x ; T)\)
\(\rightarrow\) calibration에 도움을 준다!
Reliability Diagram
위의 그림을 통해 알 수 있는점?
-
temperature-scaled 모델이
- confidence-calibrated (O)
- classwise-calibrated (X)
하나의 tuneable parameter (= \(T\) )만으로는, different class에 따른 고려를 못한다!
-
[Figure 1-d]와 같이, classwise-calibration 수행하면 좋은 결과나옴!
( classwise-ECE를 measure로써 제안함 )
\(\text { classwise-ECE }=\frac{1}{k} \sum_{j=1}^{k} \sum_{i=1}^{m} \frac{ \mid B_{i, j} \mid }{n} \mid y_{j}\left(B_{i, j}\right)-\hat{p}_{j}\left(B_{i, j}\right) \mid\).
- \(k,m,n\) : classes / bins / instances의 수
- \(\mid B_{i, j} \mid\) : size of bin
- \(\hat{p}_{j}\left(B_{i, j}\right)\) : average prediction of class \(j\) probability in bin \(B_{i,j}\)
- \(y_{j}\left(B_{i, j}\right)\) : actual proportion of class \(j\) in bin \(B_{i,j}\)
- class-j-ECE : contribution of single class \(j\) to classwise-ECE
(기존 방법(class 고려 X)의 한계점들)
-
input 무관하게, class별로 고르게 predict하면 좋은 calibrated probabilities 나오게 됨!
-
따라서, 새로운 evaluation metric이 필요하다!
( error rate말고도, proper losses 들을 사용해서 )
-
proper losses?
- probabilistic predictions를 evaluate한다
- calibration loss & refinement loss로 decompose 된다
- (주로) objective functions in post-hoc calibration으로 사용됨
- ex) Brier score, log-loss
-
proper loss는 같은 calibration map으로 minimized 된다
( = canonical calibration function of \(\hat{\mathbf{p}})\)
\(\mu(\mathbf{q})=(P(Y=1 \mid \hat{\mathbf{p}}(X)=\mathbf{q}), \ldots, P(Y=k \mid \hat{\mathbf{p}}(X)=\mathbf{q}))\).
Dirichlet calibration의 목적도 마찬가지로, 위의 canonical calibration map \(\mu\)를 estimate하는 것이다!
2. Dirichlet Calibration
( Beta calibration : 2 class )
( Dirichlet calibration : multi-class )
2-1) Dirichlet calibration map family
distribution of prediction vectors \(\hat{\mathbf{p}}(\mathbf{x})\) ( = \(k\) class distribution )이 Dirichlet distribution이라고 가정한다.
\(\hat{\mathbf{p}}(X) \mid Y=j \sim \operatorname{Dir}\left(\alpha^{(j)}\right)\).
- \(\alpha^{(j)}=\left(\alpha_{1}^{(j)}, \ldots, \alpha_{k}^{(j)}\right) \in(0, \infty)^{k}\).
Canonical calibration function : \(P(Y \mid \hat{\mathbf{p}}(X))\)
- 위 function을 3가지로 표현할 수 있다.
(1) Generative parameterization
-
\(\hat{\mu}_{\text {DirGen }}(\mathbf{q} ; \alpha, \pi)=\left(\pi_{1} f_{1}(\mathbf{q}), \ldots, \pi_{k} f_{k}(\mathbf{q})\right) / z\).
( where \(z=\sum_{j=1}^{k} \pi_{j} f_{j}(\mathbf{q})\) is the normalizer )
(2) Linear parameterization
-
\(\hat{\mu}_{\text {DirLin }}(\mathbf{q} ; \mathbf{W}, \mathbf{b})=\sigma(\mathbf{W} \ln \mathbf{q}+\mathbf{b})\).
( \(\mathbf{W} \in \mathbb{R}^{k \times k}\) 는 \(k \times k\) parameter 행렬 )
( \(\mathbf{b} \in \mathbb{R}^{k}\) 는 length \(k\)의 parameter 벡터 )
(3) Canonical parameterization
-
\(\hat{\mu}_{D i r}(\mathbf{q} ; \mathbf{A}, \mathbf{c})=\sigma\left(\mathbf{A} \ln \frac{\mathbf{q}}{1 / k}+\ln \mathbf{c}\right)\).
( \(\mathbf{A} \in[0, \infty)^{k \times k}\) 는 \(k\times k\) 행렬 )
( \(\mathbf{c} \in \mathbb{R}^{k}\) 는 length \(k\)의 probability 벡터 )
위의 (1)~(3) 요약
-
benefit of (2) Linear parameterization?
\(\rightarrow\) NN에 additional layer 추가함으로써 구현 가능하다
-
(3) Canonical parameterization
\(\rightarrow\) ANY function in Dirichlet calibration map family can be represented by a single pair of matrix \(\mathbf{A}\) & vector \(\mathbf{c}\)
2-2) Interpretability
Canonical parameterization은 interpretable하다!
-
우선, linear parameterization \(\rightarrow\) Canonical parameterization으로의 변환은 쉽다
- \(a_{i j}=w_{i j}-\min _{i} w_{i j}\).
- \(\mathbf{c}=\sigma(\mathbf{W} \ln \mathbf{u}+\mathbf{b})\).
- \(\mathbf{u}=(1 / k, \ldots, 1 / k) .\).
-
해석 : \(A\) 행렬의 요소 \(a_{i j}\) 의 증가
= Increase the calibrated probability of class \(i\) **, with effect size depending on the **uncalibrated probability of class \(j .\)
Example
[Figure 2-b & 2-c]
- element \(a_{3,9}=0.63\) : increases class 2 probability whenever class 8 has high predicted probability, modifying decision boundaries
[Figure 2-a]
- \(3+1\) interpretation points in an example for \(k=3\), where each arrow visualises the result of calibration (end of arrow) at a particular point (beginning of arrow)
- Dirichlet calibration map \(\hat{\mu}_{\text {Dir }}(\mathbf{q} ; \mathbf{A}, \mathbf{c})\) 를 통해서..
- (before) \(\left(\varepsilon, \frac{1-\varepsilon}{k-1}, \ldots, \frac{1-\varepsilon}{k-1}\right), \ldots,\left(\frac{1-\varepsilon}{k-1}, \ldots, \frac{1-\varepsilon}{k-1}, \varepsilon\right), \text { and }\left(\frac{1}{k}, \ldots, \frac{1}{k}\right)\)
- (after) \(\left(\varepsilon^{a_{11}}, \ldots, \varepsilon^{a_{k 1}}\right) / z_{1}, \ldots,\left(\varepsilon^{a_{1 k}}, \ldots, \varepsilon^{a_{k k}}\right) / z_{k}, \text { and }\left(c_{1}, \ldots, c_{k}\right)\)
2-3) Relationship to other families
Temperature scaling maps
-
이들 모두 Dirichlet family에 속한다
-
\(\hat{\mu}_{\text {Temps }}(\mathbf{q} ; t)=\hat{\mu}_{\text {DirLin }}\left(\mathbf{q} ; \frac{1}{t} \mathbf{I}, \mathbf{0}\right)\).
Matrix Scaling Family
- \(\hat{\mu}_{M a t}(\mathbf{z} ; \mathbf{W}, \mathbf{b})=\sigma(\mathbf{W} \mathbf{z}+\mathbf{b})\).
2-4) Fitting and ODIR regularization
Any calibration model with tens of thousands parameters will OVERFIT to small validation set
따라서, novel ODIR ( Off-Diagonal and Intercept Regularization ) 을 제안한다!
-
overfitting 해소에 도움
-
temperature scaling보다 outperform
-
loss function :
\(L=\frac{1}{n} \sum_{i=1}^{n} \operatorname{logloss}\left(\hat{\mu}_{\text {DirLin }}\left(\hat{\mathbf{p}}\left(\mathbf{x}_{i}\right) ; \mathbf{W}, \mathbf{b}\right), y_{i}\right)+\lambda \cdot\left(\frac{1}{k(k-1)} \sum_{i \neq j} w_{i j}^{2}\right)+\mu \cdot\left(\frac{1}{k} \sum_{j} b_{j}^{2}\right)\).
- \(w_{i j}, b_{j}\) : \(\mathbf{W}\) and \(\mathbf{b}\)의 element
- \(\lambda, \mu\) : hyper-parameters tunable ( validation 데이터에 대해 C.V 할때 )
-
위 loss function 해석 :
- (Diagonal 부분) allowed to freely follow the biases of classes
- (Intercept 부분) regularized separately from the off-diagonal elements, due to having different scales