Skip to content

Commit d85ce35

Browse files
authored
Merge pull request #106 from DoubleML/m-active-bindings
Use active bindings in the R6 OOP implementation
2 parents 1e67dba + 5d30302 commit d85ce35

24 files changed

+1374
-349
lines changed

R/double_ml.R

Lines changed: 296 additions & 169 deletions
Large diffs are not rendered by default.

R/double_ml_data.R

Lines changed: 159 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -21,72 +21,176 @@
2121
#' d_cols = "d")
2222
#' @export
2323
DoubleMLData = R6Class("DoubleMLData",
24-
public = list(
24+
active = list(
2525
#' @field all_variables (`character()`)\cr
2626
#' All variables available in the dataset.
27-
all_variables = NULL,
27+
all_variables = function(value) {
28+
if (missing(value)) return(names(self$data))
29+
else stop("can't set field all_variables")
30+
},
2831

2932
#' @field d_cols (`character()`)\cr
3033
#' The treatment variable(s).
31-
d_cols = NULL,
34+
d_cols = function(value) {
35+
if (missing(value)) return(private$d_cols_)
36+
else {
37+
d_cols = value # to get more meaningful assert error messages
38+
reset_value = !is.null(self$data_model)
39+
assert_character(d_cols, unique = TRUE)
40+
assert_subset(d_cols, self$all_variables)
41+
private$d_cols_ = d_cols
42+
if (reset_value) {
43+
private$check_disjoint_sets()
44+
self$set_data_model(self$d_cols[1])
45+
}
46+
}
47+
},
3248

3349
#' @field data ([`data.table`][data.table::data.table()])\cr
3450
#' Data object.
35-
data = NULL,
51+
data = function(value) {
52+
if (missing(value)) return(private$data_)
53+
else stop("can't set field data")
54+
},
3655

3756
#' @field data_model ([`data.table`][data.table::data.table()])\cr
3857
#' Internal data object that implements the causal model as specified by
3958
#' the user via `y_col`, `d_cols`, `x_cols` and `z_cols`.
40-
data_model = NULL,
59+
data_model = function(value) {
60+
if (missing(value)) return(private$data_model_)
61+
else stop("can't set field data_model")
62+
},
4163

4264
#' @field n_instr (`NULL`, `integer(1)`) \cr
4365
#' The number of instruments.
44-
n_instr = NULL,
66+
n_instr = function(value) {
67+
if (missing(value)) return(length(self$z_cols))
68+
else stop("can't set field n_instr")
69+
},
4570

4671
#' @field n_obs (`integer(1)`) \cr
4772
#' The number of observations.
48-
n_obs = NULL,
73+
n_obs = function(value) {
74+
if (missing(value)) return(dim(self$data)[1])
75+
else stop("can't set field n_obs")
76+
},
4977

5078
#' @field n_treat (`integer(1)`) \cr
5179
#' The umber of treatment variables.
52-
n_treat = NULL,
80+
n_treat = function(value) {
81+
if (missing(value)) return(length(self$d_cols))
82+
else stop("can't set field n_treat")
83+
},
5384

5485
#' @field other_treat_cols (`NULL`, `character()`) \cr
5586
#' If `use_other_treat_as_covariate` is `TRUE`, `other_treat_cols` are the
5687
#' treatment variables that are not "active" in the multiple-treatment case.
5788
#' These variables then are internally added to the covariates `x_cols` during
5889
#' the fitting stage. If `use_other_treat_as_covariate` is `FALSE`,
5990
#' `other_treat_cols` is `NULL`.
60-
other_treat_cols = NULL,
91+
other_treat_cols = function(value) {
92+
if (missing(value)) return(private$other_treat_cols_)
93+
else stop("can't set field other_treat_cols")
94+
},
6195

6296
#' @field treat_col (`character(1)`) \cr
6397
#' "Active" treatment variable in the multiple-treatment case.
64-
treat_col = NULL,
98+
treat_col = function(value) {
99+
if (missing(value)) return(private$treat_col_)
100+
else stop("can't set field treat_col")
101+
},
65102

66103
#' @field use_other_treat_as_covariate (`logical(1)`) \cr
67104
#' Indicates whether in the multiple-treatment case the other treatment
68105
#' variables should be added as covariates. Default is `TRUE`.
69-
use_other_treat_as_covariate = TRUE,
106+
use_other_treat_as_covariate = function(value) {
107+
if (missing(value)) return(private$use_other_treat_as_covariate_)
108+
else {
109+
use_other_treat_as_covariate = value # to get more meaningful assert error messages
110+
reset_value = !is.null(self$data_model)
111+
assert_logical(use_other_treat_as_covariate, len = 1)
112+
private$use_other_treat_as_covariate_ = use_other_treat_as_covariate
113+
if (reset_value) {
114+
private$check_disjoint_sets()
115+
self$set_data_model(self$d_cols[1])
116+
}
117+
}
118+
},
70119

71120
#' @field x_cols (`NULL`, `character()`) \cr
72121
#' The covariates. If `NULL`, all variables (columns of `data`) which are
73122
#' neither specified as outcome variable `y_col`, nor as treatment variables
74123
#' `d_cols`, nor as instrumental variables `z_cols` are used as covariates.
75124
#' Default is `NULL`.
76-
x_cols = NULL,
125+
x_cols = function(value) {
126+
if (missing(value)) return(private$x_cols_)
127+
else {
128+
x_cols = value # to get more meaningful assert error messages
129+
reset_value = !is.null(self$data_model)
130+
if (!is.null(x_cols)) {
131+
assert_character(x_cols, unique = TRUE)
132+
}
133+
134+
if (!is.null(x_cols)) {
135+
assert_subset(x_cols, self$all_variables)
136+
private$x_cols_ = x_cols
137+
} else {
138+
if (!is.null(self$z_cols)) {
139+
y_d_z = unique(c(self$y_col, self$d_cols, self$z_cols))
140+
private$x_cols_ = setdiff(self$all_variables, y_d_z)
141+
} else {
142+
y_d = union(self$y_col, self$d_cols)
143+
private$x_cols_ = setdiff(self$all_variables, y_d)
144+
}
145+
}
146+
if (reset_value) {
147+
private$check_disjoint_sets()
148+
self$set_data_model(self$d_cols[1])
149+
}
150+
}
151+
},
77152

78153
#' @field y_col (`character(1)`) \cr
79154
#' The outcome variable.
80-
y_col = NULL,
155+
y_col = function(value) {
156+
if (missing(value)) return(private$y_col_)
157+
else {
158+
y_col = value # to get more meaningful assert error messages
159+
reset_value = !is.null(self$data_model)
160+
assert_character(y_col, len = 1)
161+
assert_subset(y_col, self$all_variables)
162+
private$y_col_ = y_col
163+
if (reset_value) {
164+
private$check_disjoint_sets()
165+
self$set_data_model(self$d_cols[1])
166+
}
167+
}
168+
},
81169

82170
#' @field z_cols (`NULL`, `character()`) \cr
83171
#' The instrumental variables. Default is `NULL`.
84-
z_cols = NULL,
172+
z_cols = function(value) {
173+
if (missing(value)) return(private$z_cols_)
174+
else {
175+
z_cols = value # to get more meaningful assert error messages
176+
reset_value = !is.null(self$data_model)
177+
if (!is.null(z_cols)) {
178+
assert_character(z_cols, unique = TRUE)
179+
}
180+
assert_subset(z_cols, self$all_variables)
181+
private$z_cols_ = z_cols
182+
if (reset_value) {
183+
private$check_disjoint_sets()
184+
self$set_data_model(self$d_cols[1])
185+
}
186+
}
187+
}),
85188

189+
public = list(
86190
#' @description
87191
#' Creates a new instance of this [R6][R6::R6Class] class.
88192
#'
89-
#' @param data ([`data.table`][data.table::data.table()])\cr
193+
#' @param data ([`data.table`][data.table::data.table()], `data.frame()`)\cr
90194
#' Data object.
91195
#'
92196
#' @param y_col (`character(1)`) \cr
@@ -114,59 +218,46 @@ DoubleMLData = R6Class("DoubleMLData",
114218
z_cols = NULL,
115219
use_other_treat_as_covariate = TRUE) {
116220

117-
# TBD: Input data.frame
118-
119221
if (all(class(data) == "data.frame")) {
120-
stop(paste("'data' is a data.frame, use",
121-
"'double_ml_data_from_data_frame' call to instantiate",
122-
"DoubleMLData."))
222+
data = data.table(data)
123223
}
124224
assert_class(data, "data.table")
125-
if (!is.null(x_cols)) {
126-
assert_character(x_cols, unique = TRUE)
127-
}
128-
assert_character(y_col, len = 1)
129-
assert_character(d_cols, unique = TRUE)
130-
if (!is.null(z_cols)) {
131-
assert_character(z_cols, unique = TRUE)
132-
}
133-
assert_logical(use_other_treat_as_covariate, len = 1)
225+
assert_character(names(data), unique = TRUE)
134226

135-
self$data = data
136-
self$data_model = NULL
227+
private$data_ = data
137228

138229
self$y_col = y_col
139230
self$d_cols = d_cols
140231
self$z_cols = z_cols
141-
142-
if (!is.null(x_cols)) {
143-
self$x_cols = x_cols
144-
} else {
145-
if (!is.null(self$z_cols)) {
146-
y_d_z = unique(c(self$y_col, self$d_cols, self$z_cols))
147-
self$x_cols = setdiff(names(data), y_d_z)
148-
} else {
149-
y_d = union(self$y_col, self$d_cols)
150-
self$x_cols = setdiff(names(data), y_d)
151-
}
152-
}
153-
232+
self$x_cols = x_cols
154233
private$check_disjoint_sets()
155234

156-
self$treat_col = NULL
157-
self$other_treat_cols = NULL
158235
self$use_other_treat_as_covariate = use_other_treat_as_covariate
159236

160-
self$all_variables = names(self$data)
161-
self$n_treat = length(self$d_cols)
162-
self$n_instr = length(self$z_cols)
163-
self$n_obs = dim(self$data)[1]
164-
165237
# by default, we initialize to the first treatment variable
166238
self$set_data_model(d_cols[1])
167239

168240
invisible(self)
169241
},
242+
243+
#' @description
244+
#' Print DoubleMLData objects.
245+
print = function() {
246+
header = "================= DoubleMLData Object ==================\n"
247+
data_info = paste0(
248+
"Outcome variable: ", self$y_col, "\n",
249+
"Treatment variable(s): ", paste0(self$d_cols, collapse = ", "),
250+
"\n",
251+
"Covariates: ", paste0(self$x_cols, collapse = ", "), "\n",
252+
"Instrument(s): ", paste0(self$z_cols, collapse = ", "), "\n",
253+
"No. Observations: ", self$n_obs, "\n")
254+
cat(header, "\n",
255+
"\n------------------ Data summary ------------------\n",
256+
data_info,
257+
sep = "")
258+
259+
invisible(self)
260+
},
170261

171262
#' @description
172263
#' Setter function for `data_model`. The function implements the causal model
@@ -180,25 +271,20 @@ DoubleMLData = R6Class("DoubleMLData",
180271
assert_character(treatment_var, max.len = 1)
181272
assert_subset(treatment_var, self$d_cols)
182273

183-
if (treatment_var %in% self$x_cols) {
184-
stop(paste(
185-
"The specified treatment variable must not be an element of",
186-
"the covariates 'x_cols'."))
187-
}
188-
self$treat_col = treatment_var
274+
private$treat_col_ = treatment_var
189275

190276
if (self$n_treat > 1) {
191277
if (self$use_other_treat_as_covariate) {
192-
self$other_treat_cols = self$d_cols[self$d_cols != treatment_var]
278+
private$other_treat_cols_ = self$d_cols[self$d_cols != treatment_var]
193279
} else {
194280
message("Control variables do not include other treatment variables")
195-
self$other_treat_cols = NULL
281+
private$other_treat_cols_ = NULL
196282
}
197283
}
198284
col_indx = c(
199285
self$x_cols, self$y_col, self$treat_col, self$other_treat_cols,
200286
self$z_cols)
201-
self$data_model = self$data[, col_indx, with = FALSE]
287+
private$data_model_ = self$data[, col_indx, with = FALSE]
202288
stopifnot(nrow(self$data) == nrow(self$data_model))
203289

204290
# successful assigning treatment variable
@@ -209,6 +295,15 @@ DoubleMLData = R6Class("DoubleMLData",
209295
}
210296
),
211297
private = list(
298+
d_cols_ = NULL,
299+
data_ = NULL,
300+
data_model_ = NULL,
301+
other_treat_cols_ = NULL,
302+
treat_col_ = NULL,
303+
use_other_treat_as_covariate_ = NULL,
304+
x_cols_ = NULL,
305+
y_col_ = NULL,
306+
z_cols_ = NULL,
212307
check_disjoint_sets = function() {
213308
y_col = self$y_col
214309
x_cols = self$x_cols
@@ -281,11 +376,6 @@ DoubleMLData = R6Class("DoubleMLData",
281376
#' @param z_cols (`NULL`, `character()`) \cr
282377
#' The instrumental variables. Default is `NULL`.
283378
#'
284-
#' @param data_class (`character(1)`) \cr
285-
#' Class of returned object. By default, an object of class `DoubleMLData` is
286-
#' returned. Setting `data_class = "data.table"` returns an object of class
287-
#' `data.table`.
288-
#'
289379
#' @param use_other_treat_as_covariate (`logical(1)`) \cr
290380
#' Indicates whether in the multiple-treatment case the other treatment
291381
#' variables should be added as covariates. Default is `TRUE`.
@@ -303,42 +393,11 @@ DoubleMLData = R6Class("DoubleMLData",
303393
#' @export
304394
double_ml_data_from_data_frame = function(df, x_cols = NULL, y_col = NULL,
305395
d_cols = NULL, z_cols = NULL,
306-
data_class = "DoubleMLData",
307396
use_other_treat_as_covariate = TRUE) {
308-
309-
if (is.null(y_col) | is.null(d_cols)) {
310-
stop("Column indices y_col and d_cols not specified.")
311-
}
312-
assert_choice(data_class, c("DoubleMLData", "data.table"))
313-
314-
if (!is.null(x_cols)) {
315-
assert_character(x_cols, unique = TRUE)
316-
}
317-
assert_character(y_col, len = 1)
318-
assert_character(d_cols, unique = TRUE)
319-
320-
if (!is.null(z_cols)) {
321-
assert_character(z_cols, unique = TRUE)
322-
}
323-
if (!is.null(x_cols)) {
324-
x_cols = x_cols
325-
} else {
326-
if (!is.null(z_cols)) {
327-
y_d_z = unique(c(y_col, d_cols, z_cols))
328-
x_cols = setdiff(names(df), y_d_z)
329-
} else {
330-
y_d = union(y_col, d_cols)
331-
x_cols = setdiff(names(df), y_d)
332-
}
333-
}
334-
col_indx = c(x_cols, y_col, d_cols, z_cols)
335-
data = data.table(df)[, col_indx, with = FALSE]
336-
if (data_class == "DoubleMLData") {
337-
data = DoubleMLData$new(data,
338-
x_cols = x_cols, y_col = y_col, d_cols = d_cols,
339-
z_cols = z_cols,
340-
use_other_treat_as_covariate = use_other_treat_as_covariate)
341-
}
397+
data = DoubleMLData$new(df,
398+
x_cols = x_cols, y_col = y_col, d_cols = d_cols,
399+
z_cols = z_cols,
400+
use_other_treat_as_covariate = use_other_treat_as_covariate)
342401
return(data)
343402
}
344403

@@ -366,7 +425,6 @@ double_ml_data_from_data_frame = function(df, x_cols = NULL, y_col = NULL,
366425
#' returned. Setting `data_class = "data.table"` returns an object of class
367426
#' `data.table`.
368427
#'
369-
#'
370428
#' @param use_other_treat_as_covariate (`logical(1)`) \cr
371429
#' Indicates whether in the multiple-treatment case the other treatment
372430
#' variables should be added as covariates. Default is `TRUE`.

0 commit comments

Comments
 (0)