A Simple Baseline for Bayesian Uncertainty in Deep Learning ( NeurIPS 2019 )Permalink


AbstractPermalink

SWAG ( =SWA-Gaussian ) : simple, scalable, general purpose approach for…

  • (1) uncertainty representation
  • (2) calibration in deep learning


1. IntroductionPermalink

Representing Uncertainty is crucial.

This paper…

  • use the information contained in the SGD trajectory to efficiently approximate the posterior distn over the weights of NN
  • find that “Gaussian distn” fitted to the first 2 moments of SGD captures the local geometry of posterior!


2. Related WorkPermalink

2-1. Bayesian MethodsPermalink

MCMCPermalink

  • HMC (Hamiltonian Monte Carlo)
  • SGHMC (stochastic gradient HMC)
    • allows for stochastic gradients to be used in Bayesian Inference
    • (crucial for both scalability & exploring a space of solutions to provide good generalization)
  • SGLD (stochastic gradient Langevin dynamics)


Variational InferencePermalink

  • Reparameterization Trick


Dropout Variational InferencePermalink

  • spike and slab variational distribution
  • optimize dropout probabilities as well


Laplace ApproximationPermalink

  • assume a Gaussian posterior, N(θ,I(θ)1).
    • I(θ)1 : inverse of the Fisher information matrix


2-2. SGD based approximationPermalink

  • averaged SGD as an MCMC sampler


2-3. Methods for Calibration of DNNsPermalink

(SSDE) Ensembles of several networks


3. SWA-Gaussian for Bayesian Deep LearningPermalink

propose SWAG for Bayesian model averaging & uncertainty estimation


3-1. Stochastic Gradient Descnet (SGD)Permalink

standard training of DNNs

Δθt=ηt(1BBi=1θlogp(yifθ(xi))θlogp(θ)N).

  • loss function : NLL & regularizer
    • NLL : ilogp(yifθ(xi))
    • regularizer : logp(θ)


3-2. Stochastic Weight Averaging (SWA)Permalink

main idea of SWA

  • run SGD with a constant learning rate schedule,

    starting from a pre-trained solution & average the weights

  • θSWA =1TTi=1θi.


3-3. SWAG-DiagonalPermalink

simple diagonal format for the covariance matrix.

maintain a running average of the 2nd uncentered moment for each weight,

then compute the covariance using the following standard identity at the end of training:

  • ¯θ2=1TTi=1θ2i .
  • Σdiag =diag(¯θ2θ2SWA ).
  • Approximate posterior distribution : N(θSWA,ΣDiag )


3-4. SWAG : Low Rank plus Diagonal Covariance StructurePermalink

full SWAG algorithm

  • diagonal covariance approximation : TOO restrictive
  • more flexible low-rank plus diagonal posterior approximation


sample covariance matrix of SGD :

  • sum of outer products

  • Σ=1T1Ti=1(θiθSWA)(θiθSWA) ( rank = T )

    • but don’t know θSWA during training
  • Σ1T1Ti=1(θiˉθi)(θiˉθi)=1T1DD.

    • where D is the deviation matrix comprised of columns Di=(θiˉθi)

      and ˉθi is the running estimate of the parameters’ mean obtained from the first i samples


Combine (1) & (2)

  • (1) low-rank approximation : Σlow-rank =1K1ˆDˆD
  • (2) diagonal approximation : Σdiag =diag(¯θ2θ2SWA )

N(θSWA ,12(Σdiag +Σlow-rank )).


To sample from SWAG, we use….

  • ˜θ=θSWA+12Σ12diagz1+12(K1)ˆDz2.

    where z1N(0,Id),z2N(0,IK).


Full AlgorithmPermalink

figure2


3-5. Bayesian Model Averaging with SWAGPermalink

MAP

  • posterior logp(θD)=logp(Dθ)+logp(θ).
  • prior p(θ) : regularization in optimization


Bayesian procedure “marginalizes” the posterior over θ

  • p(yD,x)=p(yθ,x)p(θD)dθ.

  • using MC sampling…

    p(yD,x)1TTt=1p(yθt,x),θtp(θD).


Prior Choice

  • (typically) weight decay is used to regularize DNN

  • when SGD is used with momentum implicit regularization

Categories:

Updated: