Skip to content

Commit 3d6d4f0

Browse files
authored
arbitrary test statistics in calculate() (#542)
1 parent f38e0b1 commit 3d6d4f0

File tree

9 files changed

+334
-14
lines changed

9 files changed

+334
-14
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(calc_impl,"function")
34
S3method(calc_impl,Chisq)
45
S3method(calc_impl,F)
56
S3method(calc_impl,correlation)

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# infer (development version)
22

3+
* Introduced support for arbitrary test statistics in `calculate()`. In addition
4+
to the pre-implemented `calculate(stat)` options, taken as strings, users can
5+
now supply a function defining any scalar-valued test statistic. See
6+
`?calculate()` to learn more.
7+
38
# infer 1.0.9
49

510
* Replaced usage of deprecated functions ahead of a new release of the ggplot2 package (#557).
@@ -14,6 +19,7 @@
1419

1520
* Added missing commas and addressed formatting issues throughout the vignettes and articles. Backticks for package names were removed and missing parentheses for functions were added (@Joscelinrocha).
1621

22+
1723
# infer 1.0.7
1824

1925
* The aliases `p_value()` and `conf_int()`, first deprecated 6 years ago, now

R/calculate.R

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
#'
1313
#' @param x The output from [generate()] for computation-based inference or the
1414
#' output from [hypothesize()] piped in to here for theory-based inference.
15-
#' @param stat A string giving the type of the statistic to calculate. Current
15+
#' @param stat A string giving the type of the statistic to calculate or a
16+
#' function that takes in a replicate of `x` and returns a scalar value. Current
1617
#' options include `"mean"`, `"median"`, `"sum"`, `"sd"`, `"prop"`, `"count"`,
1718
#' `"diff in means"`, `"diff in medians"`, `"diff in props"`, `"Chisq"` (or
1819
#' `"chisq"`), `"F"` (or `"f"`), `"t"`, `"z"`, `"ratio of props"`, `"slope"`,
1920
#' `"odds ratio"`, `"ratio of means"`, and `"correlation"`. `infer` only
2021
#' supports theoretical tests on one or two means via the `"t"` distribution
21-
#' and one or two proportions via the `"z"`.
22+
#' and one or two proportions via the `"z"`. See the "Arbitrary test statistics"
23+
#' section below for more on how to define a custom statistic.
2224
#' @param order A string vector of specifying the order in which the levels of
2325
#' the explanatory variable should be ordered for subtraction (or division
2426
#' for ratio-based statistics), where `order = c("first", "second")` means
@@ -31,6 +33,38 @@
3133
#'
3234
#' @return A tibble containing a `stat` column of calculated statistics.
3335
#'
36+
#' @section Arbitrary test statistics:
37+
#'
38+
#' In addition to the pre-implemented statistics documented in `stat`, users can
39+
#' supply an arbitrary test statistic by supplying a function to the `stat`
40+
#' argument.
41+
#'
42+
#' The function should have arguments `stat(x, order, ...)`, where `x` is one
43+
#' replicate's worth of `x`. The `order` argument and ellipses will be supplied
44+
#' directly to the `stat` function. Internally, `calculate()` will split `x` up
45+
#' into data frames by replicate and pass them one-by-one to the supplied `stat`.
46+
#' For example, to implement `stat = "mean"` as a function, one could write:
47+
#'
48+
#' ```r
49+
#' stat_mean <- function(x, order, ...) {mean(x$hours)}
50+
#' obs_mean <-
51+
#' gss %>%
52+
#' specify(response = hours) %>%
53+
#' calculate(stat = stat_mean)
54+
#'
55+
#' set.seed(1)
56+
#' null_dist_mean <-
57+
#' gss %>%
58+
#' specify(response = hours) %>%
59+
#' hypothesize(null = "point", mu = 40) %>%
60+
#' generate(reps = 5, type = "bootstrap") %>%
61+
#' calculate(stat = stat_mean)
62+
#' ```
63+
#'
64+
#' Note that the same `stat_mean` function is supplied to both `generate()`d and
65+
#' non-`generate()`d infer objects--no need to implement support for grouping
66+
#' by `replicate` yourself.
67+
#'
3468
#' @section Missing levels in small samples:
3569
#' In some cases, when bootstrapping with small samples, some generated
3670
#' bootstrap samples will have only one level of the explanatory variable
@@ -113,22 +147,23 @@ calculate <- function(
113147
) {
114148
check_type(x, tibble::is_tibble)
115149
check_if_mlr(x, "calculate")
116-
stat <- check_calculate_stat(stat)
117-
check_input_vs_stat(x, stat)
118-
check_point_params(x, stat)
150+
stat_chr <- stat_chr(stat)
151+
stat_chr <- check_calculate_stat(stat_chr)
152+
check_input_vs_stat(x, stat_chr)
153+
check_point_params(x, stat_chr)
119154

120-
order <- check_order(x, order, in_calculate = TRUE, stat)
155+
order <- check_order(x, order, in_calculate = TRUE, stat_chr)
121156

122157
if (!is_generated(x)) {
123158
x$replicate <- 1L
124159
}
125160

126-
x <- message_on_excessive_null(x, stat = stat, fn = "calculate")
127-
x <- warn_on_insufficient_null(x, stat, ...)
161+
x <- message_on_excessive_null(x, stat = stat_chr, fn = "calculate")
162+
x <- warn_on_insufficient_null(x, stat_chr, ...)
128163

129164
# Use S3 method to match correct calculation
130165
result <- calc_impl(
131-
structure(stat, class = gsub(" ", "_", stat)),
166+
structure(stat, class = gsub(" ", "_", stat_chr)),
132167
x,
133168
order,
134169
...
@@ -165,8 +200,19 @@ check_if_mlr <- function(x, fn, call = caller_env()) {
165200
}
166201
}
167202

203+
stat_chr <- function(stat) {
204+
if (rlang::is_function(stat)) {
205+
return("function")
206+
}
207+
208+
stat
209+
}
210+
168211
check_calculate_stat <- function(stat, call = caller_env()) {
169212
check_type(stat, rlang::is_string, call = call)
213+
if (identical(stat, "function")) {
214+
return(stat)
215+
}
170216

171217
# Check for possible `stat` aliases
172218
alias_match_id <- match(stat, implemented_stats_aliases[["alias"]])
@@ -198,6 +244,10 @@ check_input_vs_stat <- function(x, stat, call = caller_env()) {
198244
)
199245
}
200246

247+
if (identical(stat, "function")) {
248+
return(x)
249+
}
250+
201251
if (!stat %in% possible_stats) {
202252
if (has_explanatory(x)) {
203253
msg_tail <- glue(
@@ -276,7 +326,7 @@ warn_on_insufficient_null <- function(x, stat, ...) {
276326
if (
277327
!is_hypothesized(x) &&
278328
!has_explanatory(x) &&
279-
!stat %in% untheorized_stats &&
329+
!stat %in% c(untheorized_stats, "function") &&
280330
!(stat == "t" && "mu" %in% names(list(...)))
281331
) {
282332
attr(x, "null") <- "point"
@@ -660,3 +710,38 @@ calc_impl.z <- function(type, x, order, ...) {
660710
df_out
661711
}
662712
}
713+
714+
#' @export
715+
calc_impl.function <- function(type, x, order, ..., call = rlang::caller_env()) {
716+
rlang::try_fetch(
717+
{
718+
if (!identical(dplyr::group_vars(x), "replicate")) {
719+
x <- dplyr::group_by(x, replicate)
720+
}
721+
x_by_replicate <- dplyr::group_split(x)
722+
res <- purrr::map(x_by_replicate, ~type(.x, order, ...))
723+
},
724+
error = function(cnd) {rethrow_stat_cnd(cnd, call = call)},
725+
warning = function(cnd) {rethrow_stat_cnd(cnd, call = call)}
726+
)
727+
728+
if (!rlang::is_scalar_atomic(res[[1]])) {
729+
cli::cli_abort(
730+
c(
731+
"The supplied {.arg stat} function must return a scalar value.",
732+
"i" = "It returned {.obj_type_friendly {res[[1]]}}."
733+
),
734+
call = call
735+
)
736+
}
737+
738+
tibble::new_tibble(list(stat = unlist(res)))
739+
}
740+
741+
rethrow_stat_cnd <- function(cnd, call = call) {
742+
cli::cli_abort(
743+
"The supplied {.arg stat} function encountered an issue.",
744+
parent = cnd,
745+
call = call
746+
)
747+
}

R/observe.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#'
1616
#' @return A 1-column tibble containing the calculated statistic `stat`.
1717
#'
18+
#' @inheritSection calculate Arbitrary test statistics
19+
#'
1820
#' @examples
1921
#' # calculating the observed mean number of hours worked per week
2022
#' gss |>

man/calculate.Rd

Lines changed: 37 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/observe.Rd

Lines changed: 37 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/calculate.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,49 @@
519519
! Multiple explanatory variables are not supported in `calculate()`.
520520
i When working with multiple explanatory variables, use `fit()` (`?infer::fit.infer()`) instead.
521521

522+
# arbitrary test statistic works
523+
524+
Code
525+
calculate(specify(gss, response = hours), stat = function(x, ...) {
526+
mean(x$hour)
527+
})
528+
Condition
529+
Error in `calculate()`:
530+
! The supplied `stat` function encountered an issue.
531+
Caused by warning:
532+
! Unknown or uninitialised column: `hour`.
533+
534+
---
535+
536+
Code
537+
calculate(specify(gss, response = hours), stat = function(x, ...) {
538+
mean("hey there")
539+
})
540+
Condition
541+
Error in `calculate()`:
542+
! The supplied `stat` function encountered an issue.
543+
Caused by warning in `mean.default()`:
544+
! argument is not numeric or logical: returning NA
545+
546+
---
547+
548+
Code
549+
calculate(specify(gss, response = hours), stat = function(x, ...) {
550+
data.frame(woops = mean(x$hours))
551+
})
552+
Condition
553+
Error in `calculate()`:
554+
! The supplied `stat` function must return a scalar value.
555+
i It returned a data frame.
556+
557+
---
558+
559+
Code
560+
calculate(specify(gss, response = hours), stat = function(x, ...) {
561+
identity
562+
})
563+
Condition
564+
Error in `calculate()`:
565+
! The supplied `stat` function must return a scalar value.
566+
i It returned a function.
567+

0 commit comments

Comments
 (0)