VIME : Extending the Success of Self- and Semi-supervised Learning to Tabular Domain
Contents
- Abstract
- Introduction
- Related Works
- Self-SL
- Semi-SL
- Problem Formulation
- Self-SL
- Semi-SL
- Proposed Model : VIME
- Self-SL for tabular data
- Semi-SL for tabular data
0. Abstract
Self- and semi-supervised learning :
-
a lot of progress in NLP & CV
-
heavily rely on the unique structure in the domain datasets
- ex) NLP : semantic relationships in language
- ex) CV : spatial relationships in images
\(\rightarrow\) not adaptable to general tabular data
VIME (proposal)
( = Value Imputation and Mask Estimation )
- novel self- and semi-supervised learning for tabular data
- (1) create a novel “pretext task” :
- estimating mask vectors from corrupted tabular data
- (2) introduce a novel “tabular data augmentation method”
1. Introduction
Scarce labeled datsets
- ex) 100,000 Genomes project
- sequenced 100,000 genomes from around 85,000 NHS patients affected by a rare disease
- rare diseases occur in 1 / 2000 people
No effective self / semi SL for tabular data
\(\because\) heavily rely on the spatial or semantic structure of image or language data
[ Self SL ] pretext tasks
- ex) ( NLP ) BERT : 4 different pretext tasks
- ex) ( CV ) rotation, jigsaw puzzle, and colorization
[ Semi SL ] regularizer
-
regularizers they use for the predictive model are based on some prior knowledge of these data structures
-
ex) consistency regularizer :
-
encourages the predictive model to have the same output distribution on a sample and its augmented variants ( ex. rotation, convex combination of images )
\(\rightarrow\) not applicable in Tabular data
-
Contribution
propose novel self /semi SL for tabular data
-
(1) self SL : introduce a novel pretext task
- mask vector estimation ( in addition to feature vector estimation )
-
(2) semi SL : introduce a novel tabular DA
-
use the trained encoder to generate multiple augmented samples
-
by masking each data using several different masks
& imputing the corrupted values for each masked data point
-
-
\(\rightarrow\) propose VIME (Value Imputation and Mask Estimation)
2. Related Works
(1) Self-SL
categorized into 2 types :
- (1) pretext task
- (2) contrastive learning
(1) pretext task
- mostly appropriate for images / natural language
- examples )
- images ) surrogate classes prediction (scaling and translation), rotation degree predictions, colorization, relative position of patches estimation, jigsaw puzzle solving, image denoising, partial-to-partial registration
- natural language ) next words and previous words predictions
(2) contrastive learning
- also mostly appropriate for images / natural language
- examples ) contrastive predictive coding, contrastive multi-view coding, SimCLR, momentum contrast
existing work on self-supervised learning, applied to tabular data
- DAE (Denoising Auto-Encoder)
- pretext task : recover the original sample from a corrupted sample
- Context Encoder
- pretext task : reconstruct the original sample from both the corrupted sample and the mask vector.
- TabNet, TaBERT
- pretext task : recovering corrupted tabular data
[ Proposal ]
- new pretext task : recover the mask vector
- novel tabular data augmentation
(2) Semi-SL
categorized into 2 types :
- (1) entropy minimization
- (2) consistency regularization
(1) entropy minimization
- encourages a classifier to output low entropy predictions on unlabeled data
(2) consistency regularization
- encourages consistency between a data & stochastically altered version of data
-
ex) Mean teacher, VAT
- ex) MixMatch, ReMixMatch = (1) + (2)
3. Problem Formulation
Notation
- \(\mathcal{D}_l=\left\{\mathbf{x}_i, y_i\right\}_{i=1}^{N_l}\) : (small) LABELED data
- \(\mathcal{D}_u=\left\{\mathbf{x}_i\right\}_{i=N_l+1}^{N_l+N_u}\) : (large) UNLABELED data
- where \(N_u \gg N_l, \mathbf{x}_i \in \mathcal{X} \subseteq \mathbb{R}^d\)
- where \(N_u \gg N_l, \mathbf{x}_i \in \mathcal{X} \subseteq \mathbb{R}^d\) & \(y_i \in \mathcal{Y}\)
- \(y_i\) : scalar / multi-dim vector
- scalar ( in single-task learning )
- multi-dim vector ( in multi-task learning )
- \(y_i\) : scalar / multi-dim vector
- \(f: \mathcal{X} \rightarrow \mathcal{Y}\) : predictive model
- Loss
- \(\sum_{i=1}^{N_l} l\left(f\left(\mathbf{x}_i\right), y_i\right)\) : empirical supervised loss
- \(\mathbb{E}_{(\mathbf{x}, y) \sim p_{X, Y}}[l(f(\mathbf{x}), y)]\) : expected supervised loss
Assumption :
- \(\mathbf{x}_i\) in \(\mathcal{D}_l\) and \(\mathcal{D}_u\) is sampled i.i.d. from a feature distribution \(p_X\)
- \(\left(\mathbf{x}_i, y_i\right)\) in \(\mathcal{D}_l\) are drawn from a joint distribution \(p_{X, Y}\)
- only limited labeled samples from \(p_{X, Y}\) are available
(1) Self-SL
focus on pretext tasks
- challenging, but highly relevant to the downstream tasks that we attempt to solve
self-supervised learning
-
(1) encoder function \(e: \mathcal{X} \rightarrow \mathcal{Z}\)
- input : \(\mathbf{x} \in \mathcal{X}\)
- output : \(\mathbf{z}=e(\mathbf{x}) \in \mathcal{Z}\)
- \(\mathbf{z}\) is optimized to solve a pretext task,
- defined with (1) pseudo-label \(y_s \in \mathcal{Y}_s\) & (2) self-supervised loss function \(l_{s s}\)
- \(\mathbf{z}\) is optimized to solve a pretext task,
-
(2) pretext predictive model : \(h: \mathcal{Z} \rightarrow \mathcal{Y}_s\)
-
trained jointly with the encoder function \(e\) ,
by minimizing the expected self-supervised loss function \(l_{s s}\)
( \(\min _{e, h} \mathbb{E}_{\left(\mathbf{x}_s, y_s\right) \sim p_{X_s, Y_s}}\left[l_{s s}\left(y_s,(h \circ e)\left(\mathbf{x}_s\right)\right)\right]\) )
-
(2) Semi-SL
optimizes the predictive model \(f\) , by minimizing (1) + (2)
- (1) supervised loss function
- (2) unsupervised loss function
\(\min _f \mathbb{E}_{(\mathbf{x}, y) \sim p_{X Y}}[l(y, f(\mathbf{x}))]+\beta \cdot \mathbb{E}_{\mathbf{x} \sim p_X, \mathbf{x}^{\prime} \sim \tilde{p}_X\left(\mathbf{x}^{\prime} \mid \mathbf{x}\right)}\left[l_u\left(f(\mathbf{x}), f\left(\mathbf{x}^{\prime}\right)\right)\right]\).
- \(\mathbf{x}^{\prime}\) : perturbed version of \(\mathbf{x}\)
- assumed to be drawn from a conditional distribution \(\tilde{p}_X\left(\mathbf{x}^{\prime} \mid \mathbf{x}\right)\)
- term (1) : estimated using \(\mathcal{D}_l\)
- term (2) : estimated using \(\mathcal{D}_u\)
4. Proposed Model: VIME
(1) Self-SL : propose two pretext tasks
(2) Semi-SL : develop an unsupervised loss function
(1) Self-SL for tabular data
propose 2 pretext tasks :
- (1) feature vector estimation
- (2) mask vector estimation.
Goal : optimize a pretext model to….
- recover an input sample (a feature vector) from its corrupted variant,
- estimate the mask vector that has been applied
Notation
-
pretext distribution : \(p_{X_s, Y_s}\)
-
binary mask vector : \(\mathbf{m}=\left[m_1, \ldots, m_d\right]^{\top} \in\{0,1\}^d\)
-
\(m_j\) : randomly sampled from a Bernoulli distribution with prob \(p_m\)
( \(p_{\mathbf{m}}=\prod_{j=1}^d \operatorname{Bern}\left(m_j \mid p_m\right)\) )
-
-
pretext generator : \(g_m: \mathcal{X} \times\{0,1\}^d \rightarrow \mathcal{X}\)
- input : \(\mathbf{x}\) from \(\mathcal{D}_u\) & mask vector \(\mathbf{m}\)
- output : masked sample \(\tilde{\mathbf{x}}\)
Pretext Generation
\(\tilde{\mathbf{x}}=g_m(\mathbf{x}, \mathbf{m})=\mathbf{m} \odot \overline{\mathbf{x}}+(1-\mathbf{m}) \odot \mathbf{x}\).
- where the \(j\)-th feature of \(\overline{\mathbf{x}}\) is sampled from the empirical distribution \(\hat{p}_{X_j}=\frac{1}{N_u} \sum_{i=N_l+1}^{N_l+N_u} \delta\left(x_j=\right.\) \(x_{i, j}\) )
- where \(x_{i, j}\) is the \(j\)-th feature of the \(i\)-th sample in \(\mathcal{D}_u\)
- corrupted sample \(\tilde{\mathbf{x}}\) is not only tabular but also similar to the samples in \(\mathcal{D}_u\)
- (compared to Gaussian Noise, etc … )
- generates \(\tilde{\mathbf{x}}\) that is more difficult to distinguish from \(\mathbf{x}\)
Two randomness
( in pretext distribution \(p_{X_s, Y_s}\) )
- (1) \(\mathbf{m}\) : random vector ( randomness from Bernoulli distn )
- (2) \(g_m\) : pretext generator ( randomness from \(\overline{\mathbf{x}}\) )
\(\rightarrow\) increases the difficulty of reconstruction
( difficulty can be adjusted by hyperparameter \(p_m\) ( = prob of corruption) )
Compared to conventional methods, more challenging
( conventional methods ex : rotation, coloring… )
- (conventional) just correcting the raw value
- (proposed masking) completely removes some of the features from \(\mathbf{x}\) & replaces them with a noise sample \(\overline{\mathbf{x}}\) , which each feature may come from a different random sample in \(\mathcal{D}_u\)
Divide the task into 2 sub-tasks ( = pretext tasks )
- (1) Mask vector estimation : predict which features have been masked
- (2) Feature vector estimation : predict the values of the features that have been corrupted.
Predictive model
Separate pretext predictive model ( for each pretext task )
- (1) Mask vector estimator, \(s_m: \mathcal{Z} \rightarrow[0,1]^d\)
- (2) Feature vector estimator, \(s_r: \mathcal{Z} \rightarrow \mathcal{X}\)
Loss Function
\(\min _{e, s_m, s_r} \mathbb{E}_{\mathbf{x} \sim p_X, \mathbf{m} \sim p_{\mathbf{m}}, \tilde{\mathbf{x}} \sim g_m(\mathbf{x}, \mathbf{m})}\left[l_m(\mathbf{m}, \hat{\mathbf{m}})+\alpha \cdot l_r(\mathbf{x}, \hat{\mathbf{x}})\right]\).
- \(\hat{\mathbf{m}}=\left(s_m \circ e\right)(\tilde{\mathbf{x}})\) , \(\hat{\mathbf{x}}=\left(s_r \circ e\right)(\tilde{\mathbf{x}})\)
Term 1) \(l_m\) : sum of the BCE for each dimension of the mask vector
- \(l_m(\mathbf{m}, \hat{\mathbf{m}})=-\frac{1}{d}\left[\sum_{j=1}^d m_j \log \left[\left(s_m \circ e\right)_j(\tilde{\mathbf{x}})\right]+\left(1-m_j\right) \log \left[1-\left(s_m \circ e\right)_j(\tilde{\mathbf{x}})\right]\right],\).
Term 2) \(l_r\) : reconstruction loss
- \(l_r(\mathbf{x}, \hat{\mathbf{x}})=\frac{1}{d}\left[\sum_{j=1}^d\left(x_j-\left(s_r \circ e\right)_j(\tilde{\mathbf{x}})\right)^2\right]\).
- ( for categorical variables, modified with CE loss )
Intuition
-
important for \(e\) to capture the correlation among the features of \(x\)
-
\(s_m\) : identify the masked features from the inconsistency between feature values
-
\(s_r\) : impute the masked features by learning from the correlated non-masked features
-
ex) if the value of a feature is very different from its correlated features,
\(\rightarrow\) this feature is likely masked and corrupted
(2) Semi-SL for tabular data
show how the encoder \(e\) can be used in semi-supervised learning
( \(f_e=f \circ e\) , \(\hat{y}=f_e(\mathbf{x})\) )
Train predictive model \(f\), with loss function below :
\(\mathcal{L}_{\text {final }}=\mathcal{L}_s+\beta \cdot \mathcal{L}_u\),
- ( supervised loss ) \(\mathcal{L}_s\)
- \(\mathcal{L}_s=\mathbb{E}_{(\mathbf{x}, y) \sim p_{X Y}}\left[l_s\left(y, f_e(\mathbf{x})\right)\right]\).
- ex) MSE, CE
- ( unsupervised (consistency) loss ) \(\mathcal{L}_u\)
- \(\mathcal{L}_u=\mathbb{E}_{\mathbf{x} \sim p_X, \mathbf{m} \sim p_{\mathrm{m}}, \tilde{\mathbf{x}} \sim g_m(\mathbf{x}, \mathbf{m})}\left[\left(f_e(\tilde{\mathbf{x}})-f_e(\mathbf{x})\right)^2\right]\).
- ( inspired by the idea in consistency regularizer )
stochastic approximation of \(\mathcal{L}_u\) :
- with \(K\) augmented samples
- \(\hat{\mathcal{L}}_u=\frac{1}{N_b K} \sum_{i=1}^{N_b} \sum_{k=1}^K\left[\left(f_e\left(\tilde{\mathbf{x}}_{i, k}\right)-f_e\left(\mathbf{x}_i\right)\right)^2\right]=\frac{1}{N_b K} \sum_{i=1}^{N_b} \sum_{k=1}^K\left[\left(f\left(\mathbf{z}_{i, k}\right)-f\left(\mathbf{z}_i\right)\right)^2\right]\).
- \(N_b\) : batch size
output for a new test sample \(\mathbf{x}^t\) : \(\hat{y}=f_e\left(\mathbf{x}^t\right)\)