Files
FLAML/test/spark/test_exceptions.py
Li Jiang 1285700d7a Update readme, bump version to 2.4.0, fix CI errors (#1466)
* Update gitignore

* Bump version to 2.4.0

* Update readme

* Pre-download california housing data

* Use pre-downloaded california housing data

* Pin lightning<=2.5.6

* Fix typo in find and replace

* Fix estimators has no attribute __sklearn_tags__

* Pin torch to 2.2.2 in tests

* Fix conflict

* Update pytorch-forecasting

* Update pytorch-forecasting

* Update pytorch-forecasting

* Use numpy<2 for testing

* Update scikit-learn

* Run Build and UT every other day

* Pin pip<24.1

* Pin pip<24.1 in pipeline

* Loosen pip, install pytorch_forecasting only in py311

* Add support to new versions of nlp dependecies

* Fix formats

* Remove redefinition

* Update mlflow versions

* Fix mlflow version syntax

* Update gitignore

* Clean up cache to free space

* Remove clean up action cache

* Fix blendsearch

* Update test workflow

* Update setup.py

* Fix catboost version

* Update workflow

* Prepare for python 3.14

* Support no catboost

* Fix tests

* Fix python_requires

* Update test workflow

* Fix vw tests

* Remove python 3.9

* Fix nlp tests

* Fix prophet

* Print pip freeze for better debugging

* Fix Optuna search does not support parameters of type Float with samplers of type Quantized

* Save dependencies for later inspection

* Fix coverage.xml not exists

* Fix github action permission

* Handle python 3.13

* Address openml is not installed

* Check dependencies before run tests

* Update dependencies

* Fix syntax error

* Use bash

* Update dependencies

* Fix git error

* Loose mlflow constraints

* Add rerun, use mlflow-skinny

* Fix git error

* Remove ray tests

* Update xgboost versions

* Fix automl pickle error

* Don't test python 3.10 on macos as it's stuck

* Rebase before push

* Reduce number of branches
2026-01-09 13:40:52 +08:00

81 lines
2.6 KiB
Python

import os
import pytest
from flaml import AutoML
from flaml.automl.data import load_openml_dataset
from flaml.tune.spark.utils import check_spark
spark_available, _ = check_spark()
skip_spark = not spark_available
pytestmark = [pytest.mark.skipif(skip_spark, reason="Spark is not installed. Skip all spark tests."), pytest.mark.spark]
os.environ["FLAML_MAX_CONCURRENT"] = "2"
def base_automl(n_concurrent_trials=1, use_ray=False, use_spark=False, verbose=0):
from minio.error import ServerError
try:
X_train, X_test, y_train, y_test = load_openml_dataset(dataset_id=537, data_dir="./")
except (ServerError, Exception):
from sklearn.datasets import fetch_california_housing
X_train, y_train = fetch_california_housing(return_X_y=True, data_home="test")
automl = AutoML()
settings = {
"time_budget": 3, # total running time in seconds
"metric": "r2", # primary metrics for regression can be chosen from: ['mae','mse','r2','rmse','mape']
"estimator_list": ["lgbm", "rf", "xgboost"], # list of ML learners
"task": "regression", # task type
"log_file_name": "houses_experiment.log", # flaml log file
"seed": 7654321, # random seed
"n_concurrent_trials": n_concurrent_trials, # the maximum number of concurrent learners
"use_ray": use_ray, # whether to use Ray for distributed training
"use_spark": use_spark, # whether to use Spark for distributed training
"verbose": verbose,
}
automl.fit(X_train=X_train, y_train=y_train, **settings)
print("Best ML leaner:", automl.best_estimator)
print("Best hyperparmeter config:", automl.best_config)
print(f"Best accuracy on validation data: {1 - automl.best_loss:.4g}")
print(f"Training duration of best run: {automl.best_config_train_time:.4g} s")
def test_both_ray_spark():
with pytest.raises(ValueError):
base_automl(n_concurrent_trials=2, use_ray=True, use_spark=True)
def test_verboses():
for verbose in [1, 3, 5]:
base_automl(verbose=verbose)
def test_import_error():
from importlib import reload
import flaml.tune.spark.utils as utils
reload(utils)
utils._have_spark = False
spark_available, spark_error_msg = utils.check_spark()
assert not spark_available
assert isinstance(spark_error_msg, ImportError)
reload(utils)
utils._spark_major_minor_version = (1, 1)
spark_available, spark_error_msg = utils.check_spark()
assert not spark_available
assert isinstance(spark_error_msg, ImportError)
reload(utils)
if __name__ == "__main__":
base_automl()
test_import_error()