An Overview of Deep Semi-Supervised Learning (2020) - Part 2


Contents

  1. Abstract
  2. Introduction
    1. SSL
    2. SSL Methods
    3. Main Assumptions in SSL
    4. Related Problems
  3. Consistency Regularization
    1. Ladder Networks
    2. Pi-Model
    3. Temporal Ensembling
    4. Mean Teachers
    5. Dual Students
    6. Fast-SWA
    7. Virtual Adversarial Training (VAT)
    8. Adversarial Dropout (AdD)
    9. Interpolation Consistency Training (ICT)
    10. Unsupervised Data Augmentation
  4. Entropy Minimization
  5. Proxy-label Methods
    1. Self-training
    2. Multi-view Training
  6. Holistic Methods
    1. MixMatch
    2. ReMixMatch
    3. FixMatch
  7. Generative Models
    1. VAE for SSL
    2. GAN for SSL
  8. Graph-Based SSL
    1. Graph Construction
    2. Label Propagation
  9. Self-Supervision for SSL


3. Entropy Minimization

  • encourage the network to make confident (i.e., low-entropy) predictions on unlabeled data

  • by adding a loss term which minimizes the entropy of the prediction function \(f_\theta(x)\)

    • ex) categorical output space with \(C\) classes : \(-\sum_{k=1}^C f_\theta(x)_k \log f_\theta(x)_k\).


However …model can quickly overfit to low confident data

\(\rightarrow\) entropy minimization doesn’t produce competitive results compared to other SSL methods

( but can produce state-of-the-art results when combined with different approaches )


4. Proxy-label Methods

(1) Self-training

Procedure

  • step 1) small amount of labeled data \(D_l\) is first used to train a prediction function \(f_{\theta}\)

  • step 2) \(f_{\theta}\) is used to assign pseudo-labels to \(x \in D_u\)

    • confidence above pre-determined threshold \(\tau\)

      ( Other heuristics can be used .

      ex) using the relative confidence instead of the absolute confidence )


Impact of self-training \(\approx\) entropy minimization

  • both forces to output more confident predictions


The main downside of such methods : unable to correct its own mistake


Meta Pseudo Labels

  • student-teacher setting
  • teacher : produce the proxy labels
    • based on an efficient meta-learning algorithm called Meta Pseudo Labels (MPL)
    • updated by policy gradients
  • encourages the teacher to adjust the target distns of training examples in a way that improves the learning of the student model


figure2


Procedure

  • step 1) Student learns from the Teacher
    • given a single input example \(x \in D_l\), teacher \(f_{\theta^{'}}\) produces a target class-distn to train student \(f_{\theta}\)
    • the student is shown \((x, f_{\theta^{'}}(x))\) & update its parameters
  • step 2) Teacher learns from the Student
    • new parameter \(\theta(t+1)\) are evaluated on an example \((x_{val}, y_{val})\)


(2) Multi-view Training

utilizes multi-view data

  • different views can be collected by …
    • different measuring methods
    • creating limited views of the original data


Notation

  • prediction function : \(f_{\theta_i}\)
  • view of \(x\) : \(v_i(x)\)


a) Co-training

  • 2 conditionally independent views \(v_1(x)\) and \(v_2(x)\)
  • 2 prediction functions \(f_{\theta_1}\) and \(f_{\theta_2}\)
    • trained on a specific view on the labeled set \(\mathcal{D}_l\)


Step 1) train \(f_{\theta_1}\) and \(f_{\theta_2}\) on \(\mathcal{D}_l\)

Step 2) Proxy Labeling

  • unlabeled data is added to the training set of the \(f_{\theta_i}\) if \(f_{\theta_j}\) outputs a confident prediction


etc

  • different learning algorithms to learn two different classifiers OK
  • two views \(v_1(x)\) and \(v_2(x)\) can be generated by injecting noise or DA


Democratic Co-training

  • replacing the different views of the input data with a number of models with different architectures and learning algorithms
  • trained models are then used to label a given example x if a majority of models confidently agree on its label.


b) Tri-Training

  • to overcome the lack of data with multiple views
  • utilize the agreement of 3 independently trained models


Procedure

  • step 1) \(\mathcal{D}_l\) is used to train \(f_{\theta_1}\), \(f_{\theta_2}\), \(f_{\theta_3}\)
  • step 2) \(x \in \mathcal{D}_u\) is then added to the training set of the function \(f_{\theta_i}\)
    • if the other 2 models agree on its predicted label
  • step 3) training stops if no data points are being added


Pros : generally applicable!

  • requires neither the existence of multiple views nor unique learning algorithms

Cons : using with NN \(\rightarrow\) expensive

  • propose to sample a limited number of unlabeled data at each training epoch

    ( the candidate pool size is increased as the training progresses )


Multi-task tri-training

  • used to reduce the time and sample complexity

  • all three models share the same feature-extractor with model-specific classification layers

  • ex) Tri-Net


Cross-View training

  • (previous works) self-training
    • model plays a dual role of a teacher and a student
  • inspiration from multi-view learning and consistency training
  • model is trained to produce consistent predictions across different views of the inputs
  • Instead of using a single model as a teacher and a student…
    • use a shared encoder & auxiliary prediction modules
    • auxiliary prediction modules
      • (1) auxiliary student modules
      • (2) primary teacher module


figure2

  • primary teacher module :
    • trained only on labeled examples
    • generate the pseudo-labels, by taking as input the full view of the unlabeled inputs
  • student modules:
    • trained to have consistent predictions with the teacher module


Loss fuction

  • \(\mathcal{L}=\mathcal{L}_u+\mathcal{L}_s=\frac{1}{ \mid \mathcal{D}_u \mid } \sum_{x \in \mathcal{D}_u} \sum_{i=1}^K d_{\mathrm{MSE}}\left(t(e(x)), s_i(e(x))\right)+\frac{1}{ \mid \mathcal{D}_l \mid } \sum_{x, y \in \mathcal{D}_l} \mathrm{H}(t(e(x)), y)\).
    • encoder : \(e\)
    • teacher module : \(t\)
    • \(K\) student modules : \(s_i\) ( where \(i \in [0,K]\) )
      • receives only a limited view of the input


The student can learn from the teacher

\(\because\) teacher : UNlimited view & student : limited view


5. Holistic Methods

unify dominant methods in SSL in a single framework


(1) MixMatch

Input

  • batch from \(\mathcal{D}_l\)
  • batch from \(\mathcal{D}_u\)
  • hyperparameters
    • sharpening temperature \(T\)
    • number of augmentations \(K\)
    • Beta distn parameter \(\alpha\)

Output

  • batch of augmented labeled examples
  • batch of augmented unlabeled examples + proxy labels

\(\rightarrow\) can be used to compute the losses


Procedure

  • step 1) Data Augmentation

    • \(\tilde{x}_1, \ldots, \tilde{x}_K\).
  • step 2) Label Guessing

    • producing proxy labels
      • step 2-1) generate the predictions for \(\tilde{x}_1, \ldots, \tilde{x}_K\)
      • step 2-2) average the predictions : \(\hat{y}=1 / K \sum_{k=1}^K\left(\hat{y}_k\right)\)
    • result : \(\left(\tilde{x}_1, \hat{y}\right), \ldots,\left(\tilde{x}_K, \hat{y}\right)\).
  • step 3) Sharpening

    • to produce confident predictions
    • \((\hat{y})_k=(\hat{y})_k^{\frac{1}{T}} / \sum_{k=1}^C(\hat{y})_k^{\frac{1}{T}}\).
  • step 4) MixUp

    • Notation
      • \(\mathcal{L}\) : augmented labeled data
      • \(\mathcal{U}\) : augmented unlabeled data + proxy labels
        • \(K\) times larger than the original batch!
    • step 1) mix these 2 batches
      • \(\mathcal{W}=\operatorname{Shuffle}(\operatorname{Concat}(\mathcal{L}, \mathcal{U}))\).
    • step 2) divide \(\mathcal{W}\)
      • (1) \(\mathcal{W}_1\) : same size as \(\mathcal{L}\)
      • (2) \(\mathcal{W}_2\) : same size as \(\mathcal{U}\)
    • step 3) Mixup operation ( + modification )
      • \(\mathcal{L}^{\prime}=\operatorname{MixUp}\left(\mathcal{L}, \mathcal{W}_1\right)\).
      • \(\mathcal{U}^{\prime}=\operatorname{MixUp}\left(\mathcal{U}, \mathcal{W}_2\right)\).


Loss Function

  • standard SSL losses
    • CE loss ( supervised )
    • Consistency loss ( unsupervised )
  • \(\mathcal{L}=\mathcal{L}_s+w \mathcal{L}_u=\frac{1}{ \mid \mathcal{L}^{\prime} \mid } \sum_{x, y \in \mathcal{L}^{\prime}} \mathrm{H}\left(y, f_\theta(x)\right)+w \frac{1}{ \mid \mathcal{U}^{\prime} \mid } \sum_{x, \hat{y} \in \mathcal{U}^{\prime}} d_{\mathrm{MSE}}\left(\hat{y}, f_\theta(x)\right)\).


(2) ReMixMatch

improve MixMatch by introducing 2 new techniques :

  • (1) distribution alignment
  • (2) augmentation anchoring


figure2


a) Distribution alignment

  • encourages the marginal distn of predictions on unlabeled data to be close to the marginal distn of GT labels

  • ex) given prediction result \(f_\theta(x)\) …

    \(\rightarrow\) \(f_\theta(x)=\operatorname{Normalize}\left(f_\theta(x) \times p(y) / \tilde{y}\right)\)


b) Augmentation anchoring

  • feeds multiple strongly augmented versions of the input into the model

    & encourage each output to be close to the prediction for a weakly-augmented version of the same input

  • model’s prediction for a weakly augmented version = proxy label


Loss Function

\(\mathcal{L}=\mathcal{L}_s+w \mathcal{L}_u=\frac{1}{ \mid \mathcal{L}^{\prime} \mid } \sum_{x, y \in \mathcal{L}^{\prime}} \mathrm{H}\left(y, f_\theta(x)\right)+w \frac{1}{ \mid \mathcal{U}^{\prime} \mid } \sum_{x, \hat{y} \in \mathcal{U}^{\prime}} \mathrm{H}\left(\hat{y}, f_\theta(x)\right)\),


Also add a self-supervised loss

  • new unlabeled batch \(\hat{\mathcal{U}}^{\prime}\) of examples is created
    • by rotating all of the examples with an angle \(r \sim\{0,90,180,270\}\). T
  • \(\mathcal{L}_{S L}=w^{\prime} \frac{1}{ \mid \hat{\mathcal{U}}^{\prime} \mid } \sum_{x, \hat{y} \in \hat{\mathcal{U}}^{\prime}} \mathrm{H}\left(\hat{y}, f_\theta(x)\right)+\lambda \frac{1}{ \mid \hat{\mathcal{U}}^{\prime} \mid } \sum_{x \in \hat{\mathcal{U}}^{\prime}} \mathrm{H}\left(r, f_\theta(x)\right)\).


(3) FixMatch

simple SSL algorithm, that combines

  • (1) consistency regularization
  • (2) pseudo-labeling

both the supervised & unsupervised losses : CE Loss


figure2


Labeled & Unlabeled

  • Labeled : just use gt
  • Unlabeled :
    • step 1) 1 weakly augmented version is computed
      • Weak augmentation function : \(A_w\)
    • step 2) predict it & use it as proxy label
      • if confidence > \(\tau\)
    • step 3) \(K\) strongly augmented versions are computed


Unsupervised Loss :

  • \(\mathcal{L}_u=w \frac{1}{K \mid \mathcal{D}_u \mid } \sum_{x \in \mathcal{D}_u} \sum_{i=1}^K 1\left(\max \left(f_\theta\left(A_w(x)\right)\right) \geq \tau\right) \mathrm{H}\left(f_\theta\left(A_w(x)\right), f_\theta\left(A_s(x)\right)\right)\).


Categories:

Updated: