Skip to content

Commit 0537cbe

Browse files
authored
Merge pull request #160 from DoubleML/m-fix-bugs
A couple of bug fixes and minor improvements
2 parents acb9d46 + e09408c commit 0537cbe

File tree

8 files changed

+50
-41
lines changed

8 files changed

+50
-41
lines changed

R/double_ml.R

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ DoubleML = R6Class("DoubleML",
771771
n_folds_tune = 5,
772772
rsmp_tune = mlr3::rsmp("cv", folds = 5),
773773
measure = NULL,
774-
terminator = mlr3tunin::trm("evals", n_evals = 20),
774+
terminator = mlr3tuning::trm("evals", n_evals = 20),
775775
algorithm = mlr3tuning::tnr("grid_search"),
776776
resolution = 5),
777777
tune_on_folds = FALSE) {
@@ -1311,7 +1311,7 @@ DoubleML = R6Class("DoubleML",
13111311
tune_settings$rsmp_tune = rsmp("cv", folds = tune_settings$n_folds_tune)
13121312
}
13131313

1314-
if (test_names(names(tune_settings), must.include = "measure")) {
1314+
if (test_names(names(tune_settings), must.include = "measure") && !is.null(tune_settings$measure)) {
13151315
assert_list(tune_settings$measure)
13161316
if (!test_names(names(tune_settings$measure),
13171317
subset.of = valid_learner)) {
@@ -1327,13 +1327,12 @@ DoubleML = R6Class("DoubleML",
13271327
check_class(tune_settings$measure[[i_msr]], "Measure"))
13281328
}
13291329
} else {
1330-
tune_settings$measure = rep(list(NA), length(valid_learner))
1330+
tune_settings$measure = rep(list(NULL), length(valid_learner))
13311331
names(tune_settings$measure) = valid_learner
13321332
}
13331333

1334-
for (i_msr in seq_len(length(tune_settings$measure))) {
1335-
if (!test_class(tune_settings$measure[[i_msr]], "Measure")) {
1336-
this_learner = names(tune_settings$measure)[i_msr]
1334+
for (this_learner in valid_learner) {
1335+
if (!test_class(tune_settings$measure[[this_learner]], "Measure")) {
13371336
tune_settings$measure[[this_learner]] = set_default_measure(
13381337
tune_settings$measure[[this_learner]],
13391338
private$task_type[[this_learner]])

R/double_ml_iivm.R

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,6 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
245245

246246
private$check_data(self$data)
247247
private$check_score(self$score)
248-
private$task_type = list(
249-
"ml_g" = NULL,
250-
"ml_m" = NULL,
251-
"ml_r" = NULL)
252248
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = TRUE)
253249
ml_m = private$assert_learner(ml_m, "ml_m", Regr = FALSE, Classif = TRUE)
254250
ml_r = private$assert_learner(ml_r, "ml_r", Regr = FALSE, Classif = TRUE)

R/double_ml_irm.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,6 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
189189

190190
private$check_data(self$data)
191191
private$check_score(self$score)
192-
private$task_type = list(
193-
"ml_g" = NULL,
194-
"ml_m" = NULL)
195192
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = TRUE)
196193
ml_m = private$assert_learner(ml_m, "ml_m", Regr = FALSE, Classif = TRUE)
197194

R/double_ml_pliv.R

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
296296
return_train_preds = FALSE,
297297
task_type = private$task_type$ml_m,
298298
fold_specific_params = private$fold_specific_params)
299+
z = self$data$data_model[[self$data$z_cols]]
299300
} else {
300301
m_hat = do.call(
301302
cbind,
@@ -313,6 +314,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
313314
task_type = private$task_type$ml_m,
314315
fold_specific_params = private$fold_specific_params)
315316
}))
317+
z = self$data$data_model[, self$data$z_cols, with = FALSE]
316318
}
317319

318320
d = self$data$data_model[[self$data$treat_col]]
@@ -328,14 +330,21 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
328330
score_elements = function(y, z, d, g_hat, m_hat, r_hat, smpls) {
329331
u_hat = y - g_hat
330332
w_hat = d - r_hat
333+
v_hat = z - m_hat
331334

332335
if (self$data$n_instr == 1) {
333-
z = self$data$data_model[[self$data$z_cols]]
334-
v_hat = z - m_hat
336+
if (is.character(self$score)) {
337+
if (self$score == "partialling out") {
338+
psi_a = -w_hat * v_hat
339+
psi_b = v_hat * u_hat
340+
}
341+
psis = list(
342+
psi_a = psi_a,
343+
psi_b = psi_b)
344+
} else if (is.function(self$score)) {
345+
psis = self$score(y, z, d, g_hat, m_hat, r_hat, smpls)
346+
}
335347
} else {
336-
z = self$data$data_model[, self$data$z_cols, with = FALSE]
337-
v_hat = z - m_hat
338-
339348
stopifnot(self$apply_cross_fitting)
340349

341350
# Projection: r_hat from projection on m_hat
@@ -349,29 +358,20 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
349358
r_r_tilde = resample(task_r_tilde, ml_r_tilde, resampling_r_tilde,
350359
store_models = TRUE)
351360
r_hat_tilde = as.data.table(r_r_tilde$prediction())$response
352-
}
353-
if (is.character(self$score)) {
354-
if (self$data$n_instr == 1) {
355-
if (self$score == "partialling out") {
356-
psi_a = -w_hat * v_hat
357-
psi_b = v_hat * u_hat
358-
}
359-
} else {
361+
362+
if (is.character(self$score)) {
360363
if (self$score == "partialling out") {
361364
psi_a = -w_hat * r_hat_tilde
362365
psi_b = r_hat_tilde * u_hat
363366
}
364-
}
365-
psis = list(
366-
psi_a = psi_a,
367-
psi_b = psi_b)
368-
} else if (is.function(self$score)) {
369-
if (self$data$n_instr > 1) {
367+
psis = list(
368+
psi_a = psi_a,
369+
psi_b = psi_b)
370+
} else if (is.function(self$score)) {
370371
stop(paste(
371372
"Callable score not implemented for DoubleMLPLIV with",
372373
"partialX=TRUE and partialZ=FALSE with several instruments."))
373374
}
374-
psis = self$score(y, z, d, g_hat, m_hat, r_hat, smpls)
375375
}
376376
return(psis)
377377
},
@@ -608,7 +608,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
608608
self$data$treat_col, data_tune_list,
609609
nuisance_id = "nuis_m",
610610
param_set$ml_m, tune_settings,
611-
tune_settings$measure$ml_g,
611+
tune_settings$measure$ml_m,
612612
private$task_type$ml_m)
613613

614614
m_params = tuning_result_m$params

R/double_ml_plr.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,6 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
147147

148148
private$check_data(self$data)
149149
private$check_score(self$score)
150-
private$task_type = list(
151-
"ml_g" = NULL,
152-
"ml_m" = NULL)
153150
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = FALSE)
154151
ml_m = private$assert_learner(ml_m, "ml_m", Regr = TRUE, Classif = TRUE)
155152

R/helper.R

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ dml_cv_predict = function(learner, X_cols, y_col,
55
return_train_preds = FALSE, task_type = NULL,
66
fold_specific_params = FALSE) {
77

8-
# TODO: Asserts
8+
valid_task_type = c("regr", "classif")
9+
assertChoice(task_type, valid_task_type)
10+
# TODO: extend asserts
911

1012
if (fold_specific_params) {
1113
stopifnot(length(smpls$train_ids) == length(smpls$test_ids))
@@ -122,6 +124,7 @@ dml_cv_predict = function(learner, X_cols, y_col,
122124

123125
dml_tune = function(learner, X_cols, y_col, data_tune_list,
124126
nuisance_id, param_set, tune_settings, measure, task_type) {
127+
125128
task_tune = lapply(data_tune_list, function(x) {
126129
initiate_task(
127130
id = nuisance_id,
@@ -130,6 +133,9 @@ dml_tune = function(learner, X_cols, y_col, data_tune_list,
130133
select_cols = X_cols,
131134
task_type = task_type)
132135
})
136+
valid_task_type = c("regr", "classif")
137+
assertChoice(task_type, valid_task_type)
138+
133139
ml_learner = initiate_learner(learner, task_type, params = learner$param_set$values)
134140
tuning_instance = lapply(task_tune, function(x) {
135141
TuningInstanceSingleCrit$new(
@@ -154,6 +160,10 @@ dml_tune = function(learner, X_cols, y_col, data_tune_list,
154160

155161
extract_prediction = function(obj_resampling, task_type, n_obs,
156162
return_train_preds = FALSE) {
163+
164+
valid_task_type = c("regr", "classif")
165+
assertChoice(task_type, valid_task_type)
166+
157167
if (compareVersion(as.character(packageVersion("mlr3")), "0.11.0") < 0) {
158168
ind_name = "row_id"
159169
} else {
@@ -204,6 +214,10 @@ extract_prediction = function(obj_resampling, task_type, n_obs,
204214
}
205215

206216
initiate_learner = function(learner, task_type, params, return_train_preds = FALSE) {
217+
218+
valid_task_type = c("regr", "classif")
219+
assertChoice(task_type, valid_task_type)
220+
207221
ml_learner = learner$clone()
208222

209223
if (!is.null(params)) {
@@ -225,6 +239,9 @@ initiate_learner = function(learner, task_type, params, return_train_preds = FAL
225239

226240
# Function to initialize task (regression or classification)
227241
initiate_task = function(id, data, target, select_cols, task_type) {
242+
valid_task_type = c("regr", "classif")
243+
assertChoice(task_type, valid_task_type)
244+
228245
if (!is.null(select_cols)) {
229246
indx = (names(data) %in% c(select_cols, target))
230247
data = data[, indx, with = FALSE]
@@ -277,7 +294,10 @@ get_cond_samples = function(smpls, D) {
277294
}
278295

279296
set_default_measure = function(measure_in = NA, task_type) {
280-
if (is.na(measure_in)) {
297+
valid_task_type = c("regr", "classif")
298+
assertChoice(task_type, valid_task_type)
299+
300+
if (is.null(measure_in)) {
281301
if (task_type == "regr") {
282302
measure = msr("regr.mse")
283303
} else if (task_type == "classif") {

man/DoubleML.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/helper-13-dml_pliv_partial_x.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ fit_nuisance_pliv_partial_x = function(data, y, d, z,
123123
data_r = data[, r_indx, drop = FALSE]
124124
task_r = mlr3::TaskRegr$new(id = paste0("nuis_r_", d), backend = data_r, target = d)
125125
if (!is.null(params_r)) {
126-
ml_g$param_set$values = params_r
126+
ml_r$param_set$values = params_r
127127
}
128128

129129
resampling_r = mlr3::rsmp("custom")

0 commit comments

Comments
 (0)