mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
move searcher and scheduler into tune (#746)
* move into tune * correct path * correct path * import path
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from flaml.searcher import CFO, BlendSearch, FLOW2, BlendSearchTuner, RandomSearch
|
||||
from flaml.tune.searcher import CFO, BlendSearch, FLOW2, BlendSearchTuner, RandomSearch
|
||||
from flaml.automl import AutoML, logger_formatter
|
||||
from flaml.onlineml.autovw import AutoVW
|
||||
from flaml.version import __version__
|
||||
|
||||
@@ -2961,7 +2961,7 @@ class AutoML(BaseEstimator):
|
||||
else:
|
||||
from ray.tune.search.optuna import OptunaSearch as SearchAlgo
|
||||
except (ImportError, AssertionError):
|
||||
from .searcher.suggestion import OptunaSearch as SearchAlgo
|
||||
from flaml.tune.searcher.suggestion import OptunaSearch as SearchAlgo
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"hpo_method={self._hpo_method} is not recognized. "
|
||||
@@ -3125,7 +3125,7 @@ class AutoML(BaseEstimator):
|
||||
else:
|
||||
from ray.tune.search import ConcurrencyLimiter
|
||||
except (ImportError, AssertionError):
|
||||
from .searcher.suggestion import ConcurrencyLimiter
|
||||
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
|
||||
if self._hpo_method in ("cfo", "grid"):
|
||||
from flaml import CFO as SearchAlgo
|
||||
elif "optuna" == self._hpo_method:
|
||||
@@ -3138,13 +3138,13 @@ class AutoML(BaseEstimator):
|
||||
else:
|
||||
from ray.tune.search.optuna import OptunaSearch as SearchAlgo
|
||||
except (ImportError, AssertionError):
|
||||
from .searcher.suggestion import OptunaSearch as SearchAlgo
|
||||
from flaml.tune.searcher.suggestion import OptunaSearch as SearchAlgo
|
||||
elif "bs" == self._hpo_method:
|
||||
from flaml import BlendSearch as SearchAlgo
|
||||
elif "random" == self._hpo_method:
|
||||
from flaml.searcher import RandomSearch as SearchAlgo
|
||||
from flaml.tune.searcher import RandomSearch as SearchAlgo
|
||||
elif "cfocat" == self._hpo_method:
|
||||
from flaml.searcher.cfo_cat import CFOCat as SearchAlgo
|
||||
from flaml.tune.searcher.cfo_cat import CFOCat as SearchAlgo
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"hpo_method={self._hpo_method} is not recognized. "
|
||||
|
||||
@@ -8,8 +8,8 @@ from flaml.tune import (
|
||||
polynomial_expansion_set,
|
||||
)
|
||||
from flaml.onlineml import OnlineTrialRunner
|
||||
from flaml.scheduler import ChaChaScheduler
|
||||
from flaml.searcher import ChampionFrontierSearcher
|
||||
from flaml.tune.scheduler import ChaChaScheduler
|
||||
from flaml.tune.searcher import ChampionFrontierSearcher
|
||||
from flaml.onlineml.trial import get_ns_feature_dim_from_vw_example
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import math
|
||||
from flaml.tune import Trial
|
||||
from flaml.scheduler import TrialScheduler
|
||||
from flaml.tune.scheduler import TrialScheduler
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Dict
|
||||
from flaml.scheduler import TrialScheduler
|
||||
from flaml.tune.scheduler import TrialScheduler
|
||||
from flaml.tune import Trial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,12 +20,12 @@ try:
|
||||
except (ImportError, AssertionError):
|
||||
from .suggestion import Searcher
|
||||
from .suggestion import OptunaSearch as GlobalSearch
|
||||
from ..tune.trial import unflatten_dict, flatten_dict
|
||||
from ..tune import INCUMBENT_RESULT
|
||||
from ..trial import unflatten_dict, flatten_dict
|
||||
from .. import INCUMBENT_RESULT
|
||||
from .search_thread import SearchThread
|
||||
from .flow2 import FLOW2
|
||||
from ..tune.space import add_cost_to_space, indexof, normalize, define_by_run_func
|
||||
from ..tune.result import TIME_TOTAL_S
|
||||
from ..space import add_cost_to_space, indexof, normalize, define_by_run_func
|
||||
from ..result import TIME_TOTAL_S
|
||||
|
||||
import logging
|
||||
|
||||
@@ -939,7 +939,7 @@ try:
|
||||
qloguniform,
|
||||
)
|
||||
except (ImportError, AssertionError):
|
||||
from ..tune.sample import (
|
||||
from ..sample import (
|
||||
uniform,
|
||||
quniform,
|
||||
choice,
|
||||
@@ -18,11 +18,10 @@ try:
|
||||
from ray.tune.utils.util import flatten_dict, unflatten_dict
|
||||
except (ImportError, AssertionError):
|
||||
from .suggestion import Searcher
|
||||
from ..tune import sample
|
||||
from ..tune.trial import flatten_dict, unflatten_dict
|
||||
from flaml.tune.sample import _BackwardsCompatibleNumpyRng
|
||||
from flaml.tune import sample
|
||||
from ..trial import flatten_dict, unflatten_dict
|
||||
from flaml.config import SAMPLE_MULTIPLY_FACTOR
|
||||
from ..tune.space import (
|
||||
from ..space import (
|
||||
complete_config,
|
||||
denormalize,
|
||||
normalize,
|
||||
@@ -85,7 +84,7 @@ class FLOW2(Searcher):
|
||||
self.space = space or {}
|
||||
self._space = flatten_dict(self.space, prevent_delimiter=True)
|
||||
self._random = np.random.RandomState(seed)
|
||||
self.rs_random = _BackwardsCompatibleNumpyRng(seed + 19823)
|
||||
self.rs_random = sample._BackwardsCompatibleNumpyRng(seed + 19823)
|
||||
self.seed = seed
|
||||
self.init_config = init_config
|
||||
self.best_config = flatten_dict(init_config)
|
||||
@@ -2,10 +2,9 @@ import numpy as np
|
||||
import logging
|
||||
import itertools
|
||||
from typing import Dict, Optional, List
|
||||
from flaml.tune import Categorical, Float, PolynomialExpansionSet
|
||||
from flaml.tune import Trial
|
||||
from flaml.tune import Categorical, Float, PolynomialExpansionSet, Trial
|
||||
from flaml.onlineml import VowpalWabbitTrial
|
||||
from flaml.searcher import CFO
|
||||
from flaml.tune.searcher import CFO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +16,7 @@ try:
|
||||
except (ImportError, AssertionError):
|
||||
from .suggestion import Searcher
|
||||
from .flow2 import FLOW2
|
||||
from ..tune.space import add_cost_to_space, unflatten_hierarchical
|
||||
from ..space import add_cost_to_space, unflatten_hierarchical
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,7 +23,7 @@ import logging
|
||||
from typing import Any, Dict, Optional, Union, List, Tuple, Callable
|
||||
import pickle
|
||||
from .variant_generator import parse_spec_vars
|
||||
from ..tune.sample import (
|
||||
from ..sample import (
|
||||
Categorical,
|
||||
Domain,
|
||||
Float,
|
||||
@@ -32,7 +32,7 @@ from ..tune.sample import (
|
||||
Quantized,
|
||||
Uniform,
|
||||
)
|
||||
from ..tune.trial import flatten_dict, unflatten_dict
|
||||
from ..trial import flatten_dict, unflatten_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
import numpy
|
||||
import random
|
||||
from ..tune.sample import Categorical, Domain, RandomState
|
||||
from ..sample import Categorical, Domain, RandomState
|
||||
|
||||
try:
|
||||
from ray import __version__ as ray_version
|
||||
@@ -10,7 +10,7 @@ try:
|
||||
from ray.tune.search.variant_generator import generate_variants
|
||||
except (ImportError, AssertionError):
|
||||
from . import sample
|
||||
from ..searcher.variant_generator import generate_variants
|
||||
from .searcher.variant_generator import generate_variants
|
||||
from typing import Dict, Optional, Any, Tuple, Generator
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
@@ -355,7 +355,7 @@ def run(
|
||||
else:
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
|
||||
from ..searcher.blendsearch import BlendSearch, CFO
|
||||
from .searcher.blendsearch import BlendSearch, CFO
|
||||
|
||||
if search_alg is None:
|
||||
flaml_scheduler_resource_attr = (
|
||||
@@ -409,7 +409,7 @@ def run(
|
||||
else:
|
||||
from ray.tune.search import ConcurrencyLimiter
|
||||
else:
|
||||
from flaml.searcher.suggestion import ConcurrencyLimiter
|
||||
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
|
||||
if (
|
||||
search_alg.__class__.__name__
|
||||
in [
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flaml.searcher.blendsearch import BlendSearchTuner as BST
|
||||
from flaml.tune.searcher.blendsearch import BlendSearchTuner as BST
|
||||
|
||||
|
||||
class BlendSearchTuner(BST):
|
||||
|
||||
@@ -26,7 +26,7 @@ def easy_objective(use_raytune, config):
|
||||
|
||||
def test_tune_scheduler(smoke_test=True, use_ray=True, use_raytune=False):
|
||||
import numpy as np
|
||||
from flaml.searcher.blendsearch import BlendSearch
|
||||
from flaml.tune.searcher.blendsearch import BlendSearch
|
||||
|
||||
np.random.seed(100)
|
||||
easy_objective_custom_tune = partial(easy_objective, use_raytune)
|
||||
|
||||
@@ -28,7 +28,7 @@ low_cost_partial_config = {"x": 1}
|
||||
|
||||
|
||||
def setup_searcher(searcher_name):
|
||||
from flaml.searcher.blendsearch import BlendSearch, CFO, RandomSearch
|
||||
from flaml.tune.searcher.blendsearch import BlendSearch, CFO, RandomSearch
|
||||
|
||||
if "cfo" in searcher_name:
|
||||
searcher = CFO(
|
||||
|
||||
@@ -7,7 +7,7 @@ def rosenbrock_function(config: dict):
|
||||
funcLoss = 50
|
||||
for key, value in config.items():
|
||||
if key in ["x1", "x2", "x3", "x4", "x5"]:
|
||||
funcLoss += value ** 2 - 10 * np.cos(2 * np.pi * value)
|
||||
funcLoss += value**2 - 10 * np.cos(2 * np.pi * value)
|
||||
if INCUMBENT_RESULT in config.keys():
|
||||
print("----------------------------------------------")
|
||||
print("incumbent result", config[INCUMBENT_RESULT])
|
||||
@@ -62,7 +62,7 @@ def test_record_incumbent(method="BlendSearch"):
|
||||
use_incumbent_result_in_evaluation=True,
|
||||
)
|
||||
elif method == "CFOCat":
|
||||
from flaml.searcher.cfo_cat import CFOCat
|
||||
from flaml.tune.searcher.cfo_cat import CFOCat
|
||||
|
||||
algo = CFOCat(
|
||||
use_incumbent_result_in_evaluation=True,
|
||||
|
||||
@@ -26,7 +26,7 @@ def _easy_objective(use_raytune, config):
|
||||
|
||||
def test_tune(externally_setup_searcher=False, use_ray=False, use_raytune=False):
|
||||
from flaml import tune
|
||||
from flaml.searcher.blendsearch import BlendSearch
|
||||
from flaml.tune.searcher.blendsearch import BlendSearch
|
||||
|
||||
easy_objective_custom_tune = partial(_easy_objective, use_raytune)
|
||||
search_space = {
|
||||
|
||||
@@ -3,7 +3,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import numpy as np
|
||||
from flaml.searcher.suggestion import ConcurrencyLimiter
|
||||
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
|
||||
from flaml import tune
|
||||
from flaml import CFO
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Require: pip install flaml[test,ray]
|
||||
"""
|
||||
from flaml.scheduler.trial_scheduler import TrialScheduler
|
||||
from flaml.tune.scheduler.trial_scheduler import TrialScheduler
|
||||
import numpy as np
|
||||
from flaml import tune
|
||||
|
||||
|
||||
@@ -32,8 +32,12 @@ def wrong_define_search_space(trial):
|
||||
|
||||
|
||||
def test_searcher():
|
||||
from flaml.searcher.suggestion import OptunaSearch, Searcher, ConcurrencyLimiter
|
||||
from flaml.searcher.blendsearch import BlendSearch, CFO, RandomSearch
|
||||
from flaml.tune.searcher.suggestion import (
|
||||
OptunaSearch,
|
||||
Searcher,
|
||||
ConcurrencyLimiter,
|
||||
)
|
||||
from flaml.tune.searcher.blendsearch import BlendSearch, CFO, RandomSearch
|
||||
from flaml.tune import sample as flamlsample
|
||||
|
||||
searcher = Searcher()
|
||||
@@ -306,6 +310,6 @@ def test_no_optuna():
|
||||
import sys
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "optuna"])
|
||||
import flaml.searcher.suggestion
|
||||
import flaml.tune.searcher.suggestion
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "optuna==2.8.0"])
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_define_by_run():
|
||||
|
||||
|
||||
def test_grid():
|
||||
from flaml.searcher.variant_generator import (
|
||||
from flaml.tune.searcher.variant_generator import (
|
||||
generate_variants,
|
||||
grid_search,
|
||||
TuneError,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Require: pip install flaml[test,ray]
|
||||
"""
|
||||
from flaml.searcher.blendsearch import BlendSearch
|
||||
from flaml import BlendSearch
|
||||
import time
|
||||
import os
|
||||
from sklearn.model_selection import train_test_split
|
||||
@@ -146,7 +146,7 @@ def _test_xgboost(method="BlendSearch"):
|
||||
},
|
||||
)
|
||||
elif "CFOCat" == method:
|
||||
from flaml.searcher.cfo_cat import CFOCat
|
||||
from flaml.tune.searcher.cfo_cat import CFOCat
|
||||
|
||||
algo = CFOCat(
|
||||
low_cost_partial_config={
|
||||
|
||||
Reference in New Issue
Block a user