Add validation and clear error messages for custom_metric parameter (#1500)

* Initial plan

* Add validation and documentation for custom_metric parameter

Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com>

* Refactor validation into reusable method and improve error handling

Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com>

* Apply pre-commit formatting fixes

Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com>
Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
Copilot
2026-01-21 08:58:11 +08:00
committed by GitHub
parent c64eeb5e8d
commit 3d489f1aaa
2 changed files with 64 additions and 1 deletions

View File

@@ -278,6 +278,34 @@ class TestMultiClass(unittest.TestCase):
except ImportError:
pass
def test_invalid_custom_metric(self):
"""Test that proper error is raised when custom_metric is called instead of passed."""
from sklearn.datasets import load_iris
X_train, y_train = load_iris(return_X_y=True)
# Test with non-callable metric in __init__
with self.assertRaises(ValueError) as context:
automl = AutoML(metric=123) # passing an int instead of function
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got int", str(context.exception))
# Test with non-callable metric in fit
automl = AutoML()
with self.assertRaises(ValueError) as context:
automl.fit(X_train=X_train, y_train=y_train, metric=[], task="classification", time_budget=1)
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got list", str(context.exception))
# Test with tuple (simulating result of calling a function that returns tuple)
with self.assertRaises(ValueError) as context:
automl = AutoML()
automl.fit(
X_train=X_train, y_train=y_train, metric=(0.5, {"loss": 0.5}), task="classification", time_budget=1
)
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got tuple", str(context.exception))
def test_classification(self, as_frame=False):
automl_experiment = AutoML()
automl_settings = {