A Simple Baseline for Bayesian Uncertainty in Deep Learning ( NeurIPS 2019 )Permalink
AbstractPermalink
SWAG ( =SWA-Gaussian ) : simple, scalable, general purpose approach for…
- (1) uncertainty representation
- (2) calibration in deep learning
1. IntroductionPermalink
Representing Uncertainty is crucial.
This paper…
- use the information contained in the SGD trajectory to efficiently approximate the posterior distn over the weights of NN
- find that “Gaussian distn” fitted to the first 2 moments of SGD captures the local geometry of posterior!
2. Related WorkPermalink
2-1. Bayesian MethodsPermalink
MCMCPermalink
- HMC (Hamiltonian Monte Carlo)
- SGHMC (stochastic gradient HMC)
- allows for stochastic gradients to be used in Bayesian Inference
- (crucial for both scalability & exploring a space of solutions to provide good generalization)
- SGLD (stochastic gradient Langevin dynamics)
Variational InferencePermalink
- Reparameterization Trick
Dropout Variational InferencePermalink
- spike and slab variational distribution
- optimize dropout probabilities as well
Laplace ApproximationPermalink
- assume a Gaussian posterior, N(θ∗,I(θ∗)−1).
- I(θ∗)−1 : inverse of the Fisher information matrix
2-2. SGD based approximationPermalink
- averaged SGD as an MCMC sampler
2-3. Methods for Calibration of DNNsPermalink
(SSDE) Ensembles of several networks
3. SWA-Gaussian for Bayesian Deep LearningPermalink
propose SWAG for Bayesian model averaging & uncertainty estimation
3-1. Stochastic Gradient Descnet (SGD)Permalink
standard training of DNNs
Δθt=−ηt(1B∑Bi=1∇θlogp(yi∣fθ(xi))−∇θlogp(θ)N).
- loss function : NLL & regularizer
- NLL : −∑ilogp(yi∣fθ(xi))
- regularizer : logp(θ)
3-2. Stochastic Weight Averaging (SWA)Permalink
main idea of SWA
-
run SGD with a constant learning rate schedule,
starting from a pre-trained solution & average the weights
-
θSWA =1T∑Ti=1θi.
3-3. SWAG-DiagonalPermalink
simple diagonal format for the covariance matrix.
maintain a running average of the 2nd uncentered moment for each weight,
then compute the covariance using the following standard identity at the end of training:
- ¯θ2=1T∑Ti=1θ2i .
- Σdiag =diag(¯θ2−θ2SWA ).
- Approximate posterior distribution : N(θSWA,ΣDiag )
3-4. SWAG : Low Rank plus Diagonal Covariance StructurePermalink
full SWAG algorithm
- diagonal covariance approximation : TOO restrictive
- more flexible low-rank plus diagonal posterior approximation
sample covariance matrix of SGD :
-
sum of outer products
-
Σ=1T−1∑Ti=1(θi−θSWA)(θi−θSWA)⊤ ( rank = T )
- but don’t know θSWA during training
-
Σ≈1T−1∑Ti=1(θi−ˉθi)(θi−ˉθi)⊤=1T−1DD⊤.
-
where D is the deviation matrix comprised of columns Di=(θi−ˉθi)
and ˉθi is the running estimate of the parameters’ mean obtained from the first i samples
-
Combine (1) & (2)
- (1) low-rank approximation : Σlow-rank =1K−1⋅ˆDˆD⊤
- (2) diagonal approximation : Σdiag =diag(¯θ2−θ2SWA )
→ N(θSWA ,12⋅(Σdiag +Σlow-rank )).
To sample from SWAG, we use….
-
˜θ=θSWA+1√2⋅Σ12diagz1+1√2(K−1)ˆDz2.
where z1∼N(0,Id),z2∼N(0,IK).
Full AlgorithmPermalink
3-5. Bayesian Model Averaging with SWAGPermalink
MAP
- posterior logp(θ∣D)=logp(D∣θ)+logp(θ).
- prior p(θ) : regularization in optimization
Bayesian procedure “marginalizes” the posterior over θ
-
p(y∗∣D,x∗)=∫p(y∗∣θ,x∗)p(θ∣D)dθ.
-
using MC sampling…
p(y∗∣D,x∗)≈1T∑Tt=1p(y∗∣θt,x∗),θt∼p(θ∣D).
Prior Choice
-
(typically) weight decay is used to regularize DNN
-
when SGD is used with momentum → implicit regularization