Masked Siamese Networks for Label-Efficient Learning
Contents
- Abstract
- Prerequisites
    - Problem Formulation
- Siamese Networks
- Vision Transformer
 
- Masked Siamese Networks (MSN)
    - Input Views
- Patchify and Mask
- Encoder
- Similarity Metric and Predictions
- Training Objective
 
0. Abstract
propose Masked Siamese Networks (MSN)
- a self-supervised learning framework for learning image representations
- with randomly masked patches to the representation of the original unmasked image
1. Prerequisites
(1) Problem Formulation
Notation
- \(\mathcal{D}=\left(\mathbf{x}_i\right)_{i=1}^U\) : unlabeled images
- \(\mathcal{S}=\left(\mathbf{x}_{s i}, y_i\right)_{i=1}^L\) : labeled images
    - with \(L \ll U\).
- images in \(\mathcal{S}\) may overlap with the images in \(\mathcal{D}\)
 
Goal :
- learn image representations by first pre-training on \(\mathcal{D}\)
- then adapting the representation to the supervised task using \(\mathcal{S}\)
(2) Siamese Networks
Goal : learn an encoder that produces similar image embeddings for two views of an image
- encoder \(f_\theta(\cdot)\) : parameterized as DNN
- representations \(z_i\) and \(z_i^{+}\)should match
(3) Vision Transformer
use Vision Transformer (ViT) architecture as encoder
- 
    step 1) extract a sequence of non-overlapping patches of resolution N × N from an image 
- 
    step 2) apply a linear layer to extract patch tokens 
- 
    step 3) add learnable positional embeddings to them ( extra learnable [CLS] token is added ) ( = aggregate information from the full sequence of patches ) 
- 
    step 4) sequence of tokens is then fed to a stack of Transformer layers - composed of self-attention & FC layer ( + skip conn )
 
- 
    step 5) output of CLS token = output of encoder 
2. Masked Siamese Networks (MSN)
combines invariance-based pre-training with mask denoising
Procedure
- step 1) random data augmentations to generate 2 views of an image
    - anchor view & target view
 
- step 2) random mask is applied to the anchor view
    - target view is left unchanged
 
( like clustering-based SSL approaches … )
\(\rightarrow\) learning occurs by computing a soft-distribution over a set of prototypes for both the anchor & target views
Objective (CE Loss)
- 
    assign the representation of the masked anchor view, to the same prototypes as the that of the unmasked target view 

(1) Input Views
sample a mini-batch of \(B \geq 1\) images
for each image \(\mathbf{x}_i\) ….
- step 1) apply a random set of data augmentations to generate..
    - target view = \(\mathbf{x}_i^{+}\)
- \(M \geq 1\) anchor views = \(\mathbf{x}_{i, 1}, \mathbf{x}_{i, 2}, \ldots, \mathbf{x}_{i, M}\)
 
(2) Patchify and Mask
step 2) patchify each view ( into \(N \times N\) patches )
step 3) after patchifying the anchor view \(\mathbf{x}_{i, m}\) ….
- 
    apply the additional step of masking ( by randomly dropping some of the patches ) 
- 
    Notation - \(\hat{\mathbf{x}}_{i, m}\) = sequence of masked anchor
- \(\hat{\mathbf{x}}_i^{+}\) = sequence of unmasked target patches
 ( because of masking, they can have different length ) 

2 strategies for masking the anchor views
- (1) Random Masking
- (2) Focal Masking
(3) Encoder
anchor encoder \(f_\theta(\cdot)\)
- 
    output : \(z_{i, m} \in \mathbb{R}^d\) ( = representation of patchified (and masked) anchor view \(\hat{\mathbf{x}}_{i, m}\) ) 
target decoder \(f_{\bar{\theta}}(\cdot)\)
- 
    output : \(z_i^{+} \in \mathbb{R}^d\) ( = representation of patchified target view \(\hat{\mathbf{x}}_i^{+}\) ) 
\(\rightarrow\) \(\bar{\theta}\) are updated via an exponential moving average of \(\theta\)
( +Both encoders correspond to the trunk of a ViT )
output of network = representation of [CLS] token
(4) Similarity Metric and Predictions
\(\mathbf{q} \in \mathbb{R}^{K \times d}\) : learnable prototypes
to train encoder …
- compute a distribution based on the similarity between
    - (1) prototypes
- (2) each anchor and target view pair
 
- penalize the encoder for differences between these distributions
For an anchor representation \(z_{i, m}\)…
- 
    compute a prediction \(p_{i, m} \in \Delta_K\) ( by measuring the cosine similarity to the prototypes matrix \(\mathbf{q} \in \mathbb{R}^{K \times d}\) ) 
- 
    predictions \(p_{i, m}\) : \(p_{i, m}:=\operatorname{softmax}\left(\frac{z_{i, m} \cdot \mathbf{q}}{\tau}\right)\) 
For an target representation \(z_i^{+}\) ….
- 
    generate a prediction \(p_i^{+} \in \Delta_K\) ( by measuring the cosine similarity to the prototypes matrix \(\mathbf{q} \in \mathbb{R}^{K \times d}\) ) 
- 
    predictions \(p_{i, m}^{+}\) : \(p_{i, m}^{+}:=\operatorname{softmax}\left(\frac{z_{i, m} \cdot \mathbf{q}}{\tau^{+}}\right)\) 
\(\rightarrow\) always choose \(\tau^{+}<\tau\) to encourage sharper target predictions
(5) Training Objective
when training encoder…
\(\rightarrow\) penalize when the anchor prediction \(p_{i, m}\) is different from the target prediction \(p_i^{+}\)
( enforce this by using CE-loss \(H\left(p_i^{+}, p_{i, m}\right)\). )
also, incorporate mean entropy maximization (ME-MAX) regularizer
( to encourage the model to utilize the full set of prototypes )
- average prediction across all the anchor views = \(\bar{p}:=\frac{1}{M B} \sum_{i=1}^B \sum_{m=1}^M p_{i, m}\)
- meaning = maximize \(H(\bar{p})\)
overall objective
- parameter : encoder parameters \(\theta\) and prototypes \(q\)
- loss function : \(\frac{1}{M B} \sum_{i=1}^B \sum_{m=1}^M H\left(p_i^{+}, p_{i, m}\right)-\lambda H(\bar{p})\)
( aware ) only compute gradients with respect to the anchor predictions \(p_{i, m}\)
( not the target predictions \(p_i^{+}\) )
