mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
Flaml: fix lgbm reproducibility (#1369)
* fix: Fixed bug where every underlying LGBMRegressor or LGBMClassifier had n_estimators = 1 * test: Added test showing case where FLAMLised CatBoostModel result isn't reproducible * fix: Fixing issue where callbacks cause LGBM results to not be reproducible * Update test/automl/test_regression.py Co-authored-by: Li Jiang <bnujli@gmail.com> * fix: Adding back the LGBM EarlyStopping * refactor: Fix tweaked to ensure other models aren't likely to be affected * test: Fixed test to allow reproduced results to be better than the FLAML results, when LGBM earlystopping is involved --------- Co-authored-by: Daniel Grindrod <Daniel.Grindrod@evotec.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
@@ -1585,18 +1585,17 @@ class LGBMEstimator(BaseEstimator):
|
||||
callbacks = None
|
||||
if callbacks is None:
|
||||
self._fit(X_train, y_train, **kwargs)
|
||||
else:
|
||||
self._fit(X_train, y_train, callbacks=callbacks, **kwargs)
|
||||
if callbacks is None:
|
||||
# for xgboost>=1.6.0, pop callbacks to enable pickle
|
||||
callbacks = self.params.pop("callbacks")
|
||||
self._model.set_params(callbacks=callbacks[:-1])
|
||||
else:
|
||||
self._fit(X_train, y_train, callbacks=callbacks, **kwargs)
|
||||
best_iteration = (
|
||||
getattr(self._model.get_booster(), "best_iteration", None)
|
||||
if isinstance(self, XGBoostSklearnEstimator)
|
||||
else self._model.best_iteration_
|
||||
)
|
||||
if best_iteration is not None:
|
||||
if best_iteration is not None and best_iteration > 0:
|
||||
self._model.set_params(n_estimators=best_iteration + 1)
|
||||
else:
|
||||
self._fit(X_train, y_train, **kwargs)
|
||||
|
||||
@@ -493,7 +493,7 @@ def test_reproducibility_of_classification_models(estimator: str):
|
||||
"extra_tree",
|
||||
"histgb",
|
||||
"kneighbor",
|
||||
# "lgbm",
|
||||
"lgbm",
|
||||
# "lrl1",
|
||||
"lrl2",
|
||||
"svc",
|
||||
|
||||
@@ -339,6 +339,52 @@ def test_reproducibility_of_catboost_regression_model():
|
||||
assert pytest.approx(val_loss_flaml) == reproduced_val_loss
|
||||
|
||||
|
||||
def test_reproducibility_of_lgbm_regression_model():
|
||||
"""FLAML finds the best model for a given dataset, which it then provides to users.
|
||||
|
||||
However, there are reported issues around LGBMs - see here:
|
||||
https://github.com/microsoft/FLAML/issues/1368
|
||||
In this test we take the best LGBM regression model which FLAML provided us, and then retrain and test it on the
|
||||
same folds, to verify that the result is reproducible.
|
||||
"""
|
||||
automl = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 3,
|
||||
"task": "regression",
|
||||
"n_jobs": 1,
|
||||
"estimator_list": ["lgbm"],
|
||||
"eval_method": "cv",
|
||||
"n_splits": 9,
|
||||
"metric": "r2",
|
||||
"keep_search_state": True,
|
||||
"skip_transform": True,
|
||||
"retrain_full": True,
|
||||
}
|
||||
X, y = fetch_california_housing(return_X_y=True, as_frame=True)
|
||||
automl.fit(X_train=X, y_train=y, **automl_settings)
|
||||
best_model = automl.model
|
||||
assert best_model is not None
|
||||
config = best_model.get_params()
|
||||
val_loss_flaml = automl.best_result["val_loss"]
|
||||
|
||||
# Take the best model, and see if we can reproduce the best result
|
||||
reproduced_val_loss, metric_for_logging, train_time, pred_time = automl._state.task.evaluate_model_CV(
|
||||
config=config,
|
||||
estimator=best_model,
|
||||
X_train_all=automl._state.X_train_all,
|
||||
y_train_all=automl._state.y_train_all,
|
||||
budget=None,
|
||||
kf=automl._state.kf,
|
||||
eval_metric="r2",
|
||||
best_val_loss=None,
|
||||
cv_score_agg_func=None,
|
||||
log_training_metric=False,
|
||||
fit_kwargs=None,
|
||||
free_mem_ratio=0,
|
||||
)
|
||||
assert pytest.approx(val_loss_flaml) == reproduced_val_loss or val_loss_flaml > reproduced_val_loss
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"estimator",
|
||||
[
|
||||
@@ -347,7 +393,7 @@ def test_reproducibility_of_catboost_regression_model():
|
||||
"extra_tree",
|
||||
"histgb",
|
||||
"kneighbor",
|
||||
# "lgbm",
|
||||
"lgbm",
|
||||
"rf",
|
||||
"xgboost",
|
||||
"xgb_limitdepth",
|
||||
@@ -376,6 +422,7 @@ def test_reproducibility_of_underlying_regression_models(estimator: str):
|
||||
"metric": "r2",
|
||||
"keep_search_state": True,
|
||||
"skip_transform": True,
|
||||
"retrain_full": False,
|
||||
}
|
||||
X, y = fetch_california_housing(return_X_y=True, as_frame=True)
|
||||
automl.fit(X_train=X, y_train=y, **automl_settings)
|
||||
|
||||
Reference in New Issue
Block a user