Learnable Bernoulli Dropout for Bayesian Deep Learning (2020)


Abstract

learnable Bernoulli dropout (LBD)를 제안함

  • model agnostic하다
  • model의 다른 parameter들과 함께 jointly optimized


효과

  • robust prediction
  • uncertainty quantification


1. LBD (Learnable Bernoulli Dropout)

Notation

  • training data : \(\mathcal{D}=\left\{\left(\boldsymbol{x}_{i}, y_{i}\right)\right\}_{i=1}^{N}\)
  • NN : \(f(\boldsymbol{x} ; \boldsymbol{\theta})\)
    • \(L\) layers
  • objective function : \(\mathcal{L}(\boldsymbol{\theta} \mid \mathcal{D}) \approx \frac{N}{M} \sum_{i=1}^{M} \mathcal{E}\left(f\left(\boldsymbol{x}_{i} ; \boldsymbol{\theta}\right), y_{i}\right)+\mathcal{R}(\boldsymbol{\theta})\)
    • \(M\) : mini-batch size
  • \(j^{\text {th }}\) fully connected layer with \(K_{j}\) neurons
  • \(W_{j} \in \mathbb{R}^{K_{j-1} \times K_{j}}\) : weight matrix connecting layer \(j-1\) to \(j\)


Dropout

  • takes the output to each layer
  • multiplies it with a random variable \(\boldsymbol{z}_{j} \sim p\left(\boldsymbol{z}_{j}\right)\) … element wise
    • ex) \(p\left(\boldsymbol{z}_{j}\right)\) : \(\operatorname{Ber}\left(\sigma\left(\alpha_{j}\right)\right)\)…. dropout rate = \(1-\sigma\left(\alpha_{j}\right)\)
    • \(\boldsymbol{\alpha}=\left\{\alpha_{j}\right\}_{j=1}^{L}\) : collection of all logits of the dropout parameters
    • \(\boldsymbol{z}=\left\{\boldsymbol{z}_{j}\right\}_{j=1}^{L}\) : collection of all dropout


Goal :

  • minimize …

    \(\min _{\boldsymbol{\theta}=\{\boldsymbol{\theta} \backslash \boldsymbol{\alpha}, \boldsymbol{\alpha}\}} \quad \mathbb{E}_{\boldsymbol{z} \sim \prod_{i=1}^{M} \operatorname{Ber}\left(\boldsymbol{z}_{i} ; \sigma(\boldsymbol{\alpha})\right)}[\mathcal{L}(\boldsymbol{\theta}, \boldsymbol{z} \mid \mathcal{D})]\).


Learn dropout rates for..

  • 1) supervised DNNs
  • 2) unsupervised VAEs


2. Variational Bayesian Inference with LBD

BNN (Bayesian Neural Net)

  • collection of weight matrix : \(\boldsymbol{W}=\left\{W_{j}\right\}_{j=1}^{L}\)

  • prior : \(p(W)\)
  • posterior : \(p(W \mid \mathcal{D})=\frac{p(\mathcal{D} \mid W) p(W)}{p(\mathcal{D})}\)
  • intractability of calculating \(p(\mathcal{D})\)
    • simple variational distn \(q_{\boldsymbol{\theta}}(W)\) to approximate posterior


propose LBD (Learnable Bernoulli Dropout) as variational approximation

  • let each neuron \(k\) in each layer “HAVE ITS OWN DROPOUT RATE, \(\alpha_{j k}\)”
  • each layer has..
    • mean weight matrix \(M_{j}\)
    • dropout parameters \(\boldsymbol{\alpha}_{j}=\left\{\alpha_{j k}\right\}_{k=1}^{K_{j-1}}\)
  • variational distn consists of…
    • \[\boldsymbol{\theta}=\left\{M_{j}, \boldsymbol{\alpha}_{j}\right\}_{j=1}^{L}.\]
  • \[q_{\boldsymbol{\theta}}(\boldsymbol{W})=\prod_{j=1}^{L} q_{\boldsymbol{\theta}}\left(W_{j}\right)\]
    • where \(q_{\boldsymbol{\theta}}\left(W_{j}\right)=M_{j}^{T} \operatorname{diag}\left(\operatorname{Ber}\left(\boldsymbol{\alpha}_{j}\right)\right)\)


Objective Function

  • \(\mathcal{L}\left(\boldsymbol{\theta}=\left\{M_{j}, \boldsymbol{\alpha}_{j}\right\}_{j=1}^{L} \mid \mathcal{D}\right) =-\frac{N}{M} \sum_{i=1}^{M} \log p\left(y_{i} \mid f\left(\boldsymbol{x}_{i} ; \boldsymbol{W}_{i}\right)\right) +\operatorname{KL}\left(q_{\theta}(\boldsymbol{W}) \mid \mid p(\boldsymbol{W})\right)\).

    • \(p\left(y_{i} \mid f\left(\boldsymbol{x}_{i} ; \boldsymbol{W}_{i}\right)\right)\) : softmax (classification) / Gaussian (regression)
    • KL-term : regularization term
  • if we use quantized zero-mean Gaussian prior with variance \(s^{2}\)…

    \(\rightarrow\) \(\operatorname{KL}\left(q_{\boldsymbol{\theta}}(\boldsymbol{W}) \mid \mid p(\boldsymbol{W})\right) \propto\) \(\sum_{j=1}^{L} \sum_{k=1}^{K_{j-1}} \frac{\alpha_{j k}}{2 s^{2}} \mid \mid M_{j}[\cdot, k] \mid \mid ^{2}-\mathcal{H}\left(\alpha_{j k}\right)\)

    • where \(M_{j}[\cdot, k]\) represents the \(k^{\text {th }}\) column of the mean weight matrix \(M_{j}\)
    • where \(\mathcal{H}\left(\alpha_{j k}\right)\) : entropy of a Bernoulli random variable with parameter \(\alpha_{j k}\)


posterior predictive \(p(y \mid \boldsymbol{x}, \mathcal{D})\) of new data \(x\) :

  • approximated by MC integration with \(S\) samples : \(\frac{1}{S} \sum_{s=1}^{S} p\left(y \mid f\left(\boldsymbol{x} ; \boldsymbol{W}^{(s)}\right) .\right.\)
  • entropy of \(p(y \mid \boldsymbol{x}, \mathcal{D})\): measure of uncertainty

Categories:

Updated: