Neural Relational Inference for Interacting Systems (2018)

Contents

  1. Abstract
  2. NRI Model
    1. Encoder
    2. Sampling
    3. Decoder
    4. Avoiding degenerate decoders
    5. Recurrent decoder
    6. Training


0. Abstract

NRI ( Neural Relational Inference )

  • learns to infer interactions,
  • while simultaneously learn the dynamics from observational data

use the form of VAE


1. NRI Model

consists of 2 parts trained jointly

  • ENCODER : predicts the interactions, given the trajectories
  • DECODER : learns the dynamical model given the interaction graph


Notation

  • consists of trajectories of \(N\) objects
  • \(\mathbf{x}=\left(\mathbf{x}^{1}, \ldots, \mathbf{x}^{T}\right)\).
    • \(\mathbf{x}^{t}=\left\{\mathbf{x}_{1}^{t}, \ldots, \mathbf{x}_{N}^{t}\right\}\).
    • \(\mathbf{x}_{i}=\left(\mathbf{x}_{i}^{1}, \ldots, \mathbf{x}_{i}^{T}\right)\).


model the dynamics with GNN, given an unknown graph \(\mathbf{z}\)

  • \(\mathbf{z}_{ij}\) : discrete edge type between node \(i\) & node \(j\) ( ex. 0 & 1 )


Task : simulatenously learn to PREDICT THE EDGE TYPES & learn the DYNAMICAL MODEL in an unsupervised way


Model : VAE

  • Obj function ( = ELBO ) : \(\mathcal{L}=\mathbb{E}_{q_{\phi}(\mathbf{z} \mid \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} \mid \mathbf{z})\right]-\mathrm{KL}\left[q_{\phi}(\mathbf{z} \mid \mathbf{x}) \mid \mid p_{\theta}(\mathbf{z})\right]\)


figure2


Encoder : \(q_{\phi}(\mathbf{z}\mid x)\)

  • return a factorized sitn of \(\mathbf{z}_{ij}\)

    ( \(\mathbf{z}_{ij}\) : discrete categorical variable )


Decoder : \(p_{\theta}(\mathbf{x} \mid \mathbf{z})=\prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\)

  • models \(p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\) with a GNN given the latent graph structure \(\mathbf{z}\).

  • prior : \(p_{\theta}(\mathbf{z})=\prod_{i \neq j} p_{\theta}\left(\mathbf{z}_{i j}\right)\)

    • factorized uniform distribution


Difference with VAE

  • (1) train the decoder to predict MULTIPLE TIME STEP

  • (2) latent distn is DISCRETE

    ( need to use reparameterization trick for categorical variable )


(1) Encoder

Goal of encoder : infer pair-wise interaction types \(\mathbf{z_{ij}}\)

  • given observed trajectories \(\mathbf{x}=\) \(\left(\mathbf{x}^{1}, \ldots, \mathbf{x}^{T}\right)\)

\(\begin{aligned} \mathbf{h}_{j}^{1} &=f_{\mathrm{emb}}\left(\mathbf{x}_{j}\right) \\ v \rightarrow e: & \mathbf{h}_{(i, j)}^{1} =f_{e}^{1}\left(\left[\mathbf{h}_{i}^{1}, \mathbf{h}_{j}^{1}\right]\right) \\ e \rightarrow v: & \mathbf{h}_{j}^{2} =f_{v}^{1}\left(\sum_{i \neq j} \mathbf{h}_{(i, j)}^{1}\right) \\ v \rightarrow e: & \mathbf{h}_{(i, j)}^{2} =f_{e}^{2}\left(\left[\mathbf{h}_{i}^{2}, \mathbf{h}_{j}^{2}\right]\right) \end{aligned}\).


Edge type posterior :

  • \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)=\operatorname{softmax}\left(\mathbf{h}_{(i, j)}^{2}\right)\).


(2) Sampling

