This commit is contained in:
skzhang1
2022-08-13 18:56:46 +00:00
38 changed files with 6350 additions and 2512 deletions

View File

@@ -37,6 +37,7 @@ from .model import (
ARIMA,
SARIMAX,
TransformersEstimator,
TemporalFusionTransformerEstimator,
TransformersEstimatorModelSelection,
)
from .data import CLASSIFICATION, group_counts, TS_FORECAST
@@ -122,6 +123,8 @@ def get_estimator_class(task, estimator_name):
estimator_class = SARIMAX
elif estimator_name == "transformer":
estimator_class = TransformersEstimator
elif estimator_name == "tft":
estimator_class = TemporalFusionTransformerEstimator
elif estimator_name == "transformer_ms":
estimator_class = TransformersEstimatorModelSelection
else:
@@ -473,7 +476,7 @@ def evaluate_model_CV(
"label_list"
) # pass the label list on to compute the evaluation metric
groups = None
shuffle = False if task in TS_FORECAST else True
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
if isinstance(kf, RepeatedStratifiedKFold):
kf = kf.split(X_train_split, y_train_split)
elif isinstance(kf, GroupKFold):