2121# ' d_cols = "d")
2222# ' @export
2323DoubleMLData = R6Class(" DoubleMLData" ,
24- public = list (
24+ active = list (
2525 # ' @field all_variables (`character()`)\cr
2626 # ' All variables available in the dataset.
27- all_variables = NULL ,
27+ all_variables = function (value ) {
28+ if (missing(value )) return (names(self $ data ))
29+ else stop(" can't set field all_variables" )
30+ },
2831
2932 # ' @field d_cols (`character()`)\cr
3033 # ' The treatment variable(s).
31- d_cols = NULL ,
34+ d_cols = function (value ) {
35+ if (missing(value )) return (private $ d_cols_ )
36+ else {
37+ d_cols = value # to get more meaningful assert error messages
38+ reset_value = ! is.null(self $ data_model )
39+ assert_character(d_cols , unique = TRUE )
40+ assert_subset(d_cols , self $ all_variables )
41+ private $ d_cols_ = d_cols
42+ if (reset_value ) {
43+ private $ check_disjoint_sets()
44+ self $ set_data_model(self $ d_cols [1 ])
45+ }
46+ }
47+ },
3248
3349 # ' @field data ([`data.table`][data.table::data.table()])\cr
3450 # ' Data object.
35- data = NULL ,
51+ data = function (value ) {
52+ if (missing(value )) return (private $ data_ )
53+ else stop(" can't set field data" )
54+ },
3655
3756 # ' @field data_model ([`data.table`][data.table::data.table()])\cr
3857 # ' Internal data object that implements the causal model as specified by
3958 # ' the user via `y_col`, `d_cols`, `x_cols` and `z_cols`.
40- data_model = NULL ,
59+ data_model = function (value ) {
60+ if (missing(value )) return (private $ data_model_ )
61+ else stop(" can't set field data_model" )
62+ },
4163
4264 # ' @field n_instr (`NULL`, `integer(1)`) \cr
4365 # ' The number of instruments.
44- n_instr = NULL ,
66+ n_instr = function (value ) {
67+ if (missing(value )) return (length(self $ z_cols ))
68+ else stop(" can't set field n_instr" )
69+ },
4570
4671 # ' @field n_obs (`integer(1)`) \cr
4772 # ' The number of observations.
48- n_obs = NULL ,
73+ n_obs = function (value ) {
74+ if (missing(value )) return (dim(self $ data )[1 ])
75+ else stop(" can't set field n_obs" )
76+ },
4977
5078 # ' @field n_treat (`integer(1)`) \cr
5179 # ' The umber of treatment variables.
52- n_treat = NULL ,
80+ n_treat = function (value ) {
81+ if (missing(value )) return (length(self $ d_cols ))
82+ else stop(" can't set field n_treat" )
83+ },
5384
5485 # ' @field other_treat_cols (`NULL`, `character()`) \cr
5586 # ' If `use_other_treat_as_covariate` is `TRUE`, `other_treat_cols` are the
5687 # ' treatment variables that are not "active" in the multiple-treatment case.
5788 # ' These variables then are internally added to the covariates `x_cols` during
5889 # ' the fitting stage. If `use_other_treat_as_covariate` is `FALSE`,
5990 # ' `other_treat_cols` is `NULL`.
60- other_treat_cols = NULL ,
91+ other_treat_cols = function (value ) {
92+ if (missing(value )) return (private $ other_treat_cols_ )
93+ else stop(" can't set field other_treat_cols" )
94+ },
6195
6296 # ' @field treat_col (`character(1)`) \cr
6397 # ' "Active" treatment variable in the multiple-treatment case.
64- treat_col = NULL ,
98+ treat_col = function (value ) {
99+ if (missing(value )) return (private $ treat_col_ )
100+ else stop(" can't set field treat_col" )
101+ },
65102
66103 # ' @field use_other_treat_as_covariate (`logical(1)`) \cr
67104 # ' Indicates whether in the multiple-treatment case the other treatment
68105 # ' variables should be added as covariates. Default is `TRUE`.
69- use_other_treat_as_covariate = TRUE ,
106+ use_other_treat_as_covariate = function (value ) {
107+ if (missing(value )) return (private $ use_other_treat_as_covariate_ )
108+ else {
109+ use_other_treat_as_covariate = value # to get more meaningful assert error messages
110+ reset_value = ! is.null(self $ data_model )
111+ assert_logical(use_other_treat_as_covariate , len = 1 )
112+ private $ use_other_treat_as_covariate_ = use_other_treat_as_covariate
113+ if (reset_value ) {
114+ private $ check_disjoint_sets()
115+ self $ set_data_model(self $ d_cols [1 ])
116+ }
117+ }
118+ },
70119
71120 # ' @field x_cols (`NULL`, `character()`) \cr
72121 # ' The covariates. If `NULL`, all variables (columns of `data`) which are
73122 # ' neither specified as outcome variable `y_col`, nor as treatment variables
74123 # ' `d_cols`, nor as instrumental variables `z_cols` are used as covariates.
75124 # ' Default is `NULL`.
76- x_cols = NULL ,
125+ x_cols = function (value ) {
126+ if (missing(value )) return (private $ x_cols_ )
127+ else {
128+ x_cols = value # to get more meaningful assert error messages
129+ reset_value = ! is.null(self $ data_model )
130+ if (! is.null(x_cols )) {
131+ assert_character(x_cols , unique = TRUE )
132+ }
133+
134+ if (! is.null(x_cols )) {
135+ assert_subset(x_cols , self $ all_variables )
136+ private $ x_cols_ = x_cols
137+ } else {
138+ if (! is.null(self $ z_cols )) {
139+ y_d_z = unique(c(self $ y_col , self $ d_cols , self $ z_cols ))
140+ private $ x_cols_ = setdiff(self $ all_variables , y_d_z )
141+ } else {
142+ y_d = union(self $ y_col , self $ d_cols )
143+ private $ x_cols_ = setdiff(self $ all_variables , y_d )
144+ }
145+ }
146+ if (reset_value ) {
147+ private $ check_disjoint_sets()
148+ self $ set_data_model(self $ d_cols [1 ])
149+ }
150+ }
151+ },
77152
78153 # ' @field y_col (`character(1)`) \cr
79154 # ' The outcome variable.
80- y_col = NULL ,
155+ y_col = function (value ) {
156+ if (missing(value )) return (private $ y_col_ )
157+ else {
158+ y_col = value # to get more meaningful assert error messages
159+ reset_value = ! is.null(self $ data_model )
160+ assert_character(y_col , len = 1 )
161+ assert_subset(y_col , self $ all_variables )
162+ private $ y_col_ = y_col
163+ if (reset_value ) {
164+ private $ check_disjoint_sets()
165+ self $ set_data_model(self $ d_cols [1 ])
166+ }
167+ }
168+ },
81169
82170 # ' @field z_cols (`NULL`, `character()`) \cr
83171 # ' The instrumental variables. Default is `NULL`.
84- z_cols = NULL ,
172+ z_cols = function (value ) {
173+ if (missing(value )) return (private $ z_cols_ )
174+ else {
175+ z_cols = value # to get more meaningful assert error messages
176+ reset_value = ! is.null(self $ data_model )
177+ if (! is.null(z_cols )) {
178+ assert_character(z_cols , unique = TRUE )
179+ }
180+ assert_subset(z_cols , self $ all_variables )
181+ private $ z_cols_ = z_cols
182+ if (reset_value ) {
183+ private $ check_disjoint_sets()
184+ self $ set_data_model(self $ d_cols [1 ])
185+ }
186+ }
187+ }),
85188
189+ public = list (
86190 # ' @description
87191 # ' Creates a new instance of this [R6][R6::R6Class] class.
88192 # '
89- # ' @param data ([`data.table`][data.table::data.table()])\cr
193+ # ' @param data ([`data.table`][data.table::data.table()], `data.frame()` )\cr
90194 # ' Data object.
91195 # '
92196 # ' @param y_col (`character(1)`) \cr
@@ -114,59 +218,46 @@ DoubleMLData = R6Class("DoubleMLData",
114218 z_cols = NULL ,
115219 use_other_treat_as_covariate = TRUE ) {
116220
117- # TBD: Input data.frame
118-
119221 if (all(class(data ) == " data.frame" )) {
120- stop(paste(" 'data' is a data.frame, use" ,
121- " 'double_ml_data_from_data_frame' call to instantiate" ,
122- " DoubleMLData." ))
222+ data = data.table(data )
123223 }
124224 assert_class(data , " data.table" )
125- if (! is.null(x_cols )) {
126- assert_character(x_cols , unique = TRUE )
127- }
128- assert_character(y_col , len = 1 )
129- assert_character(d_cols , unique = TRUE )
130- if (! is.null(z_cols )) {
131- assert_character(z_cols , unique = TRUE )
132- }
133- assert_logical(use_other_treat_as_covariate , len = 1 )
225+ assert_character(names(data ), unique = TRUE )
134226
135- self $ data = data
136- self $ data_model = NULL
227+ private $ data_ = data
137228
138229 self $ y_col = y_col
139230 self $ d_cols = d_cols
140231 self $ z_cols = z_cols
141-
142- if (! is.null(x_cols )) {
143- self $ x_cols = x_cols
144- } else {
145- if (! is.null(self $ z_cols )) {
146- y_d_z = unique(c(self $ y_col , self $ d_cols , self $ z_cols ))
147- self $ x_cols = setdiff(names(data ), y_d_z )
148- } else {
149- y_d = union(self $ y_col , self $ d_cols )
150- self $ x_cols = setdiff(names(data ), y_d )
151- }
152- }
153-
232+ self $ x_cols = x_cols
154233 private $ check_disjoint_sets()
155234
156- self $ treat_col = NULL
157- self $ other_treat_cols = NULL
158235 self $ use_other_treat_as_covariate = use_other_treat_as_covariate
159236
160- self $ all_variables = names(self $ data )
161- self $ n_treat = length(self $ d_cols )
162- self $ n_instr = length(self $ z_cols )
163- self $ n_obs = dim(self $ data )[1 ]
164-
165237 # by default, we initialize to the first treatment variable
166238 self $ set_data_model(d_cols [1 ])
167239
168240 invisible (self )
169241 },
242+
243+ # ' @description
244+ # ' Print DoubleMLData objects.
245+ print = function () {
246+ header = " ================= DoubleMLData Object ==================\n "
247+ data_info = paste0(
248+ " Outcome variable: " , self $ y_col , " \n " ,
249+ " Treatment variable(s): " , paste0(self $ d_cols , collapse = " , " ),
250+ " \n " ,
251+ " Covariates: " , paste0(self $ x_cols , collapse = " , " ), " \n " ,
252+ " Instrument(s): " , paste0(self $ z_cols , collapse = " , " ), " \n " ,
253+ " No. Observations: " , self $ n_obs , " \n " )
254+ cat(header , " \n " ,
255+ " \n ------------------ Data summary ------------------\n " ,
256+ data_info ,
257+ sep = " " )
258+
259+ invisible (self )
260+ },
170261
171262 # ' @description
172263 # ' Setter function for `data_model`. The function implements the causal model
@@ -180,25 +271,20 @@ DoubleMLData = R6Class("DoubleMLData",
180271 assert_character(treatment_var , max.len = 1 )
181272 assert_subset(treatment_var , self $ d_cols )
182273
183- if (treatment_var %in% self $ x_cols ) {
184- stop(paste(
185- " The specified treatment variable must not be an element of" ,
186- " the covariates 'x_cols'." ))
187- }
188- self $ treat_col = treatment_var
274+ private $ treat_col_ = treatment_var
189275
190276 if (self $ n_treat > 1 ) {
191277 if (self $ use_other_treat_as_covariate ) {
192- self $ other_treat_cols = self $ d_cols [self $ d_cols != treatment_var ]
278+ private $ other_treat_cols_ = self $ d_cols [self $ d_cols != treatment_var ]
193279 } else {
194280 message(" Control variables do not include other treatment variables" )
195- self $ other_treat_cols = NULL
281+ private $ other_treat_cols_ = NULL
196282 }
197283 }
198284 col_indx = c(
199285 self $ x_cols , self $ y_col , self $ treat_col , self $ other_treat_cols ,
200286 self $ z_cols )
201- self $ data_model = self $ data [, col_indx , with = FALSE ]
287+ private $ data_model_ = self $ data [, col_indx , with = FALSE ]
202288 stopifnot(nrow(self $ data ) == nrow(self $ data_model ))
203289
204290 # successful assigning treatment variable
@@ -209,6 +295,15 @@ DoubleMLData = R6Class("DoubleMLData",
209295 }
210296 ),
211297 private = list (
298+ d_cols_ = NULL ,
299+ data_ = NULL ,
300+ data_model_ = NULL ,
301+ other_treat_cols_ = NULL ,
302+ treat_col_ = NULL ,
303+ use_other_treat_as_covariate_ = NULL ,
304+ x_cols_ = NULL ,
305+ y_col_ = NULL ,
306+ z_cols_ = NULL ,
212307 check_disjoint_sets = function () {
213308 y_col = self $ y_col
214309 x_cols = self $ x_cols
@@ -281,11 +376,6 @@ DoubleMLData = R6Class("DoubleMLData",
281376# ' @param z_cols (`NULL`, `character()`) \cr
282377# ' The instrumental variables. Default is `NULL`.
283378# '
284- # ' @param data_class (`character(1)`) \cr
285- # ' Class of returned object. By default, an object of class `DoubleMLData` is
286- # ' returned. Setting `data_class = "data.table"` returns an object of class
287- # ' `data.table`.
288- # '
289379# ' @param use_other_treat_as_covariate (`logical(1)`) \cr
290380# ' Indicates whether in the multiple-treatment case the other treatment
291381# ' variables should be added as covariates. Default is `TRUE`.
@@ -303,42 +393,11 @@ DoubleMLData = R6Class("DoubleMLData",
303393# ' @export
304394double_ml_data_from_data_frame = function (df , x_cols = NULL , y_col = NULL ,
305395 d_cols = NULL , z_cols = NULL ,
306- data_class = " DoubleMLData" ,
307396 use_other_treat_as_covariate = TRUE ) {
308-
309- if (is.null(y_col ) | is.null(d_cols )) {
310- stop(" Column indices y_col and d_cols not specified." )
311- }
312- assert_choice(data_class , c(" DoubleMLData" , " data.table" ))
313-
314- if (! is.null(x_cols )) {
315- assert_character(x_cols , unique = TRUE )
316- }
317- assert_character(y_col , len = 1 )
318- assert_character(d_cols , unique = TRUE )
319-
320- if (! is.null(z_cols )) {
321- assert_character(z_cols , unique = TRUE )
322- }
323- if (! is.null(x_cols )) {
324- x_cols = x_cols
325- } else {
326- if (! is.null(z_cols )) {
327- y_d_z = unique(c(y_col , d_cols , z_cols ))
328- x_cols = setdiff(names(df ), y_d_z )
329- } else {
330- y_d = union(y_col , d_cols )
331- x_cols = setdiff(names(df ), y_d )
332- }
333- }
334- col_indx = c(x_cols , y_col , d_cols , z_cols )
335- data = data.table(df )[, col_indx , with = FALSE ]
336- if (data_class == " DoubleMLData" ) {
337- data = DoubleMLData $ new(data ,
338- x_cols = x_cols , y_col = y_col , d_cols = d_cols ,
339- z_cols = z_cols ,
340- use_other_treat_as_covariate = use_other_treat_as_covariate )
341- }
397+ data = DoubleMLData $ new(df ,
398+ x_cols = x_cols , y_col = y_col , d_cols = d_cols ,
399+ z_cols = z_cols ,
400+ use_other_treat_as_covariate = use_other_treat_as_covariate )
342401 return (data )
343402}
344403
@@ -366,7 +425,6 @@ double_ml_data_from_data_frame = function(df, x_cols = NULL, y_col = NULL,
366425# ' returned. Setting `data_class = "data.table"` returns an object of class
367426# ' `data.table`.
368427# '
369- # '
370428# ' @param use_other_treat_as_covariate (`logical(1)`) \cr
371429# ' Indicates whether in the multiple-treatment case the other treatment
372430# ' variables should be added as covariates. Default is `TRUE`.
0 commit comments