Neural Relational Inference for Interacting Systems (2018)
Contents
- Abstract
- NRI Model
- Encoder
- Sampling
- Decoder
- Avoiding degenerate decoders
- Recurrent decoder
- Training
0. Abstract
NRI ( Neural Relational Inference )
- learns to infer interactions,
- while simultaneously learn the dynamics from observational data
use the form of VAE
1. NRI Model
consists of 2 parts trained jointly
- ENCODER : predicts the interactions, given the trajectories
- DECODER : learns the dynamical model given the interaction graph
Notation
- consists of trajectories of \(N\) objects
- \(\mathbf{x}=\left(\mathbf{x}^{1}, \ldots, \mathbf{x}^{T}\right)\).
- \(\mathbf{x}^{t}=\left\{\mathbf{x}_{1}^{t}, \ldots, \mathbf{x}_{N}^{t}\right\}\).
- \(\mathbf{x}_{i}=\left(\mathbf{x}_{i}^{1}, \ldots, \mathbf{x}_{i}^{T}\right)\).
model the dynamics with GNN, given an unknown graph \(\mathbf{z}\)
- \(\mathbf{z}_{ij}\) : discrete edge type between node \(i\) & node \(j\) ( ex. 0 & 1 )
Task : simulatenously learn to PREDICT THE EDGE TYPES & learn the DYNAMICAL MODEL in an unsupervised way
Model : VAE
- Obj function ( = ELBO ) : \(\mathcal{L}=\mathbb{E}_{q_{\phi}(\mathbf{z} \mid \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} \mid \mathbf{z})\right]-\mathrm{KL}\left[q_{\phi}(\mathbf{z} \mid \mathbf{x}) \mid \mid p_{\theta}(\mathbf{z})\right]\)
Encoder : \(q_{\phi}(\mathbf{z}\mid x)\)
-
return a factorized sitn of \(\mathbf{z}_{ij}\)
( \(\mathbf{z}_{ij}\) : discrete categorical variable )
Decoder : \(p_{\theta}(\mathbf{x} \mid \mathbf{z})=\prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\)
-
models \(p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\) with a GNN given the latent graph structure \(\mathbf{z}\).
-
prior : \(p_{\theta}(\mathbf{z})=\prod_{i \neq j} p_{\theta}\left(\mathbf{z}_{i j}\right)\)
- factorized uniform distribution
Difference with VAE
-
(1) train the decoder to predict MULTIPLE TIME STEP
-
(2) latent distn is DISCRETE
( need to use reparameterization trick for categorical variable )
(1) Encoder
Goal of encoder : infer pair-wise interaction types \(\mathbf{z_{ij}}\)
- given observed trajectories \(\mathbf{x}=\) \(\left(\mathbf{x}^{1}, \ldots, \mathbf{x}^{T}\right)\)
\(\begin{aligned} \mathbf{h}_{j}^{1} &=f_{\mathrm{emb}}\left(\mathbf{x}_{j}\right) \\ v \rightarrow e: & \mathbf{h}_{(i, j)}^{1} =f_{e}^{1}\left(\left[\mathbf{h}_{i}^{1}, \mathbf{h}_{j}^{1}\right]\right) \\ e \rightarrow v: & \mathbf{h}_{j}^{2} =f_{v}^{1}\left(\sum_{i \neq j} \mathbf{h}_{(i, j)}^{1}\right) \\ v \rightarrow e: & \mathbf{h}_{(i, j)}^{2} =f_{e}^{2}\left(\left[\mathbf{h}_{i}^{2}, \mathbf{h}_{j}^{2}\right]\right) \end{aligned}\).
Edge type posterior :
- \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)=\operatorname{softmax}\left(\mathbf{h}_{(i, j)}^{2}\right)\).
(2) Sampling
Sample from \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\) ?
- but…. DISCRETE latent variable
\(\rightarrow\) use Gumbel softmax trick : \(\mathbf{z}_{i j}=\operatorname{softmax}\left(\left(\mathbf{h}_{(i, j)}^{2}+\mathbf{g}\right) / \tau\right)\).
(3) Decoder
Task of decoder :
- predict dynamics
- predict \(p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\).
\(\begin{aligned} v \rightarrow e: \quad \tilde{\mathbf{h}}_{(i, j)}^{t} &=\sum_{k} z_{i j, k} \tilde{f}_{e}^{k}\left(\left[\mathbf{x}_{i}^{t}, \mathbf{x}_{j}^{t}\right]\right) \\ e \rightarrow v: \quad \boldsymbol{\mu}_{j}^{t+1} &=\mathbf{x}_{j}^{t}+\tilde{f}_{v}\left(\sum_{i \neq j} \tilde{\mathbf{h}}_{(i, j)}^{t}\right) \\ p\left(\mathbf{x}_{j}^{t+1} \mid \mathbf{x}^{t}, \mathbf{z}\right) &=\mathcal{N}\left(\boldsymbol{\mu}_{j}^{t+1}, \sigma^{2} \mathbf{I}\right) \end{aligned}\).
add the present state \(\mathbf{x}_{j}^{t}\) our model only learns the change in state \(\Delta \mathbf{x}_{j}^{t}\).
(4) Avoiding degenerate decoders
ELBO :
-
(1) reconstruction loss term : \(\sum_{t=1}^{T} \log \left[p\left(\mathbf{x}^{t} \mid \mathbf{x}^{t-1}, \mathbf{z}\right)\right]\)
\(\rightarrow\) only includes SINGLE prediction
How to deal?
-
(1) predict MULTIsteps
\(\rightarrow\) degenerate decoder would perform much worse
-
(2) have a separate MLP for each edge type
\(\rightarrow\) makes the dependence on the edge type more explicit & harder to be ignored by the model
Predicting MULTI steps
- rolling forecast
\(\begin{array}{rlr} \boldsymbol{\mu}_{j}^{2} & =f_{\mathrm{dec}}\left(\mathbf{x}_{j}^{1}\right)\\ \boldsymbol{\mu}_{j}^{t+1} & =f_{\mathrm{dec}}\left(\boldsymbol{\mu}_{j}^{t}\right) \\ \boldsymbol{\mu}_{j}^{M+2} & =f_{\mathrm{dec}}\left(\mathbf{x}_{j}^{M+1}\right) & \\ \boldsymbol{\mu}_{j}^{t+1} & =f_{\mathrm{dec}}\left(\boldsymbol{\mu}_{j}^{t}\right) \\ \ldots & \end{array}\).
(5) Recurrent decoder
use GRU to model \(p_{\theta}\left(\mathbf{x}^{t+1} \mid \mathbf{x}^{t}, \ldots, \mathbf{x}^{1}, \mathbf{z}\right)\).
(6) Training
Process
- (1) run the encoder & get \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\)
- (2) sample \(\mathbf{z}_{i j}\) from \(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\)
- (3) run the decoder to compute \(\boldsymbol{\mu}^{2}, \ldots, \boldsymbol{\mu}^{T}\)
ELBO objective
- (1) reconstruction loss : \(\mathbb{E}_{q_{\phi}(\mathbf{z} \mid \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} \mid \mathbf{z})\right]\)
- \(-\sum_{j} \sum_{t=2}^{T} \frac{ \mid \mid \mathbf{x}_{j}^{t}-\boldsymbol{\mu}_{j}^{t} \mid \mid ^{2}}{2 \sigma^{2}}+\text { const }\).
- (2) KL-divergence term : \(\operatorname{KL}\left[q_{\phi}(\mathbf{z} \mid \mathbf{x}) \mid \mid p_{\theta}(\mathbf{z})\right]\)
- \(\sum_{i \neq j} H\left(q_{\phi}\left(\mathbf{z}_{i j} \mid \mathbf{x}\right)\right)+\text { const}\).