@@ -138,42 +138,43 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
138138 # ' The `DoubleMLData` object providing the data and specifying the variables
139139 # ' of the causal model.
140140 # '
141- # ' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], `character(1)`) \cr
142- # ' An object of the class [mlr3 regression learner][mlr3::LearnerRegr] to
143- # ' pass a learner, possibly with specified parameters, for example
144- # ' `lrn("regr.cv_glmnet", s = "lambda.min")`.
145- # ' Alternatively, a `character(1)` specifying the name of a
146- # ' [mlr3 regression learner][mlr3::LearnerRegr] that is available in
147- # ' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
148- # ' [mlr3learners](https://mlr3learners.mlr-org.com/) or
149- # ' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
150- # ' for example `"regr.cv_glmnet"`. \cr
141+ # ' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
142+ # ' [`Learner`][mlr3::Learner], `character(1)`) \cr
143+ # ' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
144+ # ' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
145+ # ' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
146+ # ' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
147+ # ' Alternatively, a [`Learner`][mlr3::Learner] object with public field
148+ # ' `task_type = "regr"` can be passed, for example of class
149+ # ' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
150+ # ' be passed with specified parameters, for example
151+ # ' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr
151152 # ' `ml_g` refers to the nuisance function \eqn{g_0(Z,X) = E[Y|X,Z]}.
152153 # '
153- # ' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
154- # ' An object of the class
155- # ' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
156- # ' possibly with specified parameters, for example
157- # ' `lrn("classif.cv_glmnet", s = "lambda.min")`.
158- # ' Alternatively, a `character(1)` specifying the name of
159- # ' a [mlr3 classification learner ][mlr3::LearnerClassif] that is available
160- # ' in [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
161- # ' [mlr3learners](https://mlr3learners.mlr-org.com/) or
162- # ' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
163- # ' for example ` "classif.cv_glmnet"`. \cr
154+ # ' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
155+ # ' [`Learner`][mlr3::Learner], `character(1)`) \cr
156+ # ' A learner of the class [`LearnerClassif` ][mlr3::LearnerClassif], which is
157+ # ' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
158+ # ' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
159+ # ' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
160+ # ' Alternatively, a [`Learner` ][mlr3::Learner] object with public field
161+ # ' `task_type = "classif"` can be passed, for example of class
162+ # ' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
163+ # ' be passed with specified parameters, for example
164+ # ' `lrn( "classif.cv_glmnet", s = "lambda.min") `. \cr
164165 # ' `ml_m` refers to the nuisance function \eqn{m_0(X) = E[Z|X]}.
165166 # '
166- # ' @param ml_r ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
167- # ' An object of the class
168- # ' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
169- # ' possibly with specified parameters, for example
170- # ' `lrn("classif.cv_glmnet", s = "lambda.min")`.
171- # ' Alternatively, a `character(1)` specifying the name of a
172- # ' [mlr3 classification learner ][mlr3::LearnerClassif] that is available in
173- # ' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
174- # ' [mlr3learners](https://mlr3learners.mlr-org.com/) or
175- # ' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
176- # ' for example ` "classif.cv_glmnet"`. \cr
167+ # ' @param ml_r ([`LearnerClassif`][mlr3::LearnerClassif],
168+ # ' [`Learner`][mlr3::Learner], `character(1)`) \cr
169+ # ' A learner of the class [`LearnerClassif` ][mlr3::LearnerClassif], which is
170+ # ' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
171+ # ' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
172+ # ' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
173+ # ' Alternatively, a [`Learner` ][mlr3::Learner] object with public field
174+ # ' `task_type = "classif"` can be passed, for example of class
175+ # ' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
176+ # ' be passed with specified parameters, for example
177+ # ' `lrn( "classif.cv_glmnet", s = "lambda.min") `. \cr
177178 # ' `ml_r` refers to the nuisance function \eqn{r_0(Z,X) = E[D|X,Z]}.
178179 # '
179180 # ' @param n_folds (`integer(1)`)\cr
@@ -241,7 +242,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
241242
242243 private $ check_data(self $ data )
243244 private $ check_score(self $ score )
244- private $ learner_class = list (
245+ private $ task_type = list (
245246 " ml_g" = NULL ,
246247 " ml_m" = NULL ,
247248 " ml_r" = NULL )
@@ -295,7 +296,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
295296 smpls = smpls ,
296297 est_params = self $ get_params(" ml_m" ),
297298 return_train_preds = FALSE ,
298- learner_class = private $ learner_class $ ml_m ,
299+ task_type = private $ task_type $ ml_m ,
299300 fold_specific_params = private $ fold_specific_params )
300301
301302 g0_hat = dml_cv_predict(self $ learner $ ml_g ,
@@ -306,7 +307,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
306307 smpls = cond_smpls $ smpls_0 ,
307308 est_params = self $ get_params(" ml_g0" ),
308309 return_train_preds = FALSE ,
309- learner_class = private $ learner_class $ ml_g ,
310+ task_type = private $ task_type $ ml_g ,
310311 fold_specific_params = private $ fold_specific_params )
311312
312313 g1_hat = dml_cv_predict(self $ learner $ ml_g ,
@@ -317,7 +318,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
317318 smpls = cond_smpls $ smpls_1 ,
318319 est_params = self $ get_params(" ml_g1" ),
319320 return_train_preds = FALSE ,
320- learner_class = private $ learner_class $ ml_g ,
321+ task_type = private $ task_type $ ml_g ,
321322 fold_specific_params = private $ fold_specific_params )
322323
323324 if (self $ subgroups $ always_takers == FALSE ) {
@@ -331,7 +332,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
331332 smpls = cond_smpls $ smpls_0 ,
332333 est_params = self $ get_params(" ml_r0" ),
333334 return_train_preds = FALSE ,
334- learner_class = private $ learner_class $ ml_r ,
335+ task_type = private $ task_type $ ml_r ,
335336 fold_specific_params = private $ fold_specific_params )
336337 }
337338
@@ -346,7 +347,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
346347 smpls = cond_smpls $ smpls_1 ,
347348 est_params = self $ get_params(" ml_r1" ),
348349 return_train_preds = FALSE ,
349- learner_class = private $ learner_class $ ml_r ,
350+ task_type = private $ task_type $ ml_r ,
350351 fold_specific_params = private $ fold_specific_params )
351352 }
352353
@@ -421,7 +422,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
421422 nuisance_id = " nuis_m" ,
422423 param_set $ ml_m , tune_settings ,
423424 tune_settings $ measure $ ml_m ,
424- private $ learner_class $ ml_m )
425+ private $ task_type $ ml_m )
425426
426427 tuning_result_g0 = dml_tune(self $ learner $ ml_g ,
427428 c(self $ data $ x_cols , self $ data $ other_treat_cols ),
@@ -430,7 +431,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
430431 nuisance_id = " nuis_g0" ,
431432 param_set $ ml_g , tune_settings ,
432433 tune_settings $ measure $ ml_g ,
433- private $ learner_class $ ml_g )
434+ private $ task_type $ ml_g )
434435
435436 tuning_result_g1 = dml_tune(self $ learner $ ml_g ,
436437 c(self $ data $ x_cols , self $ data $ other_treat_cols ),
@@ -439,7 +440,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
439440 nuisance_id = " nuis_g1" ,
440441 param_set $ ml_g , tune_settings ,
441442 tune_settings $ measure $ ml_g ,
442- private $ learner_class $ ml_g )
443+ private $ task_type $ ml_g )
443444
444445 if (self $ subgroups $ always_takers == TRUE ) {
445446 tuning_result_r0 = dml_tune(self $ learner $ ml_r ,
@@ -449,7 +450,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
449450 nuisance_id = " nuis_r0" ,
450451 param_set $ ml_r , tune_settings ,
451452 tune_settings $ measure $ ml_r ,
452- private $ learner_class $ ml_r )
453+ private $ task_type $ ml_r )
453454 } else {
454455 tuning_result_r0 = list (list (), " params" = list (list ()))
455456 }
@@ -462,7 +463,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
462463 nuisance_id = " nuis_r1" ,
463464 param_set $ ml_r , tune_settings ,
464465 tune_settings $ measure $ ml_r ,
465- private $ learner_class $ ml_r )
466+ private $ task_type $ ml_r )
466467 } else {
467468 tuning_result_r1 = list (list (), " params" = list (list ()))
468469 }
0 commit comments