@@ -471,3 +471,42 @@ test_that("select_best() and show_best() works", {
471471 dplyr :: select(num_clusters , .config )
472472 )
473473})
474+
475+ test_that(" doesn't error if recipes uses id variables" , {
476+ helper_objects <- helper_objects_tidyclust()
477+
478+ mtcars_id <- mtcars %> %
479+ tibble :: rownames_to_column(var = " model" )
480+
481+ rec_id <- recipes :: recipe(~ . , data = mtcars_id ) %> %
482+ recipes :: update_role(model , new_role = " id variable" ) %> %
483+ recipes :: step_normalize(recipes :: all_numeric_predictors())
484+
485+ set.seed(4400 )
486+ wflow <- workflows :: workflow() %> %
487+ workflows :: add_recipe(rec_id ) %> %
488+ workflows :: add_model(helper_objects $ kmeans_mod )
489+ pset <- hardhat :: extract_parameter_set_dials(wflow ) %> %
490+ update(num_clusters = dials :: num_clusters(c(1 , 3 )))
491+ grid <- dials :: grid_regular(pset , levels = 3 )
492+ folds <- rsample :: vfold_cv(mtcars_id , v = 2 )
493+ control <- tune :: control_grid(extract = identity )
494+ metrics <- cluster_metric_set(sse_within_total , sse_total )
495+
496+ res <- tune_cluster(
497+ wflow ,
498+ resamples = folds ,
499+ grid = grid ,
500+ control = control ,
501+ metrics = metrics
502+ )
503+ res_est <- tune :: collect_metrics(res )
504+ res_workflow <- res $ .extracts [[1 ]]$ .extracts [[1 ]]
505+
506+ expect_equal(res $ id , folds $ id )
507+ expect_equal(nrow(res_est ), nrow(grid ) * 2 )
508+ expect_equal(sum(res_est $ .metric == " sse_total" ), nrow(grid ))
509+ expect_equal(sum(res_est $ .metric == " sse_within_total" ), nrow(grid ))
510+ expect_equal(res_est $ n , rep(2 , nrow(grid ) * 2 ))
511+ expect_true(res_workflow $ trained )
512+ })
0 commit comments