Learnable Bernoulli Dropout for Bayesian Deep Learning (2020)
Abstract
learnable Bernoulli dropout (LBD)를 제안함
- model agnostic하다
- model의 다른 parameter들과 함께 jointly optimized
효과
- robust prediction
- uncertainty quantification
1. LBD (Learnable Bernoulli Dropout)
Notation
- training data : \(\mathcal{D}=\left\{\left(\boldsymbol{x}_{i}, y_{i}\right)\right\}_{i=1}^{N}\)
- NN : \(f(\boldsymbol{x} ; \boldsymbol{\theta})\)
- \(L\) layers
- objective function : \(\mathcal{L}(\boldsymbol{\theta} \mid \mathcal{D}) \approx \frac{N}{M} \sum_{i=1}^{M} \mathcal{E}\left(f\left(\boldsymbol{x}_{i} ; \boldsymbol{\theta}\right), y_{i}\right)+\mathcal{R}(\boldsymbol{\theta})\)
- \(M\) : mini-batch size
- \(j^{\text {th }}\) fully connected layer with \(K_{j}\) neurons
- \(W_{j} \in \mathbb{R}^{K_{j-1} \times K_{j}}\) : weight matrix connecting layer \(j-1\) to \(j\)
Dropout
- takes the output to each layer
- multiplies it with a random variable \(\boldsymbol{z}_{j} \sim p\left(\boldsymbol{z}_{j}\right)\) … element wise
- ex) \(p\left(\boldsymbol{z}_{j}\right)\) : \(\operatorname{Ber}\left(\sigma\left(\alpha_{j}\right)\right)\)…. dropout rate = \(1-\sigma\left(\alpha_{j}\right)\)
- \(\boldsymbol{\alpha}=\left\{\alpha_{j}\right\}_{j=1}^{L}\) : collection of all logits of the dropout parameters
- \(\boldsymbol{z}=\left\{\boldsymbol{z}_{j}\right\}_{j=1}^{L}\) : collection of all dropout
Goal :
-
minimize …
\(\min _{\boldsymbol{\theta}=\{\boldsymbol{\theta} \backslash \boldsymbol{\alpha}, \boldsymbol{\alpha}\}} \quad \mathbb{E}_{\boldsymbol{z} \sim \prod_{i=1}^{M} \operatorname{Ber}\left(\boldsymbol{z}_{i} ; \sigma(\boldsymbol{\alpha})\right)}[\mathcal{L}(\boldsymbol{\theta}, \boldsymbol{z} \mid \mathcal{D})]\).
Learn dropout rates for..
- 1) supervised DNNs
- 2) unsupervised VAEs
2. Variational Bayesian Inference with LBD
BNN (Bayesian Neural Net)
-
collection of weight matrix : \(\boldsymbol{W}=\left\{W_{j}\right\}_{j=1}^{L}\)
- prior : \(p(W)\)
- posterior : \(p(W \mid \mathcal{D})=\frac{p(\mathcal{D} \mid W) p(W)}{p(\mathcal{D})}\)
- intractability of calculating \(p(\mathcal{D})\)
- simple variational distn \(q_{\boldsymbol{\theta}}(W)\) to approximate posterior
propose LBD (Learnable Bernoulli Dropout) as variational approximation
- let each neuron \(k\) in each layer “HAVE ITS OWN DROPOUT RATE, \(\alpha_{j k}\)”
- each layer has..
- mean weight matrix \(M_{j}\)
- dropout parameters \(\boldsymbol{\alpha}_{j}=\left\{\alpha_{j k}\right\}_{k=1}^{K_{j-1}}\)
- variational distn consists of…
- \[\boldsymbol{\theta}=\left\{M_{j}, \boldsymbol{\alpha}_{j}\right\}_{j=1}^{L}.\]
-
\[q_{\boldsymbol{\theta}}(\boldsymbol{W})=\prod_{j=1}^{L} q_{\boldsymbol{\theta}}\left(W_{j}\right)\]
- where \(q_{\boldsymbol{\theta}}\left(W_{j}\right)=M_{j}^{T} \operatorname{diag}\left(\operatorname{Ber}\left(\boldsymbol{\alpha}_{j}\right)\right)\)
Objective Function
-
\(\mathcal{L}\left(\boldsymbol{\theta}=\left\{M_{j}, \boldsymbol{\alpha}_{j}\right\}_{j=1}^{L} \mid \mathcal{D}\right) =-\frac{N}{M} \sum_{i=1}^{M} \log p\left(y_{i} \mid f\left(\boldsymbol{x}_{i} ; \boldsymbol{W}_{i}\right)\right) +\operatorname{KL}\left(q_{\theta}(\boldsymbol{W}) \mid \mid p(\boldsymbol{W})\right)\).
- \(p\left(y_{i} \mid f\left(\boldsymbol{x}_{i} ; \boldsymbol{W}_{i}\right)\right)\) : softmax (classification) / Gaussian (regression)
- KL-term : regularization term
-
if we use quantized zero-mean Gaussian prior with variance \(s^{2}\)…
\(\rightarrow\) \(\operatorname{KL}\left(q_{\boldsymbol{\theta}}(\boldsymbol{W}) \mid \mid p(\boldsymbol{W})\right) \propto\) \(\sum_{j=1}^{L} \sum_{k=1}^{K_{j-1}} \frac{\alpha_{j k}}{2 s^{2}} \mid \mid M_{j}[\cdot, k] \mid \mid ^{2}-\mathcal{H}\left(\alpha_{j k}\right)\)
- where \(M_{j}[\cdot, k]\) represents the \(k^{\text {th }}\) column of the mean weight matrix \(M_{j}\)
- where \(\mathcal{H}\left(\alpha_{j k}\right)\) : entropy of a Bernoulli random variable with parameter \(\alpha_{j k}\)
posterior predictive \(p(y \mid \boldsymbol{x}, \mathcal{D})\) of new data \(x\) :
- approximated by MC integration with \(S\) samples : \(\frac{1}{S} \sum_{s=1}^{S} p\left(y \mid f\left(\boldsymbol{x} ; \boldsymbol{W}^{(s)}\right) .\right.\)
- entropy of \(p(y \mid \boldsymbol{x}, \mathcal{D})\): measure of uncertainty