Skip to content

Commit c4d97eb

Browse files
committed
regression tests with mlr3pipelines based learner (class GraphLearner)
add test for each of the causal models
1 parent 6e1a1b6 commit c4d97eb

File tree

6 files changed

+103
-28
lines changed

6 files changed

+103
-28
lines changed

tests/testthat/helper-05-ml-learner.R

Lines changed: 97 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
5252
params_g = list(
5353
s = "lambda.min",
5454
family = "gaussian"))
55-
5655
}
57-
5856
}
5957

6058
else if (default == TRUE) {
@@ -64,12 +62,28 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
6462
params = list(
6563
params_g = list(),
6664
params_m = list())
65+
}
6766

67+
if (learner == "graph_learner") {
68+
# pipeline learner
69+
pipe_learner = mlr3pipelines::po("learner",
70+
lrn("regr.glmnet"),
71+
lambda = 0.01,
72+
family = "gaussian")
73+
mlmethod = list(
74+
mlmethod_m = "graph_learner",
75+
mlmethod_g = "graph_learner")
76+
params = list(
77+
params_g = list(),
78+
params_m = list())
79+
ml_g = mlr3::as_learner(pipe_learner)
80+
ml_m = mlr3::as_learner(pipe_learner)
81+
} else {
82+
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
83+
ml_g$param_set$values = params$params_g
84+
ml_m = mlr3::lrn(mlmethod$mlmethod_m)
85+
ml_m$param_set$values = params$params_m
6886
}
69-
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
70-
ml_g$param_set$values = params$params_g
71-
ml_m = mlr3::lrn(mlmethod$mlmethod_m)
72-
ml_m$param_set$values = params$params_m
7387

7488
return(list(
7589
mlmethod = mlmethod, params = params,
@@ -147,12 +161,31 @@ get_default_mlmethod_pliv = function(learner) {
147161
family = "gaussian"))
148162

149163
}
150-
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
151-
ml_g$param_set$values = params$params_g
152-
ml_m = mlr3::lrn(mlmethod$mlmethod_m)
153-
ml_m$param_set$values = params$params_m
154-
ml_r = mlr3::lrn(mlmethod$mlmethod_r)
155-
ml_r$param_set$values = params$params_r
164+
165+
if (learner == "graph_learner") {
166+
# pipeline learner
167+
pipe_learner = mlr3pipelines::po("learner",
168+
lrn("regr.glmnet"),
169+
lambda = 0.01,
170+
family = "gaussian")
171+
mlmethod = list(
172+
mlmethod_m = "graph_learner",
173+
mlmethod_g = "graph_learner",
174+
mlmethod_r = "graph_learner")
175+
params = list(
176+
params_g = list(),
177+
params_m = list())
178+
ml_g = mlr3::as_learner(pipe_learner)
179+
ml_m = mlr3::as_learner(pipe_learner)
180+
ml_r = mlr3::as_learner(pipe_learner)
181+
} else {
182+
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
183+
ml_g$param_set$values = params$params_g
184+
ml_m = mlr3::lrn(mlmethod$mlmethod_m)
185+
ml_m$param_set$values = params$params_m
186+
ml_r = mlr3::lrn(mlmethod$mlmethod_r)
187+
ml_r$param_set$values = params$params_r
188+
}
156189

157190
return(list(
158191
mlmethod = mlmethod, params = params,
@@ -182,11 +215,30 @@ get_default_mlmethod_irm = function(learner) {
182215
params_m = list(cp = 0.01, minsplit = 20))
183216

184217
}
185-
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
186-
ml_g$param_set$values = params$params_g
187-
ml_m = mlr3::lrn(mlmethod$mlmethod_m, predict_type = "prob")
188-
ml_m$param_set$values = params$params_m
189218

219+
if (learner == "graph_learner") {
220+
# pipeline learner
221+
pipe_learner = mlr3pipelines::po("learner",
222+
lrn("regr.rpart"),
223+
cp = 0.01, minsplit = 20)
224+
pipe_learner_classif = mlr3pipelines::po("learner",
225+
lrn("classif.rpart",
226+
predict_type = "prob"),
227+
cp = 0.01, minsplit = 20)
228+
mlmethod = list(
229+
mlmethod_m = "graph_learner",
230+
mlmethod_g = "graph_learner")
231+
params = list(
232+
params_g = list(),
233+
params_m = list())
234+
ml_g = mlr3::as_learner(pipe_learner)
235+
ml_m = mlr3::as_learner(pipe_learner_classif)
236+
} else {
237+
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
238+
ml_g$param_set$values = params$params_g
239+
ml_m = mlr3::lrn(mlmethod$mlmethod_m, predict_type = "prob")
240+
ml_m$param_set$values = params$params_m
241+
}
190242
return(list(
191243
mlmethod = mlmethod, params = params,
192244
ml_g = ml_g, ml_m = ml_m))
@@ -219,12 +271,35 @@ get_default_mlmethod_iivm = function(learner) {
219271
params_r = list(cp = 0.01, minsplit = 20))
220272

221273
}
222-
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
223-
ml_g$param_set$values = params$params_g
224-
ml_m = mlr3::lrn(mlmethod$mlmethod_m, predict_type = "prob")
225-
ml_m$param_set$values = params$params_m
226-
ml_r = mlr3::lrn(mlmethod$mlmethod_r, predict_type = "prob")
227-
ml_r$param_set$values = params$params_r
274+
275+
if (learner == "graph_learner") {
276+
# pipeline learner
277+
pipe_learner = mlr3pipelines::po("learner",
278+
lrn("regr.rpart"),
279+
cp = 0.01, minsplit = 20)
280+
pipe_learner_classif = mlr3pipelines::po("learner",
281+
lrn("classif.rpart",
282+
predict_type = "prob"),
283+
cp = 0.01, minsplit = 20)
284+
mlmethod = list(
285+
mlmethod_m = "graph_learner",
286+
mlmethod_g = "graph_learner",
287+
mlmethod_r = "graph_learner")
288+
params = list(
289+
params_g = list(),
290+
params_m = list(),
291+
params_r = list())
292+
ml_g = mlr3::as_learner(pipe_learner)
293+
ml_m = mlr3::as_learner(pipe_learner_classif)
294+
ml_r = mlr3::as_learner(pipe_learner_classif)
295+
} else {
296+
ml_g = mlr3::lrn(mlmethod$mlmethod_g)
297+
ml_g$param_set$values = params$params_g
298+
ml_m = mlr3::lrn(mlmethod$mlmethod_m, predict_type = "prob")
299+
ml_m$param_set$values = params$params_m
300+
ml_r = mlr3::lrn(mlmethod$mlmethod_r, predict_type = "prob")
301+
ml_r$param_set$values = params$params_r
302+
}
228303

229304
return(list(
230305
mlmethod = mlmethod, params = params,

tests/testthat/helper-08-dml_plr.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,15 @@ fit_nuisance_plr = function(data, y, d,
217217
m_indx = names(data) != y
218218
data_m = data[, m_indx, drop = FALSE]
219219

220-
if (checkmate::test_class(ml_m, "LearnerRegr")) {
220+
if (ml_m$task_type == "regr") {
221221
task_m = mlr3::TaskRegr$new(id = paste0("nuis_m_", d), backend = data_m, target = d)
222222

223223
resampling_m = mlr3::rsmp("custom")
224224
resampling_m$instantiate(task_m, train_ids, test_ids)
225225

226226
r_m = mlr3::resample(task_m, ml_m, resampling_m, store_models = TRUE)
227227
m_hat_list = lapply(r_m$predictions(), function(x) x$response)
228-
} else if (checkmate::test_class(ml_m, "LearnerClassif")) {
228+
} else if ((ml_m$task_type == "classif")) {
229229
ml_m$predict_type = "prob"
230230
data_m[[d]] = factor(data_m[[d]])
231231
task_m = mlr3::TaskClassif$new(

tests/testthat/test-double_ml_iivm.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ if (on_cran) {
1414
stringsAsFactors = FALSE)
1515
} else {
1616
test_cases = expand.grid(
17-
learner = "cv_glmnet",
17+
learner = c("cv_glmnet", "graph_learner"),
1818
dml_procedure = c("dml1", "dml2"),
1919
score = "LATE",
2020
trimming_threshold = c(1e-5),

tests/testthat/test-double_ml_irm.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ if (on_cran) {
1515
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
1616
} else {
1717
test_cases = expand.grid(
18-
learner = "cv_glmnet",
18+
learner = c("cv_glmnet", "graph_learner"),
1919
dml_procedure = c("dml1", "dml2"),
2020
score = c("ATE", "ATTE"),
2121
trimming_threshold = 0,

tests/testthat/test-double_ml_pliv.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ if (on_cran) {
1313
stringsAsFactors = FALSE)
1414
} else {
1515
test_cases = expand.grid(
16-
learner = c("regr.lm", "regr.glmnet"),
16+
learner = c("regr.lm", "regr.glmnet", "graph_learner"),
1717
dml_procedure = c("dml1", "dml2"),
1818
score = "partialling out",
1919
stringsAsFactors = FALSE)

tests/testthat/test-double_ml_plr.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ if (on_cran) {
1313
stringsAsFactors = FALSE)
1414
} else {
1515
test_cases = expand.grid(
16-
learner = c("regr.lm", "regr.cv_glmnet"),
16+
learner = c("regr.lm", "regr.cv_glmnet", "graph_learner"),
1717
dml_procedure = c("dml1", "dml2"),
1818
score = c("IV-type", "partialling out"),
1919
stringsAsFactors = FALSE)

0 commit comments

Comments
 (0)