Averaging Weights Leads to Wider Optima and Better Generalization (2018)


Abstract

DNN : trained by optimizing loss function with an SGD variant

Show that simple averaging of multiple points along the trajectory of SGD, with cyclical or constant learning rate, leads to better generalization.

Propose SWA (Stochastic Weight Averaging)


1. Introduction

better understanding of the loss surface \(\rightarrow\) accelerate convergence

local optima found by SGD can be connected by simple curves!

  • Fast Geometric Ensembling (FGE) : sample multiple nearby points in “the weight space”

  • propose Stochastic Weight Averaging (SWA)


Contribution

  • SGD with cyclical & constant learning rates traverses regions of weight space corresponding to high-performance networks
  • FGE ensembles :
    • can be trained in the same time as a single model
    • test predictions for an ensemble of \(k\) models requires \(k\) times more computation
  • SWA leads to solutions that are wider than the optima found by SGD


2. Related Work

Better understanding of geometry of loss surfaces & generaliztion in DL

SWA is related to work in both (1) optimization & (2) regularization

  • SGD is more likely to converge to broad local optima than batch gradient
    • more likely tho have good test performance
  • SWA is based on averaging multiple points
  • to enforce exploration, run SGD with “constant or cyclical learning rates”


FGE (Fast Geometric Ensembling)

  • using a cyclical learning rate, it is possible to gather models that are spatially close to each other, but produce diverse predictions


Dropout

  • approach to regularize DNNs
  • different architecture is created
  • dropout = ensemble = Bayesian model averaging


3. Stochastic Weight Averaging

Section 3-1 : trajectories of SGD with a constant and cyclical learning rate

Section 3-2 : SWA algorithm

Section 3-3 : complexity

Section 3-4 : widths of solutions ( SWA vs SGD )


SWA has 2 meanings

  • 1) it is an average of SGD weights
  • 2) approximately sampling from the loss surface of the DNN, leading to stochastic weights


3-1. Analysis of SGD Trajectories

Cyclical learning rate schedule

  • linearly decrease the learning rate from \(\alpha_1\) to \(\alpha_2\)

    \(\begin{array}{l} \alpha(i)=(1-t(i)) \alpha_{1}+t(i) \alpha_{2} \\ t(i)=\frac{1}{c}(\bmod (i-1, c)+1) \end{array}\).

For even greater exploration, we consider “constant learning rate” \(\alpha(i)=\alpha_1\)

Both methods are doing exploration in the region of space corresponding to DNNs with high accuracy.


Main Difference :

individual proposals of SGD with cyclical learning rate schedule > ~ fixed rate schedule

  • (cyclical) spends several epochs fine tuning, after large steps
  • (fixed) always making steps of relatively large sizes ( explore more efficiently )


3-2. SWA Algorithm

(1) start with a pretrained model \(\hat{w}\)

  • can be trained with conventional training procedure

(2) stop the training early

  • without modifying the lr schedule

(3) starting from \(\hat{w}\), continue training

  • using cyclical or constant lr schedule


SWA is related to FGE, except …

  • FGE : averaging the predictions of the models
  • SWA : average their weights


figure2


3-3. Computational Complexity

During Training

  • need to maintain a copy of the running average of DNN weights

After Training

  • only need to store the model that aggergates the average

    ( = same memory requirements as standard training )


Extra time : only spent to update the aggregated weight average :

  • \(w_{\mathrm{SWA}} \leftarrow \frac{w_{\mathrm{SWA}} \cdot n_{\mathrm{models}}+w}{n_{\mathrm{models}}+1}\).


3-4. Solution Width

Width of a local optimum is related to generalization

\(\begin{array}{l} w_{\mathrm{SWA}}(t, d)=w_{\mathrm{SWA}}+t \cdot d \\ w_{\mathrm{SGD}}(t, d)=w_{\mathrm{SGD}}+t \cdot d \end{array}\).


Line Segment connecting those two

  • \(w(t)=t \cdot w_{\mathrm{SGD}}+(1-t) \cdot w_{\mathrm{SWA}}\).


3-5. Connection to Ensembling

FGE (Fast Geometric Ensembling)

  • training ensembles in the time required to train a single model

  • using cyclical l.r, generates a sequence of points

    that are close in weight space, but produce different predictions


SWA

  • instead of averaging the predictions, average their weights


Similarities between FGE & SWA

FGE

  • \(\bar{f}=\frac{1}{n} \sum_{i=1}^{n} f\left(w_{i}\right)\).


SWA

  • average \(w_{\text {SWA }}=\frac{1}{n} \sum_{i=1}^{n} w_{i}\).
  • \(\Delta_{i}=w_{i}-w_{\mathrm{SWA}}\).
  • linearization of \(f\) at \(w_{SWA}\) : \(f\left(w_{j}\right)=f\left(w_{\mathrm{SWA}}\right)+\left\langle\nabla f\left(w_{\mathrm{SWA}}\right), \Delta_{j}\right\rangle+O\left( \mid \mid \Delta_{j} \mid \mid ^{2}\right)\)


Difference between averaging the weights & averaging the predictions

\(\begin{array}{c} \bar{f}-f\left(w_{\mathrm{SWA}}\right)=\frac{1}{n} \sum_{i=1}^{n}\left(\left\langle\nabla f\left(w_{\mathrm{SWA}}\right), \Delta_{i}\right\rangle+O\left( \mid \mid \Delta_{i} \mid \mid ^{2}\right)\right) \\ =\left\langle\nabla f\left(w_{\mathrm{SWA}}\right), \frac{1}{n} \sum_{i=1}^{n} \Delta_{i}\right\rangle+O\left(\Delta^{2}\right)=O\left(\Delta^{2}\right), \end{array}\).


Difference between the predictions of different perturbed networks :

  • \(f\left(w_{i}\right)-f\left(w_{j}\right)=\left\langle\nabla f\left(w_{\mathrm{SWA}}\right), \Delta_{i}-\Delta_{j}\right\rangle+O\left(\Delta^{2}\right)\).


conclusion : SWA can approximate the FGE ensemble with a single model!

Categories:

Updated: