Skip to content

Commit 5526694

Browse files
authored
Merge pull request #120 from DoubleML/m-apply-styler
apply styler as described in the wiki
2 parents 2220df3 + 71048d7 commit 5526694

File tree

48 files changed

+2989
-2610
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2989
-2610
lines changed

R/double_ml.R

Lines changed: 195 additions & 112 deletions
Large diffs are not rendered by default.

R/double_ml_data.R

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@ DoubleMLData = R6Class("DoubleMLData",
2525
#' @field all_variables (`character()`)\cr
2626
#' All variables available in the dataset.
2727
all_variables = function(value) {
28-
if (missing(value)) return(names(self$data))
29-
else stop("can't set field all_variables")
28+
if (missing(value)) {
29+
return(names(self$data))
30+
} else {
31+
stop("can't set field all_variables")
32+
}
3033
},
3134

3235
#' @field d_cols (`character()`)\cr
3336
#' The treatment variable(s).
3437
d_cols = function(value) {
35-
if (missing(value)) return(private$d_cols_)
36-
else {
38+
if (missing(value)) {
39+
return(private$d_cols_)
40+
} else {
3741
d_cols = value # to get more meaningful assert error messages
3842
reset_value = !is.null(self$data_model)
3943
assert_character(d_cols, unique = TRUE)
@@ -49,37 +53,52 @@ DoubleMLData = R6Class("DoubleMLData",
4953
#' @field data ([`data.table`][data.table::data.table()])\cr
5054
#' Data object.
5155
data = function(value) {
52-
if (missing(value)) return(private$data_)
53-
else stop("can't set field data")
56+
if (missing(value)) {
57+
return(private$data_)
58+
} else {
59+
stop("can't set field data")
60+
}
5461
},
5562

5663
#' @field data_model ([`data.table`][data.table::data.table()])\cr
5764
#' Internal data object that implements the causal model as specified by
5865
#' the user via `y_col`, `d_cols`, `x_cols` and `z_cols`.
5966
data_model = function(value) {
60-
if (missing(value)) return(private$data_model_)
61-
else stop("can't set field data_model")
67+
if (missing(value)) {
68+
return(private$data_model_)
69+
} else {
70+
stop("can't set field data_model")
71+
}
6272
},
6373

6474
#' @field n_instr (`NULL`, `integer(1)`) \cr
6575
#' The number of instruments.
6676
n_instr = function(value) {
67-
if (missing(value)) return(length(self$z_cols))
68-
else stop("can't set field n_instr")
77+
if (missing(value)) {
78+
return(length(self$z_cols))
79+
} else {
80+
stop("can't set field n_instr")
81+
}
6982
},
7083

7184
#' @field n_obs (`integer(1)`) \cr
7285
#' The number of observations.
7386
n_obs = function(value) {
74-
if (missing(value)) return(dim(self$data)[1])
75-
else stop("can't set field n_obs")
87+
if (missing(value)) {
88+
return(dim(self$data)[1])
89+
} else {
90+
stop("can't set field n_obs")
91+
}
7692
},
7793

7894
#' @field n_treat (`integer(1)`) \cr
7995
#' The umber of treatment variables.
8096
n_treat = function(value) {
81-
if (missing(value)) return(length(self$d_cols))
82-
else stop("can't set field n_treat")
97+
if (missing(value)) {
98+
return(length(self$d_cols))
99+
} else {
100+
stop("can't set field n_treat")
101+
}
83102
},
84103

85104
#' @field other_treat_cols (`NULL`, `character()`) \cr
@@ -89,23 +108,30 @@ DoubleMLData = R6Class("DoubleMLData",
89108
#' the fitting stage. If `use_other_treat_as_covariate` is `FALSE`,
90109
#' `other_treat_cols` is `NULL`.
91110
other_treat_cols = function(value) {
92-
if (missing(value)) return(private$other_treat_cols_)
93-
else stop("can't set field other_treat_cols")
111+
if (missing(value)) {
112+
return(private$other_treat_cols_)
113+
} else {
114+
stop("can't set field other_treat_cols")
115+
}
94116
},
95117

96118
#' @field treat_col (`character(1)`) \cr
97119
#' "Active" treatment variable in the multiple-treatment case.
98120
treat_col = function(value) {
99-
if (missing(value)) return(private$treat_col_)
100-
else stop("can't set field treat_col")
121+
if (missing(value)) {
122+
return(private$treat_col_)
123+
} else {
124+
stop("can't set field treat_col")
125+
}
101126
},
102127

103128
#' @field use_other_treat_as_covariate (`logical(1)`) \cr
104129
#' Indicates whether in the multiple-treatment case the other treatment
105130
#' variables should be added as covariates. Default is `TRUE`.
106131
use_other_treat_as_covariate = function(value) {
107-
if (missing(value)) return(private$use_other_treat_as_covariate_)
108-
else {
132+
if (missing(value)) {
133+
return(private$use_other_treat_as_covariate_)
134+
} else {
109135
use_other_treat_as_covariate = value # to get more meaningful assert error messages
110136
reset_value = !is.null(self$data_model)
111137
assert_logical(use_other_treat_as_covariate, len = 1)
@@ -123,8 +149,9 @@ DoubleMLData = R6Class("DoubleMLData",
123149
#' `d_cols`, nor as instrumental variables `z_cols` are used as covariates.
124150
#' Default is `NULL`.
125151
x_cols = function(value) {
126-
if (missing(value)) return(private$x_cols_)
127-
else {
152+
if (missing(value)) {
153+
return(private$x_cols_)
154+
} else {
128155
x_cols = value # to get more meaningful assert error messages
129156
reset_value = !is.null(self$data_model)
130157
if (!is.null(x_cols)) {
@@ -153,8 +180,9 @@ DoubleMLData = R6Class("DoubleMLData",
153180
#' @field y_col (`character(1)`) \cr
154181
#' The outcome variable.
155182
y_col = function(value) {
156-
if (missing(value)) return(private$y_col_)
157-
else {
183+
if (missing(value)) {
184+
return(private$y_col_)
185+
} else {
158186
y_col = value # to get more meaningful assert error messages
159187
reset_value = !is.null(self$data_model)
160188
assert_character(y_col, len = 1)
@@ -170,8 +198,9 @@ DoubleMLData = R6Class("DoubleMLData",
170198
#' @field z_cols (`NULL`, `character()`) \cr
171199
#' The instrumental variables. Default is `NULL`.
172200
z_cols = function(value) {
173-
if (missing(value)) return(private$z_cols_)
174-
else {
201+
if (missing(value)) {
202+
return(private$z_cols_)
203+
} else {
175204
z_cols = value # to get more meaningful assert error messages
176205
reset_value = !is.null(self$data_model)
177206
if (!is.null(z_cols)) {
@@ -239,7 +268,7 @@ DoubleMLData = R6Class("DoubleMLData",
239268

240269
invisible(self)
241270
},
242-
271+
243272
#' @description
244273
#' Print DoubleMLData objects.
245274
print = function() {
@@ -252,10 +281,10 @@ DoubleMLData = R6Class("DoubleMLData",
252281
"Instrument(s): ", paste0(self$z_cols, collapse = ", "), "\n",
253282
"No. Observations: ", self$n_obs, "\n")
254283
cat(header, "\n",
255-
"\n------------------ Data summary ------------------\n",
256-
data_info,
257-
sep = "")
258-
284+
"\n------------------ Data summary ------------------\n",
285+
data_info,
286+
sep = "")
287+
259288
invisible(self)
260289
},
261290

@@ -395,9 +424,9 @@ double_ml_data_from_data_frame = function(df, x_cols = NULL, y_col = NULL,
395424
d_cols = NULL, z_cols = NULL,
396425
use_other_treat_as_covariate = TRUE) {
397426
data = DoubleMLData$new(df,
398-
x_cols = x_cols, y_col = y_col, d_cols = d_cols,
399-
z_cols = z_cols,
400-
use_other_treat_as_covariate = use_other_treat_as_covariate)
427+
x_cols = x_cols, y_col = y_col, d_cols = d_cols,
428+
z_cols = z_cols,
429+
use_other_treat_as_covariate = use_other_treat_as_covariate)
401430
return(data)
402431
}
403432

R/double_ml_iivm.R

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,33 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
103103
#' always takers in the sample. The entry `never_takers` (`logical(1)`)
104104
#' speficies whether there are never takers in the sample.
105105
subgroups = function(value) {
106-
if (missing(value)) return(private$subgroups_)
107-
else stop("can't set field subgroups")
106+
if (missing(value)) {
107+
return(private$subgroups_)
108+
} else {
109+
stop("can't set field subgroups")
110+
}
108111
},
109112

110113
#' @field trimming_rule (`character(1)`) \cr
111114
#' A `character(1)` specifying the trimming approach.
112115
trimming_rule = function(value) {
113-
if (missing(value)) return(private$trimming_rule_)
114-
else stop("can't set field trimming_rule")
116+
if (missing(value)) {
117+
return(private$trimming_rule_)
118+
} else {
119+
stop("can't set field trimming_rule")
120+
}
115121
},
116122

117123
#' @field trimming_threshold (`numeric(1)`) \cr
118124
#' The threshold used for timming.
119125
trimming_threshold = function(value) {
120-
if (missing(value)) return(private$trimming_threshold_)
121-
else stop("can't set field trimming_threshold")
126+
if (missing(value)) {
127+
return(private$trimming_threshold_)
128+
} else {
129+
stop("can't set field trimming_threshold")
130+
}
122131
}),
123-
132+
124133
public = list(
125134
#' @description
126135
#' Creates a new instance of this R6 class.
@@ -229,7 +238,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
229238
dml_procedure,
230239
draw_sample_splitting,
231240
apply_cross_fitting)
232-
241+
233242
private$check_data(self$data)
234243
private$check_score(self$score)
235244
private$learner_class = list(
@@ -484,7 +493,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
484493
"needs to be specified as treatment variable.")
485494
if (one_treat) {
486495
binary_treat = test_integerish(obj_dml_data$data[[obj_dml_data$d_cols]],
487-
lower = 0, upper = 1)
496+
lower = 0, upper = 1)
488497
if (!(one_treat & binary_treat)) {
489498
stop(err_msg)
490499
}
@@ -500,7 +509,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
500509
"needs to be specified as instrumental variable.")
501510
if (one_instr) {
502511
binary_instr = test_integerish(obj_dml_data$data[[obj_dml_data$z_cols]],
503-
lower = 0, upper = 1)
512+
lower = 0, upper = 1)
504513
if (!(one_instr & binary_instr)) {
505514
stop(err_msg)
506515
}

R/double_ml_irm.R

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,26 @@
7777
#'
7878
#' @export
7979
DoubleMLIRM = R6Class("DoubleMLIRM",
80-
inherit = DoubleML,
80+
inherit = DoubleML,
8181
active = list(
8282
#' @field trimming_rule (`character(1)`) \cr
8383
#' A `character(1)` specifying the trimming approach.
8484
trimming_rule = function(value) {
85-
if (missing(value)) return(private$trimming_rule_)
86-
else stop("can't set field trimming_rule")
85+
if (missing(value)) {
86+
return(private$trimming_rule_)
87+
} else {
88+
stop("can't set field trimming_rule")
89+
}
8790
},
8891

8992
#' @field trimming_threshold (`numeric(1)`) \cr
9093
#' The threshold used for timming.
9194
trimming_threshold = function(value) {
92-
if (missing(value)) return(private$trimming_threshold_)
93-
else stop("can't set field trimming_threshold")
95+
if (missing(value)) {
96+
return(private$trimming_threshold_)
97+
} else {
98+
stop("can't set field trimming_threshold")
99+
}
94100
}),
95101

96102
public = list(
@@ -381,7 +387,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
381387
"needs to be specified as treatment variable.")
382388
if (one_treat) {
383389
binary_treat = test_integerish(obj_dml_data$data[[obj_dml_data$d_cols]],
384-
lower = 0, upper = 1)
390+
lower = 0, upper = 1)
385391
if (!(one_treat & binary_treat)) {
386392
stop(err_msg)
387393
}

R/double_ml_pliv.R

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,23 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
7474
#' @field partialX (`logical(1)`) \cr
7575
#' Indicates whether covariates \eqn{X} should be partialled out.
7676
partialX = function(value) {
77-
if (missing(value)) return(private$partialX_)
78-
else stop("can't set field partialX")
77+
if (missing(value)) {
78+
return(private$partialX_)
79+
} else {
80+
stop("can't set field partialX")
81+
}
7982
},
8083

8184
#' @field partialZ (`logical(1)`) \cr
8285
#' Indicates whether instruments \eqn{Z} should be partialled out.
8386
partialZ = function(value) {
84-
if (missing(value)) return(private$partialZ_)
85-
else stop("can't set field partialZ")
87+
if (missing(value)) {
88+
return(private$partialZ_)
89+
} else {
90+
stop("can't set field partialZ")
91+
}
8692
}),
87-
93+
8894
public = list(
8995
#' @description
9096
#' Creates a new instance of this R6 class.
@@ -182,7 +188,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
182188
dml_procedure,
183189
draw_sample_splitting,
184190
apply_cross_fitting)
185-
191+
186192
private$check_data(self$data)
187193
private$check_score(self$score)
188194
assert_logical(partialX, len = 1)
@@ -358,8 +364,9 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
358364
psi_b = psi_b)
359365
} else if (is.function(self$score)) {
360366
if (self$data$n_instr > 1) {
361-
stop(paste("Callable score not implemented for DoubleMLPLIV with",
362-
"partialX=TRUE and partialZ=FALSE with several instruments."))
367+
stop(paste(
368+
"Callable score not implemented for DoubleMLPLIV with",
369+
"partialX=TRUE and partialZ=FALSE with several instruments."))
363370
}
364371
psis = self$score(y, z, d, g_hat, m_hat, r_hat, smpls)
365372
}
@@ -394,7 +401,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
394401
m_hat = m_hat_list$preds
395402
data_aux_list = lapply(m_hat_list$train_preds, function(x) {
396403
setnafill(data.table(self$data$data_model, "m_hat_on_train" = x),
397-
fill = -9999.99) # mlr3 does not allow NA's (values are not used)
404+
fill = -9999.99) # mlr3 does not allow NA's (values are not used)
398405
})
399406

400407
m_hat_tilde = dml_cv_predict(self$learner$ml_r,
@@ -425,8 +432,9 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
425432
psi_a = psi_a,
426433
psi_b = psi_b)
427434
} else if (is.function(self$score)) {
428-
stop(paste("Callable score not implemented for DoubleMLPLIV",
429-
"with partialX=TRUE and partialZ=TRUE."))
435+
stop(paste(
436+
"Callable score not implemented for DoubleMLPLIV",
437+
"with partialX=TRUE and partialZ=TRUE."))
430438
# res = self$score(y, d, g_hat, m_hat, m_hat_tilde)
431439
}
432440
res$preds = list(
@@ -464,8 +472,9 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
464472
}
465473
res = list(psi_a = psi_a, psi_b = psi_b)
466474
} else if (is.function(self$score)) {
467-
stop(paste("Callable score not implemented for DoubleMLPLIV",
468-
"with partialX=FALSE and partialZ=TRUE."))
475+
stop(paste(
476+
"Callable score not implemented for DoubleMLPLIV",
477+
"with partialX=FALSE and partialZ=TRUE."))
469478
# res = self$score(y, z, d, r_hat)
470479
}
471480
res$preds = list("ml_r" = r_hat)

0 commit comments

Comments
 (0)