
Security News
Open Source Maintainers Feeling the Weight of the EU’s Cyber Resilience Act
The EU Cyber Resilience Act is prompting compliance requests that open source maintainers may not be obligated or equipped to handle.
Paper | Documentation | RealMLP-TD-S standalone implementation | Grinsztajn et al. benchmark code | Data archive |
---|
PyTabKit provides scikit-learn interfaces for modern tabular classification and regression methods benchmarked in our paper, see below. It also contains the code we used for benchmarking these methods on our benchmarks.
pytabkit
,
we recommend using
Ensemble_HPO_Classifier(n_cv=8, use_full_caruana_ensembling=True, use_tabarena_spaces=True, n_hpo_steps=50)
with a val_metric_name
corresponding to your target metric
(e.g., class_error
, cross_entropy
, brier
, 1-auc_ovr
), or the corresponding Regressor
.
(This might take very long to fit.)RealMLP_HPO_Classifier(n_cv=8, hpo_space_name='tabarena', use_caruana_ensembling=True, n_hyperopt_steps=50)
,
also with val_metric_name
as above, or the corresponding Regressor
.pytabkit
is currently still easier to use
and supports vectorized cross-validation for RealMLP,
which can significantly speed up the training.pip install pytabkit[models]
[models]
part.pytabkit[models,autogluon,extra,hpo,bench,dev]
to install additional dependencies for
AutoGluon models, extra preprocessing,
hyperparameter optimization methods beyond random search (hyperopt/SMAC),
the benchmarking part, and testing/documentation. For the hpo part,
you might need to install swig (e.g. via pip) if the build of pyrfr fails.
See also the documentation.
To run the data download for the meta-train benchmark, you need one of rar, unrar, or 7-zip
to be installed on the system.Most of our machine learning models are directly available via scikit-learn interfaces. For example, you can use RealMLP-TD for classification as follows:
from pytabkit import RealMLP_TD_Classifier
model = RealMLP_TD_Classifier() # or TabR_S_D_Classifier, CatBoost_TD_Classifier, etc.
model.fit(X_train, y_train)
model.predict(X_test)
The code above will automatically select a GPU if available,
try to detect categorical columns in dataframes,
preprocess numerical variables and regression targets (no standardization required),
and use a training-validation split for early stopping.
All of this (and much more) can be configured through the constructor
and the parameters of the fit() method.
For example, it is possible to do bagging
(ensembling of models on 5-fold cross-validation)
simply by passing n_cv=5
to the constructor.
Here is an example for some of the parameters that can be set explicitly:
from pytabkit import RealMLP_TD_Classifier
model = RealMLP_TD_Classifier(device='cpu', random_state=0, n_cv=1, n_refit=0,
n_epochs=256, batch_size=256, hidden_sizes=[256] * 3,
val_metric_name='cross_entropy',
use_ls=False, # for metrics like AUC / log-loss
lr=0.04, verbosity=2)
model.fit(X_train, y_train, X_val, y_val, cat_col_names=['Education'])
model.predict_proba(X_test)
See this notebook for more examples. Missing numerical values are currently not allowed and need to be imputed beforehand.
Our ML models are available in up to three variants, all with best-epoch selection:
We provide the following ML models:
For using post-hoc temperature scaling and refinement stopping from our paper Rethinking Early Stopping: Refine, Then Calibrate, you can pass the following parameters to the scikit-learn interfaces:
from pytabkit import RealMLP_TD_Classifier
clf = RealMLP_TD_Classifier(
val_metric_name='ref-ll-ts', # short for 'refinement_logloss_ts-mix_all'
calibration_method='ts-mix', # temperature scaling with laplace smoothing
use_ls=False # recommended for cross-entropy loss
)
Other calibration methods and validation metrics from probmetrics can be used as well.
For reproducing the results from this paper, we refer to the documentation.
Our benchmarking code has functionality for
For more details, we refer to the documentation.
While many preprocessing methods are implemented in this repository, a standalone version of our robust scaling + smooth clipping can be found here.
If you use this repository for research purposes, please cite our paper:
@inproceedings{holzmuller2024better,
title={Better by default: {S}trong pre-tuned {MLPs} and boosted trees on tabular data},
author={Holzm{\"u}ller, David and Grinsztajn, Leo and Steinwart, Ingo},
booktitle = {Neural {Information} {Processing} {Systems}},
year={2024}
}
Code from other repositories is acknowledged as well as possible in code comments. Especially, we used code from https://github.com/yandex-research/rtdl and sub-packages (Apache 2.0 license), code from https://github.com/catboost/benchmarks/ (Apache 2.0 license), and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html (Apache 2.0 license).
n_repeats
parameter to scikit-learn interfaces for repeated cross-validationuse_caruana_ensembling=True
.
Removed the RealMLP_Ensemble_Classifier
and RealMLP_Ensemble_Regressor
from v1.4.2
since they are now redundant through this feature.space
parameter of GBDT HPO interface
to hpo_space_name
so now it also works with non-TPE versions.val_metric_name
for early stopping on different metrics.val_metric_name
HPO models and Ensemble_TD_Regressor
.tmp_folder
is specified in HPO models,
save each model to disk immediately instead of holding all of them in memory.
This can considerably reduce RAM/VRAM usage.
In this case, pickled HPO models will still rely on the models stored in the tmp_folder
.RealMLP_Ensemble_Classifier
and RealMLP_Ensemble_Regressor
,
which will use weighted ensembling and usually perform better than HPO
(but have slower inference time). We recommend using the new hpo_space_name='tabarena'
for best results.models
optional dependencies
to have a more light-weight RealMLP installation(n_samples, 1)
weight_decay
, tfms
,
and gradient_clipping_norm
to TabM.
The updated default parameters now apply the RTDL quantile transform.__
by _
in parameter names for MLP, MLP-PLR, ResNet, and FTT,
to comply with scikit-learn interface requirements.lightning
(but also still allowing pytorch-lightning
),
making skorch a lazy import, removed msgpack_numpy dependency.FAQs
ML models + benchmark for tabular data classification and regression
We found that pytabkit demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Security News
The EU Cyber Resilience Act is prompting compliance requests that open source maintainers may not be obligated or equipped to handle.
Security News
Crates.io adds Trusted Publishing support, enabling secure GitHub Actions-based crate releases without long-lived API tokens.
Research
/Security News
Undocumented protestware found in 28 npm packages disrupts UI for Russian-language users visiting Russian and Belarusian domains.