Context Autoencoder for Self-Supervised Representation Learning
Contents
- Abstract
 - Introduction
 - Approach
    
- Architecture
 - Objective Function
 
 
0. Abstract
Context AutoEncoder (CAE)
- 
    
novel masked image modeling (MIM) approach
 - 
    
for self-supervised representation pretraining
 
Goal of CAE :
- pretrain an encoder by solving the pretext task
 - pretext task :
    
- estimate the masked patches from the visible patches
 
 
Details :
- step 1) feeds the visible patches into the encoder
    
- extract representations
 
 - step 2) make predictions from visible patches to masked patches
 - introduce an alignment constraint
    
- encourage the alignment between..
        
- (1) representations of “predicted” masked patches
 - (2) representations of masked patches “computed from encoder”
 
 
 - encourage the alignment between..
        
 - step 3) predicted masked patch representations are mapped to the targets of the pretext task through a decoder
 
1. Introduction
previous MIM methods ( e.g., BEiT )
- couple the encoding & pretext task competetion roles
 
\(\leftrightarrow\) CAE : separation of encoding (=RL) & pretext task
Downstream task
- semantic segmentation
 - object detection
 - instance segmentation
 
CAE
- 
    
propose CAE for improving encoding quality
 - 
    
randomly partition image into 2 set of patches
- 
        
(1) visible
 - 
        
(2) masked
 
 - 
        
 - 
    
architecture
- (1) encoder
 - (2) latent contextual regressor (with an alignment constraint)
 - (3) decoder
 
 

2. Approach
CAE pretrains the encoder,
by solving the masked image modeling task
(1) Architecture
randomly split an image into two sets of patches
- (1) visible patches \(\mathbf{X}_v\)
 - (2) masked patches \(\mathbf{X}_m\)
 
pretext task :
- predict the masked patches from visible patches in the encoded represrentation space
 - then, map the predicted representations to the targets
 

a) Encoder \(\mathcal{F}\)
- 
    
learns representations only for VISIBLE patches
 - maps the visible patches \(\mathbf{X}_v\) to the latent representations \(\mathbf{Z}_v\)
 - ( use the ViT as \(\mathcal{F}\) )
 - process
    
- step 1) embed visual patches
 - step 2) add positional embeddings \(\mathbf{P}_v\)
 - step 3) sends the combined embeddings into transformer blocks & generate \(\mathbf{Z}_v\)
 
 
b) Latent contextual regressor \(\mathcal{H}\)
- predicts the masked patch representations from \(\mathbf{Z}_v\)
    
- prediction ( = \(\mathbf{Z}_m\) ): constrained to align with the masked patch representations computed from encoder
 
 - 
    
( use a series of transformer blocks as \(\mathcal{H}\) )
 - In this process, \(\mathbf{Z}_v\) are not updated
 
Initial queries \(\mathbf{Q}_m\) ( = mask queries )
- 
    
mask tokens that are learned as model parameters
( = same for all the masked patches )
 - 
    
= Key & Value
 
c) Alignment constraint
- imposed on \(\mathbf{Z}_m\) ( predicted by \(\mathcal{H}\))
 - feed the masked patches \(\mathbf{X}_m\) and generate \(\overline{\mathbf{Z}}_m\)
 - alignment between…
    
- \(\mathbf{Z}_m\) and \(\overline{\mathbf{Z}}_m\)
 
 
d) Decoder
- maps \(\mathbf{Z}_m\) to the targets for masked patches \(\mathbf{Y}_m\)
 - ( = stack of transformer blocks, based on self-attention )
 
(2) Objective Function
a) Masking & Targets
( following BEiT )
- adopt random block-wise masking stratregy
 - for each image, 98 out of 196 (14x14) patches are masked
 
pre-trained DALL-E tokenizer
- 
    
to generate discrete tokens for forming the targets
( = target tokens for maksed patches : \(\bar{\mathbf{Y}_m}\) )
 
b) Loss function
Loss = (1) decoding loss + (2) alignment loss
(1) decoding loss : \(\ell_y\left(\mathbf{Y}_m, \overline{\mathbf{Y}}_m\right)\) ….. CE loss
(2) alignment loss : \(\ell_z\left(\mathbf{Z}_m, \overline{\mathbf{Z}}_m\right)\) …. MSE loss
(3) total loss : \(\ell_y\left(\mathbf{Y}_m, \overline{\mathbf{Y}}_m\right)+\lambda \ell_z\left(\mathbf{Z}_m, \operatorname{sg}\left[\overline{\mathbf{Z}}_m\right]\right) .\)