|
12 | 12 | #' |
13 | 13 | #' @param x The output from [generate()] for computation-based inference or the |
14 | 14 | #' output from [hypothesize()] piped in to here for theory-based inference. |
15 | | -#' @param stat A string giving the type of the statistic to calculate. Current |
| 15 | +#' @param stat A string giving the type of the statistic to calculate or a |
| 16 | +#' function that takes in a replicate of `x` and returns a scalar value. Current |
16 | 17 | #' options include `"mean"`, `"median"`, `"sum"`, `"sd"`, `"prop"`, `"count"`, |
17 | 18 | #' `"diff in means"`, `"diff in medians"`, `"diff in props"`, `"Chisq"` (or |
18 | 19 | #' `"chisq"`), `"F"` (or `"f"`), `"t"`, `"z"`, `"ratio of props"`, `"slope"`, |
19 | 20 | #' `"odds ratio"`, `"ratio of means"`, and `"correlation"`. `infer` only |
20 | 21 | #' supports theoretical tests on one or two means via the `"t"` distribution |
21 | | -#' and one or two proportions via the `"z"`. |
| 22 | +#' and one or two proportions via the `"z"`. See the "Arbitrary test statistics" |
| 23 | +#' section below for more on how to define a custom statistic. |
22 | 24 | #' @param order A string vector of specifying the order in which the levels of |
23 | 25 | #' the explanatory variable should be ordered for subtraction (or division |
24 | 26 | #' for ratio-based statistics), where `order = c("first", "second")` means |
|
31 | 33 | #' |
32 | 34 | #' @return A tibble containing a `stat` column of calculated statistics. |
33 | 35 | #' |
| 36 | +#' @section Arbitrary test statistics: |
| 37 | +#' |
| 38 | +#' In addition to the pre-implemented statistics documented in `stat`, users can |
| 39 | +#' supply an arbitrary test statistic by supplying a function to the `stat` |
| 40 | +#' argument. |
| 41 | +#' |
| 42 | +#' The function should have arguments `stat(x, order, ...)`, where `x` is one |
| 43 | +#' replicate's worth of `x`. The `order` argument and ellipses will be supplied |
| 44 | +#' directly to the `stat` function. Internally, `calculate()` will split `x` up |
| 45 | +#' into data frames by replicate and pass them one-by-one to the supplied `stat`. |
| 46 | +#' For example, to implement `stat = "mean"` as a function, one could write: |
| 47 | +#' |
| 48 | +#' ```r |
| 49 | +#' stat_mean <- function(x, order, ...) {mean(x$hours)} |
| 50 | +#' obs_mean <- |
| 51 | +#' gss %>% |
| 52 | +#' specify(response = hours) %>% |
| 53 | +#' calculate(stat = stat_mean) |
| 54 | +#' |
| 55 | +#' set.seed(1) |
| 56 | +#' null_dist_mean <- |
| 57 | +#' gss %>% |
| 58 | +#' specify(response = hours) %>% |
| 59 | +#' hypothesize(null = "point", mu = 40) %>% |
| 60 | +#' generate(reps = 5, type = "bootstrap") %>% |
| 61 | +#' calculate(stat = stat_mean) |
| 62 | +#' ``` |
| 63 | +#' |
| 64 | +#' Note that the same `stat_mean` function is supplied to both `generate()`d and |
| 65 | +#' non-`generate()`d infer objects--no need to implement support for grouping |
| 66 | +#' by `replicate` yourself. |
| 67 | +#' |
34 | 68 | #' @section Missing levels in small samples: |
35 | 69 | #' In some cases, when bootstrapping with small samples, some generated |
36 | 70 | #' bootstrap samples will have only one level of the explanatory variable |
@@ -113,22 +147,23 @@ calculate <- function( |
113 | 147 | ) { |
114 | 148 | check_type(x, tibble::is_tibble) |
115 | 149 | check_if_mlr(x, "calculate") |
116 | | - stat <- check_calculate_stat(stat) |
117 | | - check_input_vs_stat(x, stat) |
118 | | - check_point_params(x, stat) |
| 150 | + stat_chr <- stat_chr(stat) |
| 151 | + stat_chr <- check_calculate_stat(stat_chr) |
| 152 | + check_input_vs_stat(x, stat_chr) |
| 153 | + check_point_params(x, stat_chr) |
119 | 154 |
|
120 | | - order <- check_order(x, order, in_calculate = TRUE, stat) |
| 155 | + order <- check_order(x, order, in_calculate = TRUE, stat_chr) |
121 | 156 |
|
122 | 157 | if (!is_generated(x)) { |
123 | 158 | x$replicate <- 1L |
124 | 159 | } |
125 | 160 |
|
126 | | - x <- message_on_excessive_null(x, stat = stat, fn = "calculate") |
127 | | - x <- warn_on_insufficient_null(x, stat, ...) |
| 161 | + x <- message_on_excessive_null(x, stat = stat_chr, fn = "calculate") |
| 162 | + x <- warn_on_insufficient_null(x, stat_chr, ...) |
128 | 163 |
|
129 | 164 | # Use S3 method to match correct calculation |
130 | 165 | result <- calc_impl( |
131 | | - structure(stat, class = gsub(" ", "_", stat)), |
| 166 | + structure(stat, class = gsub(" ", "_", stat_chr)), |
132 | 167 | x, |
133 | 168 | order, |
134 | 169 | ... |
@@ -165,8 +200,19 @@ check_if_mlr <- function(x, fn, call = caller_env()) { |
165 | 200 | } |
166 | 201 | } |
167 | 202 |
|
| 203 | +stat_chr <- function(stat) { |
| 204 | + if (rlang::is_function(stat)) { |
| 205 | + return("function") |
| 206 | + } |
| 207 | + |
| 208 | + stat |
| 209 | +} |
| 210 | + |
168 | 211 | check_calculate_stat <- function(stat, call = caller_env()) { |
169 | 212 | check_type(stat, rlang::is_string, call = call) |
| 213 | + if (identical(stat, "function")) { |
| 214 | + return(stat) |
| 215 | + } |
170 | 216 |
|
171 | 217 | # Check for possible `stat` aliases |
172 | 218 | alias_match_id <- match(stat, implemented_stats_aliases[["alias"]]) |
@@ -198,6 +244,10 @@ check_input_vs_stat <- function(x, stat, call = caller_env()) { |
198 | 244 | ) |
199 | 245 | } |
200 | 246 |
|
| 247 | + if (identical(stat, "function")) { |
| 248 | + return(x) |
| 249 | + } |
| 250 | + |
201 | 251 | if (!stat %in% possible_stats) { |
202 | 252 | if (has_explanatory(x)) { |
203 | 253 | msg_tail <- glue( |
@@ -276,7 +326,7 @@ warn_on_insufficient_null <- function(x, stat, ...) { |
276 | 326 | if ( |
277 | 327 | !is_hypothesized(x) && |
278 | 328 | !has_explanatory(x) && |
279 | | - !stat %in% untheorized_stats && |
| 329 | + !stat %in% c(untheorized_stats, "function") && |
280 | 330 | !(stat == "t" && "mu" %in% names(list(...))) |
281 | 331 | ) { |
282 | 332 | attr(x, "null") <- "point" |
@@ -660,3 +710,38 @@ calc_impl.z <- function(type, x, order, ...) { |
660 | 710 | df_out |
661 | 711 | } |
662 | 712 | } |
| 713 | + |
| 714 | +#' @export |
| 715 | +calc_impl.function <- function(type, x, order, ..., call = rlang::caller_env()) { |
| 716 | + rlang::try_fetch( |
| 717 | + { |
| 718 | + if (!identical(dplyr::group_vars(x), "replicate")) { |
| 719 | + x <- dplyr::group_by(x, replicate) |
| 720 | + } |
| 721 | + x_by_replicate <- dplyr::group_split(x) |
| 722 | + res <- purrr::map(x_by_replicate, ~type(.x, order, ...)) |
| 723 | + }, |
| 724 | + error = function(cnd) {rethrow_stat_cnd(cnd, call = call)}, |
| 725 | + warning = function(cnd) {rethrow_stat_cnd(cnd, call = call)} |
| 726 | + ) |
| 727 | + |
| 728 | + if (!rlang::is_scalar_atomic(res[[1]])) { |
| 729 | + cli::cli_abort( |
| 730 | + c( |
| 731 | + "The supplied {.arg stat} function must return a scalar value.", |
| 732 | + "i" = "It returned {.obj_type_friendly {res[[1]]}}." |
| 733 | + ), |
| 734 | + call = call |
| 735 | + ) |
| 736 | + } |
| 737 | + |
| 738 | + tibble::new_tibble(list(stat = unlist(res))) |
| 739 | +} |
| 740 | + |
| 741 | +rethrow_stat_cnd <- function(cnd, call = call) { |
| 742 | + cli::cli_abort( |
| 743 | + "The supplied {.arg stat} function encountered an issue.", |
| 744 | + parent = cnd, |
| 745 | + call = call |
| 746 | + ) |
| 747 | +} |
0 commit comments