TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second

Contents

Abstract

TabPFN

  • Pretrained Transformer

  • Supervised classification for small tabular datasets
  • Pros:
    • Efficient: Inference in less than a second
    • Hyperparameter tuning (X)
    • Competitive with SoTA
  • Performs in-context learning (ICL)
  • Input: Set-valued input
  • Output: Predictions for the entire test set in a single forward pass


PFN

  • Prior-Data Fitted Network (PFN)
    • Trained offline once
    • Bayesian inference on synthetic datasets drawn from our prior.
  • This prior incorporates ideas from causal reasoning: It entails a large space of structural causal models with a preference for simple structures. On the

Dataset

  • 18 datasets in the OpenML-CC18 suite
    • Contain up to 1 000 training data points, up to 100 purely numerical features without missing values, and up to 10 classes, we show that our method clearly outperforms boosted trees and performs on par with complex state-of-the-art AutoML systems with up to 230× speedup. This increases to a 5 700× speedup when using a GPU. We also validate these results on an additional 67 small numerical datasets from OpenML. We provide all our code, the trained TabPFN, an interactive browser demo and a Colab notebook at


1. Introduction

P1) Tabular Data

  • Overlooked by DL

  • Still dominated by GBDT

    ( $\because$ Short training time and robustness )


P2) Tabular Classsification

  • Radical change to how tabular classification is done
  • Use a PRETRAINED Transformer
    • Pretraining task: Synthetic dataset classification tasks from a tabular dataset prior


P3) PFN

  • Builds on Prior-Data Fitted Networks (PFNs)
  • Learn the training and prediction algorithm itself
  • Approximate Bayesian inference given any prior one can sample from
  • Approximate the posterior predictive distribution (PPD) directly
  • Can simply design a dataset-generating algorithm that encodes the desired prior.


P4) Prior

Design a prior based on …

  • Bayesian Neural Networks (BNNs)
    • to model complex feature dependencies
  • Structural Causal Models (SCMs)
    • to model potential causal mechanisms underlying tabular data


Prior is defined via parametric distributions

  • e.g., a log-scaled uniform distribution for the average number of nodes in data-generating SCMs.

$\rightarrow$ Resulting PPD: Implicitly models uncertainty over all possible data-generating mechanisms

$\rightarrow$ Corresponds to an infinitely large ensemble of data-generating mechanisms

  • i.e., instantiations of SCMs and BNNs.


Result:

  • Learn to approximate this complex PPD in a single forward-pass
  • Requires no cross-validation or model selection


P5) Key Contribution

TabPFN

  • Single Transformer that has been pre-trained to approximate probabilistic inference for the novel prior above in a single forward pass
  • Learned to solve novel small tabular classification tasks (≤ 1 000 training examples, ≤ 100 purely numerical features without missing values and ≤ 10 classes) in less than a second
  • Yielding SoTA performance.


2. Background on PFNs

(1) PPD for Supervised Learning (SL)

Bayesian framework for SL

  • Prior: defines a space of hypotheses $\Phi$ on the relationship of a set of inputs $x$ to the output labels $y$.
  • Each hypothesis $\phi \in \Phi$ can be seen as a mechanism that generates a data distribution from which we can draw samples forming a dataset.
  • For example, given a prior based on structural causal models, $\Phi$ is the space of structural causal models, a hypothesis $\phi$ is one specific SCM, and a dataset comprises samples generated through this SCM.
  • In practice, a dataset comprises training data with observed labels and test data where labels are missing or held out to assess predictive performance.
  • The PPD for a test sample $x_{\text {test }}$ specifies the distribution of its label $p\left(\cdot \mid x_{\text {test }}, D_{\text {train }}\right)$, which is conditioned on the set of training samples $D_{\text {train }}:=\left{\left(x_1, y_1\right), \ldots,\left(x_n, y_n\right)\right}$.
  • The PPD can be obtained by integration over the space of hypotheses $\Phi$, where the weight of a hypothesis $\phi \in \Phi$ is determined by its prior probability $p(\phi)$ and the likelihood $p(D \mid \phi)$ of the data $D$ given $\phi$ :
\[p(y \mid x, D) \propto \int_{\Phi} p(y \mid x, \phi) p(D \mid \phi) p(\phi) d \phi\]


(2) Synthetic Prior-fitting

Prior-fitting

  • Training of a PFN
  • To approximate the PPD & perform Bayesian prediction


How?

  • Implement it with a prior which is specified by a prior sampling scheme of the form $p(D)=\mathbb{E}_{\phi \sim p(\phi)}[p(D \mid \phi)]$,
  • Step 1) Samples hypotheses with $\phi \sim p(\phi)$
  • Step 2) Samples synthetic datasets with $D \sim p(D \mid \phi)$.
    • Repeatedly sample $D:=\left(x_i, y_i\right)_{i \in{1, \ldots, n}}$
  • Step 3) Optimize the PFN’s parameters $\theta$
    • To make predictions for $D_{\text {test }} \subset D$
    • Conditioned on the rest of the dataset $D_{\text {train }}=D \backslash D_{\text {test }}$.


Loss function

$\mathcal{L}{P F N}=\underset{\left(\left{\left(x{\text {teat }}, y_{\text {test }}\right)\right} \cup D_{\text {train }}\right) \sim p(D)}{\mathbb{E}}\left[-\log q_\theta\left(y_{\text {test }} \mid x_{\text {test }}, D_{\text {train }}\right)\right] .$.

  • Cross-entropy on held-out examples of synthetic datasets.

$\rightarrow$ Minimizing this loss approximates the true Bayesian posterior predictive distribution.


(3) Real World Inference

Pretrained model

  • Inference on unseen real-world datasets
  • Input: $\left\langle D_{\text {train }}, x_{\text {test }}\right\rangle$
  • Output: PPD $q_\theta\left(y \mid x_{\text {test }}, D_{\text {train }}\right)$ in a single forward-pass.

$\therefore$ Perform training and prediction in one step

  • (similar to prediction with Gaussian Processes)
  • Do not use gradient-based learning on data seen at inference time.
  • Can be termed as in-context learning (ICL)


(4) Architecture

Transformer

  • Encodes each feature vector and label as a token
  • Input:
    • Accept a variable length training set: $D_{\text {train }}$
      • of feature and label vectors
    • Accept a variable length query set: $x_{\text {test }}=\left{x_{(\text {test }, 1)}, \ldots, x_{(\text {test }, m)}\right}$
      • of feature vectors
  • Output:
    • Estimates of the PPD for each query.


3. The TabPFN: A PFN Fittedn on a New Prior for Tabular Data

P1) TabPFN

  • Prior-data Fitted Network

  • Fitted on data sampled from a novel prior for tabular data

    • we introduce in Section 4.
  • Modify the original PFN architecture

    • (1) Slight modifications to the attention masks

      $\rightarrow$ Shorter inference time

    • (2) Enable our model to work on datasets with different numbers of features by zero-padding


P2) Training TabPFN

Prior-fitting phase

  • Train once on samples from the prior
    • described in Section 4.
  • Details:
    • 12-layer Transformer for 18 000 batches of 512 synthetically generated datasets each
    • 20 hours on one machine with 8 GPUs (Nvidia RTX 2080 Ti)
  • Single network that is used for all our evaluations
  • Training step is moderately expensive, but is done offline


P3) Inference with TabPFN

  • Approximates the PPD for our dataset prior
    • i.e., it approximates the marginal predictions across our spaces of SCMs and BNNs (see Section 4), including a bias towards simple and causal explanations for the data.
  • Experiments: Predictions for …
    • (1) Single forward pass of our TabPFN
    • (2) Ensemble 32 forward passes of datasets modified by a power transformation (applied with probability 0.5) and rotating the indices of feature columns and class labels


4. A Prior for Tabular Data

Performance depends on the specification of a suitable prior!

( $\because$ PFN approximates the PPD for this prior )


Section intro

  • [4.1] Fundamental technique for our prior

    • Use distributions instead of point-estimates for almost all of our prior’s hyperparameters.
  • [4.2] Simplicity in our prior

  • [4.3 & 4.4] How we use SCMs & BNNs as fundamental mechanisms to generate diverse data in our prior.

  • [4.5] Into classification tasks

    • SCM and BNN priors: only yield regression tasks

      $\rightarrow$ Show how to convert them to classification tasks


(1) Fundamentally Probabilistic Models

Fitting a model typically requires finding suitable hyperparameters, e.g., the embedding size, number of layers and activation function for NNs. Commonly, resource-intensive searches need to be employed to find suitable hyperparameters (Zoph and Le, 2017; Feurer and Hutter, 2019). The result of these searches, though, is only a point estimate of the hyperparameter choice. Ensembling over multiple architectures and hyperparameter settings can yield a rough approximation to a distribution over these hyperparameters and has been shown to improve performance (Zaidi et al., 2021; Wenzel et al., 2020). This, however, scales linearly in cost with the number of choices considered.

In contrast, PFNs allow us to be fully Bayesian about our prior’s hyperparameters. By defining a probability distribution over the space of hyperparameters in the prior, such as BNN architectures, the PPD approximated by our TabPFN jointly integrates over this space and the respective model weights. We extend this approach to a mixture not only over hyperparameters but distinct priors: we mix a BNN and an SCM prior, each of which again entails a mixture of architectures and hyperparameters.


