fix: Fixed bug with catboost and groups (#1383)

Co-authored-by: Daniel Grindrod <daniel.grindrod@evotec.com>
This commit is contained in:
Daniel Grindrod
2024-12-17 05:54:49 +00:00
committed by GitHub
parent b83c8a7d3b
commit 42d1dcfa0e
2 changed files with 4 additions and 3 deletions

View File

@@ -2066,8 +2066,8 @@ class CatBoostEstimator(BaseEstimator):
self.estimator_class = CatBoostRegressor
def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs):
if "is_retrain" in kwargs:
kwargs.pop("is_retrain")
kwargs.pop("is_retrain", None)
kwargs.pop("groups", None)
start_time = time.time()
deadline = start_time + budget if budget else np.inf
train_dir = f"catboost_{str(start_time)}"

View File

@@ -68,7 +68,7 @@ def test_groups():
"model_history": True,
"eval_method": "cv",
"groups": np.random.randint(low=0, high=10, size=len(y)),
"estimator_list": ["lgbm", "rf", "xgboost", "kneighbor"],
"estimator_list": ["catboost", "lgbm", "rf", "xgboost", "kneighbor"],
"learner_selector": "roundrobin",
}
automl.fit(X, y, **automl_settings)
@@ -108,6 +108,7 @@ def test_stratified_groupkfold():
"split_type": splitter,
"groups": X_train["Airline"],
"estimator_list": [
"catboost",
"lgbm",
"rf",
"xgboost",