Retrieval & Fine-Tuning for In-Context Tabular Models

Thomas, Valentin, et al. "Retrieval & fine-tuning for in-context tabular models." NeurIPS (2024)

arxiv: https://arxiv.org/pdf/2406.05207


figure2

Abstract

Tabular data + Transformer-based ICL

$\rightarrow$ Promising results on smaller & less complex datasets

$\rightarrow$ Limitation: Struggled to scale to larger & more complex ones


Proposal: LoCalPFN (locally-calibrated PFN)

Base model = TabPFN

Combination of (1) retrieval & (2) fine-tuning

  • (1) Retrieval
    • Local subset of the data by collecting kNN
  • (2) Fine-tuning (FT)
    • Task-specific FT with this retrieved set of neighbours in context

Experiments

  • Extensive evaluation on 95 datasets curated by TabZilla from OpenML


1. Introduction

Challenges of Tabular DL

  • Diversity & heterogeneity of tabular data
  • Tree-based methods
    • More robust to the inherent challenges of tabular data


Recent works: TabPFN

  • TabPFN = Tabular + ICL

  • Trained using a “prior-fitting” procedure

    $\rightarrow$ Encapsulating the heterogeneity of tabular data

  • Process entirely new datasets in a single forward pass w/o training / tuning


Limitation of TabPFN = Scaling issue

$\rightarrow$ Memory scales quadratically in the size of the context!


Proposal = LocalPFN

  • (1) Retrieval
    • kNN of a given query point as the context for classification
  • (2) Fine-tuning (FT)
    • FT end-to-end for each task
    • With an approximate neighbour scheme to facilitate backpropagation
  • Experiments: 95-dataset benchmark from TabZilla


2. Improving Tabular ICL with Retrieval and Fine-Tuning

(1) Preliminaries on ICL for Tabular Data & TabPFN

LoCalPFN: Applies to ICL, specifically for classification tasks on tabular data

**TabPFN **

  • Trained using a prior-fitting procedure
    • With a large number of synthetic datasets
  • Trains an underlying transformer-based NN on various generative processes


Details of TabPFN

[Input] Entire training dataset + test dataset

  • $\mathcal{D}{\text{train}} \triangleq \left{(x^{i}{\text{train}}, y^{i}{\text{train}})\right}{i=1}^{N}$.
    • Feature-label pairs $x^{i}{\text{train}} \in \mathbb{R}^D$ and $y^{i}{\text{train}} \in {1, \ldots, C}$
    • Query point $x_{\text{qy}}$ (potentially in a batch)

[Output] Distribution over labels $y_{\text{qy}} \in {1, \ldots, C}$.


Posterior predictive distribution

  • $p_\theta(y_{\text{qy}} \mid x_{\text{qy}}, \mathcal{D}{\text{train}}) = \frac{\exp\left(f\theta(x_{\text{qy}}, \mathcal{D}{\text{train}})[y{\text{qy}}]\right)}{\sum_{c=1}^C \exp\left(f_\theta(x_{\text{qy}}, \mathcal{D}_{\text{train}})[c]\right)}$.

[ ]

where $[\cdot]$ denotes the vector indexing operation.

Contrary to classical machine learning methods which are trained on one dataset and then evaluated on the same distribution, TabPFN has been shown to be able to perform classification on a wide range of tasks without training, thanks to its diverse prior-fitting procedure. This makes it one of the rare foundation models for tabular data. Key to this is the ICL ability of TabPFN: by using various training examples as context, analogous to how transformers on language use the preceding tokens as context, TabPFN can classify new query points in a single forward pass.

Categories:

Updated: