Transfer Learning with Deep Tabular Models (ICLR 2023)

https://openreview.net/pdf?id=b0RuGUYo8pA


Contents

  1. Abstract
  2. Introduction
  3. TL setup in Tabular Domain
    1. MetaMIMIC for TL
    2. Tabular models
    3. TL setups and baselines
  4. Results for TL
  5. SSL pretraining
    1. MLM
    2. CL
    3. Sup vs. Self-Sup pretraining
  6. Pseudo features


Abstract

Major advantage of NN :

  • easily fine-tuned in NEW domains & learn REUSABLE features


Propose Transfer Learning (TL) with Tabular DL


1. Introduction

Design a benchmark TL task using MetaMIMIC repository

  • Compare GBDT methods vs. DL methods

  • Compare Supervised pre0training vs. Self-supervised pre-training


Propose pseudo-feature method

  • for case when UPstream data features \(\)\neq\(\) DOWNstream data features
    • ex) \(\)x_i\(\) is only in DOWNsteam data
  • Details
    • Step 1) pretrain with UPSTREAM (w/o \(\)x_i\(\))
    • Step 2) finetune with DOWNSTREAM
      • task: predicting \(\)x_i\(\)
    • Step 3) assign pseudo-values \(\)\hat{x_i}\(\) to UPSTREAM
    • Step 4) pretrain with UPSTREAM (with \(\)\hat{x_i}\(\))
    • Step 5) finetune with DOWNSTREAM


Contributions

  • Deep Tabular models + TL
  • Compare two pre-training settings
    • (1) SUPERVISED pre-training
    • (2) SELF-SUPERVISED pre-training
  • Pseudo-feature method
    • to algin UPstream & DOWNstream features


2. TL setup in Tabular Domain

(1) MetaMIMIC for TL

figure2

a) MetaMIMIC

  • medical diagnosis data
  • contains similar test results (features) across patients
  • 12 binary prediction tasks
    • related tasks of varied similarity \(\)\rightarrow\(\) suitable for TL
  • 34925 patients
  • 172 features ( 1 categorical … gender )


b) UPstream & DOWNstream tasks

By splitting MetaMIMIC data (12)

= (11) upstream + (1) downstream

  • # of data in downstream: 4/10/20/100/200 ( 5 scenarios )

\(\)\rightarrow\(\) total of 60 combinations


(2) Tabular models

6 models = 4 DL + 2 GBDT

  • 4 DL = FT-Transformer + TabTransformer + MLP + ResNet
  • 2 GBDT = Catboost + XGBoost


(3) TL setups and baselines

For downstream classification head.. 4 options

  • (1) classification head: Linear vs MLP
  • (2) fine-tune vs freeze


Baselines

  • NN from scratch ( on downstream data )

  • Catboost & XG boost

    • with stacking
    • without stacking

    ( stacking = 11 upstream targets as input features of downstream task )


3. Results for TL

Compare DL methods vs GBDT methods

  • Metric : rank aggregation metric
    • rank = take into account statistical significance of performance differences
  • Result : DL > GBDT at all data levels…especially in LOW data regime ( 4~20 downstream samples )


figure2


Summary

  • MLP is competitive, especially in LOW data regime
  • FT-Transformer offer consistent performance gains over GBDT on ALL data levels
  • Representation learning with DL brings significant gain
  • (most cases) MLP head > Linear head


4. SSL pretraining

(1) MLM

Randomly mask one feature & predict using other \(\)n-1\(\) features


(2) CL

( follow SAINT )

  • Cutmix in the input sapce
  • Mixup in the embedding space


(3) Sup vs. Self-Sup pretraining

figure2

  • CL > FS(From Scratch)
  • MLM < Supervised


5. Pseudo features

mentioned above

Categories: , ,

Updated: