Masked Siamese Networks for Label-Efficient Learning


Contents

  1. Abstract
  2. Prerequisites
    1. Problem Formulation
    2. Siamese Networks
    3. Vision Transformer
  3. Masked Siamese Networks (MSN)
    1. Input Views
    2. Patchify and Mask
    3. Encoder
    4. Similarity Metric and Predictions
    5. Training Objective


0. Abstract

propose Masked Siamese Networks (MSN)

  • a self-supervised learning framework for learning image representations
  • with randomly masked patches to the representation of the original unmasked image


1. Prerequisites

(1) Problem Formulation

Notation

  • \(\mathcal{D}=\left(\mathbf{x}_i\right)_{i=1}^U\) : unlabeled images
  • \(\mathcal{S}=\left(\mathbf{x}_{s i}, y_i\right)_{i=1}^L\) : labeled images
    • with \(L \ll U\).
    • images in \(\mathcal{S}\) may overlap with the images in \(\mathcal{D}\)


Goal :

  • learn image representations by first pre-training on \(\mathcal{D}\)
  • then adapting the representation to the supervised task using \(\mathcal{S}\)


(2) Siamese Networks

Goal : learn an encoder that produces similar image embeddings for two views of an image

  • encoder \(f_\theta(\cdot)\) : parameterized as DNN
  • representations \(z_i\) and \(z_i^{+}\)should match


(3) Vision Transformer

use Vision Transformer (ViT) architecture as encoder

  • step 1) extract a sequence of non-overlapping patches of resolution N × N from an image

  • step 2) apply a linear layer to extract patch tokens

  • step 3) add learnable positional embeddings to them

    ( extra learnable [CLS] token is added )

    ( = aggregate information from the full sequence of patches )

  • step 4) sequence of tokens is then fed to a stack of Transformer layers

    • composed of self-attention & FC layer ( + skip conn )
  • step 5) output of CLS token = output of encoder


2. Masked Siamese Networks (MSN)

combines invariance-based pre-training with mask denoising

Procedure

  • step 1) random data augmentations to generate 2 views of an image
    • anchor view & target view
  • step 2) random mask is applied to the anchor view
    • target view is left unchanged


( like clustering-based SSL approaches … )

\(\rightarrow\) learning occurs by computing a soft-distribution over a set of prototypes for both the anchor & target views

Objective (CE Loss)

  • assign the representation of the masked anchor view,

    to the same prototypes as the that of the unmasked target view



figure2


(1) Input Views

sample a mini-batch of \(B \geq 1\) images

for each image \(\mathbf{x}_i\) ….

  • step 1) apply a random set of data augmentations to generate..
    • target view = \(\mathbf{x}_i^{+}\)
    • \(M \geq 1\) anchor views = \(\mathbf{x}_{i, 1}, \mathbf{x}_{i, 2}, \ldots, \mathbf{x}_{i, M}\)


(2) Patchify and Mask

step 2) patchify each view ( into \(N \times N\) patches )

step 3) after patchifying the anchor view \(\mathbf{x}_{i, m}\) ….

  • apply the additional step of masking

    ( by randomly dropping some of the patches )

  • Notation

    • \(\hat{\mathbf{x}}_{i, m}\) = sequence of masked anchor
    • \(\hat{\mathbf{x}}_i^{+}\) = sequence of unmasked target patches

    ( because of masking, they can have different length )


figure2

2 strategies for masking the anchor views

  • (1) Random Masking
  • (2) Focal Masking


(3) Encoder

anchor encoder \(f_\theta(\cdot)\)

  • output : \(z_{i, m} \in \mathbb{R}^d\)

    ( = representation of patchified (and masked) anchor view \(\hat{\mathbf{x}}_{i, m}\) )


target decoder \(f_{\bar{\theta}}(\cdot)\)

  • output : \(z_i^{+} \in \mathbb{R}^d\)

    ( = representation of patchified target view \(\hat{\mathbf{x}}_i^{+}\) )


\(\rightarrow\) \(\bar{\theta}\) are updated via an exponential moving average of \(\theta\)

( +Both encoders correspond to the trunk of a ViT )


output of network = representation of [CLS] token


(4) Similarity Metric and Predictions

\(\mathbf{q} \in \mathbb{R}^{K \times d}\) : learnable prototypes

to train encoder …

  • compute a distribution based on the similarity between
    • (1) prototypes
    • (2) each anchor and target view pair
  • penalize the encoder for differences between these distributions


For an anchor representation \(z_{i, m}\)…

  • compute a prediction \(p_{i, m} \in \Delta_K\)

    ( by measuring the cosine similarity to the prototypes matrix \(\mathbf{q} \in \mathbb{R}^{K \times d}\) )

  • predictions \(p_{i, m}\) : \(p_{i, m}:=\operatorname{softmax}\left(\frac{z_{i, m} \cdot \mathbf{q}}{\tau}\right)\)


For an target representation \(z_i^{+}\) ….

  • generate a prediction \(p_i^{+} \in \Delta_K\)

    ( by measuring the cosine similarity to the prototypes matrix \(\mathbf{q} \in \mathbb{R}^{K \times d}\) )

  • predictions \(p_{i, m}^{+}\) : \(p_{i, m}^{+}:=\operatorname{softmax}\left(\frac{z_{i, m} \cdot \mathbf{q}}{\tau^{+}}\right)\)

\(\rightarrow\) always choose \(\tau^{+}<\tau\) to encourage sharper target predictions


(5) Training Objective

when training encoder…

\(\rightarrow\) penalize when the anchor prediction \(p_{i, m}\) is different from the target prediction \(p_i^{+}\)

( enforce this by using CE-loss \(H\left(p_i^{+}, p_{i, m}\right)\). )


also, incorporate mean entropy maximization (ME-MAX) regularizer

( to encourage the model to utilize the full set of prototypes )

  • average prediction across all the anchor views = \(\bar{p}:=\frac{1}{M B} \sum_{i=1}^B \sum_{m=1}^M p_{i, m}\)
  • meaning = maximize \(H(\bar{p})\)


overall objective

  • parameter : encoder parameters \(\theta\) and prototypes \(q\)
  • loss function : \(\frac{1}{M B} \sum_{i=1}^B \sum_{m=1}^M H\left(p_i^{+}, p_{i, m}\right)-\lambda H(\bar{p})\)

( aware ) only compute gradients with respect to the anchor predictions \(p_{i, m}\)

( not the target predictions \(p_i^{+}\) )

Categories: ,

Updated: