On Modern Deep Learning and Variational Inference (2015)


Abstract

use Bayesian statistics literature with DL models


1. Introduction

Stochastic Regularization ( ex. dropout, dropconnect, multiplicative Gaussian Noise… )

ex) Dropout

  • multiplied by Bernoulli r.v
  • slows down training, but circumvents over-fitting

  • dropout in deep NNs = variational approximation to deep GP


Modern DL tools make use of stochastic regularization

  • means that it performs approximate Bayesian inference,

    capturing the stochastic processes underlying the observed data


2. SRT in Deep Networks and GP

( \(L_2\) regularization ) Minimization objective :

\(\mathcal{L}_{\text {dropout }}:=\frac{1}{N} \sum_{i=1}^{N} E\left(\mathbf{y}_{i}, \widehat{\mathbf{y}}_{i}\right)+\lambda \sum_{i=1}^{L}\left(\mid \mid \mathbf{W}_{i}\mid \mid_{2}^{2}+\mid \mid \mathbf{b}_{i}\mid \mid_{2}^{2}\right)\).

  • with Dropout ) we sample bernoulli r.v for every input point
  • with multiplicative Gaussian noise ) we multiply each unit by \(N(1,1)\)


Gaussian Process

  • powerful tool in statistics that allows us to model distn over functions

  • softmax ( with \(D\) class )

    \(\begin{array}{l} \mathbf{F} \mid \mathbf{X} \sim \mathcal{N}(\mathbf{0}, \mathbf{K}(\mathbf{X}, \mathbf{X})) \\ y_{n} \mid \mathbf{f}_{n} \sim \text { Categorical }\left(\exp \left(\mathbf{f}_{n}\right) /\left(\sum_{d^{\prime}} \exp \left(f_{n d^{\prime}}\right)\right)\right) \end{array}\).

    • \(\mathbf{F}=\left[\mathbf{f}_{1}, \ldots, \mathbf{f}_{N}\right] \text { with } \mathbf{f}_{n}=\left[f_{n 1}, \ldots, f_{n D}\right] \text { and } f_{n d}=f_{d}\left(\mathbf{x}_{n}\right)\).
  • predictive probability

    \(p\left(\mathbf{y}^{*} \mid \mathbf{x}^{*}, \mathbf{X}, \mathbf{Y}\right)=\int p\left(\mathbf{y}^{*} \mid \mathbf{f}^{*}\right) p\left(\mathbf{f}^{*} \mid \mathbf{x}^{*}, \mathbf{X}, \mathbf{Y}\right) \mathrm{d} \mathbf{f}^{*}\).

  • predictive distribution for new input

    \(p\left(y^{*} \mid \mathbf{x}^{*}, \mathbf{X}, \mathbf{Y}\right)=\int p\left(y^{*} \mid \mathbf{f}^{*}\right) p\left(\mathbf{f}^{*} \mid \mathbf{x}^{*}, \boldsymbol{\omega}\right) p(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y}) \mathrm{d} \mathbf{f}^{*} \mathrm{~d} \boldsymbol{\omega}\).

    • \(p(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y})\) cannot usually be evaluated analytically
    • thus, define approximating variational distn \(q(\boldsymbol{\omega})\)


ELBO : \(\mathcal{L}_{\mathrm{VI}}:=\int q(\boldsymbol{\omega}) p(\mathbf{F} \mid \mathbf{X}, \boldsymbol{\omega}) \log p(\mathbf{Y} \mid \mathbf{F}) \mathrm{d} \mathbf{F} \mathrm{d} \boldsymbol{\omega}-\operatorname{KL}(q(\boldsymbol{\omega}) \mid p(\boldsymbol{\omega}))\).


GP can be approximated by defining \(\omega=\left\{\widehat{\mathbf{M}}_{1}, \widehat{\mathbf{M}}_{2}\right\}\)

\(\mathbf{f} \mid \mathbf{x}, \boldsymbol{\omega} \sim \sqrt{\frac{1}{K}} \widehat{\mathbf{M}}_{2} \sigma\left(\widehat{\mathbf{M}}_{1} \mathbf{x}+\mathbf{m}\right)\).

  • \(\boldsymbol{\omega} =\left\{\widehat{\mathbf{M}}_{i}\right\}_{i=1}^{L}\).

  • \(\widehat{\mathbf{M}}_{i} =\mathbf{M}_{i} \cdot \operatorname{diag}\left(\left[\mathbf{z}_{i, j}\right]_{j=1}^{K_{i}}\right)\).

  • \(\mathbf{z}_{i, j} \sim \text { Bernoulli }\left(p_{i}\right) \text { for } i=1, \ldots, L, j=1, \ldots, K_{i-1}\).

    probabilities \(p_{i}\) and matrices \(\mathbf{M}_{i}\) being variational parameters (with dimensions \(K_{i} \times\) \(\left.K_{i-1}\right)\).


Other SRT are obtained for alternative choice of \(q(\boldsymbol{\omega})\)


3. SRT in arbitrary networks as VI in BNN

Bayesian NN : prior distn over the weights of NN

  • often place standard matrix Gaussian
  • \(\mathbf{W}_{i} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\).
  • \(p\left(y \mid \mathbf{x},\left(\mathbf{W}_{i}\right)_{i=1}^{L}\right)=\text { Categorical }\left(\exp (\widehat{\mathbf{f}}) / \sum_{d^{\prime}} \exp \left(\widehat{f}_{d^{\prime}}\right)\right)\).


Interested in finding the most probable weights that have generated the data!

Minimize KL-divergence

  • \(\begin{array}{l} \mathrm{KL}\left(q\left(\left(\mathbf{W}_{i}\right)\right) \mid p\left(\left(\mathbf{W}_{i}\right) \mid \mathbf{X}, \mathbf{Y}\right)\right) \propto -\int q\left(\left(\mathbf{W}_{i}\right)\right) \log p\left(\mathbf{Y} \mid \mathbf{X},\left(\mathbf{W}_{i}\right)\right)+\mathrm{KL}\left(q\left(\left(\mathbf{W}_{i}\right)\right) \mid p\left(\left(\mathbf{W}_{i}\right)\right)\right) \end{array}\).


We define our approximating variational distribution \(q\left(\mathbf{W}_{i}\right)\) for every layer \(i\) as


Approximating variational distn \(q\left(\mathbf{W}_{i}\right)\) for every layer \(i\) :

  • \(\mathbf{W}_{i}=\mathbf{M}_{i} \cdot \operatorname{diag}\left(\left[\mathbf{z}_{i, j}\right]_{j=1}^{K_{i}}\right)\).
  • \(\mathbf{z}_{i, j} \sim \text { Bernoulli }\left(p_{i}\right) \text { for } i=1, \ldots, L, j=1, \ldots, K_{i-1}\).

    ( if \(z_{i,j} \sim N(1,1)\) instead of Bernoulli, we get “Multiplicative Gaussian Noise” )


Dropout & other SRTs : assessed by setting the weight matrices to their “mean” at test time


With new data point ( + MC integration ) :

\(p\left(y^{*} \mid \mathbf{x}^{*}, \mathbf{X}, \mathbf{Y}\right) \approx \int p\left(y^{*} \mid \mathbf{x}^{*},\left(\mathbf{W}_{i}\right)\right) q\left(\left(\mathbf{W}_{i}\right)\right) \approx \frac{1}{T} \sum_{t=1}^{T} p\left(y^{*} \mid \mathbf{x}^{*},\left(\mathbf{W}_{i}\right)_{t}\right)\).

  • with \(\left(\mathbf{W}_{i}\right)_{t} \sim q\left(\left(\mathbf{W}_{i}\right)\right)\).
  • refer to this as MC dropout when \(q(\cdot)\) is Bernoulli r.v

Categories:

Updated: