Skip to content

Commit cccf0e9

Browse files
Merge pull request #162 from tidymodels/native-ordering
2 parents 95f97fa + 9ebf11d commit cccf0e9

13 files changed

+152
-59
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Suggests:
3939
ClusterR,
4040
clustMixType (>= 0.3-5),
4141
covr,
42+
klaR,
4243
knitr,
4344
modeldata (>= 1.0.0),
4445
RcppHungarian,

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ S3method(update,k_means)
6565
export("%>%")
6666
export(.hier_clust_fit_stats)
6767
export(.k_means_fit_ClusterR)
68+
export(.k_means_fit_clustMixType)
69+
export(.k_means_fit_klaR)
6870
export(.k_means_fit_stats)
6971
export(augment)
7072
export(cluster_metric_set)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
* The klaR engine as been added to `k_means()`. This engine allows fitting of k-modes models. (#63)
2020

21+
* Cluster reordering is now done at the fitting time, not the extraction and prediction time. (#154)
22+
2123
# tidyclust 0.1.2
2224

2325
* The cluster specification methods for `generics::tune_args()` and `generics::tunable()` are now registered unconditionally (#115).

R/extract_fit_summary.R

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,21 @@ extract_fit_summary.workflow <- function(object, ...) {
4646

4747
#' @export
4848
extract_fit_summary.kmeans <- function(object, ..., prefix = "Cluster_") {
49-
reorder_clusts <- order(unique(object$cluster))
50-
names <- paste0(prefix, seq_len(nrow(object$centers)))
49+
names <- paste0(prefix, seq_along(object$size))
5150
names <- factor(names)
5251

5352
cluster_asignments <- factor(
54-
names[reorder_clusts][object$cluster],
53+
names[object$cluster],
5554
levels = levels(names)
5655
)
5756

58-
centroids <- object$centers[reorder_clusts, , drop = FALSE]
59-
centroids <- tibble::as_tibble(centroids)
57+
centroids <- tibble::as_tibble(object$centers)
6058

6159
list(
6260
cluster_names = names,
6361
centroids = centroids,
64-
n_members = object$size[unique(object$cluster)],
65-
sse_within_total_total = object$withinss[unique(object$cluster)],
62+
n_members = object$size,
63+
sse_within_total_total = object$withinss,
6664
sse_total = object$totss,
6765
orig_labels = unname(object$cluster),
6866
cluster_assignments = cluster_asignments
@@ -73,23 +71,21 @@ extract_fit_summary.kmeans <- function(object, ..., prefix = "Cluster_") {
7371
extract_fit_summary.KMeansCluster <- function(object,
7472
...,
7573
prefix = "Cluster_") {
76-
reorder_clusts <- order(unique(object$cluster))
7774
names <- paste0(prefix, seq_len(nrow(object$centroids)))
7875
names <- factor(names)
7976

8077
cluster_asignments <- factor(
81-
names[reorder_clusts][object$clusters],
78+
names[object$clusters],
8279
levels = levels(names)
8380
)
8481

85-
centroids <- object$centroids[reorder_clusts, , drop = FALSE]
86-
centroids <- tibble::as_tibble(centroids)
82+
centroids <- tibble::as_tibble(object$centroids)
8783

8884
list(
8985
cluster_names = names,
9086
centroids = centroids,
91-
n_members = as.integer(object$obs_per_cluster[unique(object$cluster)]),
92-
sse_within_total_total = object$WCSS_per_cluster[unique(object$cluster)],
87+
n_members = as.integer(object$obs_per_cluster),
88+
sse_within_total_total = as.numeric(object$WCSS_per_cluster),
9389
sse_total = object$total_SSE,
9490
orig_labels = object$clusters,
9591
cluster_assignments = cluster_asignments
@@ -100,23 +96,21 @@ extract_fit_summary.KMeansCluster <- function(object,
10096
extract_fit_summary.kproto <- function(object,
10197
...,
10298
prefix = "Cluster_") {
103-
reorder_clusts <- order(unique(object$cluster))
10499
names <- paste0(prefix, seq_len(nrow(object$centers)))
105100
names <- factor(names)
106101

107102
cluster_asignments <- factor(
108-
names[reorder_clusts][object$cluster],
103+
names[object$cluster],
109104
levels = levels(names)
110105
)
111106

112-
centroids <- object$centers[reorder_clusts, , drop = FALSE]
113-
centroids <- tibble::as_tibble(centroids)
107+
centroids <- tibble::as_tibble(object$centers)
114108

115109
list(
116110
cluster_names = names,
117111
centroids = centroids,
118-
n_members = as.integer(object$size[unique(object$cluster)]),
119-
sse_within_total_total = object$withinss[unique(object$cluster)],
112+
n_members = as.integer(object$size),
113+
sse_within_total_total = object$withinss,
120114
sse_total = object$tot.withinss,
121115
orig_labels = seq_len(length(table(object$cluster))),
122116
cluster_assignments = cluster_asignments
@@ -127,23 +121,21 @@ extract_fit_summary.kproto <- function(object,
127121
extract_fit_summary.kmodes <- function(object,
128122
...,
129123
prefix = "Cluster_") {
130-
reorder_clusts <- order(unique(object$cluster))
131124
names <- paste0(prefix, seq_len(nrow(object$modes)))
132125
names <- factor(names)
133126

134127
cluster_asignments <- factor(
135-
names[reorder_clusts][object$cluster],
128+
names[object$cluster],
136129
levels = levels(names)
137130
)
138131

139-
centroids <- object$modes[reorder_clusts, , drop = FALSE]
140-
centroids <- tibble::as_tibble(centroids)
132+
centroids <- tibble::as_tibble(object$modes)
141133

142134
list(
143135
cluster_names = names,
144136
centroids = centroids,
145-
n_members = as.integer(object$size[unique(object$cluster)]),
146-
sse_within_total_total = object$withinss[unique(object$cluster)],
137+
n_members = as.integer(object$size),
138+
sse_within_total_total = object$withinss,
147139
sse_total = object$tot.withinss,
148140
orig_labels = seq_len(length(table(object$cluster))),
149141
cluster_assignments = cluster_asignments

R/k_means.R

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ check_args.k_means <- function(object) {
131131

132132
#' Simple Wrapper around ClusterR kmeans
133133
#'
134-
#' This wrapper runs `ClusterR::KMeans_rcpp` and adds column names to the
135-
#' `centroids` field.
134+
#' This wrapper runs `ClusterR::KMeans_rcpp()` and adds column names to the
135+
#' `centroids` field. And reorders the clusters.
136136
#'
137137
#' @param data matrix or data frame
138138
#' @param clusters the number of clusters
@@ -160,7 +160,8 @@ check_args.k_means <- function(object) {
160160
#' obs_per_cluster, between.SS_DIV_total.SS
161161
#' @keywords internal
162162
#' @export
163-
.k_means_fit_ClusterR <- function(data, clusters,
163+
.k_means_fit_ClusterR <- function(data,
164+
clusters,
164165
num_init = 1,
165166
max_iters = 100,
166167
initializer = "kmeans++",
@@ -190,14 +191,21 @@ check_args.k_means <- function(object) {
190191
tol_optimal_init = tol_optimal_init,
191192
seed = seed
192193
)
194+
193195
colnames(res$centroids) <- colnames(data)
196+
197+
new_order <- unique(res$clusters)
198+
res$clusters <- order(new_order)[res$clusters]
199+
res$centroids <- res$centroids[new_order, , drop = FALSE]
200+
res$WCSS_per_cluster <- res$WCSS_per_cluster[, new_order, drop = FALSE]
201+
res$obs_per_cluster <- res$obs_per_cluster[, new_order, drop = FALSE]
194202
res
195203
}
196204

197205
#' Simple Wrapper around stats kmeans
198206
#'
199-
#' This wrapper runs `stats::kmeans` and adds a check that `centers` is
200-
#' specified
207+
#' This wrapper runs `stats::kmeans()` and adds a check that `centers` is
208+
#' specified. And reorders the clusters.
201209
#'
202210
#' @inheritParams stats::kmeans
203211
#' @param ... Other arguments passed to `stats::kmeans()`
@@ -213,5 +221,52 @@ check_args.k_means <- function(object) {
213221
)
214222
}
215223

216-
stats::kmeans(data, centers, ...)
224+
res <- stats::kmeans(data, centers, ...)
225+
new_order <- unique(res$cluster)
226+
res$cluster <- set_names(order(new_order)[res$cluster], names(res$cluster))
227+
res$centers <- res$centers[new_order, , drop = FALSE]
228+
res$withinss <- res$withinss[new_order]
229+
res$size <- res$size[new_order]
230+
res
231+
}
232+
233+
#' Simple Wrapper around clustMixType kmeans
234+
#'
235+
#' This wrapper runs `clustMixType::kproto()` and reorders the clusters.
236+
#'
237+
#' @inheritParams clustMixType::kproto
238+
#' @param ... Other arguments passed to `clustMixType::kproto()`
239+
#'
240+
#' @return Result from `clustMixType::kproto()`
241+
#' @keywords internal
242+
#' @export
243+
.k_means_fit_clustMixType <- function(x, k, ...) {
244+
res <- clustMixType::kproto(x, k, ...)
245+
new_order <- unique(res$cluster)
246+
res$cluster <- order(new_order)[res$cluster]
247+
res$centers <- res$centers[new_order, , drop = FALSE]
248+
res$withinss <- res$withinss[new_order]
249+
res$dists <- res$dists[, new_order, drop = FALSE]
250+
res$size <- res$size[new_order]
251+
res
252+
}
253+
254+
#' Simple Wrapper around klaR kmeans
255+
#'
256+
#' This wrapper runs `klaR::kmodes()` and reorders the clusters.
257+
#'
258+
#' @inheritParams klaR::kmodes
259+
#' @param ... Other arguments passed to `klaR::kmodes()`
260+
#'
261+
#' @return Result from `klaR::kmodes()`
262+
#' @keywords internal
263+
#' @export
264+
.k_means_fit_klaR <- function(data, modes, ...) {
265+
res <- klaR::kmodes(data, modes, ...)
266+
new_order <- unique(res$cluster)
267+
res$cluster <- order(new_order)[res$cluster]
268+
res$size <- res$size[new_order]
269+
res$modes <- res$modes[new_order, , drop = FALSE]
270+
res$withindiff <- res$withindiff[new_order]
271+
res
217272
}

R/k_means_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ make_k_means <- function() {
163163
interface = "data.frame",
164164
data = c(x = "x"),
165165
protect = c("x", "k", "keep.data"),
166-
func = c(pkg = "clustMixType", fun = "kproto"),
166+
func = c(pkg = "tidyclust", fun = ".k_means_fit_clustMixType"),
167167
defaults = list(keep.data = TRUE, verbose = FALSE)
168168
)
169169
)
@@ -230,7 +230,7 @@ make_k_means <- function() {
230230
interface = "data.frame",
231231
data = c(x = "data"),
232232
protect = c("data", "modes"),
233-
func = c(pkg = "klaR", fun = "kmodes"),
233+
func = c(pkg = "tidyclust", fun = ".k_means_fit_klaR"),
234234
defaults = list()
235235
)
236236
)

R/predict_helpers.R

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,28 @@
1+
make_predictions <- function(x, prefix, n_clusters) {
2+
levels <- seq_len(n_clusters)
3+
factor(x, levels = levels, labels = paste0(prefix, levels))
4+
}
5+
16
.k_means_predict_stats <- function(object, new_data, prefix = "Cluster_") {
2-
res <- object$centers[unique(object$cluster), , drop = FALSE]
7+
res <- object$centers
38
res <- flexclust::dist2(res, new_data)
49
res <- apply(res, 2, which.min)
5-
res <- paste0(prefix, res)
6-
factor(res)
10+
11+
make_predictions(res, prefix, length(object$size))
712
}
813

914
.k_means_predict_ClusterR <- function(object, new_data, prefix = "Cluster_") {
1015
clusters <- predict(object, new_data)
1116
n_clusters <- length(object$obs_per_cluster)
1217

13-
reorder_clusts <- order(union(unique(clusters), seq_len(n_clusters)))
14-
names <- paste0(prefix, seq_len(n_clusters))
15-
res <- names[reorder_clusts][clusters]
16-
17-
factor(res)
18+
make_predictions(clusters, prefix, n_clusters)
1819
}
1920

2021
.k_means_predict_clustMixType <- function(object, new_data, prefix = "Cluster_") {
2122
clusters <- predict(object, new_data)$cluster
2223
n_clusters <- length(object$size)
2324

24-
reorder_clusts <- order(union(unique(clusters), seq_len(n_clusters)))
25-
names <- paste0(prefix, seq_len(n_clusters))
26-
res <- names[reorder_clusts][clusters]
27-
28-
factor(res)
25+
make_predictions(clusters, prefix, n_clusters)
2926
}
3027

3128
.k_means_predict_klaR <- function(object, new_data, prefix = "Cluster_",
@@ -58,9 +55,7 @@
5855
}
5956
}
6057

61-
names <- paste0(prefix, seq_len(n_modes))
62-
63-
factor(names[clusters], levels = names)
58+
make_predictions(clusters, prefix, n_modes)
6459
}
6560

6661
.hier_clust_predict_stats <- function(object, new_data, ..., prefix = "Cluster_") {

man/dot-k_means_fit_ClusterR.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/dot-k_means_fit_clustMixType.Rd

Lines changed: 22 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/dot-k_means_fit_klaR.Rd

Lines changed: 24 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)