mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
Add multi-target regression support
- Modified validation to accept 2D y arrays (n_samples, n_targets) - Added multi-target detection in generic_task.validate_data - Filtered unsupported estimators (only XGBoost, CatBoost support multi-target) - Configured CatBoost with MultiRMSE objective for multi-target - Fixed AutoML.predict to not flatten multi-target predictions - Updated AutoML.fit docstring to document multi-target support Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com>
This commit is contained in:
@@ -634,8 +634,11 @@ class AutoML(BaseEstimator):
|
||||
X = self._state.task.preprocess(X, self._transformer)
|
||||
y_pred = estimator.predict(X, **pred_kwargs)
|
||||
|
||||
if isinstance(y_pred, np.ndarray) and y_pred.ndim > 1 and isinstance(y_pred, np.ndarray):
|
||||
y_pred = y_pred.flatten()
|
||||
# Only flatten if not multi-target regression
|
||||
if isinstance(y_pred, np.ndarray) and y_pred.ndim > 1:
|
||||
is_multi_target = getattr(self._state, 'is_multi_target', False)
|
||||
if not is_multi_target:
|
||||
y_pred = y_pred.flatten()
|
||||
if self._label_transformer:
|
||||
return self._label_transformer.inverse_transform(Series(y_pred.astype(int)))
|
||||
else:
|
||||
@@ -1272,7 +1275,9 @@ class AutoML(BaseEstimator):
|
||||
must be the timestamp column (datetime type). Other columns in
|
||||
the dataframe are assumed to be exogenous variables (categorical or numeric).
|
||||
When using ray, X_train can be a ray.ObjectRef.
|
||||
y_train: A numpy array or a pandas series of labels in shape (n, ).
|
||||
y_train: A numpy array, pandas series, or pandas dataframe of labels in shape (n, )
|
||||
for single-target tasks or (n, k) for multi-target regression tasks.
|
||||
For multi-target regression, only XGBoost and CatBoost estimators are supported.
|
||||
dataframe: A dataframe of training data including label column.
|
||||
For time series forecast tasks, dataframe must be specified and must have
|
||||
at least two columns, timestamp and label, where the first
|
||||
@@ -1883,7 +1888,8 @@ class AutoML(BaseEstimator):
|
||||
self._state.error_metric = error_metric
|
||||
|
||||
is_spark_dataframe = isinstance(X_train, psDataFrame) or isinstance(dataframe, psDataFrame)
|
||||
estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe)
|
||||
is_multi_target = getattr(self._state, 'is_multi_target', False)
|
||||
estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe, is_multi_target)
|
||||
|
||||
if is_spark_dataframe and self._use_spark:
|
||||
# For spark dataframe, use_spark must be False because spark models are trained in parallel themselves
|
||||
|
||||
@@ -2081,6 +2081,18 @@ class CatBoostEstimator(BaseEstimator):
|
||||
cat_features = list(X_train.select_dtypes(include="category").columns)
|
||||
else:
|
||||
cat_features = []
|
||||
|
||||
# Detect multi-target regression and set appropriate loss function
|
||||
is_multi_target = False
|
||||
if self._task.is_regression():
|
||||
if isinstance(y_train, np.ndarray) and y_train.ndim == 2 and y_train.shape[1] > 1:
|
||||
is_multi_target = True
|
||||
elif isinstance(y_train, DataFrame) and y_train.shape[1] > 1:
|
||||
is_multi_target = True
|
||||
|
||||
if is_multi_target and "loss_function" not in self.params:
|
||||
self.params["loss_function"] = "MultiRMSE"
|
||||
|
||||
use_best_model = kwargs.get("use_best_model", True)
|
||||
n = max(int(len(y_train) * 0.9), len(y_train) - 1000) if use_best_model else len(y_train)
|
||||
X_tr, y_tr = X_train[:n], y_train[:n]
|
||||
|
||||
@@ -119,13 +119,15 @@ class GenericTask(Task):
|
||||
"a Scipy sparse matrix or a pyspark.pandas dataframe."
|
||||
)
|
||||
assert isinstance(
|
||||
y_train_all, (np.ndarray, pd.Series, psSeries)
|
||||
), "y_train_all must be a numpy array, a pandas series or a pyspark.pandas series."
|
||||
y_train_all, (np.ndarray, pd.Series, pd.DataFrame, psSeries)
|
||||
), "y_train_all must be a numpy array, a pandas series, a pandas dataframe or a pyspark.pandas series."
|
||||
assert X_train_all.size != 0 and y_train_all.size != 0, "Input data must not be empty."
|
||||
if isinstance(X_train_all, np.ndarray) and len(X_train_all.shape) == 1:
|
||||
X_train_all = np.reshape(X_train_all, (X_train_all.size, 1))
|
||||
if isinstance(y_train_all, np.ndarray):
|
||||
y_train_all = y_train_all.flatten()
|
||||
# Only flatten if it's truly 1D (not multi-target)
|
||||
if y_train_all.ndim == 1 or (y_train_all.ndim == 2 and y_train_all.shape[1] == 1):
|
||||
y_train_all = y_train_all.flatten()
|
||||
assert X_train_all.shape[0] == y_train_all.shape[0], "# rows in X_train must match length of y_train."
|
||||
if isinstance(X_train_all, psDataFrame):
|
||||
X_train_all = X_train_all.spark.cache() # cache data to improve compute speed
|
||||
@@ -219,6 +221,20 @@ class GenericTask(Task):
|
||||
automl._X_train_all.columns.to_list() if hasattr(automl._X_train_all, "columns") else None
|
||||
)
|
||||
|
||||
# Detect multi-target regression
|
||||
is_multi_target = False
|
||||
n_targets = 1
|
||||
if self.is_regression():
|
||||
if isinstance(automl._y_train_all, np.ndarray) and automl._y_train_all.ndim == 2:
|
||||
is_multi_target = True
|
||||
n_targets = automl._y_train_all.shape[1]
|
||||
elif isinstance(automl._y_train_all, pd.DataFrame):
|
||||
is_multi_target = True
|
||||
n_targets = automl._y_train_all.shape[1]
|
||||
|
||||
state.is_multi_target = is_multi_target
|
||||
state.n_targets = n_targets
|
||||
|
||||
automl._sample_weight_full = state.fit_kwargs.get(
|
||||
"sample_weight"
|
||||
) # NOTE: _validate_data is before kwargs is updated to fit_kwargs_by_estimator
|
||||
@@ -227,14 +243,16 @@ class GenericTask(Task):
|
||||
"X_val must be None, a numpy array, a pandas dataframe, "
|
||||
"a Scipy sparse matrix or a pyspark.pandas dataframe."
|
||||
)
|
||||
assert isinstance(y_val, (np.ndarray, pd.Series, psSeries)), (
|
||||
"y_val must be None, a numpy array, a pandas series " "or a pyspark.pandas series."
|
||||
assert isinstance(y_val, (np.ndarray, pd.Series, pd.DataFrame, psSeries)), (
|
||||
"y_val must be None, a numpy array, a pandas series, a pandas dataframe " "or a pyspark.pandas series."
|
||||
)
|
||||
assert X_val.size != 0 and y_val.size != 0, (
|
||||
"Validation data are expected to be nonempty. " "Use None for X_val and y_val if no validation data."
|
||||
)
|
||||
if isinstance(y_val, np.ndarray):
|
||||
y_val = y_val.flatten()
|
||||
# Only flatten if it's truly 1D (not multi-target)
|
||||
if y_val.ndim == 1 or (y_val.ndim == 2 and y_val.shape[1] == 1):
|
||||
y_val = y_val.flatten()
|
||||
assert X_val.shape[0] == y_val.shape[0], "# rows in X_val must match length of y_val."
|
||||
if automl._transformer:
|
||||
state.X_val = automl._transformer.transform(X_val)
|
||||
@@ -819,7 +837,7 @@ class GenericTask(Task):
|
||||
pred_time /= n
|
||||
return val_loss, metric, train_time, pred_time
|
||||
|
||||
def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool = False) -> List[str]:
|
||||
def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool = False, is_multi_target: bool = False) -> List[str]:
|
||||
if "auto" != estimator_list:
|
||||
n_estimators = len(estimator_list)
|
||||
if is_spark_dataframe:
|
||||
@@ -848,6 +866,23 @@ class GenericTask(Task):
|
||||
"Non-spark dataframes only support estimator names not ending with `_spark`. Non-supported "
|
||||
"estimators are removed."
|
||||
)
|
||||
|
||||
# Filter out unsupported estimators for multi-target regression
|
||||
if is_multi_target and self.is_regression():
|
||||
# List of estimators that support multi-target regression natively
|
||||
multi_target_supported = ["xgboost", "xgb_limitdepth", "catboost"]
|
||||
original_len = len(estimator_list)
|
||||
estimator_list = [est for est in estimator_list if est in multi_target_supported]
|
||||
if len(estimator_list) == 0:
|
||||
raise ValueError(
|
||||
"Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. "
|
||||
"Non-supported estimators are removed. No estimator is left."
|
||||
)
|
||||
elif original_len != len(estimator_list):
|
||||
logger.warning(
|
||||
"Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. "
|
||||
"Non-supported estimators are removed."
|
||||
)
|
||||
return estimator_list
|
||||
if self.is_rank():
|
||||
estimator_list = ["lgbm", "xgboost", "xgb_limitdepth", "lgbm_spark"]
|
||||
@@ -897,6 +932,18 @@ class GenericTask(Task):
|
||||
for est in estimator_list
|
||||
if (est.endswith("_spark") if is_spark_dataframe else not est.endswith("_spark"))
|
||||
]
|
||||
|
||||
# Filter for multi-target regression support
|
||||
if is_multi_target and self.is_regression():
|
||||
# List of estimators that support multi-target regression natively
|
||||
multi_target_supported = ["xgboost", "xgb_limitdepth", "catboost"]
|
||||
estimator_list = [est for est in estimator_list if est in multi_target_supported]
|
||||
if len(estimator_list) == 0:
|
||||
raise ValueError(
|
||||
"Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. "
|
||||
"No supported estimator is available."
|
||||
)
|
||||
|
||||
return estimator_list
|
||||
|
||||
def default_metric(self, metric: str) -> str:
|
||||
|
||||
@@ -253,6 +253,7 @@ class Task(ABC):
|
||||
self,
|
||||
estimator_list: Union[List[str], str] = "auto",
|
||||
is_spark_dataframe: bool = False,
|
||||
is_multi_target: bool = False,
|
||||
) -> List[str]:
|
||||
"""Return the list of default estimators registered for this task type.
|
||||
|
||||
@@ -262,6 +263,7 @@ class Task(ABC):
|
||||
Args:
|
||||
estimator_list: Either 'auto' or a list of estimator names to be validated.
|
||||
is_spark_dataframe: True if the data is a spark dataframe.
|
||||
is_multi_target: True if the task involves multi-target regression.
|
||||
|
||||
Returns:
|
||||
A list of valid estimator names for this task type.
|
||||
|
||||
@@ -459,7 +459,7 @@ class TimeSeriesTask(Task):
|
||||
pred_time /= n
|
||||
return val_loss, metric, train_time, pred_time
|
||||
|
||||
def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool) -> List[str]:
|
||||
def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool, is_multi_target: bool = False) -> List[str]:
|
||||
assert not is_spark_dataframe, "Spark is not yet supported for time series"
|
||||
|
||||
# TODO: why not do this if/then in the calling function?
|
||||
|
||||
Reference in New Issue
Block a user