From 3d489f1aaa4c75cc3c238fedfb69714cb74295f2 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 21 Jan 2026 08:58:11 +0800 Subject: [PATCH 1/2] Add validation and clear error messages for custom_metric parameter (#1500) * Initial plan * Add validation and documentation for custom_metric parameter Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Refactor validation into reusable method and improve error handling Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Apply pre-commit formatting fixes Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> Co-authored-by: Li Jiang --- flaml/automl/automl.py | 37 +++++++++++++++++++++++++++++++++- test/automl/test_multiclass.py | 28 +++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/flaml/automl/automl.py b/flaml/automl/automl.py index 01f097359..d24469e7a 100644 --- a/flaml/automl/automl.py +++ b/flaml/automl/automl.py @@ -156,6 +156,10 @@ class AutoML(BaseEstimator): "pred_time": pred_time, } ``` + **Note:** When passing a custom metric function, pass the function itself + (e.g., `metric=custom_metric`), not the result of calling it + (e.g., `metric=custom_metric(...)`). FLAML will call your function + internally during the training process. task: A string of the task type, e.g., 'classification', 'regression', 'ts_forecast', 'rank', 'seq-classification', 'seq-regression', 'summarization', @@ -370,6 +374,8 @@ class AutoML(BaseEstimator): settings["n_splits"] = settings.get("n_splits", N_SPLITS) settings["auto_augment"] = settings.get("auto_augment", True) settings["metric"] = settings.get("metric", "auto") + # Validate that custom metric is callable if not a string + self._validate_metric_parameter(settings["metric"], allow_auto=True) settings["estimator_list"] = settings.get("estimator_list", "auto") settings["log_file_name"] = settings.get("log_file_name", "") settings["max_iter"] = settings.get("max_iter") # no budget by default @@ -462,6 +468,28 @@ class AutoML(BaseEstimator): except Exception: mi.mlflow_client = None + @staticmethod + def _validate_metric_parameter(metric, allow_auto=True): + """Validate that the metric parameter is either a string or a callable function. + + Args: + metric: The metric parameter to validate. + allow_auto: Whether to allow "auto" as a valid string value. + + Raises: + ValueError: If metric is not a string or callable function. + """ + if allow_auto and metric == "auto": + return + if not isinstance(metric, str) and not callable(metric): + raise ValueError( + f"The 'metric' parameter must be either a string or a callable function, " + f"but got {type(metric).__name__}. " + f"If you defined a custom_metric function, make sure to pass the function itself " + f"(e.g., metric=custom_metric) and not the result of calling it " + f"(e.g., metric=custom_metric(...))." + ) + def get_params(self, deep: bool = False) -> dict: return self._settings.copy() @@ -1810,6 +1838,10 @@ class AutoML(BaseEstimator): "pred_time": pred_time, } ``` + **Note:** When passing a custom metric function, pass the function itself + (e.g., `metric=custom_metric`), not the result of calling it + (e.g., `metric=custom_metric(...)`). FLAML will call your function + internally during the training process. task: A string of the task type, e.g., 'classification', 'regression', 'ts_forecast_regression', 'ts_forecast_classification', 'rank', 'seq-classification', @@ -2095,7 +2127,7 @@ class AutoML(BaseEstimator): split_ratio = split_ratio or self._settings.get("split_ratio") n_splits = n_splits or self._settings.get("n_splits") auto_augment = self._settings.get("auto_augment") if auto_augment is None else auto_augment - metric = metric or self._settings.get("metric") + metric = self._settings.get("metric") if metric is None else metric estimator_list = estimator_list or self._settings.get("estimator_list") log_file_name = self._settings.get("log_file_name") if log_file_name is None else log_file_name max_iter = self._settings.get("max_iter") if max_iter is None else max_iter @@ -2334,6 +2366,9 @@ class AutoML(BaseEstimator): and (self._min_sample_size * SAMPLE_MULTIPLY_FACTOR < self._state.data_size[0]) ) + # Validate metric parameter before processing + self._validate_metric_parameter(metric, allow_auto=True) + metric = task.default_metric(metric) self._state.metric = metric diff --git a/test/automl/test_multiclass.py b/test/automl/test_multiclass.py index 123e7c73c..12f8a8aa3 100644 --- a/test/automl/test_multiclass.py +++ b/test/automl/test_multiclass.py @@ -278,6 +278,34 @@ class TestMultiClass(unittest.TestCase): except ImportError: pass + def test_invalid_custom_metric(self): + """Test that proper error is raised when custom_metric is called instead of passed.""" + from sklearn.datasets import load_iris + + X_train, y_train = load_iris(return_X_y=True) + + # Test with non-callable metric in __init__ + with self.assertRaises(ValueError) as context: + automl = AutoML(metric=123) # passing an int instead of function + self.assertIn("must be either a string or a callable function", str(context.exception)) + self.assertIn("but got int", str(context.exception)) + + # Test with non-callable metric in fit + automl = AutoML() + with self.assertRaises(ValueError) as context: + automl.fit(X_train=X_train, y_train=y_train, metric=[], task="classification", time_budget=1) + self.assertIn("must be either a string or a callable function", str(context.exception)) + self.assertIn("but got list", str(context.exception)) + + # Test with tuple (simulating result of calling a function that returns tuple) + with self.assertRaises(ValueError) as context: + automl = AutoML() + automl.fit( + X_train=X_train, y_train=y_train, metric=(0.5, {"loss": 0.5}), task="classification", time_budget=1 + ) + self.assertIn("must be either a string or a callable function", str(context.exception)) + self.assertIn("but got tuple", str(context.exception)) + def test_classification(self, as_frame=False): automl_experiment = AutoML() automl_settings = { From 7ac076d54453675513805b530a5619b8b3d05c0b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 21 Jan 2026 09:06:19 +0800 Subject: [PATCH 2/2] Use scientific notation for best error in logger output (#1498) * Initial plan * Change best error format from .4f to .4e for scientific notation Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> Co-authored-by: Li Jiang --- flaml/automl/automl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flaml/automl/automl.py b/flaml/automl/automl.py index d24469e7a..1064fd29d 100644 --- a/flaml/automl/automl.py +++ b/flaml/automl/automl.py @@ -3015,7 +3015,7 @@ class AutoML(BaseEstimator): ) logger.info( - " at {:.1f}s,\testimator {}'s best error={:.4f},\tbest estimator {}'s best error={:.4f}".format( + " at {:.1f}s,\testimator {}'s best error={:.4e},\tbest estimator {}'s best error={:.4e}".format( self._state.time_from_start, estimator, search_state.best_loss,