Socket
Book a DemoInstallSign in
Socket

ntrees-tuning

Package Overview
Dependencies
Maintainers
1
Versions
2
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

ntrees-tuning

The package ntrees_tuning is an extension to sklearn. To Random Forests and Gradient Boosting it adds the ntrees parameter which gives control over how many trees are used for prediction. The main benefit is that it enables to tune the ntrees parameter w.r.t. the OOB-error without having to retrain a new model for each value of ntrees.

pipPyPI
Version
0.1.1
Maintainers
1

The package ntrees_tuning is an extension to sklearn. To Random Forests and Gradient Boosting it adds the ntrees parameter which gives control over how many trees are used for prediction. The main benefit is that it enables to tune the ntrees parameter w.r.t. the OOB-error without having to retrain a new model for each value of ntrees.

The package introduces subclasses to the sklearn-classes of Random Forest and Gradient Boosting (RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor). Each adds two new methods called predict_ntree and tune_ntree which enable predicting and tuning the ntrees parameter possible

Example usage:

1. Create data:

from sklearn.datasets import make_classification, make_regression
Xcls, ycls = make_classification(n_samples=200, n_features=20, n_classes=3, random_state=42, n_clusters_per_class=3, n_informative=5)
Xreg, yreg = make_regression(n_samples=200, n_features=20, random_state=42)

2. Create and Fit RandomForest and GradientBoosting models for Regression and Classification

For tuning the ntrees parameter new custom classes are introduced. They are direct descendants of sklearn classes (RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor).

import ntree_tuning as ntt

rf_cls = ntt.Ntree_RandForest_Classifier(n_estimators=100)
rf_cls.fit(Xcls, ycls)

rf_reg = ntt.Ntree_RandForest_Regressor(n_estimators=100)
rf_reg.fit(Xreg, yreg)

gb_cls = ntt.Ntree_GradBoost_Classifier(n_estimators=100, subsample=0.8)
gb_cls.fit(Xcls, ycls)

gb_reg = ntt.Ntree_GradBoost_Regressor(n_estimators=100, subsample=0.8)
gb_reg.fit(Xreg, yreg)

3. Tune ntrees

You then can call the tune_ntrees method to get a dictionary of the pairs of the ntrees value and the oob-error.

# Gradient Boosting
print(gb_reg.tune_ntrees())
print(gb_cls.tune_ntrees())


# Random Forests
min_trees = 20
max_trees = 80
delta_trees = 5

print(rf_reg.tune_ntrees(Xreg, yreg, min_trees, max_trees, delta_trees))
print(rf_cls.tune_ntrees(Xcls, ycls, min_trees, max_trees, delta_trees))

4. Predict with ntrees

print(gb_reg.predict_ntrees(Xreg, ntrees=10))

FAQs

Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts