mirror of
https://github.com/microsoft/FLAML.git
synced 2026-02-09 02:09:16 +08:00
adding evaluation (#495)
* adding automl.score * fixing the metric name in train_with_config * adding pickle after score * fixing a bug in automl.pickle
This commit is contained in:
11
flaml/ml.py
11
flaml/ml.py
@@ -219,6 +219,13 @@ def is_in_sklearn_metric_name_set(metric_name):
|
||||
return metric_name.startswith("ndcg") or metric_name in sklearn_metric_name_set
|
||||
|
||||
|
||||
def is_min_metric(metric_name):
|
||||
return (
|
||||
metric_name in ["rmse", "mae", "mse", "log_loss", "mape"]
|
||||
or huggingface_metric_to_mode.get(metric_name, None) == "min"
|
||||
)
|
||||
|
||||
|
||||
def sklearn_metric_loss_score(
|
||||
metric_name,
|
||||
y_predict,
|
||||
@@ -565,6 +572,8 @@ def compute_estimator(
|
||||
|
||||
if isinstance(estimator, TransformersEstimator):
|
||||
fit_kwargs["metric"] = eval_metric
|
||||
fit_kwargs["X_val"] = X_val
|
||||
fit_kwargs["y_val"] = y_val
|
||||
|
||||
if "holdout" == eval_method:
|
||||
val_loss, metric_for_logging, train_time, pred_time = get_val_loss(
|
||||
@@ -633,7 +642,7 @@ def get_classification_objective(num_labels: int) -> str:
|
||||
if num_labels == 2:
|
||||
objective_name = "binary"
|
||||
else:
|
||||
objective_name = "multi"
|
||||
objective_name = "multiclass"
|
||||
return objective_name
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user