|
| 1 | +context("Unit tests for PLIV, partialling out X, Z, XZ") |
| 2 | + |
| 3 | +lgr::get_logger("mlr3")$set_threshold("warn") |
| 4 | + |
| 5 | +on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") |
| 6 | +if (on_cran) { |
| 7 | + test_cases = expand.grid( |
| 8 | + learner = "regr.lm", |
| 9 | + dml_procedure = "dml2", |
| 10 | + score = "IV-type", |
| 11 | + stringsAsFactors = FALSE) |
| 12 | +} else { |
| 13 | + test_cases = expand.grid( |
| 14 | + learner = c("regr.lm", "regr.cv_glmnet"), |
| 15 | + dml_procedure = c("dml1", "dml2"), |
| 16 | + score = "IV-type", |
| 17 | + stringsAsFactors = FALSE) |
| 18 | +} |
| 19 | +test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") |
| 20 | + |
| 21 | +patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional initialization):", |
| 22 | + .cases = test_cases, { |
| 23 | + learner_pars = get_default_mlmethod_pliv(learner) |
| 24 | + df = data_pliv$df |
| 25 | + Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE] |
| 26 | + data_ml = double_ml_data_from_data_frame(df, |
| 27 | + y_col = "y", |
| 28 | + d_cols = "d", x_cols = Xnames, z_cols = "z") |
| 29 | + |
| 30 | + # Partial out X (default PLIV) |
| 31 | + set.seed(3141) |
| 32 | + double_mlpliv_obj = DoubleMLPLIV$new(data_ml, |
| 33 | + n_folds = 5, |
| 34 | + ml_l = learner_pars$ml_l$clone(), |
| 35 | + ml_m = learner_pars$ml_m$clone(), |
| 36 | + ml_r = learner_pars$ml_r$clone(), |
| 37 | + ml_g = learner_pars$ml_g$clone(), |
| 38 | + dml_procedure = dml_procedure, |
| 39 | + score = score) |
| 40 | + |
| 41 | + double_mlpliv_obj$fit() |
| 42 | + theta_obj = double_mlpliv_obj$coef |
| 43 | + se_obj = double_mlpliv_obj$se |
| 44 | + |
| 45 | + # Partial out X |
| 46 | + set.seed(3141) |
| 47 | + double_mlpliv_partX = DoubleMLPLIV.partialX(data_ml, |
| 48 | + n_folds = 5, |
| 49 | + ml_l = learner_pars$ml_l$clone(), |
| 50 | + ml_m = learner_pars$ml_m$clone(), |
| 51 | + ml_r = learner_pars$ml_r$clone(), |
| 52 | + ml_g = learner_pars$ml_g$clone(), |
| 53 | + dml_procedure = dml_procedure, |
| 54 | + score = score) |
| 55 | + |
| 56 | + double_mlpliv_partX$fit() |
| 57 | + theta_partX = double_mlpliv_partX$coef |
| 58 | + se_partX = double_mlpliv_partX$se |
| 59 | + |
| 60 | + expect_equal(theta_partX, theta_obj, tolerance = 1e-8) |
| 61 | + expect_equal(se_partX, se_obj, tolerance = 1e-8) |
| 62 | + } |
| 63 | +) |
0 commit comments