@@ -496,7 +496,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
496496 fold_specific_params = private $ fold_specific_params )
497497 z = self $ data $ data_model [[self $ data $ z_cols ]]
498498 } else {
499- m_hat = do.call(
499+ xx = do.call(
500500 cbind ,
501501 lapply(
502502 self $ data $ z_cols ,
@@ -510,19 +510,21 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
510510 est_params = self $ get_params(paste0(" ml_m_" , x )),
511511 return_train_preds = FALSE ,
512512 task_type = private $ task_type $ ml_m ,
513- fold_specific_params = private $ fold_specific_params )
513+ fold_specific_params = private $ fold_specific_params )$ preds
514514 }))
515+ # TODO: Export of fitted models not implemented for this case
516+ m_hat = list (preds = xx , models = NULL )
515517 z = self $ data $ data_model [, self $ data $ z_cols , with = FALSE ]
516518 }
517519
518520 d = self $ data $ data_model [[self $ data $ treat_col ]]
519521 y = self $ data $ data_model [[self $ data $ y_col ]]
520522
521- g_hat = NULL
523+ g_hat = list ( preds = NULL , models = NULL )
522524 if (exists(" ml_g" , where = private $ learner_ )) {
523525 # get an initial estimate for theta using the partialling out score
524- psi_a = - (d - r_hat ) * (z - m_hat )
525- psi_b = (z - m_hat ) * (y - l_hat )
526+ psi_a = - (d - r_hat $ preds ) * (z - m_hat $ preds )
527+ psi_b = (z - m_hat $ preds ) * (y - l_hat $ preds )
526528 theta_initial = - mean(psi_b , na.rm = TRUE ) / mean(psi_a , na.rm = TRUE )
527529
528530 data_aux = data.table(self $ data $ data_model ,
@@ -540,12 +542,19 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
540542 fold_specific_params = private $ fold_specific_params )
541543 }
542544
543- res = private $ score_elements(y , z , d , l_hat , m_hat , r_hat , g_hat , smpls )
545+ res = private $ score_elements(
546+ y , z , d , l_hat $ preds , m_hat $ preds ,
547+ r_hat $ preds , g_hat $ preds , smpls )
544548 res $ preds = list (
545- " ml_l" = l_hat ,
546- " ml_m" = m_hat ,
547- " ml_r" = r_hat ,
548- " ml_g" = g_hat )
549+ " ml_l" = l_hat $ preds ,
550+ " ml_m" = m_hat $ preds ,
551+ " ml_r" = r_hat $ preds ,
552+ " ml_g" = g_hat $ preds )
553+ res $ models = list (
554+ " ml_l" = l_hat $ models ,
555+ " ml_m" = m_hat $ models ,
556+ " ml_r" = r_hat $ models ,
557+ " ml_g" = g_hat $ models )
549558 return (res )
550559 },
551560 score_elements = function (y , z , d , l_hat , m_hat , r_hat , g_hat , smpls ) {
@@ -615,7 +624,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
615624 task_type = private $ task_type $ ml_l ,
616625 fold_specific_params = private $ fold_specific_params )
617626
618- m_hat_list = dml_cv_predict(self $ learner $ ml_m ,
627+ m_hat = dml_cv_predict(self $ learner $ ml_m ,
619628 c(
620629 self $ data $ x_cols ,
621630 self $ data $ other_treat_cols ,
@@ -628,8 +637,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
628637 return_train_preds = TRUE ,
629638 task_type = private $ task_type $ ml_m ,
630639 fold_specific_params = private $ fold_specific_params )
631- m_hat = m_hat_list $ preds
632- data_aux_list = lapply(m_hat_list $ train_preds , function (x ) {
640+ data_aux_list = lapply(m_hat $ train_preds , function (x ) {
633641 setnafill(data.table(self $ data $ data_model , " m_hat_on_train" = x ),
634642 fill = - 9999.99 ) # mlr3 does not allow NA's (values are not used)
635643 })
@@ -650,13 +658,13 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
650658 d = self $ data $ data_model [[self $ data $ treat_col ]]
651659 y = self $ data $ data_model [[self $ data $ y_col ]]
652660
653- u_hat = y - l_hat
654- w_hat = d - m_hat_tilde
661+ u_hat = y - l_hat $ preds
662+ w_hat = d - m_hat_tilde $ preds
655663
656664 if (is.character(self $ score )) {
657665 if (self $ score == " partialling out" ) {
658- psi_a = - w_hat * (m_hat - m_hat_tilde )
659- psi_b = (m_hat - m_hat_tilde ) * u_hat
666+ psi_a = - w_hat * (m_hat $ preds - m_hat_tilde $ preds )
667+ psi_b = (m_hat $ preds - m_hat_tilde $ preds ) * u_hat
660668 }
661669 res = list (
662670 psi_a = psi_a ,
@@ -665,12 +673,16 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
665673 stop(paste(
666674 " Callable score not implemented for DoubleMLPLIV" ,
667675 " with partialX=TRUE and partialZ=TRUE." ))
668- # res = self$score(y, d, g_hat, m_hat, m_hat_tilde)
676+ # res = self$score(y, d, g_hat$preds , m_hat$preds , m_hat_tilde$preds )
669677 }
670678 res $ preds = list (
671- " ml_l" = l_hat ,
672- " ml_m" = m_hat ,
673- " ml_r" = m_hat_tilde )
679+ " ml_l" = l_hat $ preds ,
680+ " ml_m" = m_hat $ preds ,
681+ " ml_r" = m_hat_tilde $ preds )
682+ res $ models = list (
683+ " ml_l" = l_hat $ models ,
684+ " ml_m" = m_hat $ models ,
685+ " ml_r" = m_hat_tilde $ models )
674686 return (res )
675687 },
676688
@@ -697,17 +709,18 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
697709
698710 if (is.character(self $ score )) {
699711 if (self $ score == " partialling out" ) {
700- psi_a = - r_hat * d
701- psi_b = r_hat * y
712+ psi_a = - r_hat $ preds * d
713+ psi_b = r_hat $ preds * y
702714 }
703715 res = list (psi_a = psi_a , psi_b = psi_b )
704716 } else if (is.function(self $ score )) {
705717 stop(paste(
706718 " Callable score not implemented for DoubleMLPLIV" ,
707719 " with partialX=FALSE and partialZ=TRUE." ))
708- # res = self$score(y, z, d, r_hat)
720+ # res = self$score(y, z, d, r_hat$preds )
709721 }
710- res $ preds = list (" ml_r" = r_hat )
722+ res $ preds = list (" ml_r" = r_hat $ preds )
723+ res $ models = list (" ml_r" = r_hat $ models )
711724 return (res )
712725 },
713726
@@ -819,8 +832,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
819832 y = self $ data $ data_model [[self $ data $ y_col ]]
820833 z = self $ data $ data_model [[self $ data $ z_cols ]]
821834
822- psi_a = - (d - r_hat ) * (z - m_hat )
823- psi_b = (z - m_hat ) * (y - l_hat )
835+ psi_a = - (d - r_hat $ preds ) * (z - m_hat $ preds )
836+ psi_b = (z - m_hat $ preds ) * (y - l_hat $ preds )
824837 theta_initial = - mean(psi_b , na.rm = TRUE ) / mean(psi_a , na.rm = TRUE )
825838
826839 data_aux = data.table(self $ data $ data_model ,
0 commit comments