Skip to content

Commit 46b5929

Browse files
authored
Merge pull request #162 from DoubleML/p-suggest-change-pliv
Unit test for functional initializer for PLIV
2 parents ecd05c2 + bfc9c91 commit 46b5929

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
context("Unit tests for PLIV, partialling out X, Z, XZ")
2+
3+
lgr::get_logger("mlr3")$set_threshold("warn")
4+
5+
on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
6+
if (on_cran) {
7+
test_cases = expand.grid(
8+
learner = "regr.lm",
9+
dml_procedure = "dml2",
10+
score = "IV-type",
11+
stringsAsFactors = FALSE)
12+
} else {
13+
test_cases = expand.grid(
14+
learner = c("regr.lm", "regr.cv_glmnet"),
15+
dml_procedure = c("dml1", "dml2"),
16+
score = "IV-type",
17+
stringsAsFactors = FALSE)
18+
}
19+
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
20+
21+
patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional initialization):",
22+
.cases = test_cases, {
23+
learner_pars = get_default_mlmethod_pliv(learner)
24+
df = data_pliv$df
25+
Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE]
26+
data_ml = double_ml_data_from_data_frame(df,
27+
y_col = "y",
28+
d_cols = "d", x_cols = Xnames, z_cols = "z")
29+
30+
# Partial out X (default PLIV)
31+
set.seed(3141)
32+
double_mlpliv_obj = DoubleMLPLIV$new(data_ml,
33+
n_folds = 5,
34+
ml_l = learner_pars$ml_l$clone(),
35+
ml_m = learner_pars$ml_m$clone(),
36+
ml_r = learner_pars$ml_r$clone(),
37+
ml_g = learner_pars$ml_g$clone(),
38+
dml_procedure = dml_procedure,
39+
score = score)
40+
41+
double_mlpliv_obj$fit()
42+
theta_obj = double_mlpliv_obj$coef
43+
se_obj = double_mlpliv_obj$se
44+
45+
# Partial out X
46+
set.seed(3141)
47+
double_mlpliv_partX = DoubleMLPLIV.partialX(data_ml,
48+
n_folds = 5,
49+
ml_l = learner_pars$ml_l$clone(),
50+
ml_m = learner_pars$ml_m$clone(),
51+
ml_r = learner_pars$ml_r$clone(),
52+
ml_g = learner_pars$ml_g$clone(),
53+
dml_procedure = dml_procedure,
54+
score = score)
55+
56+
double_mlpliv_partX$fit()
57+
theta_partX = double_mlpliv_partX$coef
58+
se_partX = double_mlpliv_partX$se
59+
60+
expect_equal(theta_partX, theta_obj, tolerance = 1e-8)
61+
expect_equal(se_partX, se_obj, tolerance = 1e-8)
62+
}
63+
)

0 commit comments

Comments
 (0)