mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
fix: Fixing the random state of ElasticNetClassifier by default, to ensure reproduciblity. Also included elasticnet in reproducibility tests (#1374)
Co-authored-by: Daniel Grindrod <Daniel.Grindrod@evotec.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
@@ -2429,6 +2429,11 @@ class ElasticNetEstimator(SKLearnEstimator):
|
||||
|
||||
def __init__(self, task="regression", **config):
|
||||
super().__init__(task, **config)
|
||||
self.params.update(
|
||||
{
|
||||
"random_state": config.get("random_seed", 10242048),
|
||||
}
|
||||
)
|
||||
assert self._task.is_regression(), "ElasticNet for regression task only"
|
||||
self.estimator_class = ElasticNet
|
||||
|
||||
|
||||
@@ -236,6 +236,7 @@ def test_multioutput():
|
||||
"estimator",
|
||||
[
|
||||
"catboost",
|
||||
"enet",
|
||||
"extra_tree",
|
||||
"histgb",
|
||||
"kneighbor",
|
||||
@@ -342,6 +343,7 @@ def test_reproducibility_of_catboost_regression_model():
|
||||
"estimator",
|
||||
[
|
||||
"catboost",
|
||||
"enet",
|
||||
"extra_tree",
|
||||
"histgb",
|
||||
"kneighbor",
|
||||
@@ -385,7 +387,6 @@ def test_reproducibility_of_underlying_regression_models(estimator: str):
|
||||
automl._state.X_train_all, automl._state.y_train_all, automl._state.kf, best_model.model, "regression"
|
||||
)
|
||||
)
|
||||
|
||||
assert pytest.approx(val_loss_flaml) == reproduced_val_loss_underlying_model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user