Efficient Variational Inference for Gaussian Process Regression Networks (2013)


Abstract

( Multi-output regression ) correlation between \(Y\)s may vary with input space

GPRNs (Gaussian Process Regression Networks)

  • flexible
  • intractable

Thus, propose 2 efficient VI methods for GPRNs


(1) GPRN-MF

  • adopts mean-field with full Gaussian over GPRN’s parameters


(2) GPRN-NPV

  • non-parametric VI
  • derive analytical forms of ELBO
  • closed-form updates of parameters
  • \(O(N)\) for parameter’s covariances


1. Introduction

Challenge in multi-output :

  • 1) develop flexible models able to capture the dependencies between \(Y\)s
  • 2) efficient inference


Various non-probabilistic approaches have been developed.

It is crucial to have full posterior probabilities

GP have proved very effective tools for single & multiple output


GP-based methods :

  • before) assume that the dependencies between the \(Y\)s are fixed

    ( = independent of the input space )

  • after ) correlation between \(Y\)s can be spatially adaptive

    \(\rightarrow\) GAUSSIAN PROCESS REGRESSION NETWORKS (GPRNs)


This paper proposes “efficient approximate inference methods for GPRNs”

(1) First method : simple MF approach of GPRN

  • show that…
    • 1) can obtain analytical expression of ELBO & closed-form update of variational params
    • 2) parameterize the corresponding covariances with only \(O(N)\) params


(2) Second method : exploits VI

  • non-parametric VI to approximate posterior of GPRN’s params
  • approximate complex distn, which are not well approximated by single Gaussian
  • needs \(O(N)\) variational params


2. GPRN

Input : \(\mathbf{x} \in \mathbb{R}^{D}\).

Output : \(\mathbf{y}(\mathbf{x}) \in \mathbb{R}^{P}\).

  • assumed to be linear combination of \(Q\) noisy latent functions \(\mathrm{f}(\mathrm{x}) \in \mathbb{R}^{Q}\)
  • corrupted by Gaussian noise

Mixing Coefficients : \(\mathbf{W}(\mathbf{x}) \in \mathbb{R}^{P} \times \mathbb{R}^{Q}\)


[ GPRN model ]

\(\begin{aligned} \mathrm{y}(\mathrm{x}) &=\mathrm{W}(\mathrm{x})\left[\mathrm{f}(\mathrm{x})+\sigma_{f} \epsilon\right]+\sigma_{y} \mathrm{z} \\ f_{j}(\mathrm{x}) & \sim \mathcal{G} \mathcal{P}\left(0, \kappa_{f}\right), \quad j=1 \ldots Q \\ W_{i j}(\mathrm{x}) & \sim \mathcal{G} \mathcal{P}\left(0, \kappa_{w}\right), \quad i=1, \ldots, P ; j=1, \ldots Q \\ \epsilon & \sim \mathcal{N}\left(\epsilon ; 0, \mathrm{I}_{Q}\right) \\ \mathrm{z} & \sim \mathcal{N}\left(\mathrm{z} ; 0, \mathrm{I}_{P}\right) \end{aligned}\).


Advantage of GPRN model :

  • 1) dependencies of outputs \(y\) are induced via latent functions \(\mathbf{f}\)

  • 2) mixing coefficients \(\mathbf{W}(\mathbf{x})\) explicitly depends on \(\mathbf{x}\)

    ( = correlations are spatially adaptive )


Notation

  • Observed Inputs : \(\mathcal{X}=\left\{\left(\mathbf{x}_{i}\right)\right\}_{i=1}^{N}\)

  • Observed Outputs : \(\mathcal{D}=\left\{\left(\mathbf{y}_{i}\right)\right\}_{i=1}^{N}\)

  • concatenation of latent function params & weights : \(\mathbf{u}=(\hat{\mathbf{f}}, \mathbf{w}),\)

  • noisy version of latent function values : \(\hat{\mathrm{f}}=\mathrm{f}+\sigma_{f} \epsilon,\)

  • hyperparameters of GPRN : \(\theta=\left\{\boldsymbol{\theta}_{f}, \boldsymbol{\theta}_{w}, \sigma_{f}, \sigma_{y}\right\}\)


Prior :

\(\mathbf{u}\) : \(p\left(\mathbf{u} \mid \boldsymbol{\theta}_{f}, \boldsymbol{\theta}_{w}, \sigma_{f}\right)=\mathcal{N}\left(\mathbf{u} ; \mathbf{0}, \mathbf{C}_{u}\right)\).


Conditional Likelihood :

\(p(\mathcal{D} \mid \mathbf{u}, \boldsymbol{\theta})=\prod_{n=1}^{N} \mathcal{N}\left(\mathbf{y}\left(\mathbf{x}_{n}\right) ; \mathbf{W}\left(\mathbf{x}_{n}\right) \hat{\mathbf{f}}\left(\mathbf{x}_{n}\right), \sigma_{y}^{2} \mathbf{I}_{P}\right)\).

  • \(\mathrm{y}(\mathrm{x}) =\mathrm{W}(\mathrm{x})\left[\mathrm{f}(\mathrm{x})+\sigma_{f} \epsilon\right]+\sigma_{y} \mathrm{z}\).

  • \(\mathrm{z} \sim \mathcal{N}\left(\mathrm{z} ; 0, \mathrm{I}_{P}\right)\).


Posterior :

\(p(\mathbf{u} \mid \mathcal{D}, \boldsymbol{\theta}) \propto p\left(\mathbf{u} \mid \boldsymbol{\theta}_{f}, \boldsymbol{\theta}_{w}, \sigma_{f}\right) p\left(\mathcal{D} \mid \mathbf{u}, \sigma_{y}\right)\).


3. VI for GPRNs

minimize KL-divergence :

  • \(\mathrm{KL}(q(\mathbf{u}) \| p(\mathbf{u} \mid \mathcal{D}))=\mathbb{E}_{q}\left[\log \frac{q(\mathbf{u})}{p(\mathbf{u} \mid \mathcal{D})}\right]\).

maximize ELBO :

  • \(\mathcal{L}(q)=\mathbb{E}_{q}[\log p(\mathcal{D} \mid \mathbf{f}, \mathbf{w})]+\mathbb{E}_{q}[\log p(\mathbf{f}, \mathbf{w})]+\mathcal{H}_{q}[q(\mathbf{f}, \mathbf{w})]\).


for mean-field method, we can obtain…

  • 1) analytical expression for ELBO
  • 2) need only \(O(N)\) params for covariances


3-1. MFVI for GPRN

factorized distributions :

\(q(\mathbf{f}, \mathbf{w}) =\prod_{j=1}^{Q} q\left(\mathbf{f}_{j}\right) \prod_{i=1}^{P} q\left(\mathbf{w}_{i j}\right)\).

  • \(q\left(\mathbf{f}_{j}\right) =\mathcal{N}\left(\mathbf{f}_{j} ; \boldsymbol{\mu}_{\mathrm{f}_{j}}, \Sigma_{\mathrm{f}_{j}}\right)\).
  • \(q\left(\mathbf{w}_{i j}\right) =\mathcal{N}\left(\mathbf{w}_{i j} ; \mu_{\mathrm{w}_{\mathrm{ij}}}, \Sigma_{\mathrm{w}_{\mathrm{ij}}}\right)\).


3-1-1. Closed-form ELBO

( full Gaussian mean-field approximation ) ELBO

(1) First term :

\(\begin{array}{l} \mathbb{E}_{q}[\log p(\mathcal{D} \mid \mathbf{f}, \mathbf{w})]&=-\frac{N P}{2} \log \left(2 \pi \sigma_{y}^{2}\right) \\ &-\frac{1}{2 \sigma_{y}^{2}} \sum_{n=1}^{N}\left(\mathbf{Y}_{\cdot n}^{T}-\Omega_{\mathrm{w}_{\mathrm{n}}} \nu_{\mathrm{f}_{\mathrm{n}}}\right)^{T}\left(\mathbf{Y}_{\cdot n}^{T}-\Omega_{\mathrm{w}_{\mathrm{n}}} \nu_{\mathrm{f}_{\mathrm{n}}}\right) \\ &-\frac{1}{2 \sigma_{y}^{2}} \sum_{i=1}^{P} \sum_{j=1}^{Q}\left[\operatorname{diag}\left(\Sigma_{\mathrm{f}_{\mathrm{j}}}\right)^{T}\left(\mu_{\mathrm{w}_{\mathrm{ij}}} \bullet \mu_{\mathrm{w}_{\mathrm{ij}}}\right)\right. \left.+\operatorname{diag}\left(\Sigma_{\mathrm{w}_{\mathrm{ij}}}\right)^{T}\left(\mu_{\mathrm{f}_{\mathrm{j}}} \bullet \mu_{\mathrm{f}_{\mathrm{j}}}\right)\right] \end{array}\).


(2) Second term :

\(\begin{array}{l} \mathbb{E}_{q}[\log p(\mathbf{f}, \mathbf{w})]&= -\frac{1}{2} \sum_{j=1}^{Q}\left(\log \left|\mathbf{K}_{f}\right|+\boldsymbol{\mu}_{\mathrm{f}_{j}}^{T} \mathbf{K}_{f}^{-1} \mu_{\mathrm{f}_{j}}+\operatorname{tr}\left(\mathbf{K}_{f}^{-1} \boldsymbol{\Sigma}_{\mathrm{f}_{\mathrm{j}}}\right)\right) \\ &-\frac{1}{2} \sum_{i, j}\left(\log \left|\mathbf{K}_{w}\right|+\mu_{\mathrm{w}_{\mathrm{i} j}} \mathbf{K}_{w}^{-1} \mu_{\mathrm{w}_{\mathrm{ij}}}+\operatorname{tr}\left(\mathbf{K}_{w}^{-1} \boldsymbol{\Sigma}_{\mathrm{w}_{\mathrm{ij}}}\right)\right) \end{array}\).


(3) Third term :

$$\mathcal{H}[q(\mathbf{f}, \mathbf{w})]=\frac{1}{2} \sum_{j=1}^{Q} \log \left \boldsymbol{\Sigma}{\mathrm{f}{\mathrm{j}}}\right +\frac{1}{2} \sum_{i, j} \log \left \boldsymbol{\Sigma}{\mathrm{w}{\mathrm{ij}}}\right +\mathrm{const}$$.


3-1-2. Efficient Closed-form Updates for Variational Parameters

Parameters for \(q(\mathbf{f}_j)\)

  • \(\mu_{\mathrm{f}_{\mathrm{j}}}=\frac{1}{\sigma_{y}^{2}} \Sigma_{\mathrm{f}_{j}} \sum_{i=1}^{P}\left(\mathrm{Y}_{\cdot i}-\sum_{k \neq j} \mu_{\mathrm{w}_{\mathrm{ik}}} \bullet \mu_{\mathrm{f}_{\mathrm{k}}}\right) \bullet \mu_{\mathrm{w}_{\mathrm{i} j}}\).
  • \(\boldsymbol{\Sigma}_{\mathrm{f}_{\mathrm{j}}}=\left(\mathbf{K}_{f}^{-1}+\frac{1}{\sigma_{y}^{2}} \sum_{i=1}^{P} \operatorname{diag}\left(\boldsymbol{\mu}_{\mathrm{w}_{\mathrm{ij}}} \bullet \boldsymbol{\mu}_{\mathrm{w}_{\mathrm{ij}}}+\operatorname{Var}\left(\mathbf{w}_{i j}\right)\right)\right)^{-1}\).


Parameters for \(q(\mathbf{w}_{ij})\)

  • \(\mu_{\mathrm{w}_{\mathrm{ij}}}=\frac{1}{\sigma_{y}^{2}} \Sigma_{\mathrm{w}_{\mathrm{ij}}}\left(\mathrm{Y}_{\cdot i}-\sum_{k \neq j} \mu_{\mathrm{f}_{\mathrm{k}}} \bullet \mu_{\mathrm{w}_{\mathrm{ik}}}\right) \bullet \mu_{\mathrm{f}_{\mathrm{j}}}\)>
  • \(\Sigma_{\mathrm{w}_{\mathrm{ij}}}=\left(\mathrm{K}_{w}^{-1}+\frac{1}{\sigma_{y}^{2}} \operatorname{diag}\left(\mu_{\mathrm{f}_{j}} \bullet \mu_{\mathrm{f}_{j}}+\operatorname{Var}\left(\mathrm{f}_{j}\right)\right)\right)^{-1}\).


3-1-3. Hyper-parameters Learning

hyperparameters : \(\boldsymbol{\theta}=\left\{\boldsymbol{\theta}_{f}, \boldsymbol{\theta}_{w}, \sigma_{f}, \sigma_{y}\right\}\).

learn by gradient-based optimization of ELBO


3-2. Non-parametric VI for GPRN

approximate posterior of GPRN, using mixture of \(K\) isotropic Gaussian

\(q(\mathbf{u})=\frac{1}{K} \sum_{k=1}^{K} q^{(k)}(\mathbf{u})=\frac{1}{K} \sum_{k=1}^{K} \mathcal{N}\left(\mathbf{u} ; \boldsymbol{\mu}^{(k)}, \sigma_{k}^{2} \mathbf{I}\right)\).

  • in practice, \(K\) is very small, so complexity is \(O(N)\)


