Skip to content

Commit f957191

Browse files
authored
Merge pull request #169 from DoubleML/store-models
Store estimated models for nuisance parameters
2 parents 38a5b02 + 9bb6ee5 commit f957191

File tree

8 files changed

+169
-54
lines changed

8 files changed

+169
-54
lines changed

R/double_ml.R

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ DoubleML = R6Class("DoubleML",
201201
}
202202
},
203203

204+
#' @field models (`array()`) \cr
205+
#' The fitted nuisance models after calling
206+
#' `fit(store_models=TRUE)`.
207+
models = function(value) {
208+
if (missing(value)) {
209+
return(private$models_)
210+
} else {
211+
stop("can't set field models")
212+
}
213+
},
214+
204215
#' @field pval (`numeric()`) \cr
205216
#' p-values for the causal parameter(s) after calling `fit()`.
206217
pval = function(value) {
@@ -359,12 +370,21 @@ DoubleML = R6Class("DoubleML",
359370
#' Indicates whether the predictions for the nuisance functions should be
360371
#' stored in field `predictions`. Default is `FALSE`.
361372
#'
373+
#'
374+
#' @param store_models (`logical(1)`) \cr
375+
#' Indicates whether the fitted models for the nuisance functions should be
376+
#' stored in field `models` if you want to analyze the models or extract
377+
#' information like variable importance. Default is `FALSE`.
378+
#'
362379
#' @return self
363-
fit = function(store_predictions = FALSE) {
380+
fit = function(store_predictions = FALSE, store_models = FALSE) {
364381

365382
if (store_predictions) {
366383
private$initialize_predictions()
367384
}
385+
if (store_models) {
386+
private$initialize_models()
387+
}
368388

369389
# TODO: insert check for tuned params
370390
for (i_rep in 1:self$n_rep) {
@@ -384,6 +404,9 @@ DoubleML = R6Class("DoubleML",
384404
if (store_predictions) {
385405
private$store_predictions(res$preds)
386406
}
407+
if (store_models) {
408+
private$store_models(res$models)
409+
}
387410

388411
# estimate the causal parameter
389412
private$all_coef_[private$i_treat, private$i_rep] = private$est_causal_pars()
@@ -1139,6 +1162,7 @@ DoubleML = R6Class("DoubleML",
11391162
psi_a_ = NULL,
11401163
psi_b_ = NULL,
11411164
predictions_ = NULL,
1165+
models_ = NULL,
11421166
pval_ = NULL,
11431167
score_ = NULL,
11441168
se_ = NULL,
@@ -1415,6 +1439,19 @@ DoubleML = R6Class("DoubleML",
14151439
},
14161440
simplify = F)
14171441
},
1442+
initialize_models = function() {
1443+
private$models_ = sapply(self$params_names(),
1444+
function(x) {
1445+
sapply(self$data$d_cols,
1446+
function(x) {
1447+
lapply(
1448+
seq(self$n_rep),
1449+
function(x) vector("list", length = self$n_folds))
1450+
},
1451+
simplify = F)
1452+
},
1453+
simplify = F)
1454+
},
14181455
store_predictions = function(preds) {
14191456
for (learner in self$params_names()) {
14201457
if (!is.null(preds[[learner]])) {
@@ -1424,6 +1461,14 @@ DoubleML = R6Class("DoubleML",
14241461
}
14251462
}
14261463
},
1464+
store_models = function(models) {
1465+
for (learner in self$params_names()) {
1466+
if (!is.null(models[[learner]])) {
1467+
private$models_[[learner]][[self$data$treat_col]][[
1468+
private$i_rep]] = models[[learner]]
1469+
}
1470+
}
1471+
},
14271472
# Comment from python: The private properties with __ always deliver the
14281473
# single treatment, single (cross-fitting) sample subselection
14291474
# The slicing is based on the two properties self._i_treat,

R/double_ml_iivm.R

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
321321
fold_specific_params = private$fold_specific_params)
322322

323323
if (self$subgroups$always_takers == FALSE) {
324-
r0_hat = rep(0, self$data$n_obs)
324+
r0_hat = list(preds = rep(0, self$data$n_obs), models = NULL)
325325
} else {
326326
r0_hat = dml_cv_predict(self$learner$ml_r,
327327
c(self$data$x_cols, self$data$other_treat_cols),
@@ -336,7 +336,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
336336
}
337337

338338
if (self$subgroups$never_takers == FALSE) {
339-
r1_hat = rep(1, self$data$n_obs)
339+
r1_hat = list(preds = rep(1, self$data$n_obs), models = NULL)
340340
} else {
341341
r1_hat = dml_cv_predict(self$learner$ml_r,
342342
c(self$data$x_cols, self$data$other_treat_cols),
@@ -356,14 +356,22 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
356356
y = self$data$data_model[[self$data$y_col]]
357357

358358
res = private$score_elements(
359-
y, z, d, g0_hat, g1_hat, m_hat, r0_hat,
360-
r1_hat, smpls)
359+
y, z, d,
360+
g0_hat$preds, g1_hat$preds, m_hat$preds,
361+
r0_hat$preds, r1_hat$preds,
362+
smpls)
361363
res$preds = list(
362-
"ml_g0" = g0_hat,
363-
"ml_g1" = g1_hat,
364-
"ml_m" = m_hat,
365-
"ml_r0" = r0_hat,
366-
"ml_r1" = r1_hat)
364+
"ml_g0" = g0_hat$preds,
365+
"ml_g1" = g1_hat$preds,
366+
"ml_m" = m_hat$preds,
367+
"ml_r0" = r0_hat$preds,
368+
"ml_r1" = r1_hat$preds)
369+
res$models = list(
370+
"ml_g0" = g0_hat$models,
371+
"ml_g1" = g1_hat$models,
372+
"ml_m" = m_hat$models,
373+
"ml_r0" = r0_hat$models,
374+
"ml_r1" = r1_hat$models)
367375
return(res)
368376
},
369377
score_elements = function(y = y, z = z, d = d,

R/double_ml_irm.R

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
242242
task_type = private$task_type$ml_g,
243243
fold_specific_params = private$fold_specific_params)
244244

245-
g1_hat = NULL
245+
g1_hat = list(preds = NULL, models = NULL)
246246
if ((is.character(self$score) && self$score == "ATE") || is.function(self$score)) {
247247
g1_hat = dml_cv_predict(self$learner$ml_g,
248248
c(self$data$x_cols, self$data$other_treat_cols),
@@ -259,11 +259,18 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
259259
d = self$data$data_model[[self$data$treat_col]]
260260
y = self$data$data_model[[self$data$y_col]]
261261

262-
res = private$score_elements(y, d, g0_hat, g1_hat, m_hat, smpls)
262+
res = private$score_elements(
263+
y, d,
264+
g0_hat$preds, g1_hat$preds, m_hat$preds,
265+
smpls)
263266
res$preds = list(
264-
"ml_g0" = g0_hat,
265-
"ml_g1" = g1_hat,
266-
"ml_m" = m_hat)
267+
"ml_g0" = g0_hat$preds,
268+
"ml_g1" = g1_hat$preds,
269+
"ml_m" = m_hat$preds)
270+
res$models = list(
271+
"ml_g0" = g0_hat$models,
272+
"ml_g1" = g1_hat$models,
273+
"ml_m" = m_hat$models)
267274
return(res)
268275
},
269276
score_elements = function(y, d, g0_hat, g1_hat, m_hat, smpls) {

R/double_ml_pliv.R

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
496496
fold_specific_params = private$fold_specific_params)
497497
z = self$data$data_model[[self$data$z_cols]]
498498
} else {
499-
m_hat = do.call(
499+
xx = do.call(
500500
cbind,
501501
lapply(
502502
self$data$z_cols,
@@ -510,19 +510,21 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
510510
est_params = self$get_params(paste0("ml_m_", x)),
511511
return_train_preds = FALSE,
512512
task_type = private$task_type$ml_m,
513-
fold_specific_params = private$fold_specific_params)
513+
fold_specific_params = private$fold_specific_params)$preds
514514
}))
515+
# TODO: Export of fitted models not implemented for this case
516+
m_hat = list(preds = xx, models = NULL)
515517
z = self$data$data_model[, self$data$z_cols, with = FALSE]
516518
}
517519

518520
d = self$data$data_model[[self$data$treat_col]]
519521
y = self$data$data_model[[self$data$y_col]]
520522

521-
g_hat = NULL
523+
g_hat = list(preds = NULL, models = NULL)
522524
if (exists("ml_g", where = private$learner_)) {
523525
# get an initial estimate for theta using the partialling out score
524-
psi_a = -(d - r_hat) * (z - m_hat)
525-
psi_b = (z - m_hat) * (y - l_hat)
526+
psi_a = -(d - r_hat$preds) * (z - m_hat$preds)
527+
psi_b = (z - m_hat$preds) * (y - l_hat$preds)
526528
theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
527529

528530
data_aux = data.table(self$data$data_model,
@@ -540,12 +542,19 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
540542
fold_specific_params = private$fold_specific_params)
541543
}
542544

543-
res = private$score_elements(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls)
545+
res = private$score_elements(
546+
y, z, d, l_hat$preds, m_hat$preds,
547+
r_hat$preds, g_hat$preds, smpls)
544548
res$preds = list(
545-
"ml_l" = l_hat,
546-
"ml_m" = m_hat,
547-
"ml_r" = r_hat,
548-
"ml_g" = g_hat)
549+
"ml_l" = l_hat$preds,
550+
"ml_m" = m_hat$preds,
551+
"ml_r" = r_hat$preds,
552+
"ml_g" = g_hat$preds)
553+
res$models = list(
554+
"ml_l" = l_hat$models,
555+
"ml_m" = m_hat$models,
556+
"ml_r" = r_hat$models,
557+
"ml_g" = g_hat$models)
549558
return(res)
550559
},
551560
score_elements = function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls) {
@@ -615,7 +624,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
615624
task_type = private$task_type$ml_l,
616625
fold_specific_params = private$fold_specific_params)
617626

618-
m_hat_list = dml_cv_predict(self$learner$ml_m,
627+
m_hat = dml_cv_predict(self$learner$ml_m,
619628
c(
620629
self$data$x_cols,
621630
self$data$other_treat_cols,
@@ -628,8 +637,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
628637
return_train_preds = TRUE,
629638
task_type = private$task_type$ml_m,
630639
fold_specific_params = private$fold_specific_params)
631-
m_hat = m_hat_list$preds
632-
data_aux_list = lapply(m_hat_list$train_preds, function(x) {
640+
data_aux_list = lapply(m_hat$train_preds, function(x) {
633641
setnafill(data.table(self$data$data_model, "m_hat_on_train" = x),
634642
fill = -9999.99) # mlr3 does not allow NA's (values are not used)
635643
})
@@ -650,13 +658,13 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
650658
d = self$data$data_model[[self$data$treat_col]]
651659
y = self$data$data_model[[self$data$y_col]]
652660

653-
u_hat = y - l_hat
654-
w_hat = d - m_hat_tilde
661+
u_hat = y - l_hat$preds
662+
w_hat = d - m_hat_tilde$preds
655663

656664
if (is.character(self$score)) {
657665
if (self$score == "partialling out") {
658-
psi_a = -w_hat * (m_hat - m_hat_tilde)
659-
psi_b = (m_hat - m_hat_tilde) * u_hat
666+
psi_a = -w_hat * (m_hat$preds - m_hat_tilde$preds)
667+
psi_b = (m_hat$preds - m_hat_tilde$preds) * u_hat
660668
}
661669
res = list(
662670
psi_a = psi_a,
@@ -665,12 +673,16 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
665673
stop(paste(
666674
"Callable score not implemented for DoubleMLPLIV",
667675
"with partialX=TRUE and partialZ=TRUE."))
668-
# res = self$score(y, d, g_hat, m_hat, m_hat_tilde)
676+
# res = self$score(y, d, g_hat$preds, m_hat$preds, m_hat_tilde$preds)
669677
}
670678
res$preds = list(
671-
"ml_l" = l_hat,
672-
"ml_m" = m_hat,
673-
"ml_r" = m_hat_tilde)
679+
"ml_l" = l_hat$preds,
680+
"ml_m" = m_hat$preds,
681+
"ml_r" = m_hat_tilde$preds)
682+
res$models = list(
683+
"ml_l" = l_hat$models,
684+
"ml_m" = m_hat$models,
685+
"ml_r" = m_hat_tilde$models)
674686
return(res)
675687
},
676688

@@ -697,17 +709,18 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
697709

698710
if (is.character(self$score)) {
699711
if (self$score == "partialling out") {
700-
psi_a = -r_hat * d
701-
psi_b = r_hat * y
712+
psi_a = -r_hat$preds * d
713+
psi_b = r_hat$preds * y
702714
}
703715
res = list(psi_a = psi_a, psi_b = psi_b)
704716
} else if (is.function(self$score)) {
705717
stop(paste(
706718
"Callable score not implemented for DoubleMLPLIV",
707719
"with partialX=FALSE and partialZ=TRUE."))
708-
# res = self$score(y, z, d, r_hat)
720+
# res = self$score(y, z, d, r_hat$preds)
709721
}
710-
res$preds = list("ml_r" = r_hat)
722+
res$preds = list("ml_r" = r_hat$preds)
723+
res$models = list("ml_r" = r_hat$models)
711724
return(res)
712725
},
713726

@@ -819,8 +832,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
819832
y = self$data$data_model[[self$data$y_col]]
820833
z = self$data$data_model[[self$data$z_cols]]
821834

822-
psi_a = -(d - r_hat) * (z - m_hat)
823-
psi_b = (z - m_hat) * (y - l_hat)
835+
psi_a = -(d - r_hat$preds) * (z - m_hat$preds)
836+
psi_b = (z - m_hat$preds) * (y - l_hat$preds)
824837
theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
825838

826839
data_aux = data.table(self$data$data_model,

R/double_ml_plr.R

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,11 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
406406
d = self$data$data_model[[self$data$treat_col]]
407407
y = self$data$data_model[[self$data$y_col]]
408408

409-
g_hat = NULL
409+
g_hat = list(preds = NULL, models = NULL)
410410
if (exists("ml_g", where = private$learner_)) {
411411
# get an initial estimate for theta using the partialling out score
412-
psi_a = -(d - m_hat) * (d - m_hat)
413-
psi_b = (d - m_hat) * (y - l_hat)
412+
psi_a = -(d - m_hat$preds) * (d - m_hat$preds)
413+
psi_b = (d - m_hat$preds) * (y - l_hat$preds)
414414
theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
415415

416416
data_aux = data.table(self$data$data_model,
@@ -428,11 +428,17 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
428428
fold_specific_params = private$fold_specific_params)
429429
}
430430

431-
res = private$score_elements(y, d, l_hat, m_hat, g_hat, smpls)
431+
res = private$score_elements(
432+
y, d, l_hat$preds, m_hat$preds, g_hat$preds,
433+
smpls)
432434
res$preds = list(
433-
"ml_l" = l_hat,
434-
"ml_m" = m_hat,
435-
"ml_g" = g_hat)
435+
"ml_l" = l_hat$preds,
436+
"ml_m" = m_hat$preds,
437+
"ml_g" = g_hat$preds)
438+
res$models = list(
439+
"ml_l" = l_hat$models,
440+
"ml_m" = m_hat$models,
441+
"ml_g" = g_hat$models)
436442
return(res)
437443
},
438444
score_elements = function(y, d, l_hat, m_hat, g_hat, smpls) {
@@ -519,8 +525,8 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
519525
d = self$data$data_model[[self$data$treat_col]]
520526
y = self$data$data_model[[self$data$y_col]]
521527

522-
psi_a = -(d - m_hat) * (d - m_hat)
523-
psi_b = (d - m_hat) * (y - l_hat)
528+
psi_a = -(d - m_hat$preds) * (d - m_hat$preds)
529+
psi_b = (d - m_hat$preds) * (y - l_hat$preds)
524530
theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
525531

526532
data_aux = data.table(self$data$data_model,

0 commit comments

Comments
 (0)