Fix HPO evaluation bug (#645)

* fix eval automl metric bug on val_loss inconsistency

* updating starting point search space to continuous

* shortening notebok
This commit is contained in:
Xueqing Liu
2022-07-28 23:08:42 -04:00
committed by GitHub
parent d649fefa6b
commit 5eb5d43d7f
17 changed files with 724 additions and 362 deletions

View File

@@ -685,7 +685,6 @@ class AutoML(BaseEstimator):
fit_kwargs_by_estimator = {
"transformer": {
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}
@@ -1626,7 +1625,6 @@ class AutoML(BaseEstimator):
fit_kwargs_by_estimator = {
"transformer": {
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}
@@ -2298,7 +2296,6 @@ class AutoML(BaseEstimator):
fit_kwargs_by_estimator = {
"transformer": {
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}

View File

@@ -417,32 +417,18 @@ class TransformersEstimator(BaseEstimator):
def search_space(cls, data_size, task, **params):
search_space_dict = {
"learning_rate": {
"domain": tune.choice(
[1e-6, 2e-6, 4e-6, 8e-6, 16e-6, 32e-6, 64e-6, 128e-6]
),
"init_value": 8e-6,
"domain": tune.loguniform(1e-6, 1e-4),
"init_value": 1e-5,
},
"num_train_epochs": {
"domain": tune.choice([1, 3, 5, 7, 9]),
"domain": tune.choice([1, 2, 3, 4, 5]),
"init_value": 3.0, # to be consistent with roberta
},
"per_device_train_batch_size": {
"domain": tune.choice([4, 8, 16, 32]),
"domain": tune.choice([4, 8, 16, 32, 64]),
"init_value": 32,
},
"warmup_ratio": {
"domain": tune.choice([0, 0.1, 0.2, 0.3]),
"init_value": 0.0,
},
"weight_decay": {
"domain": tune.choice([0, 0.1, 0.2, 0.3]),
"init_value": 0.0,
},
"adam_epsilon": {
"domain": tune.choice([1e-8, 1e-7, 1e-6]),
"init_value": 1e-6,
},
"seed": {"domain": tune.randint(40, 45), "init_value": 42},
"seed": {"domain": tune.randint(1, 40), "init_value": 20},
"global_max_steps": {
"domain": sys.maxsize,
"init_value": sys.maxsize,
@@ -451,18 +437,6 @@ class TransformersEstimator(BaseEstimator):
return search_space_dict
@property
def checkpoint_freq(self):
return (
int(
min(self._training_args.num_train_epochs, 1)
* len(self._X_train)
/ self._training_args.per_device_train_batch_size
/ self._training_args.ckpt_per_epoch
)
+ 1
)
@property
def fp16(self):
return self._kwargs.get("gpu_per_trial") and self._training_args.fp16
@@ -512,9 +486,6 @@ class TransformersEstimator(BaseEstimator):
local_dir, self.params, self.trial_id
)
self._training_args.eval_steps = (
self._training_args.logging_steps
) = self._training_args.saving_steps = self.checkpoint_freq
self._training_args.fp16 = self.fp16
self._training_args.no_cuda = self.no_cuda
@@ -762,7 +733,7 @@ class TransformersEstimator(BaseEstimator):
if trainer.ckpt_to_metric:
best_ckpt, _ = min(
trainer.ckpt_to_metric.items(), key=lambda x: x[1]["eval_loss"]
trainer.ckpt_to_metric.items(), key=lambda x: x[1]["eval_automl_metric"]
)
best_ckpt_global_step = trainer.ckpt_to_global_step[best_ckpt]
for each_ckpt in list(trainer.ckpt_to_metric):

View File

@@ -19,9 +19,9 @@ class TrainerForAuto(Seq2SeqTrainer):
return super().predict(
test_dataset,
ignore_keys,
metric_key_prefix,
max_length,
num_beams,
metric_key_prefix=metric_key_prefix,
max_length=max_length,
num_beams=num_beams,
)
else:
return super(Seq2SeqTrainer, self).predict(

View File

@@ -28,7 +28,6 @@ class TrainingArgumentsForAuto(TrainingArguments):
pad_to_max_length (bool, optional, defaults to "False"):
whether to pad all samples to model maximum sentence length.
If False, will pad the samples dynamically when batching to the maximum length in the batch.
ckpt_per_epoch (int, optional, defaults to 1): An integer, the number of checkpoints per epoch.
per_device_eval_batch_size (int, optional, defaults to 1): An integer, the per gpu evaluation batch size.
label_list (List[str], optional, defaults to None): A list of string, the string list of the label names.
When the task is sequence labeling/token classification, there are two formats of the labels:
@@ -67,8 +66,6 @@ class TrainingArgumentsForAuto(TrainingArguments):
},
)
ckpt_per_epoch: int = field(default=1, metadata={"help": "checkpoint per epoch"})
per_device_eval_batch_size: int = field(
default=1,
metadata={"help": "per gpu evaluation batch size"},
@@ -78,6 +75,18 @@ class TrainingArgumentsForAuto(TrainingArguments):
default=None, metadata={"help": "The string list of the label names. "}
)
eval_steps: int = field(
default=500, metadata={"help": "Run an evaluation every X steps."}
)
save_steps: int = field(
default=500, metadata={"help": "Save checkpoint every X updates steps."}
)
logging_steps: int = field(
default=500, metadata={"help": "Log every X updates steps."}
)
@staticmethod
def load_args_from_console():
from dataclasses import fields

File diff suppressed because one or more lines are too long

View File

@@ -171,7 +171,6 @@
" \"resources_per_trial\": {\"gpu\": 1, \"cpu\": 1},\n",
" \"num_samples\": 1,\n",
" \"time_budget\": 100000, # unlimited time budget\n",
" \"ckpt_per_epoch\": 5,\n",
" \"fp16\": True,\n",
" \"algo_mode\": \"grid\", # set the search algorithm to grid search\n",
" \"space_mode\": \"grid\", # set the search space to the recommended grid space\n",
@@ -326,7 +325,6 @@
" \"resources_per_trial\": {\"gpu\": 1, \"cpu\": 1},\n",
" \"num_samples\": -1,\n",
" \"time_budget\": time_budget,\n",
" \"ckpt_per_epoch\": 5,\n",
" \"fp16\": True,\n",
" \"algo_mode\": \"hpo\", # set the search algorithm mode to hpo\n",
" \"algo_name\": \"rs\",\n",
@@ -364,10 +362,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(pid=50964)\u001b[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001b[2m\u001b[36m(pid=50964)\u001b[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001b[2m\u001b[36m(pid=50948)\u001b[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
"\u001b[2m\u001b[36m(pid=50948)\u001b[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
]
},
{
@@ -485,12 +483,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
]
},
{
@@ -590,18 +588,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
]
},
{
@@ -701,21 +699,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
]
},
{
@@ -806,4 +804,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}

View File

@@ -24,7 +24,6 @@ def test_custom_hp_nlp():
automl_settings["fit_kwargs_by_estimator"] = {
"transformer": {
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}

View File

@@ -1,5 +1,5 @@
{"class": "transformer_ms",
"hyperparameters": {"learning_rate": 1e-5, "num_train_epochs": 1.0, "per_device_train_batch_size": 8,
"warmup_ratio": 0.0, "weight_decay": 0.0, "adam_epsilon": 1e-6, "seed": 44, "global_max_steps": 101,
"seed": 44, "global_max_steps": 101,
"model_path": "google/electra-base-discriminator"}
}

View File

@@ -1,5 +1,5 @@
{"class": "transformer_ms",
"hyperparameters": {"learning_rate": 1e-5, "num_train_epochs": 1.0, "per_device_train_batch_size": 8,
"warmup_ratio": 0.0, "weight_decay": 0.0, "adam_epsilon": 1e-6, "seed": 43, "global_max_steps": 100,
"seed": 43, "global_max_steps": 100,
"model_path": "google/electra-base-discriminator"}
}

View File

@@ -1,5 +1,5 @@
{"class": "transformer_ms",
"hyperparameters": {"learning_rate": 1e-5, "num_train_epochs": 1.0, "per_device_train_batch_size": 8,
"warmup_ratio": 0.0, "weight_decay": 0.0, "adam_epsilon": 1e-6, "seed": 41, "global_max_steps": 102,
"seed": 41, "global_max_steps": 102,
"model_path": "google/electra-base-discriminator" }
}

View File

@@ -1,5 +1,5 @@
{"class": "transformer_ms",
"hyperparameters": {"learning_rate": 1e-5, "num_train_epochs": 1.0, "per_device_train_batch_size": 8,
"warmup_ratio": 0.0, "weight_decay": 0.0, "adam_epsilon": 1e-6, "seed": 42, "global_max_steps": 103,
"seed": 42, "global_max_steps": 103,
"model_path": "google/electra-base-discriminator" }
}

View File

@@ -1,5 +1,5 @@
{"class": "transformer_ms",
"hyperparameters": {"learning_rate": 1e-5, "num_train_epochs": 1.0, "per_device_train_batch_size": 8,
"warmup_ratio": 0.0, "weight_decay": 0.0, "adam_epsilon": 1e-6, "seed": 40, "global_max_steps": 105,
"seed": 40, "global_max_steps": 105,
"model_path": "google/electra-base-discriminator"}
}

View File

@@ -27,6 +27,21 @@ def test_hf_data():
except requests.exceptions.HTTPError:
return
import json
with open("seqclass.log", "r") as fin:
for line in fin:
each_log = json.loads(line.strip("\n"))
if "validation_loss" in each_log:
val_loss = each_log["validation_loss"]
min_inter_result = min(
each_dict.get("eval_automl_metric", sys.maxsize)
for each_dict in each_log["logged_metric"]["intermediate_results"]
)
if min_inter_result != sys.maxsize:
assert val_loss == min_inter_result
automl = AutoML()
automl_settings.pop("max_iter", None)

View File

@@ -6,7 +6,7 @@ from utils import get_toy_data_summarization, get_automl_settings
@pytest.mark.skipif(
sys.platform == "darwin" or sys.version < "3.7",
reason="do not run on mac os or py < 3.7",
reason="do not run on mac os or py3.6",
)
def test_summarization():
# TODO: manual test for how effective postprocess_seq2seq_prediction_label is

View File

@@ -1518,7 +1518,6 @@ def get_automl_settings(estimator_name="transformer"):
automl_settings["fit_kwargs_by_estimator"] = {
estimator_name: {
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}
@@ -1527,7 +1526,6 @@ def get_automl_settings(estimator_name="transformer"):
estimator_name: {
"model_path": "google/electra-small-discriminator",
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}

View File

@@ -85,7 +85,6 @@ def _test_hf_data():
"transformer": {
"model_path": "facebook/muppet-roberta-base",
"output_dir": "test/data/output/",
"ckpt_per_epoch": 5,
"fp16": True,
}
}

View File

@@ -86,7 +86,6 @@ automl_settings["fit_kwargs_by_estimator"] = { # setting the huggingface argume
"transformer": {
"model_path": "google/electra-small-discriminator", # if model_path is not set, the default model is facebook/muppet-roberta-base: https://huggingface.co/facebook/muppet-roberta-base
"output_dir": "data/output/", # setting the output directory
"ckpt_per_epoch": 5, # setting the number of checkpoints per epoch
"fp16": False,
} # setting whether to use FP16
}
@@ -138,7 +137,6 @@ automl_settings["fit_kwargs_by_estimator"] = { # setting the huggingface ar
"transformer": {
"model_path": "t5-small", # if model_path is not set, the default model is t5-small: https://huggingface.co/t5-small
"output_dir": "data/output/", # setting the output directory
"ckpt_per_epoch": 5, # setting the number of checkpoints per epoch
"fp16": False,
} # setting whether to use FP16
}
@@ -366,6 +364,6 @@ For tasks that are not currently supported, use `flaml.tune` for [customized tun
### Link to Jupyter notebook
To run these examples in our Jupyter notebook, please go to:
To run more examples, especially examples using Ray Tune, please go to:
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_nlp.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_nlp.ipynb)