3-2-1. Closed-form ELBO

\(q(\mathbf{u})=\frac{1}{K} \sum_{k=1}^{K} q^{(k)}(\mathbf{u})=\frac{1}{K} \sum_{k=1}^{K} \mathcal{N}\left(\mathbf{u} ; \boldsymbol{\mu}^{(k)}, \sigma_{k}^{2} \mathbf{I}\right)\). cannot be computed analytically

\(\rightarrow\) need approximation


Expectations decompose as … (using mean-field) :

(1) First term

\(\begin{array}{l} \mathbb{E}_{q}[\log p(\mathcal{D} \mid \mathbf{f}, \mathbf{w})]= \\ -\frac{1}{2 K \sigma_{y}^{2}} \sum_{k} \sum_{n}\left(\mathbf{Y}_{\cdot n}^{T}-\Omega_{\mathrm{wn}}^{(k)} \nu_{\mathrm{f}_{\mathrm{n}}}^{(k)}\right)^{T}\left(\mathbf{Y}_{\cdot n}^{T}-\Omega_{\mathrm{W}_{\mathrm{n}}}^{(k)} \nu_{\mathrm{f}_{\mathrm{n}}}^{(k)}\right) \\ -\frac{1}{2 K}\left(\sum_{k, j} \frac{P \sigma_{k}^{2}}{\sigma_{y}^{2}} \mu_{\mathrm{f}_{j}}^{(k)^{T}} \mu_{\mathrm{f}_{j}}^{(k)}+\sum_{k, i, j} \frac{P \sigma_{k}^{2}}{\sigma_{y}^{2}} \mu_{\mathrm{w}_{\mathrm{ij}}}^{(k)^{T}} \mu_{\mathrm{w}_{\mathrm{ij}}}^{(k)}\right) \\ -\frac{1}{2 K}\left(\sum_{k} \frac{\sigma_{k}^{4}}{\sigma_{y}^{2}} N P Q+N P \log \left(2 \pi \sigma_{y}^{2}\right)\right) \end{array}\).


(2) Second term

\(\begin{array}{l} \mathbb{E}_{q}[\log p(\mathbf{f}, \mathbf{w})]=-\frac{1}{2}\left(Q \log \left|\mathbf{K}_{f}\right|+P Q \log \left|\mathbf{K}_{w}\right|\right) \\ -\frac{1}{2 K}\left[\sum_{k, j} \boldsymbol{\mu}_{\mathfrak{f}_{j}}^{(k)^{T}} \mathbf{K}_{f}^{-1} \boldsymbol{\mu}_{\mathfrak{f}_{j}}^{(k)}+\sigma_{k}^{2} \operatorname{tr}\left(\mathbf{K}_{f}^{-1}\right)\right. \\ \left.+\sum_{k, i, j} \boldsymbol{\mu}_{\mathbf{w}_{i j}}^{(k)^{T}} \mathbf{K}_{w}^{-1} \boldsymbol{\mu}_{\mathrm{w}_{i j}}^{(k)}+\sigma_{k}^{2} \operatorname{tr}\left(\mathbf{K}_{w}^{-1}\right)\right] \end{array}\).


(3) Third term

\(\mathcal{H}_{q}[q(\mathbf{u})] \geq -\frac{1}{K} \sum_{k=1}^{K} \log \frac{1}{K} \sum_{j=1}^{K} \mathcal{N}\left(\bold{\mu}^{(k)} ; \bold{ \mu}^{(j)},\left(\sigma_{k}^{2}+\sigma_{j}^{2}\right) \mathbf{I}\right)\).


(1)~(3) define tight ELBO


3-2-2. Optimization of Variational Parameters and Hyper-parameters

Optimization of variational params \(\left\{\mu_{\mathrm{f}_{j}}^{(k)}, \mu_{\mathrm{w}_{\mathrm{ij}}}^{(k)}\right\}\) & hyperparameters \(\theta\)


3-3. Predictive Distribution

for non-parametric VI, predictive mean turns out to be…

\(\mathbb{E}\left[\mathbf{y}^{*} \mid \mathbf{x}^{*}, \mathcal{D}\right]=\frac{1}{K} \sum_{k=1}^{K} \mathbf{K}_{w}^{*} \mathbf{K}_{w}^{-1} \boldsymbol{\mu}_{\mathbf{w}}^{(k)} \mathbf{K}_{f}^{*} \mathbf{K}_{f}^{-1} \boldsymbol{\mu}_{\mathbf{f}}^{(k)}\).

Categories:

Updated: