mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
Fix formatting with pre-commit
- Changed single quotes to double quotes for consistency - Removed trailing whitespace - Fixed line spacing Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com>
This commit is contained in:
@@ -829,7 +829,7 @@ class AutoML(BaseEstimator):
|
||||
|
||||
# Only flatten if not multi-target regression
|
||||
if isinstance(y_pred, np.ndarray) and y_pred.ndim > 1:
|
||||
is_multi_target = getattr(self._state, 'is_multi_target', False)
|
||||
is_multi_target = getattr(self._state, "is_multi_target", False)
|
||||
if not is_multi_target:
|
||||
y_pred = y_pred.flatten()
|
||||
if self._label_transformer:
|
||||
@@ -2495,7 +2495,7 @@ class AutoML(BaseEstimator):
|
||||
self._state.error_metric = error_metric
|
||||
|
||||
is_spark_dataframe = isinstance(X_train, psDataFrame) or isinstance(dataframe, psDataFrame)
|
||||
is_multi_target = getattr(self._state, 'is_multi_target', False)
|
||||
is_multi_target = getattr(self._state, "is_multi_target", False)
|
||||
estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe, is_multi_target)
|
||||
|
||||
if is_spark_dataframe and self._use_spark:
|
||||
|
||||
@@ -373,14 +373,14 @@ class DataTransformer:
|
||||
datetime_columns,
|
||||
)
|
||||
self._drop = drop
|
||||
|
||||
|
||||
# Check if y is multi-target (DataFrame or 2D array with multiple targets)
|
||||
is_multi_target = False
|
||||
if isinstance(y, DataFrame) and y.shape[1] > 1:
|
||||
is_multi_target = True
|
||||
elif isinstance(y, np.ndarray) and y.ndim == 2 and y.shape[1] > 1:
|
||||
is_multi_target = True
|
||||
|
||||
|
||||
# Skip label encoding for multi-target regression
|
||||
if is_multi_target and task.is_regression():
|
||||
self.label_transformer = None
|
||||
|
||||
@@ -2112,7 +2112,7 @@ class CatBoostEstimator(BaseEstimator):
|
||||
cat_features = list(X_train.select_dtypes(include="category").columns)
|
||||
else:
|
||||
cat_features = []
|
||||
|
||||
|
||||
# Detect multi-target regression and set appropriate loss function
|
||||
is_multi_target = False
|
||||
if self._task.is_regression():
|
||||
@@ -2120,10 +2120,10 @@ class CatBoostEstimator(BaseEstimator):
|
||||
is_multi_target = True
|
||||
elif isinstance(y_train, DataFrame) and y_train.shape[1] > 1:
|
||||
is_multi_target = True
|
||||
|
||||
|
||||
if is_multi_target and "loss_function" not in self.params:
|
||||
self.params["loss_function"] = "MultiRMSE"
|
||||
|
||||
|
||||
use_best_model = kwargs.get("use_best_model", True)
|
||||
n = max(int(len(y_train) * 0.9), len(y_train) - 1000) if use_best_model else len(y_train)
|
||||
X_tr, y_tr = X_train[:n], y_train[:n]
|
||||
|
||||
@@ -231,7 +231,7 @@ class GenericTask(Task):
|
||||
elif isinstance(automl._y_train_all, pd.DataFrame):
|
||||
is_multi_target = True
|
||||
n_targets = automl._y_train_all.shape[1]
|
||||
|
||||
|
||||
state.is_multi_target = is_multi_target
|
||||
state.n_targets = n_targets
|
||||
|
||||
@@ -1287,7 +1287,9 @@ class GenericTask(Task):
|
||||
pred_time /= n
|
||||
return val_loss, metric, train_time, pred_time
|
||||
|
||||
def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool = False, is_multi_target: bool = False) -> List[str]:
|
||||
def default_estimator_list(
|
||||
self, estimator_list: List[str], is_spark_dataframe: bool = False, is_multi_target: bool = False
|
||||
) -> List[str]:
|
||||
if "auto" != estimator_list:
|
||||
n_estimators = len(estimator_list)
|
||||
if is_spark_dataframe:
|
||||
@@ -1316,7 +1318,7 @@ class GenericTask(Task):
|
||||
"Non-spark dataframes only support estimator names not ending with `_spark`. Non-supported "
|
||||
"estimators are removed."
|
||||
)
|
||||
|
||||
|
||||
# Filter out unsupported estimators for multi-target regression
|
||||
if is_multi_target and self.is_regression():
|
||||
# List of estimators that support multi-target regression natively
|
||||
@@ -1382,7 +1384,7 @@ class GenericTask(Task):
|
||||
for est in estimator_list
|
||||
if (est.endswith("_spark") if is_spark_dataframe else not est.endswith("_spark"))
|
||||
]
|
||||
|
||||
|
||||
# Filter for multi-target regression support
|
||||
if is_multi_target and self.is_regression():
|
||||
# List of estimators that support multi-target regression natively
|
||||
@@ -1393,7 +1395,7 @@ class GenericTask(Task):
|
||||
"Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. "
|
||||
"No supported estimator is available."
|
||||
)
|
||||
|
||||
|
||||
return estimator_list
|
||||
|
||||
def default_metric(self, metric: str) -> str:
|
||||
|
||||
@@ -458,7 +458,9 @@ class TimeSeriesTask(Task):
|
||||
pred_time /= n
|
||||
return val_loss, metric, train_time, pred_time
|
||||
|
||||
def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool, is_multi_target: bool = False) -> List[str]:
|
||||
def default_estimator_list(
|
||||
self, estimator_list: List[str], is_spark_dataframe: bool, is_multi_target: bool = False
|
||||
) -> List[str]:
|
||||
assert not is_spark_dataframe, "Spark is not yet supported for time series"
|
||||
|
||||
# TODO: why not do this if/then in the calling function?
|
||||
|
||||
@@ -16,9 +16,7 @@ class TestMultiTargetRegression(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Create multi-target regression datasets for testing."""
|
||||
# Create synthetic multi-target regression data
|
||||
self.X, self.y = make_regression(
|
||||
n_samples=200, n_features=10, n_targets=3, random_state=42, noise=0.1
|
||||
)
|
||||
self.X, self.y = make_regression(n_samples=200, n_features=10, n_targets=3, random_state=42, noise=0.1)
|
||||
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
|
||||
self.X, self.y, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user