Sample from \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\) ?

  • but…. DISCRETE latent variable

\(\rightarrow\) use Gumbel softmax trick : \(\mathbf{z}_{i j}=\operatorname{softmax}\left(\left(\mathbf{h}_{(i, j)}^{2}+\mathbf{g}\right) / \tau\right)\).


(3) Decoder

Task of decoder :

  • predict dynamics
  • predict \(p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\).


\(\begin{aligned} v \rightarrow e: \quad \tilde{\mathbf{h}}_{(i, j)}^{t} &=\sum_{k} z_{i j, k} \tilde{f}_{e}^{k}\left(\left[\mathbf{x}_{i}^{t}, \mathbf{x}_{j}^{t}\right]\right) \\ e \rightarrow v: \quad \boldsymbol{\mu}_{j}^{t+1} &=\mathbf{x}_{j}^{t}+\tilde{f}_{v}\left(\sum_{i \neq j} \tilde{\mathbf{h}}_{(i, j)}^{t}\right) \\ p\left(\mathbf{x}_{j}^{t+1} \mid \mathbf{x}^{t}, \mathbf{z}\right) &=\mathcal{N}\left(\boldsymbol{\mu}_{j}^{t+1}, \sigma^{2} \mathbf{I}\right) \end{aligned}\).


add the present state \(\mathbf{x}_{j}^{t}\) our model only learns the change in state \(\Delta \mathbf{x}_{j}^{t}\).


(4) Avoiding degenerate decoders

ELBO :

  • (1) reconstruction loss term : \(\sum_{t=1}^{T} \log \left[p\left(\mathbf{x}^{t} \mid \mathbf{x}^{t-1}, \mathbf{z}\right)\right]\)

    \(\rightarrow\) only includes SINGLE prediction


How to deal?

  • (1) predict MULTIsteps

    \(\rightarrow\) degenerate decoder would perform much worse

  • (2) have a separate MLP for each edge type

    \(\rightarrow\) makes the dependence on the edge type more explicit & harder to be ignored by the model


Predicting MULTI steps

  • rolling forecast

\(\begin{array}{rlr} \boldsymbol{\mu}_{j}^{2} & =f_{\mathrm{dec}}\left(\mathbf{x}_{j}^{1}\right)\\ \boldsymbol{\mu}_{j}^{t+1} & =f_{\mathrm{dec}}\left(\boldsymbol{\mu}_{j}^{t}\right) \\ \boldsymbol{\mu}_{j}^{M+2} & =f_{\mathrm{dec}}\left(\mathbf{x}_{j}^{M+1}\right) & \\ \boldsymbol{\mu}_{j}^{t+1} & =f_{\mathrm{dec}}\left(\boldsymbol{\mu}_{j}^{t}\right) \\ \ldots & \end{array}\).


(5) Recurrent decoder

use GRU to model \(p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\).


(6) Training

Process

  • (1) run the encoder & get \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\)
  • (2) sample \(\mathbf{z}_{i j}\) from \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\)
  • (3) run the decoder to compute \(\boldsymbol{\mu}^{2}, \ldots, \boldsymbol{\mu}^{T}\)


ELBO objective

  • (1) reconstruction loss : \(\mathbb{E}_{q_{\phi}(\mathbf{z} \mid \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} \mid \mathbf{z})\right]\)
    • \(-\sum_{j} \sum_{t=2}^{T} \frac{ \mid \mid \mathbf{x}_{j}^{t}-\boldsymbol{\mu}_{j}^{t} \mid \mid ^{2}}{2 \sigma^{2}}+\text { const }\).
  • (2) KL-divergence term : \(\operatorname{KL}\left[q_{\phi}(\mathbf{z} \mid \mathbf{x}) \mid \mid p_{\theta}(\mathbf{z})\right]\)
    • \(\sum_{i \neq j} H\left(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\right)+\text { const}\).


Categories: ,

Updated: