mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
Add validation and clear error messages for custom_metric parameter (#1500)
* Initial plan * Add validation and documentation for custom_metric parameter Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Refactor validation into reusable method and improve error handling Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Apply pre-commit formatting fixes Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
@@ -278,6 +278,34 @@ class TestMultiClass(unittest.TestCase):
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def test_invalid_custom_metric(self):
|
||||
"""Test that proper error is raised when custom_metric is called instead of passed."""
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
X_train, y_train = load_iris(return_X_y=True)
|
||||
|
||||
# Test with non-callable metric in __init__
|
||||
with self.assertRaises(ValueError) as context:
|
||||
automl = AutoML(metric=123) # passing an int instead of function
|
||||
self.assertIn("must be either a string or a callable function", str(context.exception))
|
||||
self.assertIn("but got int", str(context.exception))
|
||||
|
||||
# Test with non-callable metric in fit
|
||||
automl = AutoML()
|
||||
with self.assertRaises(ValueError) as context:
|
||||
automl.fit(X_train=X_train, y_train=y_train, metric=[], task="classification", time_budget=1)
|
||||
self.assertIn("must be either a string or a callable function", str(context.exception))
|
||||
self.assertIn("but got list", str(context.exception))
|
||||
|
||||
# Test with tuple (simulating result of calling a function that returns tuple)
|
||||
with self.assertRaises(ValueError) as context:
|
||||
automl = AutoML()
|
||||
automl.fit(
|
||||
X_train=X_train, y_train=y_train, metric=(0.5, {"loss": 0.5}), task="classification", time_budget=1
|
||||
)
|
||||
self.assertIn("must be either a string or a callable function", str(context.exception))
|
||||
self.assertIn("but got tuple", str(context.exception))
|
||||
|
||||
def test_classification(self, as_frame=False):
|
||||
automl_experiment = AutoML()
|
||||
automl_settings = {
|
||||
|
||||
Reference in New Issue
Block a user