mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
Support pickling the whole AutoML instance, Sync Fabric till 0d4ab16f (#1481)
This commit is contained in:
@@ -413,13 +413,47 @@ class AutoML(BaseEstimator):
|
||||
"""
|
||||
|
||||
state = self.__dict__.copy()
|
||||
state.pop("mlflow_integration", None)
|
||||
# Keep mlflow_integration for post-load visualization (e.g., infos), but
|
||||
# strip non-picklable runtime-only members (thread futures, clients).
|
||||
mlflow_integration = state.get("mlflow_integration", None)
|
||||
if mlflow_integration is not None:
|
||||
import copy
|
||||
|
||||
mi = copy.copy(mlflow_integration)
|
||||
# These are runtime-only and often contain locks / threads.
|
||||
if hasattr(mi, "futures"):
|
||||
mi.futures = {}
|
||||
if hasattr(mi, "futures_log_model"):
|
||||
mi.futures_log_model = {}
|
||||
if hasattr(mi, "train_func"):
|
||||
mi.train_func = None
|
||||
if hasattr(mi, "mlflow_client"):
|
||||
mi.mlflow_client = None
|
||||
state["mlflow_integration"] = mi
|
||||
# MLflow signature objects may hold references to Spark/pandas-on-Spark
|
||||
# inputs and can indirectly capture SparkContext, which is not picklable.
|
||||
state.pop("estimator_signature", None)
|
||||
state.pop("pipeline_signature", None)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
# Ensure attribute exists post-unpickle.
|
||||
self.mlflow_integration = None
|
||||
# Ensure mlflow_integration runtime members exist post-unpickle.
|
||||
mi = getattr(self, "mlflow_integration", None)
|
||||
if mi is not None:
|
||||
if not hasattr(mi, "futures") or mi.futures is None:
|
||||
mi.futures = {}
|
||||
if not hasattr(mi, "futures_log_model") or mi.futures_log_model is None:
|
||||
mi.futures_log_model = {}
|
||||
if not hasattr(mi, "train_func"):
|
||||
mi.train_func = None
|
||||
if not hasattr(mi, "mlflow_client") or mi.mlflow_client is None:
|
||||
try:
|
||||
import mlflow as _mlflow
|
||||
|
||||
mi.mlflow_client = _mlflow.tracking.MlflowClient()
|
||||
except Exception:
|
||||
mi.mlflow_client = None
|
||||
|
||||
def get_params(self, deep: bool = False) -> dict:
|
||||
return self._settings.copy()
|
||||
@@ -1114,17 +1148,344 @@ class AutoML(BaseEstimator):
|
||||
return self._state.data_size[0] if self._sample else None
|
||||
|
||||
def pickle(self, output_file_name):
|
||||
"""Serialize the AutoML instance to a pickle file.
|
||||
|
||||
Notes:
|
||||
When the trained estimator(s) are Spark-based, they may hold references
|
||||
to SparkContext/SparkSession via Spark ML objects. Such objects are not
|
||||
safely picklable and can cause pickling/broadcast errors.
|
||||
|
||||
This method externalizes Spark ML models into an adjacent artifact
|
||||
directory and stores only lightweight metadata in the pickle.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
|
||||
def _safe_name(name: str) -> str:
|
||||
return re.sub(r"[^A-Za-z0-9_.-]+", "_", name)
|
||||
|
||||
def _iter_trained_estimators():
|
||||
trained = getattr(self, "_trained_estimator", None)
|
||||
if trained is not None:
|
||||
yield "_trained_estimator", trained
|
||||
for est_name in getattr(self, "estimator_list", []) or []:
|
||||
ss = getattr(self, "_search_states", {}).get(est_name)
|
||||
te = ss and getattr(ss, "trained_estimator", None)
|
||||
if te is not None:
|
||||
yield f"_search_states.{est_name}.trained_estimator", te
|
||||
|
||||
def _scrub_pyspark_refs(root_obj):
|
||||
"""Best-effort removal of pyspark objects prior to pickling.
|
||||
|
||||
SparkContext/SparkSession and Spark DataFrame objects are not picklable.
|
||||
This function finds such objects within common containers and instance
|
||||
attributes and replaces them with None, returning a restore mapping.
|
||||
"""
|
||||
|
||||
try:
|
||||
import pyspark
|
||||
from pyspark.broadcast import Broadcast
|
||||
from pyspark.sql import DataFrame as SparkDataFrame
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
try:
|
||||
import pyspark.pandas as ps
|
||||
|
||||
psDataFrameType = getattr(ps, "DataFrame", None)
|
||||
psSeriesType = getattr(ps, "Series", None)
|
||||
except Exception:
|
||||
psDataFrameType = None
|
||||
psSeriesType = None
|
||||
|
||||
bad_types = [
|
||||
pyspark.SparkContext,
|
||||
SparkSession,
|
||||
SparkDataFrame,
|
||||
Broadcast,
|
||||
]
|
||||
if psDataFrameType is not None:
|
||||
bad_types.append(psDataFrameType)
|
||||
if psSeriesType is not None:
|
||||
bad_types.append(psSeriesType)
|
||||
bad_types = tuple(t for t in bad_types if t is not None)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
restore = {}
|
||||
visited = set()
|
||||
|
||||
def _mark(parent, key, value, path):
|
||||
restore[(id(parent), key)] = (parent, key, value)
|
||||
try:
|
||||
if isinstance(parent, dict):
|
||||
parent[key] = None
|
||||
elif isinstance(parent, list):
|
||||
parent[key] = None
|
||||
elif isinstance(parent, tuple):
|
||||
# tuples are immutable; we can't modify in-place
|
||||
pass
|
||||
else:
|
||||
setattr(parent, key, None)
|
||||
except Exception:
|
||||
# Best-effort.
|
||||
pass
|
||||
|
||||
def _walk(obj, depth, parent=None, key=None, path="self"):
|
||||
if obj is None:
|
||||
return
|
||||
oid = id(obj)
|
||||
if oid in visited:
|
||||
return
|
||||
visited.add(oid)
|
||||
|
||||
if isinstance(obj, bad_types):
|
||||
if parent is not None:
|
||||
_mark(parent, key, obj, path)
|
||||
return
|
||||
if depth <= 0:
|
||||
return
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for k, v in list(obj.items()):
|
||||
_walk(v, depth - 1, parent=obj, key=k, path=f"{path}[{k!r}]")
|
||||
return
|
||||
if isinstance(obj, list):
|
||||
for i, v in enumerate(list(obj)):
|
||||
_walk(v, depth - 1, parent=obj, key=i, path=f"{path}[{i}]")
|
||||
return
|
||||
if isinstance(obj, tuple):
|
||||
# Can't scrub inside tuples safely; but still inspect for diagnostics.
|
||||
for i, v in enumerate(obj):
|
||||
_walk(v, depth - 1, parent=None, key=None, path=f"{path}[{i}]")
|
||||
return
|
||||
if isinstance(obj, set):
|
||||
for v in list(obj):
|
||||
_walk(v, depth - 1, parent=None, key=None, path=f"{path}{{...}}")
|
||||
return
|
||||
|
||||
d = getattr(obj, "__dict__", None)
|
||||
if isinstance(d, dict):
|
||||
for attr, v in list(d.items()):
|
||||
_walk(v, depth - 1, parent=obj, key=attr, path=f"{path}.{attr}")
|
||||
|
||||
_walk(root_obj, depth=6)
|
||||
return restore
|
||||
|
||||
# Temporarily remove non-picklable pieces (e.g., SparkContext-backed objects)
|
||||
# and externalize spark models.
|
||||
estimator_to_training_function = {}
|
||||
spark_restore = []
|
||||
artifact_dir = None
|
||||
state_restore = {}
|
||||
automl_restore = {}
|
||||
scrub_restore = {}
|
||||
|
||||
try:
|
||||
# Signatures are only used for MLflow logging; they are not required
|
||||
# for inference and can capture SparkContext via pyspark objects.
|
||||
for attr in ("estimator_signature", "pipeline_signature"):
|
||||
if hasattr(self, attr):
|
||||
automl_restore[attr] = getattr(self, attr)
|
||||
setattr(self, attr, None)
|
||||
|
||||
for estimator in self.estimator_list:
|
||||
search_state = self._search_states[estimator]
|
||||
if hasattr(search_state, "training_function"):
|
||||
estimator_to_training_function[estimator] = search_state.training_function
|
||||
del search_state.training_function
|
||||
|
||||
# AutoMLState may keep Spark / pandas-on-Spark dataframes which are not picklable.
|
||||
# They are not required for inference, so strip them for serialization.
|
||||
state = getattr(self, "_state", None)
|
||||
if state is not None:
|
||||
for attr in (
|
||||
"X_train",
|
||||
"y_train",
|
||||
"X_train_all",
|
||||
"y_train_all",
|
||||
"X_val",
|
||||
"y_val",
|
||||
"weight_val",
|
||||
"groups_val",
|
||||
"sample_weight_all",
|
||||
"groups",
|
||||
"groups_all",
|
||||
"kf",
|
||||
):
|
||||
if hasattr(state, attr):
|
||||
state_restore[attr] = getattr(state, attr)
|
||||
setattr(state, attr, None)
|
||||
|
||||
for key, est in _iter_trained_estimators():
|
||||
if getattr(est, "estimator_baseclass", None) != "spark":
|
||||
continue
|
||||
|
||||
# Drop training data reference (Spark DataFrame / pandas-on-Spark).
|
||||
old_df_train = getattr(est, "df_train", None)
|
||||
old_model = getattr(est, "_model", None)
|
||||
|
||||
model_meta = None
|
||||
if old_model is not None:
|
||||
if artifact_dir is None:
|
||||
artifact_dir = output_file_name + ".flaml_artifacts"
|
||||
os.makedirs(artifact_dir, exist_ok=True)
|
||||
# store relative dirname so the pickle+folder can be moved together
|
||||
self._flaml_pickle_artifacts_dirname = os.path.basename(artifact_dir)
|
||||
|
||||
model_dir = os.path.join(artifact_dir, _safe_name(key))
|
||||
# Spark ML models are saved as directories.
|
||||
try:
|
||||
writer = old_model.write()
|
||||
writer.overwrite().save(model_dir)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed to externalize Spark model for pickling. "
|
||||
"Please ensure the Spark ML model supports write().overwrite().save(path)."
|
||||
) from e
|
||||
|
||||
model_meta = {
|
||||
"path": os.path.relpath(model_dir, os.path.dirname(output_file_name) or "."),
|
||||
"class": old_model.__class__.__module__ + "." + old_model.__class__.__name__,
|
||||
}
|
||||
# Replace in-memory Spark model with metadata only.
|
||||
est._model = None
|
||||
est._flaml_spark_model_meta = model_meta
|
||||
|
||||
est.df_train = None
|
||||
spark_restore.append((est, old_model, old_df_train, model_meta))
|
||||
|
||||
with open(output_file_name, "wb") as f:
|
||||
try:
|
||||
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
|
||||
except Exception:
|
||||
# Some pyspark objects can still be captured indirectly.
|
||||
scrub_restore = _scrub_pyspark_refs(self)
|
||||
if scrub_restore:
|
||||
f.seek(0)
|
||||
f.truncate()
|
||||
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
# Restore training_function and Spark models so current object remains usable.
|
||||
for estimator, tf in estimator_to_training_function.items():
|
||||
self._search_states[estimator].training_function = tf
|
||||
|
||||
for attr, val in automl_restore.items():
|
||||
setattr(self, attr, val)
|
||||
|
||||
state = getattr(self, "_state", None)
|
||||
if state is not None and state_restore:
|
||||
for attr, val in state_restore.items():
|
||||
setattr(state, attr, val)
|
||||
|
||||
for est, old_model, old_df_train, model_meta in spark_restore:
|
||||
est._model = old_model
|
||||
est.df_train = old_df_train
|
||||
if model_meta is not None and hasattr(est, "_flaml_spark_model_meta"):
|
||||
delattr(est, "_flaml_spark_model_meta")
|
||||
|
||||
if scrub_restore:
|
||||
for _, (parent, key, value) in scrub_restore.items():
|
||||
try:
|
||||
if isinstance(parent, dict):
|
||||
parent[key] = value
|
||||
elif isinstance(parent, list):
|
||||
parent[key] = value
|
||||
else:
|
||||
setattr(parent, key, value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def load_pickle(cls, input_file_name: str, load_spark_models: bool = True):
|
||||
"""Load an AutoML instance saved by :meth:`pickle`.
|
||||
|
||||
Args:
|
||||
input_file_name: Path to the pickle file created by :meth:`pickle`.
|
||||
load_spark_models: Whether to load externalized Spark ML models back
|
||||
into the estimator objects. If False, Spark estimators will remain
|
||||
without their underlying Spark model and cannot be used for predict.
|
||||
|
||||
Returns:
|
||||
The deserialized AutoML instance.
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
import pickle
|
||||
|
||||
estimator_to_training_function = {}
|
||||
for estimator in self.estimator_list:
|
||||
search_state = self._search_states[estimator]
|
||||
if hasattr(search_state, "training_function"):
|
||||
estimator_to_training_function[estimator] = search_state.training_function
|
||||
del search_state.training_function
|
||||
with open(input_file_name, "rb") as f:
|
||||
automl = pickle.load(f)
|
||||
|
||||
with open(output_file_name, "wb") as f:
|
||||
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
|
||||
# Recreate per-estimator training_function if it was removed for pickling.
|
||||
try:
|
||||
for est_name, ss in getattr(automl, "_search_states", {}).items():
|
||||
if not hasattr(ss, "training_function"):
|
||||
ss.training_function = partial(
|
||||
AutoMLState._compute_with_config_base,
|
||||
state=automl._state,
|
||||
estimator=est_name,
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort; training_function is only needed for re-searching.
|
||||
pass
|
||||
|
||||
if not load_spark_models:
|
||||
return automl
|
||||
|
||||
base_dir = os.path.dirname(input_file_name) or "."
|
||||
|
||||
def _iter_trained_estimators_loaded():
|
||||
trained = getattr(automl, "_trained_estimator", None)
|
||||
if trained is not None:
|
||||
yield trained
|
||||
for ss in getattr(automl, "_search_states", {}).values():
|
||||
te = ss and getattr(ss, "trained_estimator", None)
|
||||
if te is not None:
|
||||
yield te
|
||||
|
||||
for est in _iter_trained_estimators_loaded():
|
||||
meta = getattr(est, "_flaml_spark_model_meta", None)
|
||||
if not meta:
|
||||
continue
|
||||
model_path = meta.get("path")
|
||||
model_class = meta.get("class")
|
||||
if not model_path or not model_class:
|
||||
continue
|
||||
|
||||
abs_model_path = os.path.join(base_dir, model_path)
|
||||
|
||||
module_name, _, class_name = model_class.rpartition(".")
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
model_cls = getattr(module, class_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to import Spark model class '{model_class}'") from e
|
||||
|
||||
# Most Spark ML models support either Class.load(path) or Class.read().load(path).
|
||||
if hasattr(model_cls, "load"):
|
||||
est._model = model_cls.load(abs_model_path)
|
||||
elif hasattr(model_cls, "read"):
|
||||
est._model = model_cls.read().load(abs_model_path)
|
||||
else:
|
||||
try:
|
||||
from pyspark.ml.pipeline import PipelineModel
|
||||
|
||||
loaded_model = PipelineModel.load(abs_model_path)
|
||||
if not isinstance(loaded_model, model_cls):
|
||||
raise RuntimeError(
|
||||
f"Loaded model type '{type(loaded_model).__name__}' does not match expected type '{model_class}'."
|
||||
)
|
||||
est._model = loaded_model
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Spark model class '{model_class}' does not support load/read(). "
|
||||
"Unable to restore Spark model from artifacts."
|
||||
) from e
|
||||
|
||||
return automl
|
||||
|
||||
@property
|
||||
def trainable(self) -> Callable[[dict], float | None]:
|
||||
|
||||
@@ -135,6 +135,7 @@ class BaseEstimator(sklearn.base.ClassifierMixin, sklearn.base.BaseEstimator):
|
||||
self._task = task if isinstance(task, Task) else task_factory(task, None, None)
|
||||
self.params = self.config2params(config)
|
||||
self.estimator_class = self._model = None
|
||||
self.estimator_baseclass = "sklearn"
|
||||
if "_estimator_type" in self.params:
|
||||
self._estimator_type = self.params.pop("_estimator_type")
|
||||
else:
|
||||
@@ -439,6 +440,7 @@ class SparkEstimator(BaseEstimator):
|
||||
raise SPARK_ERROR
|
||||
super().__init__(task, **config)
|
||||
self.df_train = None
|
||||
self.estimator_baseclass = "spark"
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
@@ -974,7 +976,7 @@ class TransformersEstimator(BaseEstimator):
|
||||
from .nlp.huggingface.utils import tokenize_text
|
||||
from .nlp.utils import is_a_list_of_str
|
||||
|
||||
is_str = str(X.dtypes[0]) in ("string", "str")
|
||||
is_str = str(X.dtypes.iloc[0]) in ("string", "str")
|
||||
is_list_of_str = is_a_list_of_str(X[list(X.keys())[0]].to_list()[0])
|
||||
|
||||
if is_str or is_list_of_str:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional
|
||||
from flaml.automl.task.task import NLG_TASKS
|
||||
|
||||
try:
|
||||
from transformers import TrainingArguments
|
||||
from transformers import Seq2SeqTrainingArguments as TrainingArguments
|
||||
except ImportError:
|
||||
TrainingArguments = object
|
||||
|
||||
|
||||
@@ -396,7 +396,7 @@ def load_model(checkpoint_path, task, num_labels=None):
|
||||
|
||||
if task in (SEQCLASSIFICATION, SEQREGRESSION):
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
checkpoint_path, config=model_config, ignore_mismatched_sizes=True
|
||||
checkpoint_path, config=model_config, ignore_mismatched_sizes=True, trust_remote_code=True
|
||||
)
|
||||
elif task == TOKENCLASSIFICATION:
|
||||
return AutoModelForTokenClassification.from_pretrained(checkpoint_path, config=model_config)
|
||||
|
||||
@@ -151,7 +151,7 @@ class TimeSeriesTask(Task):
|
||||
raise ValueError("Must supply either X_train_all and y_train_all, or dataframe and label")
|
||||
|
||||
try:
|
||||
dataframe[self.time_col] = pd.to_datetime(dataframe[self.time_col])
|
||||
dataframe.loc[:, self.time_col] = pd.to_datetime(dataframe[self.time_col])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"For '{TS_FORECAST}' task, time column {self.time_col} must contain timestamp values."
|
||||
|
||||
@@ -76,6 +76,8 @@ class SklearnWrapper:
|
||||
self.pca = None
|
||||
|
||||
def fit(self, X: pd.DataFrame, y: pd.Series, **kwargs):
|
||||
if "is_retrain" in kwargs:
|
||||
kwargs.pop("is_retrain")
|
||||
self._X = X
|
||||
self._y = y
|
||||
|
||||
@@ -92,7 +94,14 @@ class SklearnWrapper:
|
||||
|
||||
for i, model in enumerate(self.models):
|
||||
offset = i + self.lags
|
||||
model.fit(X_trans[: len(X) - offset], y[offset:], **fit_params)
|
||||
if len(X) - offset > 2:
|
||||
# series with length 2 will meet All features are either constant or ignored.
|
||||
# TODO: see why the non-constant features are ignored. Selector?
|
||||
model.fit(X_trans[: len(X) - offset], y[offset:], **fit_params)
|
||||
elif len(X) > offset and "catboost" not in str(model).lower():
|
||||
model.fit(X_trans[: len(X) - offset], y[offset:], **fit_params)
|
||||
else:
|
||||
print("[INFO]: Length of data should longer than period + lags.")
|
||||
return self
|
||||
|
||||
def predict(self, X, X_train=None, y_train=None):
|
||||
|
||||
@@ -121,7 +121,12 @@ class TimeSeriesDataset:
|
||||
|
||||
@property
|
||||
def X_all(self) -> pd.DataFrame:
|
||||
return pd.concat([self.X_train, self.X_val], axis=0)
|
||||
# Remove empty or all-NA columns before concatenation
|
||||
X_train_filtered = self.X_train.dropna(axis=1, how="all")
|
||||
X_val_filtered = self.X_val.dropna(axis=1, how="all")
|
||||
|
||||
# Concatenate the filtered DataFrames
|
||||
return pd.concat([X_train_filtered, X_val_filtered], axis=0)
|
||||
|
||||
@property
|
||||
def y_train(self) -> pd.DataFrame:
|
||||
@@ -472,7 +477,7 @@ class DataTransformerTS:
|
||||
if "__NAN__" not in X[col].cat.categories:
|
||||
X[col] = X[col].cat.add_categories("__NAN__").fillna("__NAN__")
|
||||
else:
|
||||
X[col] = X[col].fillna("__NAN__")
|
||||
X[col] = X[col].fillna("__NAN__").infer_objects(copy=False)
|
||||
X[col] = X[col].astype("category")
|
||||
|
||||
for column in self.num_columns:
|
||||
|
||||
@@ -130,7 +130,7 @@ class TestRegression(unittest.TestCase):
|
||||
)
|
||||
automl.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **settings)
|
||||
|
||||
def test_parallel(self, hpo_method=None):
|
||||
def test_parallel_and_pickle(self, hpo_method=None):
|
||||
automl_experiment = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 10,
|
||||
@@ -153,6 +153,18 @@ class TestRegression(unittest.TestCase):
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
# test pickle and load_pickle, should work for prediction
|
||||
automl_experiment.pickle("automl_xgboost_spark.pkl")
|
||||
automl_loaded = AutoML().load_pickle("automl_xgboost_spark.pkl")
|
||||
assert automl_loaded.best_estimator == automl_experiment.best_estimator
|
||||
assert automl_loaded.best_loss == automl_experiment.best_loss
|
||||
automl_loaded.predict(X_train)
|
||||
|
||||
import shutil
|
||||
|
||||
shutil.rmtree("automl_xgboost_spark.pkl", ignore_errors=True)
|
||||
shutil.rmtree("automl_xgboost_spark.pkl.flaml_artifacts", ignore_errors=True)
|
||||
|
||||
def test_sparse_matrix_regression_holdout(self):
|
||||
X_train = scipy.sparse.random(8, 100)
|
||||
y_train = np.random.uniform(size=8)
|
||||
|
||||
@@ -165,7 +165,7 @@ def test_spark_synapseml_rank():
|
||||
_test_spark_synapseml_lightgbm(spark, "rank")
|
||||
|
||||
|
||||
def test_spark_input_df():
|
||||
def test_spark_input_df_and_pickle():
|
||||
import pandas as pd
|
||||
|
||||
file_url = "https://mmlspark.blob.core.windows.net/publicwasb/company_bankruptcy_prediction_data.csv"
|
||||
@@ -201,6 +201,19 @@ def test_spark_input_df():
|
||||
**settings,
|
||||
)
|
||||
|
||||
# test pickle and load_pickle, should work for prediction
|
||||
automl.pickle("automl_spark.pkl")
|
||||
automl_loaded = AutoML().load_pickle("automl_spark.pkl")
|
||||
assert automl_loaded.best_estimator == automl.best_estimator
|
||||
assert automl_loaded.best_loss == automl.best_loss
|
||||
automl_loaded.predict(df)
|
||||
automl_loaded.model.estimator.transform(test_data)
|
||||
|
||||
import shutil
|
||||
|
||||
shutil.rmtree("automl_spark.pkl", ignore_errors=True)
|
||||
shutil.rmtree("automl_spark.pkl.flaml_artifacts", ignore_errors=True)
|
||||
|
||||
if estimator_list == ["rf_spark"]:
|
||||
return
|
||||
|
||||
@@ -393,13 +406,13 @@ def test_auto_convert_dtypes_spark():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_spark_synapseml_classification()
|
||||
test_spark_synapseml_regression()
|
||||
test_spark_synapseml_rank()
|
||||
test_spark_input_df()
|
||||
test_get_random_dataframe()
|
||||
test_auto_convert_dtypes_pandas()
|
||||
test_auto_convert_dtypes_spark()
|
||||
# test_spark_synapseml_classification()
|
||||
# test_spark_synapseml_regression()
|
||||
# test_spark_synapseml_rank()
|
||||
test_spark_input_df_and_pickle()
|
||||
# test_get_random_dataframe()
|
||||
# test_auto_convert_dtypes_pandas()
|
||||
# test_auto_convert_dtypes_spark()
|
||||
|
||||
# import cProfile
|
||||
# import pstats
|
||||
|
||||
@@ -28,10 +28,10 @@ skip_spark = not spark_available
|
||||
pytestmark = [pytest.mark.skipif(skip_spark, reason="Spark is not installed. Skip all spark tests."), pytest.mark.spark]
|
||||
|
||||
|
||||
def test_parallel_xgboost(hpo_method=None, data_size=1000):
|
||||
def test_parallel_xgboost_and_pickle(hpo_method=None, data_size=1000):
|
||||
automl_experiment = AutoML()
|
||||
automl_settings = {
|
||||
"time_budget": 10,
|
||||
"time_budget": 30,
|
||||
"metric": "ap",
|
||||
"task": "classification",
|
||||
"log_file_name": "test/sparse_classification.log",
|
||||
@@ -53,15 +53,27 @@ def test_parallel_xgboost(hpo_method=None, data_size=1000):
|
||||
print(automl_experiment.best_iteration)
|
||||
print(automl_experiment.best_estimator)
|
||||
|
||||
# test pickle and load_pickle, should work for prediction
|
||||
automl_experiment.pickle("automl_xgboost_spark.pkl")
|
||||
automl_loaded = AutoML().load_pickle("automl_xgboost_spark.pkl")
|
||||
assert automl_loaded.best_estimator == automl_experiment.best_estimator
|
||||
assert automl_loaded.best_loss == automl_experiment.best_loss
|
||||
automl_loaded.predict(X_train)
|
||||
|
||||
import shutil
|
||||
|
||||
shutil.rmtree("automl_xgboost_spark.pkl", ignore_errors=True)
|
||||
shutil.rmtree("automl_xgboost_spark.pkl.flaml_artifacts", ignore_errors=True)
|
||||
|
||||
|
||||
def test_parallel_xgboost_others():
|
||||
# use random search as the hpo_method
|
||||
test_parallel_xgboost(hpo_method="random")
|
||||
test_parallel_xgboost_and_pickle(hpo_method="random")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="currently not supporting too large data, will support spark dataframe in the future")
|
||||
def test_large_dataset():
|
||||
test_parallel_xgboost(data_size=90000000)
|
||||
test_parallel_xgboost_and_pickle(data_size=90000000)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -95,10 +107,10 @@ def test_custom_learner(data_size=1000):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_parallel_xgboost()
|
||||
test_parallel_xgboost_others()
|
||||
# test_large_dataset()
|
||||
if skip_my_learner:
|
||||
print("please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file")
|
||||
else:
|
||||
test_custom_learner()
|
||||
test_parallel_xgboost_and_pickle()
|
||||
# test_parallel_xgboost_others()
|
||||
# # test_large_dataset()
|
||||
# if skip_my_learner:
|
||||
# print("please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file")
|
||||
# else:
|
||||
# test_custom_learner()
|
||||
|
||||
Reference in New Issue
Block a user