Support pickling the whole AutoML instance, Sync Fabric till 0d4ab16f (#1481)

This commit is contained in:
Li Jiang
2026-01-12 23:04:38 +08:00
committed by GitHub
parent bb213e7ebd
commit ced1d6f331
10 changed files with 452 additions and 38 deletions

View File

@@ -413,13 +413,47 @@ class AutoML(BaseEstimator):
"""
state = self.__dict__.copy()
state.pop("mlflow_integration", None)
# Keep mlflow_integration for post-load visualization (e.g., infos), but
# strip non-picklable runtime-only members (thread futures, clients).
mlflow_integration = state.get("mlflow_integration", None)
if mlflow_integration is not None:
import copy
mi = copy.copy(mlflow_integration)
# These are runtime-only and often contain locks / threads.
if hasattr(mi, "futures"):
mi.futures = {}
if hasattr(mi, "futures_log_model"):
mi.futures_log_model = {}
if hasattr(mi, "train_func"):
mi.train_func = None
if hasattr(mi, "mlflow_client"):
mi.mlflow_client = None
state["mlflow_integration"] = mi
# MLflow signature objects may hold references to Spark/pandas-on-Spark
# inputs and can indirectly capture SparkContext, which is not picklable.
state.pop("estimator_signature", None)
state.pop("pipeline_signature", None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
# Ensure attribute exists post-unpickle.
self.mlflow_integration = None
# Ensure mlflow_integration runtime members exist post-unpickle.
mi = getattr(self, "mlflow_integration", None)
if mi is not None:
if not hasattr(mi, "futures") or mi.futures is None:
mi.futures = {}
if not hasattr(mi, "futures_log_model") or mi.futures_log_model is None:
mi.futures_log_model = {}
if not hasattr(mi, "train_func"):
mi.train_func = None
if not hasattr(mi, "mlflow_client") or mi.mlflow_client is None:
try:
import mlflow as _mlflow
mi.mlflow_client = _mlflow.tracking.MlflowClient()
except Exception:
mi.mlflow_client = None
def get_params(self, deep: bool = False) -> dict:
return self._settings.copy()
@@ -1114,17 +1148,344 @@ class AutoML(BaseEstimator):
return self._state.data_size[0] if self._sample else None
def pickle(self, output_file_name):
"""Serialize the AutoML instance to a pickle file.
Notes:
When the trained estimator(s) are Spark-based, they may hold references
to SparkContext/SparkSession via Spark ML objects. Such objects are not
safely picklable and can cause pickling/broadcast errors.
This method externalizes Spark ML models into an adjacent artifact
directory and stores only lightweight metadata in the pickle.
"""
import os
import pickle
import re
def _safe_name(name: str) -> str:
return re.sub(r"[^A-Za-z0-9_.-]+", "_", name)
def _iter_trained_estimators():
trained = getattr(self, "_trained_estimator", None)
if trained is not None:
yield "_trained_estimator", trained
for est_name in getattr(self, "estimator_list", []) or []:
ss = getattr(self, "_search_states", {}).get(est_name)
te = ss and getattr(ss, "trained_estimator", None)
if te is not None:
yield f"_search_states.{est_name}.trained_estimator", te
def _scrub_pyspark_refs(root_obj):
"""Best-effort removal of pyspark objects prior to pickling.
SparkContext/SparkSession and Spark DataFrame objects are not picklable.
This function finds such objects within common containers and instance
attributes and replaces them with None, returning a restore mapping.
"""
try:
import pyspark
from pyspark.broadcast import Broadcast
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
try:
import pyspark.pandas as ps
psDataFrameType = getattr(ps, "DataFrame", None)
psSeriesType = getattr(ps, "Series", None)
except Exception:
psDataFrameType = None
psSeriesType = None
bad_types = [
pyspark.SparkContext,
SparkSession,
SparkDataFrame,
Broadcast,
]
if psDataFrameType is not None:
bad_types.append(psDataFrameType)
if psSeriesType is not None:
bad_types.append(psSeriesType)
bad_types = tuple(t for t in bad_types if t is not None)
except Exception:
return {}
restore = {}
visited = set()
def _mark(parent, key, value, path):
restore[(id(parent), key)] = (parent, key, value)
try:
if isinstance(parent, dict):
parent[key] = None
elif isinstance(parent, list):
parent[key] = None
elif isinstance(parent, tuple):
# tuples are immutable; we can't modify in-place
pass
else:
setattr(parent, key, None)
except Exception:
# Best-effort.
pass
def _walk(obj, depth, parent=None, key=None, path="self"):
if obj is None:
return
oid = id(obj)
if oid in visited:
return
visited.add(oid)
if isinstance(obj, bad_types):
if parent is not None:
_mark(parent, key, obj, path)
return
if depth <= 0:
return
if isinstance(obj, dict):
for k, v in list(obj.items()):
_walk(v, depth - 1, parent=obj, key=k, path=f"{path}[{k!r}]")
return
if isinstance(obj, list):
for i, v in enumerate(list(obj)):
_walk(v, depth - 1, parent=obj, key=i, path=f"{path}[{i}]")
return
if isinstance(obj, tuple):
# Can't scrub inside tuples safely; but still inspect for diagnostics.
for i, v in enumerate(obj):
_walk(v, depth - 1, parent=None, key=None, path=f"{path}[{i}]")
return
if isinstance(obj, set):
for v in list(obj):
_walk(v, depth - 1, parent=None, key=None, path=f"{path}{{...}}")
return
d = getattr(obj, "__dict__", None)
if isinstance(d, dict):
for attr, v in list(d.items()):
_walk(v, depth - 1, parent=obj, key=attr, path=f"{path}.{attr}")
_walk(root_obj, depth=6)
return restore
# Temporarily remove non-picklable pieces (e.g., SparkContext-backed objects)
# and externalize spark models.
estimator_to_training_function = {}
spark_restore = []
artifact_dir = None
state_restore = {}
automl_restore = {}
scrub_restore = {}
try:
# Signatures are only used for MLflow logging; they are not required
# for inference and can capture SparkContext via pyspark objects.
for attr in ("estimator_signature", "pipeline_signature"):
if hasattr(self, attr):
automl_restore[attr] = getattr(self, attr)
setattr(self, attr, None)
for estimator in self.estimator_list:
search_state = self._search_states[estimator]
if hasattr(search_state, "training_function"):
estimator_to_training_function[estimator] = search_state.training_function
del search_state.training_function
# AutoMLState may keep Spark / pandas-on-Spark dataframes which are not picklable.
# They are not required for inference, so strip them for serialization.
state = getattr(self, "_state", None)
if state is not None:
for attr in (
"X_train",
"y_train",
"X_train_all",
"y_train_all",
"X_val",
"y_val",
"weight_val",
"groups_val",
"sample_weight_all",
"groups",
"groups_all",
"kf",
):
if hasattr(state, attr):
state_restore[attr] = getattr(state, attr)
setattr(state, attr, None)
for key, est in _iter_trained_estimators():
if getattr(est, "estimator_baseclass", None) != "spark":
continue
# Drop training data reference (Spark DataFrame / pandas-on-Spark).
old_df_train = getattr(est, "df_train", None)
old_model = getattr(est, "_model", None)
model_meta = None
if old_model is not None:
if artifact_dir is None:
artifact_dir = output_file_name + ".flaml_artifacts"
os.makedirs(artifact_dir, exist_ok=True)
# store relative dirname so the pickle+folder can be moved together
self._flaml_pickle_artifacts_dirname = os.path.basename(artifact_dir)
model_dir = os.path.join(artifact_dir, _safe_name(key))
# Spark ML models are saved as directories.
try:
writer = old_model.write()
writer.overwrite().save(model_dir)
except Exception as e:
raise RuntimeError(
"Failed to externalize Spark model for pickling. "
"Please ensure the Spark ML model supports write().overwrite().save(path)."
) from e
model_meta = {
"path": os.path.relpath(model_dir, os.path.dirname(output_file_name) or "."),
"class": old_model.__class__.__module__ + "." + old_model.__class__.__name__,
}
# Replace in-memory Spark model with metadata only.
est._model = None
est._flaml_spark_model_meta = model_meta
est.df_train = None
spark_restore.append((est, old_model, old_df_train, model_meta))
with open(output_file_name, "wb") as f:
try:
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
except Exception:
# Some pyspark objects can still be captured indirectly.
scrub_restore = _scrub_pyspark_refs(self)
if scrub_restore:
f.seek(0)
f.truncate()
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
else:
raise
finally:
# Restore training_function and Spark models so current object remains usable.
for estimator, tf in estimator_to_training_function.items():
self._search_states[estimator].training_function = tf
for attr, val in automl_restore.items():
setattr(self, attr, val)
state = getattr(self, "_state", None)
if state is not None and state_restore:
for attr, val in state_restore.items():
setattr(state, attr, val)
for est, old_model, old_df_train, model_meta in spark_restore:
est._model = old_model
est.df_train = old_df_train
if model_meta is not None and hasattr(est, "_flaml_spark_model_meta"):
delattr(est, "_flaml_spark_model_meta")
if scrub_restore:
for _, (parent, key, value) in scrub_restore.items():
try:
if isinstance(parent, dict):
parent[key] = value
elif isinstance(parent, list):
parent[key] = value
else:
setattr(parent, key, value)
except Exception:
pass
@classmethod
def load_pickle(cls, input_file_name: str, load_spark_models: bool = True):
"""Load an AutoML instance saved by :meth:`pickle`.
Args:
input_file_name: Path to the pickle file created by :meth:`pickle`.
load_spark_models: Whether to load externalized Spark ML models back
into the estimator objects. If False, Spark estimators will remain
without their underlying Spark model and cannot be used for predict.
Returns:
The deserialized AutoML instance.
"""
import importlib
import os
import pickle
estimator_to_training_function = {}
for estimator in self.estimator_list:
search_state = self._search_states[estimator]
if hasattr(search_state, "training_function"):
estimator_to_training_function[estimator] = search_state.training_function
del search_state.training_function
with open(input_file_name, "rb") as f:
automl = pickle.load(f)
with open(output_file_name, "wb") as f:
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
# Recreate per-estimator training_function if it was removed for pickling.
try:
for est_name, ss in getattr(automl, "_search_states", {}).items():
if not hasattr(ss, "training_function"):
ss.training_function = partial(
AutoMLState._compute_with_config_base,
state=automl._state,
estimator=est_name,
)
except Exception:
# Best-effort; training_function is only needed for re-searching.
pass
if not load_spark_models:
return automl
base_dir = os.path.dirname(input_file_name) or "."
def _iter_trained_estimators_loaded():
trained = getattr(automl, "_trained_estimator", None)
if trained is not None:
yield trained
for ss in getattr(automl, "_search_states", {}).values():
te = ss and getattr(ss, "trained_estimator", None)
if te is not None:
yield te
for est in _iter_trained_estimators_loaded():
meta = getattr(est, "_flaml_spark_model_meta", None)
if not meta:
continue
model_path = meta.get("path")
model_class = meta.get("class")
if not model_path or not model_class:
continue
abs_model_path = os.path.join(base_dir, model_path)
module_name, _, class_name = model_class.rpartition(".")
try:
module = importlib.import_module(module_name)
model_cls = getattr(module, class_name)
except Exception as e:
raise RuntimeError(f"Failed to import Spark model class '{model_class}'") from e
# Most Spark ML models support either Class.load(path) or Class.read().load(path).
if hasattr(model_cls, "load"):
est._model = model_cls.load(abs_model_path)
elif hasattr(model_cls, "read"):
est._model = model_cls.read().load(abs_model_path)
else:
try:
from pyspark.ml.pipeline import PipelineModel
loaded_model = PipelineModel.load(abs_model_path)
if not isinstance(loaded_model, model_cls):
raise RuntimeError(
f"Loaded model type '{type(loaded_model).__name__}' does not match expected type '{model_class}'."
)
est._model = loaded_model
except Exception as e:
raise RuntimeError(
f"Spark model class '{model_class}' does not support load/read(). "
"Unable to restore Spark model from artifacts."
) from e
return automl
@property
def trainable(self) -> Callable[[dict], float | None]:

View File

@@ -135,6 +135,7 @@ class BaseEstimator(sklearn.base.ClassifierMixin, sklearn.base.BaseEstimator):
self._task = task if isinstance(task, Task) else task_factory(task, None, None)
self.params = self.config2params(config)
self.estimator_class = self._model = None
self.estimator_baseclass = "sklearn"
if "_estimator_type" in self.params:
self._estimator_type = self.params.pop("_estimator_type")
else:
@@ -439,6 +440,7 @@ class SparkEstimator(BaseEstimator):
raise SPARK_ERROR
super().__init__(task, **config)
self.df_train = None
self.estimator_baseclass = "spark"
def _preprocess(
self,
@@ -974,7 +976,7 @@ class TransformersEstimator(BaseEstimator):
from .nlp.huggingface.utils import tokenize_text
from .nlp.utils import is_a_list_of_str
is_str = str(X.dtypes[0]) in ("string", "str")
is_str = str(X.dtypes.iloc[0]) in ("string", "str")
is_list_of_str = is_a_list_of_str(X[list(X.keys())[0]].to_list()[0])
if is_str or is_list_of_str:

View File

@@ -5,7 +5,7 @@ from typing import List, Optional
from flaml.automl.task.task import NLG_TASKS
try:
from transformers import TrainingArguments
from transformers import Seq2SeqTrainingArguments as TrainingArguments
except ImportError:
TrainingArguments = object

View File

@@ -396,7 +396,7 @@ def load_model(checkpoint_path, task, num_labels=None):
if task in (SEQCLASSIFICATION, SEQREGRESSION):
return AutoModelForSequenceClassification.from_pretrained(
checkpoint_path, config=model_config, ignore_mismatched_sizes=True
checkpoint_path, config=model_config, ignore_mismatched_sizes=True, trust_remote_code=True
)
elif task == TOKENCLASSIFICATION:
return AutoModelForTokenClassification.from_pretrained(checkpoint_path, config=model_config)

View File

@@ -151,7 +151,7 @@ class TimeSeriesTask(Task):
raise ValueError("Must supply either X_train_all and y_train_all, or dataframe and label")
try:
dataframe[self.time_col] = pd.to_datetime(dataframe[self.time_col])
dataframe.loc[:, self.time_col] = pd.to_datetime(dataframe[self.time_col])
except Exception:
raise ValueError(
f"For '{TS_FORECAST}' task, time column {self.time_col} must contain timestamp values."

View File

@@ -76,6 +76,8 @@ class SklearnWrapper:
self.pca = None
def fit(self, X: pd.DataFrame, y: pd.Series, **kwargs):
if "is_retrain" in kwargs:
kwargs.pop("is_retrain")
self._X = X
self._y = y
@@ -92,7 +94,14 @@ class SklearnWrapper:
for i, model in enumerate(self.models):
offset = i + self.lags
model.fit(X_trans[: len(X) - offset], y[offset:], **fit_params)
if len(X) - offset > 2:
# series with length 2 will meet All features are either constant or ignored.
# TODO: see why the non-constant features are ignored. Selector?
model.fit(X_trans[: len(X) - offset], y[offset:], **fit_params)
elif len(X) > offset and "catboost" not in str(model).lower():
model.fit(X_trans[: len(X) - offset], y[offset:], **fit_params)
else:
print("[INFO]: Length of data should longer than period + lags.")
return self
def predict(self, X, X_train=None, y_train=None):

View File

@@ -121,7 +121,12 @@ class TimeSeriesDataset:
@property
def X_all(self) -> pd.DataFrame:
return pd.concat([self.X_train, self.X_val], axis=0)
# Remove empty or all-NA columns before concatenation
X_train_filtered = self.X_train.dropna(axis=1, how="all")
X_val_filtered = self.X_val.dropna(axis=1, how="all")
# Concatenate the filtered DataFrames
return pd.concat([X_train_filtered, X_val_filtered], axis=0)
@property
def y_train(self) -> pd.DataFrame:
@@ -472,7 +477,7 @@ class DataTransformerTS:
if "__NAN__" not in X[col].cat.categories:
X[col] = X[col].cat.add_categories("__NAN__").fillna("__NAN__")
else:
X[col] = X[col].fillna("__NAN__")
X[col] = X[col].fillna("__NAN__").infer_objects(copy=False)
X[col] = X[col].astype("category")
for column in self.num_columns:

View File

@@ -130,7 +130,7 @@ class TestRegression(unittest.TestCase):
)
automl.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **settings)
def test_parallel(self, hpo_method=None):
def test_parallel_and_pickle(self, hpo_method=None):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 10,
@@ -153,6 +153,18 @@ class TestRegression(unittest.TestCase):
except ImportError:
return
# test pickle and load_pickle, should work for prediction
automl_experiment.pickle("automl_xgboost_spark.pkl")
automl_loaded = AutoML().load_pickle("automl_xgboost_spark.pkl")
assert automl_loaded.best_estimator == automl_experiment.best_estimator
assert automl_loaded.best_loss == automl_experiment.best_loss
automl_loaded.predict(X_train)
import shutil
shutil.rmtree("automl_xgboost_spark.pkl", ignore_errors=True)
shutil.rmtree("automl_xgboost_spark.pkl.flaml_artifacts", ignore_errors=True)
def test_sparse_matrix_regression_holdout(self):
X_train = scipy.sparse.random(8, 100)
y_train = np.random.uniform(size=8)

View File

@@ -165,7 +165,7 @@ def test_spark_synapseml_rank():
_test_spark_synapseml_lightgbm(spark, "rank")
def test_spark_input_df():
def test_spark_input_df_and_pickle():
import pandas as pd
file_url = "https://mmlspark.blob.core.windows.net/publicwasb/company_bankruptcy_prediction_data.csv"
@@ -201,6 +201,19 @@ def test_spark_input_df():
**settings,
)
# test pickle and load_pickle, should work for prediction
automl.pickle("automl_spark.pkl")
automl_loaded = AutoML().load_pickle("automl_spark.pkl")
assert automl_loaded.best_estimator == automl.best_estimator
assert automl_loaded.best_loss == automl.best_loss
automl_loaded.predict(df)
automl_loaded.model.estimator.transform(test_data)
import shutil
shutil.rmtree("automl_spark.pkl", ignore_errors=True)
shutil.rmtree("automl_spark.pkl.flaml_artifacts", ignore_errors=True)
if estimator_list == ["rf_spark"]:
return
@@ -393,13 +406,13 @@ def test_auto_convert_dtypes_spark():
if __name__ == "__main__":
test_spark_synapseml_classification()
test_spark_synapseml_regression()
test_spark_synapseml_rank()
test_spark_input_df()
test_get_random_dataframe()
test_auto_convert_dtypes_pandas()
test_auto_convert_dtypes_spark()
# test_spark_synapseml_classification()
# test_spark_synapseml_regression()
# test_spark_synapseml_rank()
test_spark_input_df_and_pickle()
# test_get_random_dataframe()
# test_auto_convert_dtypes_pandas()
# test_auto_convert_dtypes_spark()
# import cProfile
# import pstats

View File

@@ -28,10 +28,10 @@ skip_spark = not spark_available
pytestmark = [pytest.mark.skipif(skip_spark, reason="Spark is not installed. Skip all spark tests."), pytest.mark.spark]
def test_parallel_xgboost(hpo_method=None, data_size=1000):
def test_parallel_xgboost_and_pickle(hpo_method=None, data_size=1000):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 10,
"time_budget": 30,
"metric": "ap",
"task": "classification",
"log_file_name": "test/sparse_classification.log",
@@ -53,15 +53,27 @@ def test_parallel_xgboost(hpo_method=None, data_size=1000):
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
# test pickle and load_pickle, should work for prediction
automl_experiment.pickle("automl_xgboost_spark.pkl")
automl_loaded = AutoML().load_pickle("automl_xgboost_spark.pkl")
assert automl_loaded.best_estimator == automl_experiment.best_estimator
assert automl_loaded.best_loss == automl_experiment.best_loss
automl_loaded.predict(X_train)
import shutil
shutil.rmtree("automl_xgboost_spark.pkl", ignore_errors=True)
shutil.rmtree("automl_xgboost_spark.pkl.flaml_artifacts", ignore_errors=True)
def test_parallel_xgboost_others():
# use random search as the hpo_method
test_parallel_xgboost(hpo_method="random")
test_parallel_xgboost_and_pickle(hpo_method="random")
@pytest.mark.skip(reason="currently not supporting too large data, will support spark dataframe in the future")
def test_large_dataset():
test_parallel_xgboost(data_size=90000000)
test_parallel_xgboost_and_pickle(data_size=90000000)
@pytest.mark.skipif(
@@ -95,10 +107,10 @@ def test_custom_learner(data_size=1000):
if __name__ == "__main__":
test_parallel_xgboost()
test_parallel_xgboost_others()
# test_large_dataset()
if skip_my_learner:
print("please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file")
else:
test_custom_learner()
test_parallel_xgboost_and_pickle()
# test_parallel_xgboost_others()
# # test_large_dataset()
# if skip_my_learner:
# print("please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file")
# else:
# test_custom_learner()