mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
Fix BlendSearch OptunaSearch warning for non-hierarchical spaces with Ray Tune domains (#1477)
* Initial plan * Fix BlendSearch OptunaSearch warning for non-hierarchical spaces Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Clean up test file Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Add regression test for BlendSearch UDF mode warning fix Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Improve the fix and tests * Fix Define-by-run function passed in argument is not yet supported when using --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> Co-authored-by: Li Jiang <bnujli@gmail.com> Co-authored-by: Li Jiang <lijiang1@microsoft.com>
This commit is contained in:
@@ -217,7 +217,24 @@ class BlendSearch(Searcher):
|
||||
if global_search_alg is not None:
|
||||
self._gs = global_search_alg
|
||||
elif getattr(self, "__name__", None) != "CFO":
|
||||
if space and self._ls.hierarchical:
|
||||
# Use define-by-run for OptunaSearch when needed:
|
||||
# - Hierarchical/conditional spaces are best supported via define-by-run.
|
||||
# - Ray Tune domain/grid specs can trigger an "unresolved search space" warning
|
||||
# unless we switch to define-by-run.
|
||||
use_define_by_run = bool(getattr(self._ls, "hierarchical", False))
|
||||
if (not use_define_by_run) and isinstance(space, dict) and space:
|
||||
try:
|
||||
from .variant_generator import parse_spec_vars
|
||||
|
||||
_, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
use_define_by_run = bool(domain_vars or grid_vars)
|
||||
except Exception:
|
||||
# Be conservative: if we can't determine whether the space is
|
||||
# unresolved, fall back to the original behavior.
|
||||
use_define_by_run = False
|
||||
|
||||
self._use_define_by_run = use_define_by_run
|
||||
if use_define_by_run:
|
||||
from functools import partial
|
||||
|
||||
gs_space = partial(define_by_run_func, space=space)
|
||||
@@ -487,7 +504,7 @@ class BlendSearch(Searcher):
|
||||
self._ls_bound_max,
|
||||
self._subspace.get(trial_id, self._ls.space),
|
||||
)
|
||||
if self._gs is not None and self._experimental and (not self._ls.hierarchical):
|
||||
if self._gs is not None and self._experimental and (not getattr(self, "_use_define_by_run", False)):
|
||||
self._gs.add_evaluated_point(flatten_dict(config), objective)
|
||||
# TODO: recover when supported
|
||||
# converted = convert_key(config, self._gs.space)
|
||||
|
||||
@@ -324,3 +324,26 @@ def test_no_optuna():
|
||||
import flaml.tune.searcher.suggestion
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "optuna==2.8.0"])
|
||||
|
||||
|
||||
def test_unresolved_search_space(caplog):
|
||||
import logging
|
||||
|
||||
from flaml import tune
|
||||
from flaml.tune.searcher.blendsearch import BlendSearch
|
||||
|
||||
if caplog is not None:
|
||||
caplog.set_level(logging.INFO)
|
||||
|
||||
BlendSearch(metric="loss", mode="min", space={"lr": tune.uniform(0.001, 0.1), "depth": tune.randint(1, 10)})
|
||||
try:
|
||||
text = caplog.text
|
||||
except AttributeError:
|
||||
text = ""
|
||||
assert (
|
||||
"unresolved search space" not in text and text
|
||||
), "BlendSearch should not produce warning about unresolved search space"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_unresolved_search_space(None)
|
||||
|
||||
Reference in New Issue
Block a user