VICReg : Variance-Invariance-Covariance Regularization for Self-Supervised LearningPermalink
ContentsPermalink
- Abstract
- VICReg : Intuition
- VICReg : Detailed Description
- Method
0. AbstractPermalink
main challenge of SSL
- prevent a collapse in which the encoders produce constant or non-informative vectors
introduce VICReg (Variance-Invariance-Covariance Regularization)
-
explicitly avoids the collapse problem,
with two regularizations terms applied to both embeddings separately
- (1) a term that maintains the variance of each embedding dimension above a threshold
- (2) a term that decorrelates each pair of variables
-
does not require other techniques…
1. VICReg : IntuitionPermalink
VICReg (Variance-Invariance-Covariance Regularization)
- a self-supervised method for training joint embedding architectures
- basic idea : use a loss function with three terms
- (1) Invariance :
- the MSE between the embedding vectors
- (2) Variance :
- a hinge loss to maintain the standard deviation (over a batch) of each variable
- forces the embedding vectors of samples within a batch to be different
- (3) Covariance :
- a term that attracts the covariances (over a batch) between every pair of (centered) embedding variables towards zero
- decorrelates the variables of each embedding
- prevents an informational collapse
- (1) Invariance :
2. VICReg : Detailed DescriptionPermalink
use a Siamese net
- encoder : fθ ….. outputs representation
- expander : hϕ ….. maps the representations into an embedding
- role 1 ) eliminate the information by which the two representations differ
- role 2 ) expand the dimension in a non-linear fashion so that decorrelating the embedding variables will reduce the dependencies between the variables of the representation vector.
- loss function : s
- learns invariance to data transformations
- regularized with a variance term v and a covariance term c
( After pretraining, the expander is discarded )
(1) MethodPermalink
a) NotationPermalink
-
image i , from dataset D
-
2 image transformations ( = random crops of the image, followed by color distortions )
- x=t(i).
- x′=t′(i).
-
2 representations
- y=fθ(x).
- y′=fθ(x′).
-
2 embeddings
- z=hϕ(y).
- z′=hϕ(y′).
→ Loss is computed on these embeddings
-
Batch of embeddings : Z′=[z′1,…,z′n].
- zj : vector composed of each value at dimension j in all vectors in Z
b) variance, invariance and covariance termsPermalink
- Variance regularization term v
- a hinge function on the standard deviation of the embeddings along the batch dimension:
- v(Z)=1d∑dj=1max(0,γ−S(zj,ϵ)).
- S(x,ϵ)=√Var(x)+ϵ,.
- encourages the variance inside the current batch to be equal to γ
- prevent collapse with all inputs to be mapped to same vector
-
Covariance matrix of Z
-
C(Z)=1n−1∑ni=1(zi−ˉz)(zi−ˉz)T, where ˉz=1n∑ni=1zi.
-
( inspired by Barlow Twins ) define covariance regularization as…
→ sum of the squared off-diagonal coefficients of C(Z)
→ c(Z)=1d∑i≠j[C(Z)]2i,j.
-
- Invariance criterion s ( between Z and Z′ ) - MSE between each pair of vectors - s(Z,Z′)=1n∑i∣∣zi−z′i∣∣22.
c) overall loss functionPermalink
ℓ(Z,Z′)=λs(Z,Z′)+μ[v(Z)+v(Z′)]+ν[c(Z)+c(Z′)].
overall objective function ( over an unlabelled dataset D )
- L=∑I∈D∑t,t′∼Tℓ(ZI,Z′I).