Skip to content

Commit 65facd8

Browse files
committed
prevent using the subclassed methods check_score and check_data when constructing DoubleML objects (see also DoubleML/doubleml-for-py#103)
1 parent 1c2e5db commit 65facd8

File tree

5 files changed

+15
-6
lines changed

5 files changed

+15
-6
lines changed

R/double_ml.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,6 @@ DoubleML = R6Class("DoubleML",
829829
# check and pick up obj_dml_data
830830

831831
assert_class(data, "DoubleMLData")
832-
private$check_data(data)
833832
self$data = data
834833

835834
# initialize learners and parameters which are set model specific
@@ -853,7 +852,7 @@ DoubleML = R6Class("DoubleML",
853852
# check and set dml_procedure and score
854853
assert_choice(dml_procedure, c("dml1", "dml2"))
855854
self$dml_procedure = dml_procedure
856-
self$score = private$check_score(score)
855+
self$score = score
857856

858857
if (self$n_folds == 1 & self$apply_cross_fitting) {
859858
message(paste(

R/double_ml_iivm.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
218218
dml_procedure,
219219
draw_sample_splitting,
220220
apply_cross_fitting)
221+
222+
private$check_data(self$data)
223+
private$check_score(self$score)
221224
private$learner_class = list(
222225
"ml_g" = NULL,
223226
"ml_m" = NULL,
@@ -456,7 +459,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
456459
valid_score = c("LATE")
457460
assertChoice(score, valid_score)
458461
}
459-
return(score)
462+
return()
460463
},
461464
check_data = function(obj_dml_data) {
462465
one_treat = (obj_dml_data$n_treat == 1)

R/double_ml_irm.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
169169
draw_sample_splitting,
170170
apply_cross_fitting)
171171

172+
private$check_data(self$data)
173+
private$check_score(self$score)
172174
private$learner_class = list(
173175
"ml_g" = NULL,
174176
"ml_m" = NULL)
@@ -353,7 +355,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
353355
assertChoice(score, valid_score)
354356
}
355357
}
356-
return(score)
358+
return()
357359
},
358360
check_data = function(obj_dml_data) {
359361
if (!is.null(obj_dml_data$z_cols)) {

R/double_ml_pliv.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
174174
dml_procedure,
175175
draw_sample_splitting,
176176
apply_cross_fitting)
177+
178+
private$check_data(self$data)
179+
private$check_score(self$score)
177180
assert_logical(partialX, len = 1)
178181
assert_logical(partialZ, len = 1)
179182
self$partialX = partialX
@@ -668,7 +671,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
668671
valid_score = c("partialling out")
669672
assertChoice(score, valid_score)
670673
}
671-
return(score)
674+
return()
672675
},
673676
check_data = function(obj_dml_data) {
674677
return()

R/double_ml_plr.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
145145
draw_sample_splitting,
146146
apply_cross_fitting)
147147

148+
private$check_data(self$data)
149+
private$check_score(self$score)
148150
private$learner_class = list(
149151
"ml_g" = NULL,
150152
"ml_m" = NULL)
@@ -261,7 +263,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
261263
valid_score = c("IV-type", "partialling out")
262264
assertChoice(score, valid_score)
263265
}
264-
return(score)
266+
return()
265267
},
266268
check_data = function(obj_dml_data) {
267269
if (!is.null(obj_dml_data$z_cols)) {

0 commit comments

Comments
 (0)