backup & recover global vars for nested tune.run (#584)

* backup & recover global vars for nested tune.run

* ensure recovering global vars before return
This commit is contained in:
Chi Wang
2022-06-14 11:03:54 -07:00
committed by GitHub
parent 65fa72d583
commit 1111d6d43a
2 changed files with 155 additions and 85 deletions

View File

@@ -1,5 +1,5 @@
# !
# * Copyright (c) Microsoft Corporation. All rights reserved.
# * Copyright (c) FLAML authors. All rights reserved.
# * Licensed under the MIT License. See LICENSE file in the
# * project root for license information.
from typing import Optional, Union, List, Callable, Tuple
@@ -23,8 +23,6 @@ from .result import DEFAULT_METRIC
import logging
logger = logging.getLogger(__name__)
_use_ray = True
_runner = None
_verbose = 0
@@ -91,32 +89,35 @@ def report(_metric=None, **kwargs):
global _running_trial
global _training_iteration
if _use_ray:
from ray import tune
try:
from ray import tune
return tune.report(_metric, **kwargs)
return tune.report(_metric, **kwargs)
except ImportError:
# calling tune.report() outside tune.run()
return
result = kwargs
if _metric:
result[DEFAULT_METRIC] = _metric
trial = getattr(_runner, "running_trial", None)
if not trial:
return None
if _running_trial == trial:
_training_iteration += 1
else:
result = kwargs
if _metric:
result[DEFAULT_METRIC] = _metric
trial = getattr(_runner, "running_trial", None)
if not trial:
return None
if _running_trial == trial:
_training_iteration += 1
else:
_training_iteration = 0
_running_trial = trial
result["training_iteration"] = _training_iteration
result["config"] = trial.config
if INCUMBENT_RESULT in result["config"]:
del result["config"][INCUMBENT_RESULT]
for key, value in trial.config.items():
result["config/" + key] = value
_runner.process_trial_result(trial, result)
if _verbose > 2:
logger.info(f"result: {result}")
if trial.is_finished():
raise StopIteration
_training_iteration = 0
_running_trial = trial
result["training_iteration"] = _training_iteration
result["config"] = trial.config
if INCUMBENT_RESULT in result["config"]:
del result["config"][INCUMBENT_RESULT]
for key, value in trial.config.items():
result["config/" + key] = value
_runner.process_trial_result(trial, result)
if _verbose > 2:
logger.info(f"result: {result}")
if trial.is_finished():
raise StopIteration
def run(
@@ -146,6 +147,7 @@ def run(
max_failure: Optional[int] = 100,
use_ray: Optional[bool] = False,
use_incumbent_result_in_evaluation: Optional[bool] = None,
**ray_args,
):
"""The trigger for HPO.
@@ -296,15 +298,33 @@ def run(
max_failure: int | the maximal consecutive number of failures to sample
a trial before the tuning is terminated.
use_ray: A boolean of whether to use ray as the backend.
**ray_args: keyword arguments to pass to ray.tune.run().
Only valid when use_ray=True.
"""
global _use_ray
global _verbose
global _running_trial
global _training_iteration
old_use_ray = _use_ray
old_verbose = _verbose
old_running_trial = _running_trial
old_training_iteration = _training_iteration
if not use_ray:
_verbose = verbose
old_handlers = logger.handlers
old_level = logger.getEffectiveLevel()
logger.handlers = []
if (
old_handlers
and isinstance(old_handlers[0], logging.StreamHandler)
and not isinstance(old_handlers[0], logging.FileHandler)
):
# Add the console handler.
logger.addHandler(old_handlers[0])
if verbose > 0:
import os
if local_dir:
import os
os.makedirs(local_dir, exist_ok=True)
logger.addHandler(
logging.FileHandler(
@@ -314,7 +334,7 @@ def run(
+ ".log"
)
)
elif not logger.handlers:
elif not logger.hasHandlers():
# Add the console handler.
_ch = logging.StreamHandler()
logger_formatter = logging.Formatter(
@@ -347,7 +367,7 @@ def run(
flaml_scheduler_reduction_factor = reduction_factor
scheduler = None
try:
import optuna
import optuna as _
SearchAlgorithm = BlendSearch
except ImportError:
@@ -432,18 +452,26 @@ def run(
"Please install ray[tune] or set use_ray=False"
)
_use_ray = True
return tune.run(
evaluation_function,
metric=metric,
mode=mode,
search_alg=search_alg,
scheduler=scheduler,
time_budget_s=time_budget_s,
verbose=verbose,
local_dir=local_dir,
num_samples=num_samples,
resources_per_trial=resources_per_trial,
)
try:
analysis = tune.run(
evaluation_function,
metric=metric,
mode=mode,
search_alg=search_alg,
scheduler=scheduler,
time_budget_s=time_budget_s,
verbose=verbose,
local_dir=local_dir,
num_samples=num_samples,
resources_per_trial=resources_per_trial,
**ray_args,
)
return analysis
finally:
_use_ray = old_use_ray
_verbose = old_verbose
_running_trial = old_running_trial
_training_iteration = old_training_iteration
# simple sequential run without using tune.run() from ray
time_start = time.time()
@@ -453,45 +481,56 @@ def run(
from .trial_runner import SequentialTrialRunner
global _runner
_runner = SequentialTrialRunner(
search_alg=search_alg,
scheduler=scheduler,
metric=metric,
mode=mode,
)
num_trials = 0
if time_budget_s is None:
time_budget_s = np.inf
fail = 0
ub = (len(evaluated_rewards) if evaluated_rewards else 0) + max_failure
while (
time.time() - time_start < time_budget_s
and (num_samples < 0 or num_trials < num_samples)
and fail < ub
):
trial_to_run = _runner.step()
if trial_to_run:
num_trials += 1
if verbose:
logger.info(f"trial {num_trials} config: {trial_to_run.config}")
result = evaluation_function(trial_to_run.config)
if result is not None:
if isinstance(result, dict):
if result:
report(**result)
else:
# When the result returned is an empty dict, set the trial status to error
trial_to_run.set_status(Trial.ERROR)
else:
report(_metric=result)
_runner.stop_trial(trial_to_run)
fail = 0
else:
fail += 1 # break with ub consecutive failures
if fail == ub:
logger.warning(
f"fail to sample a trial for {max_failure} times in a row, stopping."
old_runner = _runner
try:
_runner = SequentialTrialRunner(
search_alg=search_alg,
scheduler=scheduler,
metric=metric,
mode=mode,
)
if verbose > 0:
logger.handlers.clear()
return ExperimentAnalysis(_runner.get_trials(), metric=metric, mode=mode)
num_trials = 0
if time_budget_s is None:
time_budget_s = np.inf
fail = 0
ub = (len(evaluated_rewards) if evaluated_rewards else 0) + max_failure
while (
time.time() - time_start < time_budget_s
and (num_samples < 0 or num_trials < num_samples)
and fail < ub
):
trial_to_run = _runner.step()
if trial_to_run:
num_trials += 1
if verbose:
logger.info(f"trial {num_trials} config: {trial_to_run.config}")
result = evaluation_function(trial_to_run.config)
if result is not None:
if isinstance(result, dict):
if result:
report(**result)
else:
# When the result returned is an empty dict, set the trial status to error
trial_to_run.set_status(Trial.ERROR)
else:
report(_metric=result)
_runner.stop_trial(trial_to_run)
fail = 0
else:
fail += 1 # break with ub consecutive failures
if fail == ub:
logger.warning(
f"fail to sample a trial for {max_failure} times in a row, stopping."
)
analysis = ExperimentAnalysis(_runner.get_trials(), metric=metric, mode=mode)
return analysis
finally:
# recover the global variables in case of nested run
_use_ray = old_use_ray
_verbose = old_verbose
_running_trial = old_running_trial
_training_iteration = old_training_iteration
_runner = old_runner
if not use_ray:
logger.handlers = old_handlers
logger.setLevel(old_level)

View File

@@ -20,6 +20,37 @@ logger.addHandler(logging.FileHandler("logs/tune.log"))
logger.setLevel(logging.INFO)
def test_nested_run():
from flaml import AutoML, tune
data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_x, val_x, y_train, y_val = train_test_split(data, labels, test_size=0.25)
space_pca = {
"n_components": tune.uniform(0.5, 0.99),
}
def pca_flaml(config):
n_components = config["n_components"]
from sklearn.decomposition import PCA
pca = PCA(n_components)
X_train = pca.fit_transform(train_x)
X_val = pca.transform(val_x)
automl = AutoML()
automl.fit(X_train, y_train, X_val=X_val, y_val=y_val, time_budget=1)
return {"loss": automl.best_loss}
analysis = tune.run(
pca_flaml,
space_pca,
metric="loss",
mode="min",
num_samples=5,
local_dir="logs",
)
print(analysis.best_result)
def train_breast_cancer(config: dict):
# This is a simple training function to be passed into Tune
# Load dataset
@@ -182,7 +213,7 @@ def _test_xgboost(method="BlendSearch"):
logger.info(f"Best model parameters: {best_trial.config}")
def test_nested():
def test_nested_space():
from flaml import tune, CFO
search_space = {