Skip to content

Commit 0b26ffe

Browse files
committed
test that all levels are preserved with 1 row predictions
1 parent f6ff822 commit 0b26ffe

File tree

5 files changed

+82
-0
lines changed

5 files changed

+82
-0
lines changed

tests/testthat/test-hier_clust-stats.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ test_that("predicting", {
2727
)
2828
})
2929

30+
test_that("all levels are preserved with 1 row predictions", {
31+
set.seed(1234)
32+
spec <- hier_clust(num_clusters = 3) %>%
33+
set_engine("stats")
34+
35+
res <- fit(spec, ~., mtcars)
36+
37+
preds <- predict(res, mtcars[1, ])
38+
39+
expect_identical(
40+
levels(preds$.pred_cluster),
41+
paste0("Cluster_", 1:3)
42+
)
43+
})
44+
3045
test_that("extract_centroids() works", {
3146
set.seed(1234)
3247
spec <- hier_clust(num_clusters = 3) %>%

tests/testthat/test-k_means-clustMixType.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ test_that("predicting", {
4040
)
4141
})
4242

43+
test_that("all levels are preserved with 1 row predictions", {
44+
set.seed(1234)
45+
spec <- k_means(num_clusters = 3) %>%
46+
set_engine("clustMixType")
47+
48+
res <- fit(spec, ~., iris)
49+
50+
preds <- predict(res, iris[1, ])
51+
52+
expect_identical(
53+
levels(preds$.pred_cluster),
54+
paste0("Cluster_", 1:3)
55+
)
56+
})
57+
4358
test_that("extract_centroids() works", {
4459
skip_if_not_installed("clustMixType")
4560

tests/testthat/test-k_means-clusterR.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ test_that("predicting", {
3232
)
3333
})
3434

35+
test_that("all levels are preserved with 1 row predictions", {
36+
set.seed(1234)
37+
spec <- k_means(num_clusters = 3) %>%
38+
set_engine("ClusterR")
39+
40+
res <- fit(spec, ~., mtcars)
41+
42+
preds <- predict(res, mtcars[1, ])
43+
44+
expect_identical(
45+
levels(preds$.pred_cluster),
46+
paste0("Cluster_", 1:3)
47+
)
48+
})
49+
3550
test_that("extract_centroids() works", {
3651
skip_if_not_installed("ClusterR")
3752

tests/testthat/test-k_means-klaR.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,28 @@ test_that("predicting", {
5151
)
5252
})
5353

54+
test_that("all levels are preserved with 1 row predictions", {
55+
skip_if_not_installed("klaR")
56+
skip_if_not_installed("modeldata")
57+
58+
data("ames", package = "modeldata")
59+
60+
ames_cat <- dplyr::select(ames, dplyr::where(is.factor))
61+
62+
set.seed(1234)
63+
spec <- k_means(num_clusters = 3) %>%
64+
set_engine("klaR")
65+
66+
res <- fit(spec, ~., ames_cat)
67+
68+
preds <- predict(res, ames_cat[1, ])
69+
70+
expect_identical(
71+
levels(preds$.pred_cluster),
72+
paste0("Cluster_", 1:3)
73+
)
74+
})
75+
5476
test_that("predicting ties argument works", {
5577
skip_if_not_installed("klaR")
5678

tests/testthat/test-k_means-stats.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ test_that("predicting", {
2727
)
2828
})
2929

30+
test_that("all levels are preserved with 1 row predictions", {
31+
set.seed(1234)
32+
spec <- k_means(num_clusters = 3) %>%
33+
set_engine("stats")
34+
35+
res <- fit(spec, ~., mtcars)
36+
37+
preds <- predict(res, mtcars[1, ])
38+
39+
expect_identical(
40+
levels(preds$.pred_cluster),
41+
paste0("Cluster_", 1:3)
42+
)
43+
})
44+
3045
test_that("extract_centroids() works", {
3146
set.seed(1234)
3247
spec <- k_means(num_clusters = 3) %>%

0 commit comments

Comments
 (0)