Skip to content

Commit 4e5c45a

Browse files
authored
Merge pull request #114 from DoubleML/p-binary-outcome-irm
Binary outcome variable in IRM and IIVM
2 parents 6a88a9e + 91c8b90 commit 4e5c45a

17 files changed

+685
-129
lines changed

R/datasets.R

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,16 @@ fetch_401k = function(return_type = "DoubleMLData", polynomial_features = FALSE,
8484
data = data.frame(
8585
"net_tfa" = data$net_tfa,
8686
model.matrix(formula_x, data),
87-
"p401" = data$p401, "e401" = data$e401)
87+
"p401" = data$p401, "e401" = data$e401
88+
)
8889
d_cols = "p401"
8990
z_cols = "e401"
9091
} else {
9192
# see https://github.com/VC2015/DMLonGitHub/blob/b91cbf96c01eccd73367fbd6601ecdd7aa78403b/401K.R#L67
9293
data = data.frame(
9394
"net_tfa" = data$net_tfa, model.matrix(formula_x, data),
94-
"e401" = data$e401)
95+
"e401" = data$e401
96+
)
9597
d_cols = "e401"
9698
}
9799
if (return_type == "data.frame") {
@@ -103,7 +105,8 @@ fetch_401k = function(return_type = "DoubleMLData", polynomial_features = FALSE,
103105
dt = as.data.table(data)
104106
data = DoubleMLData$new(dt,
105107
y_col = y_col, d_cols = d_cols, x_cols = x_cols,
106-
z_cols = z_cols)
108+
z_cols = z_cols
109+
)
107110
return(data)
108111
}
109112
}
@@ -199,7 +202,9 @@ fetch_401k = function(return_type = "DoubleMLData", polynomial_features = FALSE,
199202
#' x_cols = c(
200203
#' "female", "black", "othrace", "dep1", "dep2",
201204
#' "q2", "q3", "q4", "q5", "q6", "agelt35", "agegt54",
202-
#' "durable", "lusd", "husd"))
205+
#' "durable", "lusd", "husd"
206+
#' )
207+
#' )
203208
#' obj_dml_data_bonus
204209
#' @export
205210
fetch_bonus = function(return_type = "DoubleMLData",
@@ -235,7 +240,8 @@ fetch_bonus = function(return_type = "DoubleMLData",
235240
}
236241
data = data.frame(
237242
"inuidur1" = data$inuidur1, model.matrix(formula_x, data),
238-
"tg" = data$tg)
243+
"tg" = data$tg
244+
)
239245
if (return_type == "data.frame") {
240246
return(data)
241247
} else if (return_type == "data.table") {
@@ -315,7 +321,12 @@ make_plr_CCDDHNR2018 = function(n_obs = 500, dim_x = 20, alpha = 0.5,
315321

316322
assert_choice(
317323
return_type,
318-
c("data.table", "matrix", "data.frame", "DoubleMLData"))
324+
c("data.table", "matrix", "data.frame", "DoubleMLData")
325+
)
326+
assert_count(n_obs)
327+
assert_count(dim_x)
328+
assert_numeric(alpha, len = 1)
329+
319330
cov_mat = toeplitz(0.7^(0:(dim_x - 1)))
320331
a_0 = 1
321332
a_1 = 0.25
@@ -405,6 +416,16 @@ make_plr_CCDDHNR2018 = function(n_obs = 500, dim_x = 20, alpha = 0.5,
405416
make_plr_turrell2018 = function(n_obs = 100, dim_x = 20, theta = 0.5,
406417
return_type = "DoubleMLData", nu = 0, gamma = 1) {
407418

419+
assert_choice(
420+
return_type,
421+
c("data.table", "matrix", "data.frame", "DoubleMLData")
422+
)
423+
assert_count(n_obs)
424+
assert_count(dim_x)
425+
assert_numeric(theta, len = 1)
426+
assert_numeric(nu, len = 1)
427+
assert_numeric(gamma, len = 1)
428+
408429
b = 1 / (1:dim_x)
409430
sigma = genPositiveDefMat(dim_x)
410431
x = rmvnorm(n = n_obs, mean = rep(0, dim_x), sigma = sigma$Sigma)
@@ -497,7 +518,12 @@ make_pliv_CHS2015 = function(n_obs, alpha = 1, dim_x = 200, dim_z = 150,
497518

498519
assert_choice(
499520
return_type,
500-
c("data.table", "matrix", "data.frame", "DoubleMLData"))
521+
c("data.table", "matrix", "data.frame", "DoubleMLData")
522+
)
523+
assert_count(n_obs)
524+
assert_count(dim_x)
525+
assert_count(dim_z)
526+
assert_numeric(alpha, len = 1)
501527
if (dim_x < dim_z) {
502528
stop("Dimension of X should be greater than dimension of Z.")
503529
}
@@ -545,14 +571,13 @@ make_pliv_CHS2015 = function(n_obs, alpha = 1, dim_x = 200, dim_z = 150,
545571
data = DoubleMLData$new(dt,
546572
y_col = "y", d_cols = "d",
547573
x_cols = colnames(x),
548-
z_cols = colnames(z))
574+
z_cols = colnames(z)
575+
)
549576
return(data)
550577
}
551578
}
552-
return(data)
553579
}
554580

555-
556581
#' @title Generates data from a interactive regression (IRM) model.
557582
#'
558583
#' @description
@@ -610,7 +635,16 @@ make_irm_data = function(n_obs = 500, dim_x = 20, theta = 0, R2_d = 0.5,
610635
# inspired by https://onlinelibrary.wiley.com/doi/abs/10.3982/ECTA12723
611636
# (see supplement)
612637

613-
assert_choice(return_type, c("data.table", "matrix", "data.frame", "DoubleMLData"))
638+
assert_choice(
639+
return_type,
640+
c("data.table", "matrix", "data.frame", "DoubleMLData")
641+
)
642+
assert_count(n_obs)
643+
assert_count(dim_x)
644+
assert_numeric(theta, len = 1)
645+
assert_numeric(R2_d, len = 1)
646+
assert_numeric(R2_y, len = 1)
647+
614648
v = runif(n_obs)
615649
zeta = rnorm(n_obs)
616650
cov_mat = toeplitz(0.5^(0:(dim_x - 1)))
@@ -702,10 +736,16 @@ make_iivm_data = function(n_obs = 500, dim_x = 20, theta = 1, alpha_x = 0.2,
702736

703737
assert_choice(
704738
return_type,
705-
c("data.table", "matrix", "data.frame", "DoubleMLData"))
739+
c("data.table", "matrix", "data.frame", "DoubleMLData")
740+
)
741+
assert_count(n_obs)
742+
assert_count(dim_x)
743+
assert_numeric(theta, len = 1)
744+
assert_numeric(alpha_x, len = 1)
706745
xx = rmvnorm(
707746
n = n_obs, mean = rep(0, 2),
708-
sigma = matrix(c(1, 0.3, 0.3, 1), ncol = 2, nrow = 2))
747+
sigma = matrix(c(1, 0.3, 0.3, 1), ncol = 2, nrow = 2)
748+
)
709749
u = xx[, 1]
710750
v = xx[, 2]
711751

@@ -739,7 +779,8 @@ make_iivm_data = function(n_obs = 500, dim_x = 20, theta = 1, alpha_x = 0.2,
739779
data = DoubleMLData$new(dt,
740780
y_col = "y", d_cols = "d",
741781
x_cols = colnames(x),
742-
z_cols = "z")
782+
z_cols = "z"
783+
)
743784
return(data)
744785
}
745786
}

R/double_ml_data.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ DoubleMLData = R6Class("DoubleMLData",
289289
},
290290

291291
#' @description
292-
#' Setter function for `data_model`. The function implements the causal model
293-
#' as specified by the user via `y_col`, `d_cols`, `x_cols` and `z_cols` and
294-
#' assigns the role for the treatment variables in the multiple-treatment
295-
#' case.
292+
#' Setter function for `data_model`. The function implements the causal
293+
#' model as specified by the user via `y_col`, `d_cols`, `x_cols` and
294+
#' `z_cols` and assigns the role for the treatment variables in the
295+
#' multiple-treatment case.
296296
#' @param treatment_var (`character()`)\cr
297297
#' Active treatment variable that will be set to `treat_col`.
298298
set_data_model = function(treatment_var) {

R/double_ml_iivm.R

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,19 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
139139
#' of the causal model.
140140
#'
141141
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
142-
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
142+
#' [`LearnerClassif`][mlr3::LearnerClassif], [`Learner`][mlr3::Learner],
143+
#' `character(1)`) \cr
143144
#' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
144145
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
145146
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
146147
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
148+
#' For binary treatment outcomes, an object of the class
149+
#' [`LearnerClassif`][mlr3::LearnerClassif] can be passed, for example
150+
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
147151
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
148-
#' `task_type = "regr"` can be passed, for example of class
149-
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
150-
#' be passed with specified parameters, for example
151-
#' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr
152+
#' `task_type = "regr"` or `task_type = "classif"` can be passed,
153+
#' respectively, for example of class
154+
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. \cr
152155
#' `ml_g` refers to the nuisance function \eqn{g_0(Z,X) = E[Y|X,Z]}.
153156
#'
154157
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
@@ -246,7 +249,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
246249
"ml_g" = NULL,
247250
"ml_m" = NULL,
248251
"ml_r" = NULL)
249-
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = FALSE)
252+
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = TRUE)
250253
ml_m = private$assert_learner(ml_m, "ml_m", Regr = FALSE, Classif = TRUE)
251254
ml_r = private$assert_learner(ml_r, "ml_r", Regr = FALSE, Classif = TRUE)
252255

R/double_ml_irm.R

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,19 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
108108
#' of the causal model.
109109
#'
110110
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
111-
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
111+
#' [`LearnerClassif`][mlr3::LearnerClassif], [`Learner`][mlr3::Learner],
112+
#' `character(1)`) \cr
112113
#' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
113114
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
114115
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
115116
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
117+
#' For binary treatment outcomes, an object of the class
118+
#' [`LearnerClassif`][mlr3::LearnerClassif] can be passed, for example
119+
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
116120
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
117-
#' `task_type = "regr"` can be passed, for example of class
118-
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
119-
#' be passed with specified parameters, for example
120-
#' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr
121+
#' `task_type = "regr"` or `task_type = "classif"` can be passed,
122+
#' respectively, for example of class
123+
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. \cr
121124
#' `ml_g` refers to the nuisance function \eqn{g_0(X) = E[Y|X,D]}.
122125
#'
123126
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
@@ -189,7 +192,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
189192
private$task_type = list(
190193
"ml_g" = NULL,
191194
"ml_m" = NULL)
192-
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = FALSE)
195+
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = TRUE)
193196
ml_m = private$assert_learner(ml_m, "ml_m", Regr = FALSE, Classif = TRUE)
194197

195198
private$learner_ = list(

R/helper.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ extract_prediction = function(obj_resampling, task_type, n_obs,
190190
}
191191
} else {
192192
preds = rep(NA_real_, n_obs)
193-
if (testR6(obj_resampling, classes = "ResampleResult")) obj_resampling = list(obj_resampling)
193+
if (testR6(obj_resampling, classes = "ResampleResult")) {
194+
obj_resampling = list(obj_resampling)
195+
}
194196
n_obj_rsmp = length(obj_resampling)
195197
for (i_obj_rsmp in 1:n_obj_rsmp) {
196198
f_hat = as.data.table(obj_resampling[[i_obj_rsmp]]$prediction("test"))
@@ -205,7 +207,9 @@ initiate_learner = function(learner, task_type, params, return_train_preds = FAL
205207
ml_learner = learner$clone()
206208

207209
if (!is.null(params)) {
208-
ml_learner$param_set$values = insert_named(ml_learner$param_set$values, params)
210+
ml_learner$param_set$values = insert_named(
211+
ml_learner$param_set$values,
212+
params)
209213
} # else if (is.null(params) | length(params) == 0) {
210214
# message("No parameters provided for learners. Default values are used.")
211215
# }

man/DoubleMLData.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/DoubleMLIIVM.Rd

Lines changed: 8 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/DoubleMLIRM.Rd

Lines changed: 8 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fetch_bonus.Rd

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)