mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
backup & recover global vars for nested tune.run (#584)
* backup & recover global vars for nested tune.run * ensure recovering global vars before return
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# !
|
||||
# * Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# * Copyright (c) FLAML authors. All rights reserved.
|
||||
# * Licensed under the MIT License. See LICENSE file in the
|
||||
# * project root for license information.
|
||||
from typing import Optional, Union, List, Callable, Tuple
|
||||
@@ -23,8 +23,6 @@ from .result import DEFAULT_METRIC
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_use_ray = True
|
||||
_runner = None
|
||||
_verbose = 0
|
||||
@@ -91,32 +89,35 @@ def report(_metric=None, **kwargs):
|
||||
global _running_trial
|
||||
global _training_iteration
|
||||
if _use_ray:
|
||||
from ray import tune
|
||||
try:
|
||||
from ray import tune
|
||||
|
||||
return tune.report(_metric, **kwargs)
|
||||
return tune.report(_metric, **kwargs)
|
||||
except ImportError:
|
||||
# calling tune.report() outside tune.run()
|
||||
return
|
||||
result = kwargs
|
||||
if _metric:
|
||||
result[DEFAULT_METRIC] = _metric
|
||||
trial = getattr(_runner, "running_trial", None)
|
||||
if not trial:
|
||||
return None
|
||||
if _running_trial == trial:
|
||||
_training_iteration += 1
|
||||
else:
|
||||
result = kwargs
|
||||
if _metric:
|
||||
result[DEFAULT_METRIC] = _metric
|
||||
trial = getattr(_runner, "running_trial", None)
|
||||
if not trial:
|
||||
return None
|
||||
if _running_trial == trial:
|
||||
_training_iteration += 1
|
||||
else:
|
||||
_training_iteration = 0
|
||||
_running_trial = trial
|
||||
result["training_iteration"] = _training_iteration
|
||||
result["config"] = trial.config
|
||||
if INCUMBENT_RESULT in result["config"]:
|
||||
del result["config"][INCUMBENT_RESULT]
|
||||
for key, value in trial.config.items():
|
||||
result["config/" + key] = value
|
||||
_runner.process_trial_result(trial, result)
|
||||
if _verbose > 2:
|
||||
logger.info(f"result: {result}")
|
||||
if trial.is_finished():
|
||||
raise StopIteration
|
||||
_training_iteration = 0
|
||||
_running_trial = trial
|
||||
result["training_iteration"] = _training_iteration
|
||||
result["config"] = trial.config
|
||||
if INCUMBENT_RESULT in result["config"]:
|
||||
del result["config"][INCUMBENT_RESULT]
|
||||
for key, value in trial.config.items():
|
||||
result["config/" + key] = value
|
||||
_runner.process_trial_result(trial, result)
|
||||
if _verbose > 2:
|
||||
logger.info(f"result: {result}")
|
||||
if trial.is_finished():
|
||||
raise StopIteration
|
||||
|
||||
|
||||
def run(
|
||||
@@ -146,6 +147,7 @@ def run(
|
||||
max_failure: Optional[int] = 100,
|
||||
use_ray: Optional[bool] = False,
|
||||
use_incumbent_result_in_evaluation: Optional[bool] = None,
|
||||
**ray_args,
|
||||
):
|
||||
"""The trigger for HPO.
|
||||
|
||||
@@ -296,15 +298,33 @@ def run(
|
||||
max_failure: int | the maximal consecutive number of failures to sample
|
||||
a trial before the tuning is terminated.
|
||||
use_ray: A boolean of whether to use ray as the backend.
|
||||
**ray_args: keyword arguments to pass to ray.tune.run().
|
||||
Only valid when use_ray=True.
|
||||
"""
|
||||
global _use_ray
|
||||
global _verbose
|
||||
global _running_trial
|
||||
global _training_iteration
|
||||
old_use_ray = _use_ray
|
||||
old_verbose = _verbose
|
||||
old_running_trial = _running_trial
|
||||
old_training_iteration = _training_iteration
|
||||
if not use_ray:
|
||||
_verbose = verbose
|
||||
old_handlers = logger.handlers
|
||||
old_level = logger.getEffectiveLevel()
|
||||
logger.handlers = []
|
||||
if (
|
||||
old_handlers
|
||||
and isinstance(old_handlers[0], logging.StreamHandler)
|
||||
and not isinstance(old_handlers[0], logging.FileHandler)
|
||||
):
|
||||
# Add the console handler.
|
||||
logger.addHandler(old_handlers[0])
|
||||
if verbose > 0:
|
||||
import os
|
||||
|
||||
if local_dir:
|
||||
import os
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
logger.addHandler(
|
||||
logging.FileHandler(
|
||||
@@ -314,7 +334,7 @@ def run(
|
||||
+ ".log"
|
||||
)
|
||||
)
|
||||
elif not logger.handlers:
|
||||
elif not logger.hasHandlers():
|
||||
# Add the console handler.
|
||||
_ch = logging.StreamHandler()
|
||||
logger_formatter = logging.Formatter(
|
||||
@@ -347,7 +367,7 @@ def run(
|
||||
flaml_scheduler_reduction_factor = reduction_factor
|
||||
scheduler = None
|
||||
try:
|
||||
import optuna
|
||||
import optuna as _
|
||||
|
||||
SearchAlgorithm = BlendSearch
|
||||
except ImportError:
|
||||
@@ -432,18 +452,26 @@ def run(
|
||||
"Please install ray[tune] or set use_ray=False"
|
||||
)
|
||||
_use_ray = True
|
||||
return tune.run(
|
||||
evaluation_function,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
time_budget_s=time_budget_s,
|
||||
verbose=verbose,
|
||||
local_dir=local_dir,
|
||||
num_samples=num_samples,
|
||||
resources_per_trial=resources_per_trial,
|
||||
)
|
||||
try:
|
||||
analysis = tune.run(
|
||||
evaluation_function,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
time_budget_s=time_budget_s,
|
||||
verbose=verbose,
|
||||
local_dir=local_dir,
|
||||
num_samples=num_samples,
|
||||
resources_per_trial=resources_per_trial,
|
||||
**ray_args,
|
||||
)
|
||||
return analysis
|
||||
finally:
|
||||
_use_ray = old_use_ray
|
||||
_verbose = old_verbose
|
||||
_running_trial = old_running_trial
|
||||
_training_iteration = old_training_iteration
|
||||
|
||||
# simple sequential run without using tune.run() from ray
|
||||
time_start = time.time()
|
||||
@@ -453,45 +481,56 @@ def run(
|
||||
from .trial_runner import SequentialTrialRunner
|
||||
|
||||
global _runner
|
||||
_runner = SequentialTrialRunner(
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
)
|
||||
num_trials = 0
|
||||
if time_budget_s is None:
|
||||
time_budget_s = np.inf
|
||||
fail = 0
|
||||
ub = (len(evaluated_rewards) if evaluated_rewards else 0) + max_failure
|
||||
while (
|
||||
time.time() - time_start < time_budget_s
|
||||
and (num_samples < 0 or num_trials < num_samples)
|
||||
and fail < ub
|
||||
):
|
||||
trial_to_run = _runner.step()
|
||||
if trial_to_run:
|
||||
num_trials += 1
|
||||
if verbose:
|
||||
logger.info(f"trial {num_trials} config: {trial_to_run.config}")
|
||||
result = evaluation_function(trial_to_run.config)
|
||||
if result is not None:
|
||||
if isinstance(result, dict):
|
||||
if result:
|
||||
report(**result)
|
||||
else:
|
||||
# When the result returned is an empty dict, set the trial status to error
|
||||
trial_to_run.set_status(Trial.ERROR)
|
||||
else:
|
||||
report(_metric=result)
|
||||
_runner.stop_trial(trial_to_run)
|
||||
fail = 0
|
||||
else:
|
||||
fail += 1 # break with ub consecutive failures
|
||||
if fail == ub:
|
||||
logger.warning(
|
||||
f"fail to sample a trial for {max_failure} times in a row, stopping."
|
||||
old_runner = _runner
|
||||
try:
|
||||
_runner = SequentialTrialRunner(
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
)
|
||||
if verbose > 0:
|
||||
logger.handlers.clear()
|
||||
return ExperimentAnalysis(_runner.get_trials(), metric=metric, mode=mode)
|
||||
num_trials = 0
|
||||
if time_budget_s is None:
|
||||
time_budget_s = np.inf
|
||||
fail = 0
|
||||
ub = (len(evaluated_rewards) if evaluated_rewards else 0) + max_failure
|
||||
while (
|
||||
time.time() - time_start < time_budget_s
|
||||
and (num_samples < 0 or num_trials < num_samples)
|
||||
and fail < ub
|
||||
):
|
||||
trial_to_run = _runner.step()
|
||||
if trial_to_run:
|
||||
num_trials += 1
|
||||
if verbose:
|
||||
logger.info(f"trial {num_trials} config: {trial_to_run.config}")
|
||||
result = evaluation_function(trial_to_run.config)
|
||||
if result is not None:
|
||||
if isinstance(result, dict):
|
||||
if result:
|
||||
report(**result)
|
||||
else:
|
||||
# When the result returned is an empty dict, set the trial status to error
|
||||
trial_to_run.set_status(Trial.ERROR)
|
||||
else:
|
||||
report(_metric=result)
|
||||
_runner.stop_trial(trial_to_run)
|
||||
fail = 0
|
||||
else:
|
||||
fail += 1 # break with ub consecutive failures
|
||||
if fail == ub:
|
||||
logger.warning(
|
||||
f"fail to sample a trial for {max_failure} times in a row, stopping."
|
||||
)
|
||||
analysis = ExperimentAnalysis(_runner.get_trials(), metric=metric, mode=mode)
|
||||
return analysis
|
||||
finally:
|
||||
# recover the global variables in case of nested run
|
||||
_use_ray = old_use_ray
|
||||
_verbose = old_verbose
|
||||
_running_trial = old_running_trial
|
||||
_training_iteration = old_training_iteration
|
||||
_runner = old_runner
|
||||
if not use_ray:
|
||||
logger.handlers = old_handlers
|
||||
logger.setLevel(old_level)
|
||||
|
||||
@@ -20,6 +20,37 @@ logger.addHandler(logging.FileHandler("logs/tune.log"))
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def test_nested_run():
|
||||
from flaml import AutoML, tune
|
||||
|
||||
data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
|
||||
train_x, val_x, y_train, y_val = train_test_split(data, labels, test_size=0.25)
|
||||
space_pca = {
|
||||
"n_components": tune.uniform(0.5, 0.99),
|
||||
}
|
||||
|
||||
def pca_flaml(config):
|
||||
n_components = config["n_components"]
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
pca = PCA(n_components)
|
||||
X_train = pca.fit_transform(train_x)
|
||||
X_val = pca.transform(val_x)
|
||||
automl = AutoML()
|
||||
automl.fit(X_train, y_train, X_val=X_val, y_val=y_val, time_budget=1)
|
||||
return {"loss": automl.best_loss}
|
||||
|
||||
analysis = tune.run(
|
||||
pca_flaml,
|
||||
space_pca,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
num_samples=5,
|
||||
local_dir="logs",
|
||||
)
|
||||
print(analysis.best_result)
|
||||
|
||||
|
||||
def train_breast_cancer(config: dict):
|
||||
# This is a simple training function to be passed into Tune
|
||||
# Load dataset
|
||||
@@ -182,7 +213,7 @@ def _test_xgboost(method="BlendSearch"):
|
||||
logger.info(f"Best model parameters: {best_trial.config}")
|
||||
|
||||
|
||||
def test_nested():
|
||||
def test_nested_space():
|
||||
from flaml import tune, CFO
|
||||
|
||||
search_space = {
|
||||
|
||||
Reference in New Issue
Block a user