fix: Cross validation process isn't always run to completion (#1360)

This commit is contained in:
Daniel Grindrod
2024-10-01 01:24:53 +01:00
committed by GitHub
parent e5d95f5674
commit 5c0f18b7bc

View File

@@ -706,7 +706,6 @@ class GenericTask(Task):
fit_kwargs = {}
if cv_score_agg_func is None:
cv_score_agg_func = default_cv_score_agg_func
start_time = time.time()
val_loss_folds = []
log_metric_folds = []
metric = None
@@ -813,8 +812,6 @@ class GenericTask(Task):
if is_spark_dataframe:
X_train.spark.unpersist() # uncache data to free memory
X_val.spark.unpersist() # uncache data to free memory
if budget and time.time() - start_time >= budget:
break
val_loss, metric = cv_score_agg_func(val_loss_folds, log_metric_folds)
n = total_fold_num
pred_time /= n