@@ -99,7 +99,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
9999 # # m_hat_list =lapply(m_hat_list$test, extract_test_pred)
100100 #
101101 # p_hat_list = lapply(r_p$data$prediction, function(x) x$test$prob[, "1"])
102- p_hat_list = lapply(r_p $ data $ predictions(), function (x ) x $ prob [, " 1" ])
102+ p_hat_list = lapply(r_p $ predictions(), function (x ) x $ prob [, " 1" ])
103103
104104 # nuisance mu0: E[Y|Z=0, X]
105105 mu_indx = names(data ) != d & names(data ) != z
@@ -125,7 +125,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
125125 # g0_hat_list = lapply(g0_hat_list, function(x) x$response)
126126 #
127127 # mu0_hat_list = lapply(r_mu0$data$prediction, function(x) x$test$response)
128- mu0_hat_list = lapply(r_mu0 $ data $ predictions(), function (x ) x $ response )
128+ mu0_hat_list = lapply(r_mu0 $ predictions(), function (x ) x $ response )
129129
130130 # nuisance g1: E[Y|Z=1, X]
131131 task_mu1 = mlr3 :: TaskRegr $ new(id = paste0(" nuis_mu1_" , z ), backend = data_mu , target = y )
@@ -139,7 +139,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
139139
140140 r_mu1 = mlr3 :: resample(task_mu1 , ml_mu1 , resampling_mu1 , store_models = TRUE )
141141 # mu1_hat_list = lapply(r_mu1$data$prediction, function(x) x$test$response)
142- mu1_hat_list = lapply(r_mu1 $ data $ predictions(), function (x ) x $ response )
142+ mu1_hat_list = lapply(r_mu1 $ predictions(), function (x ) x $ response )
143143
144144
145145 # nuisance m0: E[D|Z=0, X]
@@ -171,7 +171,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
171171 test_ids_m0 = lapply(1 : n_iters , function (x ) resampling_m0 $ test_set(x ))
172172 r_m0 = mlr3 :: resample(task_m0 , ml_m0 , resampling_m0 , store_models = TRUE )
173173 # m0_hat_list = lapply(r_m0$data$prediction, function(x) x$test$prob[, "1"])
174- m0_hat_list = lapply(r_m0 $ data $ predictions(), function (x ) x $ prob [, " 1" ])
174+ m0_hat_list = lapply(r_m0 $ predictions(), function (x ) x $ prob [, " 1" ])
175175 }
176176
177177 if (never_takers == FALSE ) {
@@ -194,7 +194,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
194194 test_ids_m1 = lapply(1 : n_iters , function (x ) resampling_m1 $ test_set(x ))
195195 r_m1 = mlr3 :: resample(task_m1 , ml_m1 , resampling_m1 , store_models = TRUE )
196196 # m1_hat_list = lapply(r_m1$data$prediction, function(x) x$test$prob[, "1"])
197- m1_hat_list = lapply(r_m1 $ data $ predictions(), function (x ) x $ prob [, " 1" ])
197+ m1_hat_list = lapply(r_m1 $ predictions(), function (x ) x $ prob [, " 1" ])
198198 }
199199
200200 if ((resampling_p $ iters != resampling_mu0 $ iters ) ||
0 commit comments