Representation Learning with Contrastive Predictive Coding
Contents
- Abstract
 - Contrastive Predictive Coding
    
- Motivation & Intuition
 - Contrastive Predictive Coding
 - Info NCE & Mutual Information Estimation
 
 
0. Abstract
“Contrastive Predictive Coding (CPC)”
- 
    
universal UNsupervised learning approach to extract useful representations from HIGH-dim data
 - 
    
details
- by predicting the future in the latent space with autoregressive models
 - use a probabilistic contrastive loss
 - tractable by using negative sampling
 
 
1. Contrastive Predictive Coding
(1) Motivation & Intuition
Main intution
- 
    
encode the underlying SHARED information between different parts of HIGH-dim signal
( + discard LOW-level information & noise )
 - 
    
Slow features ( = shared info, global structure … )
- the further in the future, amount of shared info becomes lower
 
 
Challenges in predicting high-dim data
- 
    
(1) unimodal losses ( ex. MSE, CE ) are not useful
 - 
    
(2) generative models : computationally intense
- 
        
waste in capturing relationships in data \(x\), ( ignoring context \(c\) )
 - 
        
modling \(p(x \mid c)\) directly ?
\(\rightarrow\) may not be optimal for the purpose of extracting shared info between \(x\) & \(c\)
 
 - 
        
 
This paper :
- 
    
encode the target \(x\) ( future ) & context \(c\) ( present) into compact distributed vector,
in a way that maximally preserves the MUTUAL information of \(x\) & \(c\)
 - 
    
\(I(x ; c)=\sum_{x, c} p(x, c) \log \frac{p(x \mid c)}{p(x)}\).
 
(2) Contrastive Predictive Coding
a) architecture

b) notation
Model :
- [encoder] \(g_{\text {enc }}\) ……… \(z_{t}=g_{\text {enc }}\left(x_{t}\right)\)
 - [AR model] \(g_{\mathrm{ar}}\) ………. summarizes all \(z_{\leq t}\) in the latent space & produce context \(c_{t}=g_{\mathrm{ar}}\left(z_{\leq t}\right)\)
 
c) mutual information
do not predict \(x_{t+k}\) directly with generative model \(p_{k}\left(x_{t+k} \mid c_{t}\right)\)
\(\rightarrow\) instead, model a density ratio ( which preserves mutual information between \(x_{t+k}\) & \(c_t\) )
Density ratio
- 
    
\(f_{k}\left(x_{t+k}, c_{t}\right) \propto \frac{p\left(x_{t+k} \mid c_{t}\right)}{p\left(x_{t+k}\right)}\).
( \(f\) can be unnroamlized )
 - 
    
use a log-bilinear model, \(f_{k}\left(x_{t+k}, c_{t}\right)=\exp \left(z_{t+k}^{T} W_{k} c_{t}\right)\)
- linear transformation \(W_{k}^{T} c_{t}\) is used for prediction, with different \(W_k\) for every step \(k\)
 
 
Working in LOW-dim
- 
    
by using density ratio & inferring \(z_{t+k}\) with encoder,
\(\rightarrow\) relieve the model from modeling the high-dim \(x\)
 
(3) Info NCE & Mutual Information Estimation
Encoder & AR model : jointly optimized, based on NCE ( = InfoNCE )
Notation
- \(X=\left\{x_{1}, \ldots x_{N}\right\}\) of \(N\) random samples
 - pos & neg
    
- 1 pos ~ \(p\left(x_{t+k} \mid c_{t}\right)\)
 - (N-1) neg ~ \(p\left(x_{t+k}\right)\)
 
 
InfoNCE loss : \(\mathcal{L}_{\mathrm{N}}=-\underset{X}{\mathbb{E}}\left[\log \frac{f_{k}\left(x_{t+k}, c_{t}\right)}{\sum_{x_{j} \in X} f_{k}\left(x_{j}, c_{t}\right)}\right]\)
- optimizing InfoNCE = estimating density ratio
 
Optimal probability for this loss : \(p\left(d=i \mid X, c_{t}\right)\)
- meaning of \([d=i]\) : \(x_i\) is positive sample
 - \(\begin{aligned}
p\left(d=i \mid X, c_{t}\right) &=\frac{p\left(x_{i} \mid c_{t}\right) \prod_{l \neq i} p\left(x_{l}\right)}{\sum_{j=1}^{N} p\left(x_{j} \mid c_{t}\right) \prod_{l \neq j} p\left(x_{l}\right)}=\frac{\frac{p\left(x_{i} \mid c_{t}\right)}{p\left(x_{i}\right)}}{\sum_{j=1}^{N} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}} 
\end{aligned}\).
    
- probability that \(x_i\) was drawn from \(p\left(x_{t+k} \mid c_{t}\right)\), rather than \(p\left(x_{t+k}\right)\)
 
 
minimizing InfoNCE loss (\(L_N\)) = maximize the lower bound of MI (Mutual Information)
\(I\left(x_{t+k}, c_{t}\right) \geq \log (N)-\mathcal{L}_{\mathrm{N}}\).
( proof )
\(\begin{aligned} \mathcal{L}_{\mathrm{N}}^{\mathrm{opt}} &=-\underset{X}{\mathbb{E}} \log \left[\frac{\frac{p\left(x_{t+k} \mid c_{t}\right)}{p\left(x_{t+k}\right)}}{\frac{p\left(x_{t+k} \mid c_{t}\right)}{p\left(x_{t+k}\right)}+\sum_{x_{j} \in X_{\text {neg }}} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}}\right] \\ &=\underset{X}{\mathbb{E}} \log \left[1+\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)} \sum_{x_{j} \in X_{\text {neg }}} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}\right] \\ & \approx \underset{X}{\mathbb{E}} \log \left[1+\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)}(N-1) \underset{x_{j}}{\mathbb{E}} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}\right] \\ &=\underset{X}{\mathbb{E}} \log \left[1+\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)}(N-1)\right] \\ & \geq \underset{X}{\mathbb{E}} \log \left[\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)} N\right] \\ &=-I\left(x_{t+k}, c_{t}\right)+\log (N), \end{aligned}\).