Wasserstein Variational Inference (NeurIPS 2018)


  • Introduce Wasserstein Variational Inference ( = Bayesian Inference based on optimal transport theory )

  • Uses a new family of divergence, which includes (1) f-divergence & (2) Wasserstein distance

  • Gradients of Wasserstein variational loss : obtained by backpropagating through the Sinkhorn iterations

  • Introduce several new forms of autoencoders

1. Introduction

in Variational Inference … KL divergence plays the central role

But recently, OPTIMAL TRANSPORT DIVERGENCES such as Wasserstein distance have gained popularity

  • usually in “generative modeling”

    \(\because\) well-behave in situations where KL-divergence is either infinite or undefined

Proposes variational Bayesian inference! define new c-Wasserstein family of divergences

  • includes (1) f-divergence & (2) Wasserstein distance
    • f-divergences include both “forward & reverse KL”

1-1. Background on joint-contrastive variational inference

Review about joint-contrastive variational inference

  • latent variable \(z\)
  • observed data \(x\)

ex) reverse KL-divergence : \(D_{K L}(p(x, z) \mid \mid q(x, z))=\mathbb{E}_{q(x, z)}\left[\log \frac{q(x, z)}{p(x, z)}\right]\)

  • \(q(x, z)=q(z \mid x) k(x)\) : product between the variational posterior & sampling distribution of the data.
  • advantage : no need to evaluate the intractable distn \(p(z \mid x)\)

1.2 Background on optimal transport

(1) Optimal transport divergences :

  • distance between two probability distns as the cost of transporting probability mass from one to the other.

  • \(\Gamma[p, q]\) : set of all bivariate probability measures on the product space \(X \times X\), whose marginals are \(p\) and \(q\)

  • formula :

    \(W_{c}(p, q)=\inf _{\gamma \in \Gamma[p, q]} \int c\left(x_{1}, x_{2}\right) \mathrm{d} \gamma\left(x_{1}, x_{2}\right)\).

    • super-cubic complexity
    • complexity can be reduced by adopting “entropic regularization”

(2) Define new set of joint distn

  • \(U_{\epsilon}[p, q]=\left\{\gamma \in \Gamma[p, q] \mid D_{K L}(\gamma(x, y) \mid \mid p(x) q(y)) \leq \epsilon^{-1}\right\}\).
    • have the “mutual info” between two variables, bounded by the regularization param \(\epsilon^{-1}\)

(3) Rewrite optimal transport divergence using above :

  • \(W_{c, \epsilon}(p, q)=\inf _{u \in U_{\epsilon}[p, q]} \int c\left(x_{1}, x_{2}\right) \mathrm{d} u\left(x_{1}, x_{2}\right)\).

  • when \(p\) and \(q\) are discrete distn : (by Sinkhorn iterations)

    \(W_{c, \epsilon}(p, q)=\lim _{t \rightarrow \infty} \mathcal{S}_{t}^{\epsilon}[p, q, c]\).

    • \(\mathcal{S}_{t}^{\epsilon}[p, q, c]\) : output of the \(t^{th}\) Sinkhorn iteration


2. Wasserstein Variational Inference

(1) introduce new family of divergences, includes both (1) optimal transport divergences & (2) f-divergences

(2) Then, develop a black-box and likelihood-free variational algorithm

2-1. c-Wasserstein divergences

(1) Traditional divergences

  • ex) KL-divergence
  • depend explicitly on the distn \(p\) and \(q\)

(2) Optimal transport divergences

  • ex) Wasserstein distance
  • depend on the distn \(p\) and \(q\) only through the constraints of an optimization problem

(1)+(2) = New divergence, “c-Wasserstein divergence”

  • generalize both forms of dependencies ( (1) and (2) )
  • FORMULA : \(W_{C}(p, q)=\inf _{\gamma \in \Gamma[p, q]} \int C^{p, q}\left(x_{1}, x_{2}\right) \mathrm{d} \gamma\left(x_{1}, x_{2}\right)\).

  • cost function \(C^{p, q}\left(x_{1}, x_{2}\right)\) :

    • depends both on the 2 scalars \(x_{1}\) and \(x_{2}\) and on 2 distributions \(p\) and \(q\)

    • assumes to following properties :

      \(\begin{array}{l} \text { 1. } C^{p, p}\left(x_{1}, x_{2}\right) \geq 0, \forall x_{1}, x_{2} \in \operatorname{supp}(p) \\ \text { 2. } C^{p, p}(x, x)=0, \forall x \in \operatorname{supp}(p) \\ \text { 3. } \mathbb{E}_{\gamma}\left[C^{p, q}\left(x_{1}, x_{2}\right)\right] \geq 0, \forall \gamma \in \Gamma[p, q] \end{array}\).

[ Theorem 1 ]

The functional \(W_{C}(p, q)\) is a (pseudo-)divergence, meaning that \(W_{C}(p, q) \geq 0\) for \(p\) and \(q\) and \(W_{C}(p, p)=0\) for all \(p\)

  • all optimal transport divergences are part of” “c-Wasserstein family”

[ Theorem 2 ]

Let \(f: \mathbb{R} \rightarrow \mathbb{R}\) be a convex function such that \(f(1)=0 .\) The cost functional \(C^{p, q}(x, y)=f(g(x, y))\) respects property 3 when \(\mathbb{E}_{\gamma}[g(x, y)]=1\) for all \(\gamma \in \Gamma[p, q]\)

[ Theorem 3 ]

Let \(W\left(p_{n}, q_{n}\right)\) be the Wasserstein distance between two empirical distributions \(p_{n}\) and \(q_{n}\).

For \(n\) tending to infinity, there is a positive number s such that \(\mathbb{E}_{p q}\left[W\left(p_{n}, q_{n}\right)\right] \lesssim W(p, q)+n^{-1 / s}\)

2.2 Stochastic Wasserstein variational inference

( loss func ) c-Wasserstein divergence between \(p(x, z)\) and \(q(x, z)\) :

\(\mathcal{L}_{C}[p, q]=W_{C}(p(z, x), q(z, x))=\inf _{\gamma \in \Gamma[p, q]} \int C^{p, q}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right) \mathrm{d} \gamma\left(x_{1}, z_{1} ; x_{1}, z_{1}\right)\)……… eq(A)

  • minimized when \(p=q\)
  • but \(\mathcal{L}_{C}[p, q]\) could be 0 even if \(p \neq q .\)

Black-box MC estimate of the gradient of eq(A) :

  • step 1) discrete c-Wasserstein divergence

    \[\mathcal{L}_{C}\left[p_{n}, q_{n}\right]=\inf _{\gamma} \sum_{j, k} C^{p, q}\left(x_{1}^{(j)}, z_{1}^{(j)}, x_{2}^{(k)}, z_{2}^{(k)}\right) \gamma\left(x_{1}^{(j)}, z_{1}^{(j)}, x_{2}^{(k)}, z_{2}^{(k)}\right)\]
    • where \(\left(x_{1}^{(j)}, z_{1}^{(j)}\right)\) and \(\left(x_{2}^{(k)}, z_{2}^{(k)}\right)\) are sampled from \(p(x, z)\) and \(q(x, z)\) respectively
    • asymptotically unbiased
  • step 2) use the modified loss ( to eliminate bias )

    \(\tilde{\mathcal{L}}_{C}\left[p_{n}, q_{n}\right]=\mathcal{L}_{C}\left[p_{n}, q_{n}\right]-\left(\mathcal{L}_{C}\left[p_{n}, p_{n}\right]+\mathcal{L}_{C}\left[q_{n}, q_{n}\right]\right) / 2\).

    • expectation = 0 ( when \(p=q\) )

      \(\lim _{n \rightarrow \infty} \tilde{\mathcal{L}}_{C}\left[p_{n}, q_{n}\right]=\mathcal{L}_{C}[p, q]\).

  • step 3) compute the gradient of the loss ( using automatic differentiation )

    entropy-regularized version of optimal transport cost can be approximated by truncating the Sinkhorn iterations

    \(\nabla \mathcal{L}_{C}\left[p_{n}, q_{n}\right]=\nabla \mathcal{S}_{t}^{\epsilon}\left[p_{n}, q_{n}, C_{p, q}\right]\).

3. Examples of c-Wasserstein divergences

now introduce 2 classes of c-Wasserstein divergences

  • that are suitable for deep Bayesian VI
  • question : how to define COST?

Show that KL-div & f-div are part of c-Wasserstein divergences

3-1. (1) A metric divergence for latent space

cost : \(C_{P B}^{p}\left(z_{1}, z_{2}\right)=d_{x}\left(g_{p}\left(z_{1}\right), g_{p}\left(z_{2}\right)\right)\)

  • simplest way to assign a geometric transport cost to the latent space :

    pull back a metric function from the observable space

  • \(d_{x}\left(x_{1}, x_{2}\right)\) = metric function in the observable space

  • \(g_{p}(z)\) = deterministic function that maps \(z\) to the expected value of \(p(x \mid z)\)

3.2 (2) Autoencoder divergences

(1) Latent autoencoder cost

cost : \(C_{L A}^{q}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right)=d\left(z_{1}-h_{q}\left(x_{1}\right), z_{2}-h_{q}\left(x_{2}\right)\right)\)

  • ( expected value of \(q(z \mid x)\) is given by the deterministic function \(h_q(z)\) )
  • transport cost between the latent residuals of the two models

(2) Observable autoencoder cost

cost : \(C_{O A}^{p}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right)=d\left(x_{1}-g_{p}\left(z_{1}\right), x_{2}-g_{p}\left(z_{2}\right)\right)\).

  • \(g_{p}(z)\) gives the expected value of the generator

  • if deterministic generator :

    • \(C_{O A}^{p}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right)=d\left(0, x_{2}-g_{p}\left(z_{2}\right)\right)\).

    • then, the resulting divergence is just “average reconstruction error”

      \(\inf _{\gamma \in \Gamma[p]} \int d\left(0, x_{2}-g_{p}\left(z_{2}\right)\right) \mathrm{d} \gamma=\mathbb{E}_{q(x, z)}\left[d\left(0, x-g_{p}(z)\right)\right]\).

3-3. \(f\)- divergences

all \(f\)-divergences are part of c-Wasserstein family!

cost : \(C_{f}^{p, q}\left(x_{1}, x_{2}\right)=f\left(\frac{p\left(x_{2}\right)}{q\left(x_{2}\right)}\right)\).

  • \(f\) : convex function such that \(f(0)=1\)

  • by [ Theorem 2 ], it defines a valid c-Wasserstein divergence

4. Wasserstein Variational Autoencoders


  • \(\mathcal{D}_{p}\) and \(\mathcal{D}_{q}\) : parametrized probability distributions

  • \(g_{p}(z)\) and \(h_{q}(x)\) : outputs of deep networks

Decoder (probabilistic model) : \(p(z, x)=\mathcal{D}_{p}\left(x \mid g_{p}(z)\right) p(z)\).

Encoder (variational model) : \(q(z, x)=\mathcal{D}_{q}\left(z \mid \boldsymbol{h}_{q}(x)\right) k(x)\)

Define a LARGE family of objective functions of VAEs by combining costs functions!

​ \(\begin{aligned} C_{\boldsymbol{w}, f}^{p, q}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right)=& w_{1} d_{x}\left(x_{1}, x_{2}\right)+w_{2} C_{P B}^{p}\left(z_{1}, z_{2}\right)+w_{3} C_{L A}^{p}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right) \\ &+w_{4} C_{O A}^{q}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right)+w_{5} C_{f}^{p, q}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right) \end{aligned}\).

5. Connections with related methods

5-1. Operator Variational Inference (?)

Wasserstein Variational Inference = “special case of generalized version of operator variational inference

operator variational inference

  • objective : \(\mathcal{L}_{O P}=\sup _{f \in \mathfrak{F}} \zeta\left(\mathbb{E}_{q(x, z)}\left[\mathcal{O}^{p, q} f\right]\right)\).

  • [ dual representation ]

    c-Wasserstein loss : \(W_{c}(p, q)=\sup _{f \in L_{C}}\left[\mathbb{E}_{p(x, z)}[f(x, z)]-\mathbb{E}_{q(x, z)}[f(x, z)]\right]\)

    ​ where \(L_{C}[p, q]=\left\{f: X \rightarrow \mathbb{R} \mid f\left(x_{1}, z_{1}\right)- f\left(x_{2}, z_{2}\right) \leq C^{p, q}\left(x_{1}, z_{1} ; x_{2}, z_{2}\right)\right\}\)

  • using importance sampling…

    \(W_{c}(p, q)=\sup _{f \in L_{C}[p, q]}\left[\mathbb{E}_{q(x, z)}\left[\left(\frac{p(x, z)}{q(x, z)}-1\right) f(x, z)\right]\right]\).

5-2. Wasserstein Autoencoders (WAE)

recently inroduced WAE : uses “regularized optimal transport divergence between \(p(x)\) and \(k(x)\)”

Regularized Loss : \(\mathcal{L}_{W A}=\mathbb{E}_{q(x, z)}\left[c_{x}\left(x, g_{p}(z)\right)\right]+\lambda D(p(z) \mid \mid q(z))\).

  • derived from optimal transport loss!

    ( \(\mathcal{L}_{W A} \approx W_{c_{x}}(p(x), k(x))\) )

  • when \(D(p(z) \mid \mid q(z))\) is c-Wasserstein divergence , \(\mathcal{L}_{W A}\) is a Wasserstein variational loss

    \(\mathbb{E}_{q(x, z)}\left[c_{x}\left(x, g_{p}(x)\right)\right]+\lambda W_{C_{z}}(p(z), q(z))=\inf _{\gamma \in \Gamma[p, q]} \int\left[c_{x}\left(x_{2}, g_{p}\left(z_{2}\right)\right)+\lambda C_{z}^{p, q}\left(z_{1}, z_{2}\right)\right] \mathrm{d} \gamma\).

