Skip to content

Commit 15fbc48

Browse files
authored
Merge pull request #101 from DoubleML/m-cleanup-tests
Refactoring, cleanup and extension of the unit test framework
2 parents fb97d07 + f908aa3 commit 15fbc48

File tree

76 files changed

+4619
-4311
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+4619
-4311
lines changed

R/double_ml.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,12 +737,12 @@ DoubleML = R6Class("DoubleML",
737737
#' a `matrix()` with adjusted p_values.
738738
p_adjust = function(method = "romano-wolf", return_matrix = TRUE) {
739739
if (all(is.na(self$coef))) {
740-
stop("apply fit() before p_adust().")
740+
stop("apply fit() before p_adjust().")
741741
}
742742

743743
if (tolower(method) %in% c("rw", "romano-wolf")) {
744744
if (is.null(self$boot_t_stat) | all(is.na(self$coef))) {
745-
stop("apply fit() & bootstrap() before p_adust().")
745+
stop("apply fit() & bootstrap() before p_adjust().")
746746
}
747747
k = self$data$n_treat
748748
pinit = p_val_corrected = vector(mode = "numeric", length = k)

R/double_ml_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,12 @@ DoubleMLData = R6Class("DoubleMLData",
245245
if (any(z_cols %in% d_cols)) {
246246
stop(paste(
247247
"At least one variable/column is set as treatment",
248-
"variable ('d_cols') and instrumental variable in 'z_cols')."))
248+
"variable ('d_cols') and instrumental variable in 'z_cols'."))
249249
}
250250
if (any(z_cols %in% x_cols)) {
251251
stop(paste(
252252
"At least one variable/column is set as covariate ('x_cols')",
253-
"and instrumental variable in 'z_cols')."))
253+
"and instrumental variable in 'z_cols'."))
254254
}
255255
}
256256
}

R/double_ml_iivm.R

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -463,24 +463,35 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
463463
},
464464
check_data = function(obj_dml_data) {
465465
one_treat = (obj_dml_data$n_treat == 1)
466-
binary_treat = test_integerish(obj_dml_data$data[[obj_dml_data$d_cols]],
467-
lower = 0, upper = 1)
468-
if (!(one_treat & binary_treat)) {
469-
stop(paste(
470-
"Incompatible data.\n",
471-
"To fit an IIVM model with DoubleML",
472-
"exactly one binary variable with values 0 and 1",
473-
"needs to be specified as treatment variable."))
466+
err_msg = paste(
467+
"Incompatible data.\n",
468+
"To fit an IIVM model with DoubleML",
469+
"exactly one binary variable with values 0 and 1",
470+
"needs to be specified as treatment variable.")
471+
if (one_treat) {
472+
binary_treat = test_integerish(obj_dml_data$data[[obj_dml_data$d_cols]],
473+
lower = 0, upper = 1)
474+
if (!(one_treat & binary_treat)) {
475+
stop(err_msg)
476+
}
477+
} else {
478+
stop(err_msg)
474479
}
480+
475481
one_instr = (obj_dml_data$n_instr == 1)
476-
binary_instr = test_integerish(obj_dml_data$data[[obj_dml_data$z_cols]],
477-
lower = 0, upper = 1)
478-
if (!(one_instr & binary_instr)) {
479-
stop(paste(
480-
"Incompatible data.\n",
481-
"To fit an IIVM model with DoubleML",
482-
"exactly one binary variable with values 0 and 1",
483-
"needs to be specified as instrumental variable."))
482+
err_msg = paste(
483+
"Incompatible data.\n",
484+
"To fit an IIVM model with DoubleML",
485+
"exactly one binary variable with values 0 and 1",
486+
"needs to be specified as instrumental variable.")
487+
if (one_instr) {
488+
binary_instr = test_integerish(obj_dml_data$data[[obj_dml_data$z_cols]],
489+
lower = 0, upper = 1)
490+
if (!(one_instr & binary_instr)) {
491+
stop(err_msg)
492+
}
493+
} else {
494+
stop(err_msg)
484495
}
485496
return()
486497
}

R/double_ml_irm.R

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
351351
check_class(score, "function"))
352352
if (is.character(score)) {
353353
valid_score = c("ATE", "ATTE")
354-
if (!(score %in% valid_score)) {
355-
assertChoice(score, valid_score)
356-
}
354+
assertChoice(score, valid_score)
357355
}
358356
return()
359357
},
@@ -362,18 +360,23 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
362360
stop(paste(
363361
"Incompatible data.\n", paste(obj_dml_data$z_cols, collapse = ", "),
364362
"has been set as instrumental variable(s).\n",
365-
"To fit an interactive IV regression model use DoubleMLIIVM
366-
instead of DoubleMLIRM."))
363+
"To fit an interactive IV regression model use DoubleMLIIVM",
364+
"instead of DoubleMLIRM."))
367365
}
368366
one_treat = (obj_dml_data$n_treat == 1)
369-
binary_treat = test_integerish(obj_dml_data$data[[obj_dml_data$d_cols]],
370-
lower = 0, upper = 1)
371-
if (!(one_treat & binary_treat)) {
372-
stop(paste(
373-
"Incompatible data.\n",
374-
"To fit an IRM model with DoubleML",
375-
"exactly one binary variable with values 0 and 1",
376-
"needs to be specified as treatment variable."))
367+
err_msg = paste(
368+
"Incompatible data.\n",
369+
"To fit an IRM model with DoubleML",
370+
"exactly one binary variable with values 0 and 1",
371+
"needs to be specified as treatment variable.")
372+
if (one_treat) {
373+
binary_treat = test_integerish(obj_dml_data$data[[obj_dml_data$d_cols]],
374+
lower = 0, upper = 1)
375+
if (!(one_treat & binary_treat)) {
376+
stop(err_msg)
377+
}
378+
} else {
379+
stop(err_msg)
377380
}
378381
return()
379382
}

R/double_ml_pliv.R

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
348348
psi_b = psi_b)
349349
} else if (is.function(self$score)) {
350350
if (self$data$n_instr > 1) {
351-
stop("Callable score not implemented for DoubleMLPLIV with
352-
partialX=TRUE and partialZ=FALSE with several instruments")
351+
stop(paste("Callable score not implemented for DoubleMLPLIV with",
352+
"partialX=TRUE and partialZ=FALSE with several instruments."))
353353
}
354354
psis = self$score(y, z, d, g_hat, m_hat, r_hat, smpls)
355355
}
@@ -415,8 +415,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
415415
psi_a = psi_a,
416416
psi_b = psi_b)
417417
} else if (is.function(self$score)) {
418-
stop("Callable score not implemented for DoubleMLPLIV
419-
with partialX=TRUE and partialZ=TRUE.")
418+
stop(paste("Callable score not implemented for DoubleMLPLIV",
419+
"with partialX=TRUE and partialZ=TRUE."))
420420
# res = self$score(y, d, g_hat, m_hat, m_hat_tilde)
421421
}
422422
res$preds = list(
@@ -454,8 +454,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
454454
}
455455
res = list(psi_a = psi_a, psi_b = psi_b)
456456
} else if (is.function(self$score)) {
457-
stop("Callable score not implemented for DoubleMLPLIV
458-
with partialX=FALSE and partialZ=TRUE.")
457+
stop(paste("Callable score not implemented for DoubleMLPLIV",
458+
"with partialX=FALSE and partialZ=TRUE."))
459459
# res = self$score(y, z, d, r_hat)
460460
}
461461
res$preds = list("ml_r" = r_hat)
@@ -674,6 +674,13 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
674674
return()
675675
},
676676
check_data = function(obj_dml_data) {
677+
if (obj_dml_data$n_instr == 0) {
678+
stop(paste(
679+
"Incompatible data.\n",
680+
"At least one variable must be set as instrumental variable.\n",
681+
"To fit a partially linear regression model without instrumental",
682+
"variable(s) use DoubleMLPLR instead of DoubleMLPLIV."))
683+
}
677684
return()
678685
}
679686
)

R/double_ml_plr.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
270270
stop(paste(
271271
"Incompatible data.\n", paste(obj_dml_data$z_cols, collapse = ", "),
272272
"has been set as instrumental variable(s).\n",
273-
"To fit a partially linear IV regression model use
274-
DoubleMLPLIV instead of DoubleMLPLR."))
273+
"To fit a partially linear IV regression model use",
274+
"DoubleMLPLIV instead of DoubleMLPLR."))
275275
}
276276
return()
277277
}

tests/testthat/helper-01-dgp.R

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
se_repeated = function(se_s, coefficients, theta_s) {
2+
se = sqrt(stats::median(se_s^2 + (theta_s - coefficients)^2))
3+
return(se)
4+
}
5+
6+
7+
sample_splitting = function(k, data) {
8+
9+
resampling = mlr3::ResamplingCV$new()
10+
resampling$param_set$values$folds = k
11+
12+
dummy_task = mlr3::Task$new("dummy_resampling", "regr", data)
13+
resampling = resampling$instantiate(dummy_task)
14+
15+
n_iters = resampling$iters
16+
train_ids = lapply(1:n_iters, function(x) resampling$train_set(x))
17+
test_ids = lapply(1:n_iters, function(x) resampling$test_set(x))
18+
19+
return(list(train_ids = train_ids, test_ids = test_ids))
20+
}
21+
22+
23+
draw_bootstrap_weights = function(bootstrap, n_rep_boot, n_obs) {
24+
if (bootstrap == "Bayes") {
25+
weights = stats::rexp(n_rep_boot * n_obs, rate = 1) - 1
26+
} else if (bootstrap == "normal") {
27+
weights = stats::rnorm(n_rep_boot * n_obs)
28+
} else if (bootstrap == "wild") {
29+
weights = stats::rnorm(n_rep_boot * n_obs) / sqrt(2) + (stats::rnorm(n_rep_boot * n_obs)^2 - 1) / 2
30+
} else {
31+
stop("invalid boot method")
32+
}
33+
weights = matrix(weights, nrow = n_rep_boot, ncol = n_obs, byrow = TRUE)
34+
35+
return(weights)
36+
}
37+
38+
39+
functional_bootstrap = function(theta, se, psi, psi_a, k, smpls,
40+
n_rep_boot, weights) {
41+
score = psi
42+
J = mean(psi_a)
43+
boot_coef = matrix(NA, nrow = 1, ncol = n_rep_boot)
44+
boot_t_stat = matrix(NA, nrow = 1, ncol = n_rep_boot)
45+
for (i in seq(n_rep_boot)) {
46+
boot_coef[1, i] = mean(weights[i, ] * 1 / J * score)
47+
boot_t_stat[1, i] = boot_coef[1, i] / se
48+
}
49+
50+
res = list(boot_coef = boot_coef, boot_t_stat = boot_t_stat)
51+
return(res)
52+
}
53+
54+
trim_vec = function(values, trimming_threshold) {
55+
if (trimming_threshold > 0) {
56+
values[values < trimming_threshold] = trimming_threshold
57+
values[values > 1 - trimming_threshold] = 1 - trimming_threshold
58+
}
59+
return(values)
60+
}

tests/testthat/helper-02-simdata.R

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/testthat/helper-03-dgp.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,41 @@ dgp1_toeplitz = function(n, p, betamax = 4, decay = 0.99, threshold = 0, noiseva
132132

133133
return(data)
134134
}
135+
136+
make_data_pliv_partialZ = function(n_obs, alpha = 1, dim_x = 5, dim_z = 150) {
137+
sigma_e_u = matrix(c(1, 0.6, 0.6, 1), ncol = 2)
138+
mu_e_u = rep(0, 2)
139+
e_u = mvtnorm::rmvnorm(n = n_obs, mean = mu_e_u, sigma = sigma_e_u)
140+
epsilon = e_u[, 1]
141+
u = e_u[, 2]
142+
143+
sigma_x = toeplitz(0.5^(0:(dim_x - 1)))
144+
mu_x = rep(0, dim_x)
145+
x = mvtnorm::rmvnorm(n = n_obs, mean = mu_x, sigma = sigma_x)
146+
147+
I_z = diag(x = 1, ncol = dim_z, nrow = dim_z)
148+
mu_xi = rep(0, dim_z)
149+
xi = mvtnorm::rmvnorm(n = n_obs, mean = mu_xi, sigma = 0.25 * I_z)
150+
151+
beta = 1 / (1:dim_x)^2
152+
gamma = beta
153+
delta = 1 / (1:dim_z)^2
154+
155+
zeros = matrix(0, nrow = dim_x, ncol = (dim_z - dim_x))
156+
I_x = diag(x = 1, ncol = dim_x, nrow = dim_x)
157+
Pi = cbind(I_x, zeros)
158+
159+
z = x %*% Pi + xi
160+
d = x %*% gamma + z %*% delta + u
161+
y = alpha * d + x %*% beta + epsilon
162+
163+
164+
colnames(x) = paste0("X", 1:dim_x)
165+
colnames(z) = paste0("Z", 1:dim_z)
166+
colnames(y) = "y"
167+
colnames(d) = "d"
168+
169+
data = data.frame(x, y, d, z)
170+
return(data)
171+
}
172+

0 commit comments

Comments
 (0)