A Simple Baseline for Bayesian Uncertainty in Deep Learning ( NeurIPS 2019 )


Abstract

SWAG ( =SWA-Gaussian ) : simple, scalable, general purpose approach for…

  • (1) uncertainty representation
  • (2) calibration in deep learning


1. Introduction

Representing Uncertainty is crucial.

This paper…

  • use the information contained in the SGD trajectory to efficiently approximate the posterior distn over the weights of NN
  • find that “Gaussian distn” fitted to the first 2 moments of SGD captures the local geometry of posterior!


2. Related Work

2-1. Bayesian Methods

MCMC

  • HMC (Hamiltonian Monte Carlo)
  • SGHMC (stochastic gradient HMC)
    • allows for stochastic gradients to be used in Bayesian Inference
    • (crucial for both scalability & exploring a space of solutions to provide good generalization)
  • SGLD (stochastic gradient Langevin dynamics)


Variational Inference

  • Reparameterization Trick


Dropout Variational Inference

  • spike and slab variational distribution
  • optimize dropout probabilities as well


Laplace Approximation

  • assume a Gaussian posterior, \(\mathcal{N}\left(\theta^{*}, \mathcal{I}\left(\theta^{*}\right)^{-1}\right)\).
    • \(\mathcal{I}\left(\theta^{*}\right)^{-1}\) : inverse of the Fisher information matrix


2-2. SGD based approximation

  • averaged SGD as an MCMC sampler


2-3. Methods for Calibration of DNNs

(SSDE) Ensembles of several networks


3. SWA-Gaussian for Bayesian Deep Learning

propose SWAG for Bayesian model averaging & uncertainty estimation


3-1. Stochastic Gradient Descnet (SGD)

standard training of DNNs

\(\Delta \theta_{t}=-\eta_{t}\left(\frac{1}{B} \sum_{i=1}^{B} \nabla_{\theta} \log p\left(y_{i} \mid f_{\theta}\left(x_{i}\right)\right)-\frac{\nabla_{\theta} \log p(\theta)}{N}\right)\).

  • loss function : NLL & regularizer
    • NLL : \(-\sum_{i} \log p\left(y_{i} \mid f_{\theta}\left(x_{i}\right)\right)\)
    • regularizer : \(\log p(\theta)\)


3-2. Stochastic Weight Averaging (SWA)

main idea of SWA

  • run SGD with a constant learning rate schedule,

    starting from a pre-trained solution & average the weights

  • \(\theta_{\text {SWA }}=\frac{1}{T} \sum_{i=1}^{T} \theta_{i}\).


3-3. SWAG-Diagonal

simple diagonal format for the covariance matrix.

maintain a running average of the 2nd uncentered moment for each weight,

then compute the covariance using the following standard identity at the end of training:

  • \(\overline{\theta^{2}}=\frac{1}{T} \sum_{i=1}^{T} \theta_{i}^{2}\) .
  • \(\Sigma_{\text {diag }}=\operatorname{diag}\left(\overline{\theta^{2}}-\theta_{\text {SWA }}^{2}\right)\).
  • Approximate posterior distribution : \(\mathcal{N}\left(\theta_{\mathrm{SWA}}, \Sigma_{\text {Diag }}\right)\)


3-4. SWAG : Low Rank plus Diagonal Covariance Structure

full SWAG algorithm

  • diagonal covariance approximation : TOO restrictive
  • more flexible low-rank plus diagonal posterior approximation


sample covariance matrix of SGD :

  • sum of outer products

  • \(\Sigma=\frac{1}{T-1} \sum_{i=1}^{T}\left(\theta_{i}-\theta_{\mathrm{SWA}}\right)\left(\theta_{i}-\theta_{\mathrm{SWA}}\right)^{\top}\) ( rank = \(T\) )

    • but don’t know \(\theta_{SWA}\) during training
  • \(\Sigma \approx \frac{1}{T-1} \sum_{i=1}^{T}\left(\theta_{i}-\bar{\theta}_{i}\right)\left(\theta_{i}-\right.\left.\bar{\theta}_{i}\right)^{\top}=\frac{1}{T-1} D D^{\top}\).

    • where \(D\) is the deviation matrix comprised of columns \(D_{i}=\left(\theta_{i}-\bar{\theta}_{i}\right)\)

      and \(\bar{\theta}_{i}\) is the running estimate of the parameters’ mean obtained from the first \(i\) samples


Combine (1) & (2)

  • (1) low-rank approximation : \(\Sigma_{\text {low-rank }}=\frac{1}{K-1} \cdot \widehat{D} \widehat{D}^{\top}\)
  • (2) diagonal approximation : \(\Sigma_{\text {diag }}=\operatorname{diag}\left(\overline{\theta^{2}}-\theta_{\text {SWA }}^{2}\right)\)

\(\rightarrow\) \(\mathcal{N}\left(\theta_{\text {SWA }}, \frac{1}{2} \cdot\left(\Sigma_{\text {diag }}+\Sigma_{\text {low-rank }}\right)\right)\).


To sample from SWAG, we use….

  • \(\tilde{\theta}=\theta_{\mathrm{SWA}}+\frac{1}{\sqrt{2}} \cdot \Sigma_{\mathrm{diag}}^{\frac{1}{2}} z_{1}+\frac{1}{\sqrt{2(K-1)}} \widehat{D} z_{2}\).

    where \(z_{1} \sim \mathcal{N}\left(0, I_{d}\right), z_{2} \sim \mathcal{N}\left(0, I_{K}\right)\).


Full Algorithm

figure2


3-5. Bayesian Model Averaging with SWAG

MAP

  • posterior \(\log p(\theta \mid \mathcal{D})=\log p(\mathcal{D} \mid \theta)+\log p(\theta)\).
  • prior \(p(\theta)\) : regularization in optimization


Bayesian procedure “marginalizes” the posterior over \(\theta\)

  • \(p\left(y_{*} \mid \mathcal{D}, x_{*}\right)=\int p\left(y_{*} \mid \theta, x_{*}\right) p(\theta \mid \mathcal{D}) d \theta\).

  • using MC sampling…

    \(p\left(y_{*} \mid \mathcal{D}, x_{*}\right) \approx \frac{1}{T} \sum_{t=1}^{T} p\left(y_{*} \mid \theta_{t}, x_{*}\right), \quad \theta_{t} \sim p(\theta \mid \mathcal{D})\).


Prior Choice

  • (typically) weight decay is used to regularize DNN

  • when SGD is used with momentum \(\rightarrow\) implicit regularization

Categories:

Updated: