Merge branch 'main' into copilot/add-multi-target-support

This commit is contained in:
Li Jiang
2026-01-21 11:47:42 +08:00
committed by GitHub
2 changed files with 65 additions and 2 deletions

View File

@@ -156,6 +156,10 @@ class AutoML(BaseEstimator):
"pred_time": pred_time,
}
```
**Note:** When passing a custom metric function, pass the function itself
(e.g., `metric=custom_metric`), not the result of calling it
(e.g., `metric=custom_metric(...)`). FLAML will call your function
internally during the training process.
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast', 'rank',
'seq-classification', 'seq-regression', 'summarization',
@@ -370,6 +374,8 @@ class AutoML(BaseEstimator):
settings["n_splits"] = settings.get("n_splits", N_SPLITS)
settings["auto_augment"] = settings.get("auto_augment", True)
settings["metric"] = settings.get("metric", "auto")
# Validate that custom metric is callable if not a string
self._validate_metric_parameter(settings["metric"], allow_auto=True)
settings["estimator_list"] = settings.get("estimator_list", "auto")
settings["log_file_name"] = settings.get("log_file_name", "")
settings["max_iter"] = settings.get("max_iter") # no budget by default
@@ -462,6 +468,28 @@ class AutoML(BaseEstimator):
except Exception:
mi.mlflow_client = None
@staticmethod
def _validate_metric_parameter(metric, allow_auto=True):
"""Validate that the metric parameter is either a string or a callable function.
Args:
metric: The metric parameter to validate.
allow_auto: Whether to allow "auto" as a valid string value.
Raises:
ValueError: If metric is not a string or callable function.
"""
if allow_auto and metric == "auto":
return
if not isinstance(metric, str) and not callable(metric):
raise ValueError(
f"The 'metric' parameter must be either a string or a callable function, "
f"but got {type(metric).__name__}. "
f"If you defined a custom_metric function, make sure to pass the function itself "
f"(e.g., metric=custom_metric) and not the result of calling it "
f"(e.g., metric=custom_metric(...))."
)
def get_params(self, deep: bool = False) -> dict:
return self._settings.copy()
@@ -1815,6 +1843,10 @@ class AutoML(BaseEstimator):
"pred_time": pred_time,
}
```
**Note:** When passing a custom metric function, pass the function itself
(e.g., `metric=custom_metric`), not the result of calling it
(e.g., `metric=custom_metric(...)`). FLAML will call your function
internally during the training process.
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast_regression',
'ts_forecast_classification', 'rank', 'seq-classification',
@@ -2100,7 +2132,7 @@ class AutoML(BaseEstimator):
split_ratio = split_ratio or self._settings.get("split_ratio")
n_splits = n_splits or self._settings.get("n_splits")
auto_augment = self._settings.get("auto_augment") if auto_augment is None else auto_augment
metric = metric or self._settings.get("metric")
metric = self._settings.get("metric") if metric is None else metric
estimator_list = estimator_list or self._settings.get("estimator_list")
log_file_name = self._settings.get("log_file_name") if log_file_name is None else log_file_name
max_iter = self._settings.get("max_iter") if max_iter is None else max_iter
@@ -2339,6 +2371,9 @@ class AutoML(BaseEstimator):
and (self._min_sample_size * SAMPLE_MULTIPLY_FACTOR < self._state.data_size[0])
)
# Validate metric parameter before processing
self._validate_metric_parameter(metric, allow_auto=True)
metric = task.default_metric(metric)
self._state.metric = metric
@@ -2986,7 +3021,7 @@ class AutoML(BaseEstimator):
)
logger.info(
" at {:.1f}s,\testimator {}'s best error={:.4f},\tbest estimator {}'s best error={:.4f}".format(
" at {:.1f}s,\testimator {}'s best error={:.4e},\tbest estimator {}'s best error={:.4e}".format(
self._state.time_from_start,
estimator,
search_state.best_loss,

View File

@@ -278,6 +278,34 @@ class TestMultiClass(unittest.TestCase):
except ImportError:
pass
def test_invalid_custom_metric(self):
"""Test that proper error is raised when custom_metric is called instead of passed."""
from sklearn.datasets import load_iris
X_train, y_train = load_iris(return_X_y=True)
# Test with non-callable metric in __init__
with self.assertRaises(ValueError) as context:
automl = AutoML(metric=123) # passing an int instead of function
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got int", str(context.exception))
# Test with non-callable metric in fit
automl = AutoML()
with self.assertRaises(ValueError) as context:
automl.fit(X_train=X_train, y_train=y_train, metric=[], task="classification", time_budget=1)
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got list", str(context.exception))
# Test with tuple (simulating result of calling a function that returns tuple)
with self.assertRaises(ValueError) as context:
automl = AutoML()
automl.fit(
X_train=X_train, y_train=y_train, metric=(0.5, {"loss": 0.5}), task="classification", time_budget=1
)
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got tuple", str(context.exception))
def test_classification(self, as_frame=False):
automl_experiment = AutoML()
automl_settings = {