Skip to content

Commit 3187b81

Browse files
Merge pull request #284 from tidymodels/sparse-tf
add sparse arg to step_tf()
2 parents 4e50607 + 74adf63 commit 3187b81

File tree

6 files changed

+220
-13
lines changed

6 files changed

+220
-13
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
S3method(.recipes_estimate_sparsity,step_dummy_hash)
44
S3method(.recipes_estimate_sparsity,step_texthash)
5+
S3method(.recipes_estimate_sparsity,step_tf)
56
S3method(bake,step_clean_levels)
67
S3method(bake,step_clean_names)
78
S3method(bake,step_dummy_hash)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* `step_texthash()` and `step_dummy_hash()` gained `sparse` argument. When set to `"yes"`, `step_dummy()` will produce sparse vectors. (#282)
44

5+
* `step_tf()` gained `sparse` argument. When set to `"yes"`, `step_dummy()` will produce sparse vectors. (#284)
6+
57
# textrecipes 1.0.7
68

79
## Improvements

R/tf.R

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#' be stored here once this preprocessing step has be trained by
2121
#' [recipes::prep.recipe()].
2222
#' @template args-prefix
23+
#' @template args-sparse
2324
#' @template args-keep_original_cols
2425
#' @template args-skip
2526
#' @template args-id
@@ -69,6 +70,12 @@
6970
#' cat(result)
7071
#' ```
7172
#'
73+
#' @template sparse-creation
74+
#'
75+
#' @description
76+
#' `sparse = "yes"` doesn't take effect when
77+
#' `weight_scheme = "double normalization"` as it doesn't produce sparse data.
78+
#'
7279
#' @template case-weights-not-supported
7380
#'
7481
#' @seealso [step_tokenize()] to turn characters into [`tokens`][tokenlist()]
@@ -106,6 +113,7 @@ step_tf <-
106113
vocabulary = NULL,
107114
res = NULL,
108115
prefix = "tf",
116+
sparse = "auto",
109117
keep_original_cols = FALSE,
110118
skip = FALSE,
111119
id = rand_id("tf")
@@ -122,6 +130,7 @@ step_tf <-
122130
weight = weight,
123131
vocabulary = vocabulary,
124132
prefix = prefix,
133+
sparse = sparse,
125134
keep_original_cols = keep_original_cols,
126135
skip = skip,
127136
id = id
@@ -148,6 +157,7 @@ step_tf_new <-
148157
vocabulary,
149158
res,
150159
prefix,
160+
sparse,
151161
keep_original_cols,
152162
skip,
153163
id
@@ -163,6 +173,7 @@ step_tf_new <-
163173
vocabulary = vocabulary,
164174
res = res,
165175
prefix = prefix,
176+
sparse = sparse,
166177
keep_original_cols = keep_original_cols,
167178
skip = skip,
168179
id = id
@@ -197,6 +208,7 @@ prep.step_tf <- function(x, training, info = NULL, ...) {
197208
vocabulary = x$vocabulary,
198209
res = token_list,
199210
prefix = x$prefix,
211+
sparse = x$sparse,
200212
keep_original_cols = get_keep_original_cols(x),
201213
skip = x$skip,
202214
id = x$id
@@ -219,11 +231,16 @@ bake.step_tf <- function(object, new_data, ...) {
219231
object$res[[col_name]],
220232
paste0(object$prefix, "_", col_name),
221233
object$weight_scheme,
222-
object$weight
234+
object$weight,
235+
object$sparse
223236
)
224237

225238
if (object$weight_scheme %in% c("binary", "raw count")) {
226-
tf_text <- purrr::map_dfc(tf_text, as.integer)
239+
if (sparse_is_yes(object$sparse)) {
240+
tf_text <- purrr::map_dfc(tf_text, sparsevctrs::as_sparse_integer)
241+
} else {
242+
tf_text <- purrr::map_dfc(tf_text, as.integer)
243+
}
227244
}
228245

229246
tf_text <- recipes::check_name(tf_text, new_data, object, names(tf_text))
@@ -264,12 +281,21 @@ tidy.step_tf <- function(x, ...) {
264281
res
265282
}
266283

267-
tf_function <- function(data, names, labels, weights, weight) {
268-
counts <- as.matrix(tokenlist_to_dtm(data, names))
284+
tf_function <- function(data, names, labels, weights, weight, sparse) {
285+
counts <- tokenlist_to_dtm(data, names)
286+
287+
if (weights == "double normalization" || !sparse_is_yes(sparse)) {
288+
counts <- as.matrix(counts)
289+
out <- tf_weight(counts, weights, weight)
290+
colnames(out) <- paste0(labels, "_", names)
291+
out <- as_tibble(out)
292+
} else {
293+
counts <- sparsevctrs::coerce_to_sparse_tibble(counts)
294+
out <- tf_weight_sparse(counts, weights)
295+
colnames(out) <- paste0(labels, "_", names)
296+
}
269297

270-
tf <- tf_weight(counts, weights, weight)
271-
colnames(tf) <- paste0(labels, "_", names)
272-
as_tibble(tf)
298+
out
273299
}
274300

275301
tf_weight <- function(x, scheme, weight) {
@@ -294,6 +320,49 @@ tf_weight <- function(x, scheme, weight) {
294320
}
295321
}
296322

323+
tf_weight_sparse <- function(x, scheme) {
324+
if (scheme == "binary") {
325+
res <- lapply(x, function(x) {
326+
positions <- sparsevctrs::sparse_positions(x)
327+
len <- length(x)
328+
329+
sparsevctrs::sparse_integer(rep(1, length(positions)), positions, len)
330+
})
331+
332+
res <- tibble::new_tibble(res)
333+
return(res)
334+
}
335+
if (scheme == "raw count") {
336+
return(x)
337+
}
338+
if (scheme == "term frequency") {
339+
x <- sparsevctrs::coerce_to_sparse_matrix(x)
340+
rowsums_x <- Matrix::rowSums(x)
341+
res <- x / rowsums_x
342+
if (any(rowsums_x == 0)) {
343+
res[rowsums_x == 0, ] <- 0
344+
}
345+
res <- sparsevctrs::coerce_to_sparse_tibble(res)
346+
return(res)
347+
}
348+
if (scheme == "log normalization") {
349+
res <- lapply(x, function(x) {
350+
values <- sparsevctrs::sparse_values(x)
351+
positions <- sparsevctrs::sparse_positions(x)
352+
len <- length(x)
353+
354+
sparsevctrs::sparse_double(
355+
log(1 + values),
356+
positions,
357+
len
358+
)
359+
})
360+
361+
res <- tibble::new_tibble(res)
362+
return(res)
363+
}
364+
}
365+
297366
#' @rdname required_pkgs.step
298367
#' @export
299368
required_pkgs.step_tf <- function(x, ...) {
@@ -314,3 +383,21 @@ tunable.step_tf <- function(x, ...) {
314383
component_id = x$id
315384
)
316385
}
386+
387+
#' @export
388+
.recipes_estimate_sparsity.step_tf <- function(x, data, ...) {
389+
get_levels <- function(col) {
390+
n_chars <- nchar(col[seq(1, min(10, length(col)))])
391+
392+
floor(mean(n_chars))
393+
}
394+
395+
n_levels <- lapply(data, get_levels)
396+
397+
lapply(n_levels, function(n_lvl) {
398+
c(
399+
n_cols = n_lvl,
400+
sparsity = 1 - 1 / n_lvl
401+
)
402+
})
403+
}

R/tokenlist.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,19 +195,19 @@ tokenlist_filter_function <- function(x, fn) {
195195

196196
keeps <- lapply(tokens, fn)
197197

198-
out <- purrr::map2(tokens, keeps, ~.x[.y])
198+
out <- purrr::map2(tokens, keeps, ~ .x[.y])
199199

200200
lemma <- maybe_get_lemma(x)
201201
if (!is.null(lemma)) {
202-
lemma <- purrr::map2(lemma, keeps, ~.x[.y])
202+
lemma <- purrr::map2(lemma, keeps, ~ .x[.y])
203203
names(lemma) <- NULL
204204
} else {
205205
lemma <- NULL
206206
}
207207

208208
pos <- maybe_get_pos(x)
209209
if (!is.null(pos)) {
210-
pos <- purrr::map2(pos, keeps, ~.x[.y])
210+
pos <- purrr::map2(pos, keeps, ~ .x[.y])
211211
names(pos) <- NULL
212212
} else {
213213
pos <- NULL

man/step_tf.Rd

Lines changed: 24 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-tf.R

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ test_that("check_name() is used", {
134134
test_that("tunable", {
135135
rec <-
136136
recipe(~., data = mtcars) %>%
137-
step_tf(all_predictors())
137+
step_tf(all_predictors())
138138
rec_param <- tunable.step_tf(rec$steps[[1]])
139139
expect_equal(rec_param$name, c("weight_scheme", "weight"))
140140
expect_true(all(rec_param$source == "recipe"))
@@ -173,6 +173,101 @@ test_that("bad args", {
173173
)
174174
})
175175

176+
test_that("sparse = 'yes' works", {
177+
rec <- recipe(~., data = test_data)
178+
179+
dense <- rec %>%
180+
step_tokenize(text) %>%
181+
step_tf(text, weight_scheme = "raw count", sparse = "no") %>%
182+
prep() %>%
183+
bake(NULL)
184+
sparse <- rec %>%
185+
step_tokenize(text) %>%
186+
step_tf(text, weight_scheme = "raw count", sparse = "yes") %>%
187+
prep() %>%
188+
bake(NULL)
189+
190+
expect_identical(dense, sparse)
191+
192+
expect_false(any(vapply(dense, sparsevctrs::is_sparse_integer, logical(1))))
193+
expect_true(all(vapply(sparse, sparsevctrs::is_sparse_integer, logical(1))))
194+
195+
dense <- rec %>%
196+
step_tokenize(text) %>%
197+
step_tf(text, weight_scheme = "binary", sparse = "no") %>%
198+
prep() %>%
199+
bake(NULL)
200+
sparse <- rec %>%
201+
step_tokenize(text) %>%
202+
step_tf(text, weight_scheme = "binary", sparse = "yes") %>%
203+
prep() %>%
204+
bake(NULL)
205+
206+
expect_identical(dense, sparse)
207+
208+
expect_false(any(vapply(dense, sparsevctrs::is_sparse_integer, logical(1))))
209+
expect_true(all(vapply(sparse, sparsevctrs::is_sparse_integer, logical(1))))
210+
211+
dense <- rec %>%
212+
step_tokenize(text) %>%
213+
step_tf(text, weight_scheme = "term frequency", sparse = "no") %>%
214+
prep() %>%
215+
bake(NULL)
216+
sparse <- rec %>%
217+
step_tokenize(text) %>%
218+
step_tf(text, weight_scheme = "term frequency", sparse = "yes") %>%
219+
prep() %>%
220+
bake(NULL)
221+
222+
expect_identical(dense, sparse)
223+
224+
expect_false(any(vapply(dense, sparsevctrs::is_sparse_double, logical(1))))
225+
expect_true(all(vapply(sparse, sparsevctrs::is_sparse_double, logical(1))))
226+
227+
dense <- rec %>%
228+
step_tokenize(text) %>%
229+
step_tf(text, weight_scheme = "log normalization", sparse = "no") %>%
230+
prep() %>%
231+
bake(NULL)
232+
sparse <- rec %>%
233+
step_tokenize(text) %>%
234+
step_tf(text, weight_scheme = "log normalization", sparse = "yes") %>%
235+
prep() %>%
236+
bake(NULL)
237+
238+
expect_identical(dense, sparse)
239+
240+
expect_false(any(vapply(dense, sparsevctrs::is_sparse_double, logical(1))))
241+
expect_true(all(vapply(sparse, sparsevctrs::is_sparse_double, logical(1))))
242+
})
243+
244+
test_that("sparse argument is backwards compatible", {
245+
rec <- recipe(~., data = test_data) %>%
246+
step_tokenize(text) %>%
247+
step_tf(text, sparse = "no") %>%
248+
prep()
249+
250+
exp <- bake(rec, test_data)
251+
252+
# Simulate old recipe
253+
rec$steps[[1]]$sparse <- NULL
254+
255+
expect_identical(
256+
bake(rec, test_data),
257+
exp
258+
)
259+
})
260+
261+
test_that(".recipes_toggle_sparse_args works", {
262+
rec <- recipe(~., data = test_data) %>%
263+
step_tokenize(text) %>%
264+
step_tf(text, sparse = "auto")
265+
266+
exp <- rec %>% prep() %>% bake(NULL) %>% sparsevctrs::sparsity()
267+
268+
expect_true(.recipes_estimate_sparsity(rec) >= exp)
269+
})
270+
176271
# Infrastructure ---------------------------------------------------------------
177272

178273
test_that("bake method errors when needed non-standard role columns are missing", {

0 commit comments

Comments
 (0)