(2) Simplicity

We base our priors on a notion of simplicity, such as stated by Occam’s Razor or the Speed Prior (Schmidhuber, 2002). When considering competing hypotheses, the simpler one is to be preferred. Work in cognitive science has also uncovered this preference for simple explanations in human thinking (Wojtowicz and DeDeo, 2020). Any notion of simplicity, however, depends on choosing a particular criterion that defines simplicity. In the following, we introduce priors based on SCMs and BNNs, in which we implement simplicity as graphs with few nodes and parameters.


(3) SCM Prior

It has been demonstrated that causal knowledge can facilitate various ML tasks, including semisupervised learning, transfer learning and out-of-distribution generalization (Schölkopf et al., 2012; Janzing, 2020; Rothenhäusler et al., 2018). Tabular data often exhibits causal relationships between columns, and causal mechanisms have been shown to be a strong prior in human reasoning (Waldmann and Hagmayer, 2013; Wojtowicz and DeDeo, 2020). Thus, we base our TabPFN prior on SCMs that model causal relationships (Pearl, 2009; Peters et al., 2017). An SCM consists of a collection $Z:=\left(\left{z_1, \ldots, z_k\right}\right)$ of structural assignments (called mechanisms): $z_i=f_i\left(z_{\mathrm{PA}{\mathcal{G}}(i)}, \epsilon_i\right)$, where $\mathrm{PA}{\mathcal{G}}(i)$ is the set of parents of the node $i$ (its direct causes) in an underlying DAG $\mathcal{G}$ (the causal graph), $f_i$ is a (potentially non-linear) deterministic function and $\epsilon_i$ is a noise variable. Causal relationships in $\mathcal{G}$ are represented by directed edges pointing from causes to effects and each mechanism $z_i$ is assigned to a node in $\mathcal{G}$, as visualized in Figure 2.

a) Predictions based on ideas from causal reasoning

Previous works have applied causal reasoning to predict observations on unseen data by using causal inference, a method which seeks to identify causal relations between the components of a system by the use of interventions and observational data (Pearl, 2010; Pearl and Mackenzie, 2018; Lin et al., 2021). The predicted causal representations are then used to make observational predictions on novel samples or to provide explainability. Most existing work focuses on determining a single causal graph to use for downstream prediction, which can be problematic since most kinds of SCMs are non-identifiable without interventional data, and the number of compatible DAGs explodes due to the combinatorial nature of the space of DAGs. Recently Ke et al. (2022) and Lorch et al. (2022) use transformers to approximate the causal graphs from observational and interventional data. We skip any explicit graph representation in our inference step and approximate the PPD directly. Thus, we do not perform causal inference but solve the downstream prediction task directly. This implicit assumption of SCM-like processes generating our data can be explained in Pearl’s “ladder of causation”, an abstraction of inference categories, where each higher rung represents a more involved notion of inference (Pearl and Mackenzie, 2018). At the lowest rung lies association, which includes most of ML. The second rung considers predicting the effect of interventions, i.e., what happens when we influence features directly. Our work can be considered as “rung 1.5”, similar to Kyono et al. (2020; 2021): we do not perform causal reasoning, but make association-based predictions on observational data assuming SCMs model common datasets well. In Figure 8 in Appendix B, we experimentally show that our indeed align with simple SCM hypotheses.

b) Defining a prior based on causal models

To create a PFN prior based on SCMs, we have to define a sampling procedure that creates supervised learning tasks (i.e., datasets). Here, each dataset is based on one randomly-sampled SCM (including the DAG structure and deterministic functions $\left.f_i\right)$. Given an SCM, we sample a set $z_X$ of nodes in the causal graph $\mathcal{G}$, one for each feature in our synthetic dataset, as well as one node $z_y$ from $\mathcal{G}$. These nodes are observed nodes: values of $z_X$ will be included in the set of features, while values from $z_y$ will act as targets. For each such SCM and list of nodes $z_X$ and $z_y, n$ samples are generated by sampling all noise variables in the SCM $n$ times, propagating these through the graph and retrieving the values at th e nodes $z_X$ and $z_y$ for all $n$ samples. Figure 2b depicts an SCM with observed feature- and target-nodes in grey. The resulting features and targets are correlated through the generating DAG structure. This leads to features conditionally dependent through forward and backward causation, i.e., targets might be a cause or an effect of features. In Figure 3, we compare samples generated by two distinct SCMs to actual datasets, demonstrating the diversity in the space of datasets our prior can model.

In this work, we instantiate a large subfamily of DAGs and deterministic functions $f_i$ to build SCMs described in Appendix C.1. Since efficient sampling is the only requirement we have, the instantiated subfamily is very general, including multiple activation functions and noise distributions.

Categories:

Updated: