move searcher and scheduler into tune (#746)

* move into tune

* correct path

* correct path

* import path
This commit is contained in:
Chi Wang
2022-10-04 16:03:22 -07:00
committed by GitHub
parent 3e3ce3e33e
commit 860cbc233e
27 changed files with 46 additions and 44 deletions

View File

@@ -1,4 +1,4 @@
from flaml.searcher import CFO, BlendSearch, FLOW2, BlendSearchTuner, RandomSearch
from flaml.tune.searcher import CFO, BlendSearch, FLOW2, BlendSearchTuner, RandomSearch
from flaml.automl import AutoML, logger_formatter
from flaml.onlineml.autovw import AutoVW
from flaml.version import __version__

View File

@@ -2961,7 +2961,7 @@ class AutoML(BaseEstimator):
else:
from ray.tune.search.optuna import OptunaSearch as SearchAlgo
except (ImportError, AssertionError):
from .searcher.suggestion import OptunaSearch as SearchAlgo
from flaml.tune.searcher.suggestion import OptunaSearch as SearchAlgo
else:
raise NotImplementedError(
f"hpo_method={self._hpo_method} is not recognized. "
@@ -3125,7 +3125,7 @@ class AutoML(BaseEstimator):
else:
from ray.tune.search import ConcurrencyLimiter
except (ImportError, AssertionError):
from .searcher.suggestion import ConcurrencyLimiter
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
if self._hpo_method in ("cfo", "grid"):
from flaml import CFO as SearchAlgo
elif "optuna" == self._hpo_method:
@@ -3138,13 +3138,13 @@ class AutoML(BaseEstimator):
else:
from ray.tune.search.optuna import OptunaSearch as SearchAlgo
except (ImportError, AssertionError):
from .searcher.suggestion import OptunaSearch as SearchAlgo
from flaml.tune.searcher.suggestion import OptunaSearch as SearchAlgo
elif "bs" == self._hpo_method:
from flaml import BlendSearch as SearchAlgo
elif "random" == self._hpo_method:
from flaml.searcher import RandomSearch as SearchAlgo
from flaml.tune.searcher import RandomSearch as SearchAlgo
elif "cfocat" == self._hpo_method:
from flaml.searcher.cfo_cat import CFOCat as SearchAlgo
from flaml.tune.searcher.cfo_cat import CFOCat as SearchAlgo
else:
raise NotImplementedError(
f"hpo_method={self._hpo_method} is not recognized. "

View File

@@ -8,8 +8,8 @@ from flaml.tune import (
polynomial_expansion_set,
)
from flaml.onlineml import OnlineTrialRunner
from flaml.scheduler import ChaChaScheduler
from flaml.searcher import ChampionFrontierSearcher
from flaml.tune.scheduler import ChaChaScheduler
from flaml.tune.searcher import ChampionFrontierSearcher
from flaml.onlineml.trial import get_ns_feature_dim_from_vw_example
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,7 @@
import numpy as np
import math
from flaml.tune import Trial
from flaml.scheduler import TrialScheduler
from flaml.tune.scheduler import TrialScheduler
import logging

View File

@@ -1,7 +1,7 @@
import numpy as np
import logging
from typing import Dict
from flaml.scheduler import TrialScheduler
from flaml.tune.scheduler import TrialScheduler
from flaml.tune import Trial
logger = logging.getLogger(__name__)

View File

@@ -20,12 +20,12 @@ try:
except (ImportError, AssertionError):
from .suggestion import Searcher
from .suggestion import OptunaSearch as GlobalSearch
from ..tune.trial import unflatten_dict, flatten_dict
from ..tune import INCUMBENT_RESULT
from ..trial import unflatten_dict, flatten_dict
from .. import INCUMBENT_RESULT
from .search_thread import SearchThread
from .flow2 import FLOW2
from ..tune.space import add_cost_to_space, indexof, normalize, define_by_run_func
from ..tune.result import TIME_TOTAL_S
from ..space import add_cost_to_space, indexof, normalize, define_by_run_func
from ..result import TIME_TOTAL_S
import logging
@@ -939,7 +939,7 @@ try:
qloguniform,
)
except (ImportError, AssertionError):
from ..tune.sample import (
from ..sample import (
uniform,
quniform,
choice,

View File

@@ -18,11 +18,10 @@ try:
from ray.tune.utils.util import flatten_dict, unflatten_dict
except (ImportError, AssertionError):
from .suggestion import Searcher
from ..tune import sample
from ..tune.trial import flatten_dict, unflatten_dict
from flaml.tune.sample import _BackwardsCompatibleNumpyRng
from flaml.tune import sample
from ..trial import flatten_dict, unflatten_dict
from flaml.config import SAMPLE_MULTIPLY_FACTOR
from ..tune.space import (
from ..space import (
complete_config,
denormalize,
normalize,
@@ -85,7 +84,7 @@ class FLOW2(Searcher):
self.space = space or {}
self._space = flatten_dict(self.space, prevent_delimiter=True)
self._random = np.random.RandomState(seed)
self.rs_random = _BackwardsCompatibleNumpyRng(seed + 19823)
self.rs_random = sample._BackwardsCompatibleNumpyRng(seed + 19823)
self.seed = seed
self.init_config = init_config
self.best_config = flatten_dict(init_config)

View File

@@ -2,10 +2,9 @@ import numpy as np
import logging
import itertools
from typing import Dict, Optional, List
from flaml.tune import Categorical, Float, PolynomialExpansionSet
from flaml.tune import Trial
from flaml.tune import Categorical, Float, PolynomialExpansionSet, Trial
from flaml.onlineml import VowpalWabbitTrial
from flaml.searcher import CFO
from flaml.tune.searcher import CFO
logger = logging.getLogger(__name__)

View File

@@ -16,7 +16,7 @@ try:
except (ImportError, AssertionError):
from .suggestion import Searcher
from .flow2 import FLOW2
from ..tune.space import add_cost_to_space, unflatten_hierarchical
from ..space import add_cost_to_space, unflatten_hierarchical
import logging
logger = logging.getLogger(__name__)

View File

@@ -23,7 +23,7 @@ import logging
from typing import Any, Dict, Optional, Union, List, Tuple, Callable
import pickle
from .variant_generator import parse_spec_vars
from ..tune.sample import (
from ..sample import (
Categorical,
Domain,
Float,
@@ -32,7 +32,7 @@ from ..tune.sample import (
Quantized,
Uniform,
)
from ..tune.trial import flatten_dict, unflatten_dict
from ..trial import flatten_dict, unflatten_dict
logger = logging.getLogger(__name__)

View File

@@ -20,7 +20,7 @@ import logging
from typing import Any, Dict, Generator, List, Tuple
import numpy
import random
from ..tune.sample import Categorical, Domain, RandomState
from ..sample import Categorical, Domain, RandomState
try:
from ray import __version__ as ray_version

View File

@@ -10,7 +10,7 @@ try:
from ray.tune.search.variant_generator import generate_variants
except (ImportError, AssertionError):
from . import sample
from ..searcher.variant_generator import generate_variants
from .searcher.variant_generator import generate_variants
from typing import Dict, Optional, Any, Tuple, Generator
import numpy as np
import logging

View File

@@ -355,7 +355,7 @@ def run(
else:
logger.setLevel(logging.CRITICAL)
from ..searcher.blendsearch import BlendSearch, CFO
from .searcher.blendsearch import BlendSearch, CFO
if search_alg is None:
flaml_scheduler_resource_attr = (
@@ -409,7 +409,7 @@ def run(
else:
from ray.tune.search import ConcurrencyLimiter
else:
from flaml.searcher.suggestion import ConcurrencyLimiter
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
if (
search_alg.__class__.__name__
in [

View File

@@ -1,4 +1,4 @@
from flaml.searcher.blendsearch import BlendSearchTuner as BST
from flaml.tune.searcher.blendsearch import BlendSearchTuner as BST
class BlendSearchTuner(BST):

View File

@@ -26,7 +26,7 @@ def easy_objective(use_raytune, config):
def test_tune_scheduler(smoke_test=True, use_ray=True, use_raytune=False):
import numpy as np
from flaml.searcher.blendsearch import BlendSearch
from flaml.tune.searcher.blendsearch import BlendSearch
np.random.seed(100)
easy_objective_custom_tune = partial(easy_objective, use_raytune)

View File

@@ -28,7 +28,7 @@ low_cost_partial_config = {"x": 1}
def setup_searcher(searcher_name):
from flaml.searcher.blendsearch import BlendSearch, CFO, RandomSearch
from flaml.tune.searcher.blendsearch import BlendSearch, CFO, RandomSearch
if "cfo" in searcher_name:
searcher = CFO(

View File

@@ -7,7 +7,7 @@ def rosenbrock_function(config: dict):
funcLoss = 50
for key, value in config.items():
if key in ["x1", "x2", "x3", "x4", "x5"]:
funcLoss += value ** 2 - 10 * np.cos(2 * np.pi * value)
funcLoss += value**2 - 10 * np.cos(2 * np.pi * value)
if INCUMBENT_RESULT in config.keys():
print("----------------------------------------------")
print("incumbent result", config[INCUMBENT_RESULT])
@@ -62,7 +62,7 @@ def test_record_incumbent(method="BlendSearch"):
use_incumbent_result_in_evaluation=True,
)
elif method == "CFOCat":
from flaml.searcher.cfo_cat import CFOCat
from flaml.tune.searcher.cfo_cat import CFOCat
algo = CFOCat(
use_incumbent_result_in_evaluation=True,

View File

@@ -26,7 +26,7 @@ def _easy_objective(use_raytune, config):
def test_tune(externally_setup_searcher=False, use_ray=False, use_raytune=False):
from flaml import tune
from flaml.searcher.blendsearch import BlendSearch
from flaml.tune.searcher.blendsearch import BlendSearch
easy_objective_custom_tune = partial(_easy_objective, use_raytune)
search_space = {

View File

@@ -3,7 +3,7 @@ import shutil
import tempfile
import unittest
import numpy as np
from flaml.searcher.suggestion import ConcurrencyLimiter
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
from flaml import tune
from flaml import CFO

View File

@@ -1,6 +1,6 @@
"""Require: pip install flaml[test,ray]
"""
from flaml.scheduler.trial_scheduler import TrialScheduler
from flaml.tune.scheduler.trial_scheduler import TrialScheduler
import numpy as np
from flaml import tune

View File

@@ -32,8 +32,12 @@ def wrong_define_search_space(trial):
def test_searcher():
from flaml.searcher.suggestion import OptunaSearch, Searcher, ConcurrencyLimiter
from flaml.searcher.blendsearch import BlendSearch, CFO, RandomSearch
from flaml.tune.searcher.suggestion import (
OptunaSearch,
Searcher,
ConcurrencyLimiter,
)
from flaml.tune.searcher.blendsearch import BlendSearch, CFO, RandomSearch
from flaml.tune import sample as flamlsample
searcher = Searcher()
@@ -306,6 +310,6 @@ def test_no_optuna():
import sys
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "optuna"])
import flaml.searcher.suggestion
import flaml.tune.searcher.suggestion
subprocess.check_call([sys.executable, "-m", "pip", "install", "optuna==2.8.0"])

View File

@@ -69,7 +69,7 @@ def test_define_by_run():
def test_grid():
from flaml.searcher.variant_generator import (
from flaml.tune.searcher.variant_generator import (
generate_variants,
grid_search,
TuneError,

View File

@@ -1,6 +1,6 @@
"""Require: pip install flaml[test,ray]
"""
from flaml.searcher.blendsearch import BlendSearch
from flaml import BlendSearch
import time
import os
from sklearn.model_selection import train_test_split
@@ -146,7 +146,7 @@ def _test_xgboost(method="BlendSearch"):
},
)
elif "CFOCat" == method:
from flaml.searcher.cfo_cat import CFOCat
from flaml.tune.searcher.cfo_cat import CFOCat
algo = CFOCat(
low_cost_partial_config={