diff --git a/flaml/model.py b/flaml/model.py index acd42ce6c..71896f11a 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -396,12 +396,14 @@ class TransformersEstimator(BaseEstimator): ) def fit(self, X_train: DataFrame, y_train: Series, budget=None, **kwargs): + import transformers + + transformers.logging.set_verbosity_error() + from transformers import EarlyStoppingCallback from transformers.trainer_utils import set_seed from transformers import AutoTokenizer - from transformers.data import DataCollatorWithPadding - import transformers from datasets import Dataset from .nlp.utils import ( get_num_labels,