Files
FLAML/test/conftest.py
Li Jiang 1285700d7a Update readme, bump version to 2.4.0, fix CI errors (#1466)
* Update gitignore

* Bump version to 2.4.0

* Update readme

* Pre-download california housing data

* Use pre-downloaded california housing data

* Pin lightning<=2.5.6

* Fix typo in find and replace

* Fix estimators has no attribute __sklearn_tags__

* Pin torch to 2.2.2 in tests

* Fix conflict

* Update pytorch-forecasting

* Update pytorch-forecasting

* Update pytorch-forecasting

* Use numpy<2 for testing

* Update scikit-learn

* Run Build and UT every other day

* Pin pip<24.1

* Pin pip<24.1 in pipeline

* Loosen pip, install pytorch_forecasting only in py311

* Add support to new versions of nlp dependecies

* Fix formats

* Remove redefinition

* Update mlflow versions

* Fix mlflow version syntax

* Update gitignore

* Clean up cache to free space

* Remove clean up action cache

* Fix blendsearch

* Update test workflow

* Update setup.py

* Fix catboost version

* Update workflow

* Prepare for python 3.14

* Support no catboost

* Fix tests

* Fix python_requires

* Update test workflow

* Fix vw tests

* Remove python 3.9

* Fix nlp tests

* Fix prophet

* Print pip freeze for better debugging

* Fix Optuna search does not support parameters of type Float with samplers of type Quantized

* Save dependencies for later inspection

* Fix coverage.xml not exists

* Fix github action permission

* Handle python 3.13

* Address openml is not installed

* Check dependencies before run tests

* Update dependencies

* Fix syntax error

* Use bash

* Update dependencies

* Fix git error

* Loose mlflow constraints

* Add rerun, use mlflow-skinny

* Fix git error

* Remove ray tests

* Update xgboost versions

* Fix automl pickle error

* Don't test python 3.10 on macos as it's stuck

* Rebase before push

* Reduce number of branches
2026-01-09 13:40:52 +08:00

58 lines
2.4 KiB
Python

from typing import Any, Dict, List, Union
import numpy as np
import pandas as pd
import pytest
from sklearn.metrics import f1_score, r2_score
try:
from catboost import CatBoostClassifier, CatBoostRegressor, Pool
except ImportError: # pragma: no cover
CatBoostClassifier = None
CatBoostRegressor = None
Pool = None
def _is_catboost_model_type(model_type: type) -> bool:
if CatBoostClassifier is not None and CatBoostRegressor is not None:
return model_type is CatBoostClassifier or model_type is CatBoostRegressor
return getattr(model_type, "__module__", "").startswith("catboost")
def evaluate_cv_folds_with_underlying_model(X_train_all, y_train_all, kf, model: Any, task: str) -> List[float]:
"""Mimic the FLAML CV process to calculate the metrics across each fold.
:param X_train_all: X training data
:param y_train_all: y training data
:param kf: The splitter object to use to generate the folds
:param model: The estimator to fit to the data during the CV process
:param task: classification or regression
:return: An array containing the metrics
"""
rng = np.random.RandomState(2020)
all_fold_metrics: List[float] = []
for train_index, val_index in kf.split(X_train_all, y_train_all):
X_train_split, y_train_split = X_train_all, y_train_all
train_index = rng.permutation(train_index)
X_train = X_train_split.iloc[train_index]
X_val = X_train_split.iloc[val_index]
y_train, y_val = y_train_split[train_index], y_train_split[val_index]
model_type = type(model)
if not _is_catboost_model_type(model_type):
model.fit(X_train, y_train)
else:
if Pool is None:
pytest.skip("catboost is not installed")
use_best_model = True
n = max(int(len(y_train) * 0.9), len(y_train) - 1000) if use_best_model else len(y_train)
X_tr, y_tr = (X_train)[:n], y_train[:n]
eval_set = Pool(data=X_train[n:], label=y_train[n:], cat_features=[]) if use_best_model else None
model.fit(X_tr, y_tr, eval_set=eval_set, use_best_model=True)
y_pred_classes = model.predict(X_val)
if task == "classification":
reproduced_metric = 1 - f1_score(y_val, y_pred_classes)
else:
reproduced_metric = 1 - r2_score(y_val, y_pred_classes)
all_fold_metrics.append(float(reproduced_metric))
return all_fold_metrics