Context Autoencoder for Self-Supervised Representation Learning


Contents

  1. Abstract
  2. Introduction
  3. Approach
    1. Architecture
    2. Objective Function


0. Abstract

Context AutoEncoder (CAE)

  • novel masked image modeling (MIM) approach

  • for self-supervised representation pretraining


Goal of CAE :

  • pretrain an encoder by solving the pretext task
  • pretext task :
    • estimate the masked patches from the visible patches


Details :

  • step 1) feeds the visible patches into the encoder
    • extract representations
  • step 2) make predictions from visible patches to masked patches
  • introduce an alignment constraint
    • encourage the alignment between..
      • (1) representations of “predicted” masked patches
      • (2) representations of masked patches “computed from encoder”
  • step 3) predicted masked patch representations are mapped to the targets of the pretext task through a decoder


1. Introduction

previous MIM methods ( e.g., BEiT )

  • couple the encoding & pretext task competetion roles

\(\leftrightarrow\) CAE : separation of encoding (=RL) & pretext task


Downstream task

  • semantic segmentation
  • object detection
  • instance segmentation


CAE

  • propose CAE for improving encoding quality

  • randomly partition image into 2 set of patches

    • (1) visible

    • (2) masked

  • architecture

    • (1) encoder
    • (2) latent contextual regressor (with an alignment constraint)
    • (3) decoder


figure2


2. Approach

CAE pretrains the encoder,

by solving the masked image modeling task


(1) Architecture

randomly split an image into two sets of patches

  • (1) visible patches \(\mathbf{X}_v\)
  • (2) masked patches \(\mathbf{X}_m\)


pretext task :

  • predict the masked patches from visible patches in the encoded represrentation space
  • then, map the predicted representations to the targets


figure2


a) Encoder \(\mathcal{F}\)

  • learns representations only for VISIBLE patches

  • maps the visible patches \(\mathbf{X}_v\) to the latent representations \(\mathbf{Z}_v\)
  • ( use the ViT as \(\mathcal{F}\) )
  • process
    • step 1) embed visual patches
    • step 2) add positional embeddings \(\mathbf{P}_v\)
    • step 3) sends the combined embeddings into transformer blocks & generate \(\mathbf{Z}_v\)


b) Latent contextual regressor \(\mathcal{H}\)

  • predicts the masked patch representations from \(\mathbf{Z}_v\)
    • prediction ( = \(\mathbf{Z}_m\) ): constrained to align with the masked patch representations computed from encoder
  • ( use a series of transformer blocks as \(\mathcal{H}\) )

  • In this process, \(\mathbf{Z}_v\) are not updated


Initial queries \(\mathbf{Q}_m\) ( = mask queries )

  • mask tokens that are learned as model parameters

    ( = same for all the masked patches )

  • = Key & Value


c) Alignment constraint

  • imposed on \(\mathbf{Z}_m\) ( predicted by \(\mathcal{H}\))
  • feed the masked patches \(\mathbf{X}_m\) and generate \(\overline{\mathbf{Z}}_m\)
  • alignment between…
    • \(\mathbf{Z}_m\) and \(\overline{\mathbf{Z}}_m\)


d) Decoder

  • maps \(\mathbf{Z}_m\) to the targets for masked patches \(\mathbf{Y}_m\)
  • ( = stack of transformer blocks, based on self-attention )


(2) Objective Function

a) Masking & Targets

( following BEiT )

  • adopt random block-wise masking stratregy
  • for each image, 98 out of 196 (14x14) patches are masked


pre-trained DALL-E tokenizer

  • to generate discrete tokens for forming the targets

    ( = target tokens for maksed patches : \(\bar{\mathbf{Y}_m}\) )


b) Loss function

Loss = (1) decoding loss + (2) alignment loss


(1) decoding loss : \(\ell_y\left(\mathbf{Y}_m, \overline{\mathbf{Y}}_m\right)\) ….. CE loss

(2) alignment loss : \(\ell_z\left(\mathbf{Z}_m, \overline{\mathbf{Z}}_m\right)\) …. MSE loss

(3) total loss : \(\ell_y\left(\mathbf{Y}_m, \overline{\mathbf{Y}}_m\right)+\lambda \ell_z\left(\mathbf{Z}_m, \operatorname{sg}\left[\overline{\mathbf{Z}}_m\right]\right) .\)

Categories: ,

Updated: