Skip to content

Commit b0f5047

Browse files
authored
Merge pull request #141 from DoubleML/p-assert-learner-class
Use task_type instead of learner_class
2 parents dcd73a2 + 88dfd31 commit b0f5047

17 files changed

+411
-321
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ Suggests:
5151
sandwich,
5252
AER,
5353
rpart,
54-
bbotk
54+
bbotk,
55+
mlr3pipelines
5556
VignetteBuilder: knitr
5657
Collate:
5758
'double_ml.R'

R/double_ml.R

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ DoubleML = R6Class("DoubleML",
11501150
i_treat = NA_integer_,
11511151
fold_specific_params = NULL,
11521152
summary_table = NULL,
1153-
learner_class = list(),
1153+
task_type = list(),
11541154
is_cluster_data = FALSE,
11551155
n_folds_per_cluster = NA_integer_,
11561156
smpls_cluster_ = NULL,
@@ -1250,28 +1250,31 @@ DoubleML = R6Class("DoubleML",
12501250
check_character(learner, max.len = 1),
12511251
check_class(learner, "Learner"))
12521252

1253+
if (test_class(learner, "AutoTuner")) {
1254+
stop(paste0(
1255+
"Learners of class 'AutoTuner' are not supported."
1256+
))
1257+
}
12531258
if (is.character(learner)) {
12541259
# warning("Learner provision by character() will be deprecated in the
12551260
# future.")
12561261
learner = lrn(learner)
12571262
}
12581263

1259-
if (Regr & test_class(learner, "LearnerRegr")) {
1260-
private$learner_class[learner_name] = "LearnerRegr"
1261-
}
1262-
if (Classif & test_class(learner, "LearnerClassif")) {
1263-
private$learner_class[learner_name] = "LearnerClassif"
1264+
if ((Regr & learner$task_type == "regr") |
1265+
(Classif & learner$task_type == "classif")) {
1266+
private$task_type[learner_name] = learner$task_type
12641267
}
12651268

1266-
if ((Regr & !Classif & !test_class(learner, "LearnerRegr"))) {
1269+
if ((Regr & !Classif & !learner$task_type == "regr")) {
12671270
stop(paste0(
12681271
"Invalid learner provided for ", learner_name,
1269-
": must be of class 'LearnerRegr'"))
1272+
": 'learner$task_type' must be 'regr'"))
12701273
}
1271-
if ((Classif & !Regr & !test_class(learner, "LearnerClassif"))) {
1274+
if ((Classif & !Regr & !learner$task_type == "classif")) {
12721275
stop(paste0(
12731276
"Invalid learner provided for ", learner_name,
1274-
": must be of class 'LearnerClassif'"))
1277+
": 'learner$task_type must be 'classif'"))
12751278
}
12761279
invisible(learner)
12771280
},
@@ -1333,7 +1336,7 @@ DoubleML = R6Class("DoubleML",
13331336
this_learner = names(tune_settings$measure)[i_msr]
13341337
tune_settings$measure[[this_learner]] = set_default_measure(
13351338
tune_settings$measure[[this_learner]],
1336-
private$learner_class[[this_learner]])
1339+
private$task_type[[this_learner]])
13371340
}
13381341
}
13391342

R/double_ml_iivm.R

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -138,42 +138,43 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
138138
#' The `DoubleMLData` object providing the data and specifying the variables
139139
#' of the causal model.
140140
#'
141-
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], `character(1)`) \cr
142-
#' An object of the class [mlr3 regression learner][mlr3::LearnerRegr] to
143-
#' pass a learner, possibly with specified parameters, for example
144-
#' `lrn("regr.cv_glmnet", s = "lambda.min")`.
145-
#' Alternatively, a `character(1)` specifying the name of a
146-
#' [mlr3 regression learner][mlr3::LearnerRegr] that is available in
147-
#' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
148-
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
149-
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
150-
#' for example `"regr.cv_glmnet"`. \cr
141+
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
142+
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
143+
#' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
144+
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
145+
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
146+
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
147+
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
148+
#' `task_type = "regr"` can be passed, for example of class
149+
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
150+
#' be passed with specified parameters, for example
151+
#' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr
151152
#' `ml_g` refers to the nuisance function \eqn{g_0(Z,X) = E[Y|X,Z]}.
152153
#'
153-
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
154-
#' An object of the class
155-
#' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
156-
#' possibly with specified parameters, for example
157-
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
158-
#' Alternatively, a `character(1)` specifying the name of
159-
#' a [mlr3 classification learner][mlr3::LearnerClassif] that is available
160-
#' in [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
161-
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
162-
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
163-
#' for example `"classif.cv_glmnet"`. \cr
154+
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
155+
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
156+
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
157+
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
158+
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
159+
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
160+
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
161+
#' `task_type = "classif"` can be passed, for example of class
162+
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
163+
#' be passed with specified parameters, for example
164+
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
164165
#' `ml_m` refers to the nuisance function \eqn{m_0(X) = E[Z|X]}.
165166
#'
166-
#' @param ml_r ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
167-
#' An object of the class
168-
#' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
169-
#' possibly with specified parameters, for example
170-
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
171-
#' Alternatively, a `character(1)` specifying the name of a
172-
#' [mlr3 classification learner][mlr3::LearnerClassif] that is available in
173-
#' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
174-
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
175-
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
176-
#' for example `"classif.cv_glmnet"`. \cr
167+
#' @param ml_r ([`LearnerClassif`][mlr3::LearnerClassif],
168+
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
169+
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
170+
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
171+
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
172+
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
173+
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
174+
#' `task_type = "classif"` can be passed, for example of class
175+
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
176+
#' be passed with specified parameters, for example
177+
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
177178
#' `ml_r` refers to the nuisance function \eqn{r_0(Z,X) = E[D|X,Z]}.
178179
#'
179180
#' @param n_folds (`integer(1)`)\cr
@@ -241,7 +242,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
241242

242243
private$check_data(self$data)
243244
private$check_score(self$score)
244-
private$learner_class = list(
245+
private$task_type = list(
245246
"ml_g" = NULL,
246247
"ml_m" = NULL,
247248
"ml_r" = NULL)
@@ -295,7 +296,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
295296
smpls = smpls,
296297
est_params = self$get_params("ml_m"),
297298
return_train_preds = FALSE,
298-
learner_class = private$learner_class$ml_m,
299+
task_type = private$task_type$ml_m,
299300
fold_specific_params = private$fold_specific_params)
300301

301302
g0_hat = dml_cv_predict(self$learner$ml_g,
@@ -306,7 +307,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
306307
smpls = cond_smpls$smpls_0,
307308
est_params = self$get_params("ml_g0"),
308309
return_train_preds = FALSE,
309-
learner_class = private$learner_class$ml_g,
310+
task_type = private$task_type$ml_g,
310311
fold_specific_params = private$fold_specific_params)
311312

312313
g1_hat = dml_cv_predict(self$learner$ml_g,
@@ -317,7 +318,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
317318
smpls = cond_smpls$smpls_1,
318319
est_params = self$get_params("ml_g1"),
319320
return_train_preds = FALSE,
320-
learner_class = private$learner_class$ml_g,
321+
task_type = private$task_type$ml_g,
321322
fold_specific_params = private$fold_specific_params)
322323

323324
if (self$subgroups$always_takers == FALSE) {
@@ -331,7 +332,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
331332
smpls = cond_smpls$smpls_0,
332333
est_params = self$get_params("ml_r0"),
333334
return_train_preds = FALSE,
334-
learner_class = private$learner_class$ml_r,
335+
task_type = private$task_type$ml_r,
335336
fold_specific_params = private$fold_specific_params)
336337
}
337338

@@ -346,7 +347,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
346347
smpls = cond_smpls$smpls_1,
347348
est_params = self$get_params("ml_r1"),
348349
return_train_preds = FALSE,
349-
learner_class = private$learner_class$ml_r,
350+
task_type = private$task_type$ml_r,
350351
fold_specific_params = private$fold_specific_params)
351352
}
352353

@@ -421,7 +422,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
421422
nuisance_id = "nuis_m",
422423
param_set$ml_m, tune_settings,
423424
tune_settings$measure$ml_m,
424-
private$learner_class$ml_m)
425+
private$task_type$ml_m)
425426

426427
tuning_result_g0 = dml_tune(self$learner$ml_g,
427428
c(self$data$x_cols, self$data$other_treat_cols),
@@ -430,7 +431,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
430431
nuisance_id = "nuis_g0",
431432
param_set$ml_g, tune_settings,
432433
tune_settings$measure$ml_g,
433-
private$learner_class$ml_g)
434+
private$task_type$ml_g)
434435

435436
tuning_result_g1 = dml_tune(self$learner$ml_g,
436437
c(self$data$x_cols, self$data$other_treat_cols),
@@ -439,7 +440,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
439440
nuisance_id = "nuis_g1",
440441
param_set$ml_g, tune_settings,
441442
tune_settings$measure$ml_g,
442-
private$learner_class$ml_g)
443+
private$task_type$ml_g)
443444

444445
if (self$subgroups$always_takers == TRUE) {
445446
tuning_result_r0 = dml_tune(self$learner$ml_r,
@@ -449,7 +450,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
449450
nuisance_id = "nuis_r0",
450451
param_set$ml_r, tune_settings,
451452
tune_settings$measure$ml_r,
452-
private$learner_class$ml_r)
453+
private$task_type$ml_r)
453454
} else {
454455
tuning_result_r0 = list(list(), "params" = list(list()))
455456
}
@@ -462,7 +463,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
462463
nuisance_id = "nuis_r1",
463464
param_set$ml_r, tune_settings,
464465
tune_settings$measure$ml_r,
465-
private$learner_class$ml_r)
466+
private$task_type$ml_r)
466467
} else {
467468
tuning_result_r1 = list(list(), "params" = list(list()))
468469
}

R/double_ml_irm.R

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -107,29 +107,30 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
107107
#' The `DoubleMLData` object providing the data and specifying the variables
108108
#' of the causal model.
109109
#'
110-
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], `character(1)`) \cr
111-
#' An object of the class [mlr3 regression learner][mlr3::LearnerRegr] to
112-
#' pass a learner, possibly with specified parameters, for example
113-
#' `lrn("regr.cv_glmnet", s = "lambda.min")`.
114-
#' Alternatively, a `character(1)` specifying the name of a
115-
#' [mlr3 regression learner][mlr3::LearnerRegr] that is available in
116-
#' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
117-
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
118-
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/), for example
119-
#' `"regr.cv_glmnet"`. \cr
110+
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
111+
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
112+
#' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
113+
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
114+
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
115+
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
116+
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
117+
#' `task_type = "regr"` can be passed, for example of class
118+
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
119+
#' be passed with specified parameters, for example
120+
#' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr
120121
#' `ml_g` refers to the nuisance function \eqn{g_0(X) = E[Y|X,D]}.
121122
#'
122-
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
123-
#' An object of the class
124-
#' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
125-
#' possibly with specified parameters, for example
126-
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
127-
#' Alternatively, a `character(1)` specifying the name of a
128-
#' [mlr3 classification learner][mlr3::LearnerClassif] that is available
129-
#' in [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
130-
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
131-
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
132-
#' for example `"classif.cv_glmnet"`. \cr
123+
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
124+
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
125+
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
126+
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
127+
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
128+
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
129+
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
130+
#' `task_type = "classif"` can be passed, for example of class
131+
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
132+
#' be passed with specified parameters, for example
133+
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
133134
#' `ml_m` refers to the nuisance function \eqn{m_0(X) = E[D|X]}.
134135
#'
135136
#' @param n_folds (`integer(1)`)\cr
@@ -185,7 +186,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
185186

186187
private$check_data(self$data)
187188
private$check_score(self$score)
188-
private$learner_class = list(
189+
private$task_type = list(
189190
"ml_g" = NULL,
190191
"ml_m" = NULL)
191192
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = FALSE)
@@ -227,7 +228,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
227228
smpls = smpls,
228229
est_params = self$get_params("ml_m"),
229230
return_train_preds = FALSE,
230-
learner_class = private$learner_class$ml_m,
231+
task_type = private$task_type$ml_m,
231232
fold_specific_params = private$fold_specific_params)
232233

233234
g0_hat = dml_cv_predict(self$learner$ml_g,
@@ -238,7 +239,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
238239
smpls = cond_smpls$smpls_0,
239240
est_params = self$get_params("ml_g0"),
240241
return_train_preds = FALSE,
241-
learner_class = private$learner_class$ml_g,
242+
task_type = private$task_type$ml_g,
242243
fold_specific_params = private$fold_specific_params)
243244

244245
g1_hat = NULL
@@ -251,7 +252,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
251252
smpls = cond_smpls$smpls_1,
252253
est_params = self$get_params("ml_g1"),
253254
return_train_preds = FALSE,
254-
learner_class = private$learner_class$ml_g,
255+
task_type = private$task_type$ml_g,
255256
fold_specific_params = private$fold_specific_params)
256257
}
257258

@@ -330,7 +331,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
330331
nuisance_id = "nuis_m",
331332
param_set$ml_m, tune_settings,
332333
tune_settings$measure$ml_m,
333-
private$learner_class$ml_m)
334+
private$task_type$ml_m)
334335

335336
tuning_result_g0 = dml_tune(self$learner$ml_g,
336337
c(self$data$x_cols, self$data$other_treat_cols),
@@ -339,7 +340,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
339340
nuisance_id = "nuis_g0",
340341
param_set$ml_g, tune_settings,
341342
tune_settings$measure$ml_g,
342-
private$learner_class$ml_g)
343+
private$task_type$ml_g)
343344

344345
if ((is.character(self$score) && self$score == "ATE") || is.function(self$score)) {
345346
tuning_result_g1 = dml_tune(self$learner$ml_g,
@@ -349,7 +350,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
349350
nuisance_id = "nuis_g1",
350351
param_set$ml_g, tune_settings,
351352
tune_settings$measure$ml_g,
352-
private$learner_class$ml_g)
353+
private$task_type$ml_g)
353354
} else {
354355
tuning_result_g1 = list(list(), "params" = list(list()))
355356
}

0 commit comments

Comments
 (0)