mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
Merge branch 'main' into copilot/add-multi-target-support
This commit is contained in:
@@ -156,6 +156,10 @@ class AutoML(BaseEstimator):
|
||||
"pred_time": pred_time,
|
||||
}
|
||||
```
|
||||
**Note:** When passing a custom metric function, pass the function itself
|
||||
(e.g., `metric=custom_metric`), not the result of calling it
|
||||
(e.g., `metric=custom_metric(...)`). FLAML will call your function
|
||||
internally during the training process.
|
||||
task: A string of the task type, e.g.,
|
||||
'classification', 'regression', 'ts_forecast', 'rank',
|
||||
'seq-classification', 'seq-regression', 'summarization',
|
||||
@@ -370,6 +374,8 @@ class AutoML(BaseEstimator):
|
||||
settings["n_splits"] = settings.get("n_splits", N_SPLITS)
|
||||
settings["auto_augment"] = settings.get("auto_augment", True)
|
||||
settings["metric"] = settings.get("metric", "auto")
|
||||
# Validate that custom metric is callable if not a string
|
||||
self._validate_metric_parameter(settings["metric"], allow_auto=True)
|
||||
settings["estimator_list"] = settings.get("estimator_list", "auto")
|
||||
settings["log_file_name"] = settings.get("log_file_name", "")
|
||||
settings["max_iter"] = settings.get("max_iter") # no budget by default
|
||||
@@ -462,6 +468,28 @@ class AutoML(BaseEstimator):
|
||||
except Exception:
|
||||
mi.mlflow_client = None
|
||||
|
||||
@staticmethod
|
||||
def _validate_metric_parameter(metric, allow_auto=True):
|
||||
"""Validate that the metric parameter is either a string or a callable function.
|
||||
|
||||
Args:
|
||||
metric: The metric parameter to validate.
|
||||
allow_auto: Whether to allow "auto" as a valid string value.
|
||||
|
||||
Raises:
|
||||
ValueError: If metric is not a string or callable function.
|
||||
"""
|
||||
if allow_auto and metric == "auto":
|
||||
return
|
||||
if not isinstance(metric, str) and not callable(metric):
|
||||
raise ValueError(
|
||||
f"The 'metric' parameter must be either a string or a callable function, "
|
||||
f"but got {type(metric).__name__}. "
|
||||
f"If you defined a custom_metric function, make sure to pass the function itself "
|
||||
f"(e.g., metric=custom_metric) and not the result of calling it "
|
||||
f"(e.g., metric=custom_metric(...))."
|
||||
)
|
||||
|
||||
def get_params(self, deep: bool = False) -> dict:
|
||||
return self._settings.copy()
|
||||
|
||||
@@ -1815,6 +1843,10 @@ class AutoML(BaseEstimator):
|
||||
"pred_time": pred_time,
|
||||
}
|
||||
```
|
||||
**Note:** When passing a custom metric function, pass the function itself
|
||||
(e.g., `metric=custom_metric`), not the result of calling it
|
||||
(e.g., `metric=custom_metric(...)`). FLAML will call your function
|
||||
internally during the training process.
|
||||
task: A string of the task type, e.g.,
|
||||
'classification', 'regression', 'ts_forecast_regression',
|
||||
'ts_forecast_classification', 'rank', 'seq-classification',
|
||||
@@ -2100,7 +2132,7 @@ class AutoML(BaseEstimator):
|
||||
split_ratio = split_ratio or self._settings.get("split_ratio")
|
||||
n_splits = n_splits or self._settings.get("n_splits")
|
||||
auto_augment = self._settings.get("auto_augment") if auto_augment is None else auto_augment
|
||||
metric = metric or self._settings.get("metric")
|
||||
metric = self._settings.get("metric") if metric is None else metric
|
||||
estimator_list = estimator_list or self._settings.get("estimator_list")
|
||||
log_file_name = self._settings.get("log_file_name") if log_file_name is None else log_file_name
|
||||
max_iter = self._settings.get("max_iter") if max_iter is None else max_iter
|
||||
@@ -2339,6 +2371,9 @@ class AutoML(BaseEstimator):
|
||||
and (self._min_sample_size * SAMPLE_MULTIPLY_FACTOR < self._state.data_size[0])
|
||||
)
|
||||
|
||||
# Validate metric parameter before processing
|
||||
self._validate_metric_parameter(metric, allow_auto=True)
|
||||
|
||||
metric = task.default_metric(metric)
|
||||
self._state.metric = metric
|
||||
|
||||
@@ -2986,7 +3021,7 @@ class AutoML(BaseEstimator):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
" at {:.1f}s,\testimator {}'s best error={:.4f},\tbest estimator {}'s best error={:.4f}".format(
|
||||
" at {:.1f}s,\testimator {}'s best error={:.4e},\tbest estimator {}'s best error={:.4e}".format(
|
||||
self._state.time_from_start,
|
||||
estimator,
|
||||
search_state.best_loss,
|
||||
|
||||
@@ -278,6 +278,34 @@ class TestMultiClass(unittest.TestCase):
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def test_invalid_custom_metric(self):
|
||||
"""Test that proper error is raised when custom_metric is called instead of passed."""
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
X_train, y_train = load_iris(return_X_y=True)
|
||||
|
||||
# Test with non-callable metric in __init__
|
||||
with self.assertRaises(ValueError) as context:
|
||||
automl = AutoML(metric=123) # passing an int instead of function
|
||||
self.assertIn("must be either a string or a callable function", str(context.exception))
|
||||
self.assertIn("but got int", str(context.exception))
|
||||
|
||||
# Test with non-callable metric in fit
|
||||
automl = AutoML()
|
||||
with self.assertRaises(ValueError) as context:
|
||||
automl.fit(X_train=X_train, y_train=y_train, metric=[], task="classification", time_budget=1)
|
||||
self.assertIn("must be either a string or a callable function", str(context.exception))
|
||||
self.assertIn("but got list", str(context.exception))
|
||||
|
||||
# Test with tuple (simulating result of calling a function that returns tuple)
|
||||
with self.assertRaises(ValueError) as context:
|
||||
automl = AutoML()
|
||||
automl.fit(
|
||||
X_train=X_train, y_train=y_train, metric=(0.5, {"loss": 0.5}), task="classification", time_budget=1
|
||||
)
|
||||
self.assertIn("must be either a string or a callable function", str(context.exception))
|
||||
self.assertIn("but got tuple", str(context.exception))
|
||||
|
||||
def test_classification(self, as_frame=False):
|
||||
automl_experiment = AutoML()
|
||||
automl_settings = {
|
||||
|
||||
Reference in New Issue
Block a user