mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
update
This commit is contained in:
@@ -37,6 +37,7 @@ from .model import (
|
||||
ARIMA,
|
||||
SARIMAX,
|
||||
TransformersEstimator,
|
||||
TemporalFusionTransformerEstimator,
|
||||
TransformersEstimatorModelSelection,
|
||||
)
|
||||
from .data import CLASSIFICATION, group_counts, TS_FORECAST
|
||||
@@ -122,6 +123,8 @@ def get_estimator_class(task, estimator_name):
|
||||
estimator_class = SARIMAX
|
||||
elif estimator_name == "transformer":
|
||||
estimator_class = TransformersEstimator
|
||||
elif estimator_name == "tft":
|
||||
estimator_class = TemporalFusionTransformerEstimator
|
||||
elif estimator_name == "transformer_ms":
|
||||
estimator_class = TransformersEstimatorModelSelection
|
||||
else:
|
||||
@@ -473,7 +476,7 @@ def evaluate_model_CV(
|
||||
"label_list"
|
||||
) # pass the label list on to compute the evaluation metric
|
||||
groups = None
|
||||
shuffle = False if task in TS_FORECAST else True
|
||||
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
|
||||
if isinstance(kf, RepeatedStratifiedKFold):
|
||||
kf = kf.split(X_train_split, y_train_split)
|
||||
elif isinstance(kf, GroupKFold):
|
||||
|
||||
Reference in New Issue
Block a user