Skip to content

Commit a611699

Browse files
committed
Merge branch 'master' of github.com:DoubleML/doubleml-for-r into m-mlr3-dev-tests
2 parents aec69d8 + c383fcd commit a611699

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

tests/testthat/helper-08-dml_plr.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ score = "IV-type", se_type = "ls", ...) {
6565
# # g_hat_list = mlr::getRRPredictionList(r_g)
6666
# #g_hat_list = lapply(g_hat_list$test, extract_test_pred)
6767
# g_hat_list = lapply(g_hat_list, function(x) x$response)
68-
g_hat_list = lapply(r_g$data$predictions(), function(x) x$response)
68+
g_hat_list = lapply(r_g$predictions(), function(x) x$response)
6969
# nuisance m
7070
m_indx = names(data) != y
7171
data_m = data[, m_indx, drop = FALSE]
@@ -88,7 +88,7 @@ score = "IV-type", se_type = "ls", ...) {
8888
# # m_hat_list = mlr::getRRPredictionList(r_m)
8989
# m_hat_list = lapply(m_hat_list, function(x) x$response)
9090
# # m_hat_list =lapply(m_hat_list$test, extract_test_pred)
91-
m_hat_list = lapply(r_m$data$predictions(), function(x) x$response)
91+
m_hat_list = lapply(r_m$predictions(), function(x) x$response)
9292

9393

9494
# if ((rin$desc$iters != r_g$pred$instance$desc$iters) ||

tests/testthat/helper-09-dml_plriv.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ dml_plriv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod,
7575
# #g_hat_list = lapply(g_hat_list$test, extract_test_pred)
7676
# g_hat_list = lapply(g_hat_list, function(x) x$response)
7777
# g_hat_list = lapply(r_g$data$prediction, function(x) x$test$response)
78-
g_hat_list = lapply(r_g$data$predictions(), function(x) x$response)
78+
g_hat_list = lapply(r_g$predictions(), function(x) x$response)
7979

8080
# nuisance m: E[Z|X]
8181
m_indx = names(data) != y & names(data) != d
@@ -99,7 +99,7 @@ dml_plriv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod,
9999
# m_hat_list = lapply(m_hat_list, function(x) x$response)
100100
# # m_hat_list =lapply(m_hat_list$test, extract_test_pred)
101101
# m_hat_list = lapply(r_m$data$prediction, function(x) x$test$response)
102-
m_hat_list = lapply(r_m$data$predictions(), function(x) x$response)
102+
m_hat_list = lapply(r_m$predictions(), function(x) x$response)
103103

104104

105105
# nuisance r: E[D|X]
@@ -124,7 +124,7 @@ dml_plriv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod,
124124
# r_hat_list = lapply(r_hat_list, function(x) x$response)
125125
# # m_hat_list =lapply(m_hat_list$test, extract_test_pred)
126126
# r_hat_list = lapply(r_r$data$prediction, function(x) x$test$response)
127-
r_hat_list = lapply(r_r$data$predictions(), function(x) x$response)
127+
r_hat_list = lapply(r_r$predictions(), function(x) x$response)
128128

129129

130130

tests/testthat/helper-10-dml_irm.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ bootstrap = "normal", nRep = 500, ...) {
7777
# m_hat_list = lapply(m_hat_list, function(x) x$response)
7878
# # m_hat_list =lapply(m_hat_list$test, extract_test_pred)
7979
# m_hat_list = lapply(r_m$data$prediction, function(x) x$test$prob[, "1"])
80-
m_hat_list = lapply(r_m$data$predictions(), function(x) x$prob[, "1"])
80+
m_hat_list = lapply(r_m$predictions(), function(x) x$prob[, "1"])
8181

8282
# nuisance g0: E[Y|D=0, X]
8383
g_indx = names(data) != d
@@ -102,7 +102,7 @@ bootstrap = "normal", nRep = 500, ...) {
102102
# #g_hat_list = lapply(g_hat_list$test, extract_test_pred)
103103
# g0_hat_list = lapply(g0_hat_list, function(x) x$response)
104104
# g0_hat_list = lapply(r_g0$data$prediction, function(x) x$test$response)
105-
g0_hat_list = lapply(r_g0$data$predictions(), function(x) x$response)
105+
g0_hat_list = lapply(r_g0$predictions(), function(x) x$response)
106106

107107
# nuisance g1: E[Y|D=1, X]
108108
task_g1 = mlr3::TaskRegr$new(id = paste0("nuis_g1_", d), backend = data_g, target = y)
@@ -124,7 +124,7 @@ bootstrap = "normal", nRep = 500, ...) {
124124
# g1_hat_list = lapply(g1_hat_list, function(x) x$response)
125125
# # }
126126
# g1_hat_list = lapply(r_g1$data$prediction, function(x) x$test$response)
127-
g1_hat_list = lapply(r_g1$data$predictions(), function(x) x$response)
127+
g1_hat_list = lapply(r_g1$predictions(), function(x) x$response)
128128

129129

130130
if ((resampling_m$iters != resampling_g0$iters) ||

tests/testthat/helper-11-dml_irmiv.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
9999
# # m_hat_list =lapply(m_hat_list$test, extract_test_pred)
100100
#
101101
# p_hat_list = lapply(r_p$data$prediction, function(x) x$test$prob[, "1"])
102-
p_hat_list = lapply(r_p$data$predictions(), function(x) x$prob[, "1"])
102+
p_hat_list = lapply(r_p$predictions(), function(x) x$prob[, "1"])
103103

104104
# nuisance mu0: E[Y|Z=0, X]
105105
mu_indx = names(data) != d & names(data) != z
@@ -125,7 +125,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
125125
# g0_hat_list = lapply(g0_hat_list, function(x) x$response)
126126
#
127127
# mu0_hat_list = lapply(r_mu0$data$prediction, function(x) x$test$response)
128-
mu0_hat_list = lapply(r_mu0$data$predictions(), function(x) x$response)
128+
mu0_hat_list = lapply(r_mu0$predictions(), function(x) x$response)
129129

130130
# nuisance g1: E[Y|Z=1, X]
131131
task_mu1 = mlr3::TaskRegr$new(id = paste0("nuis_mu1_", z), backend = data_mu, target = y)
@@ -139,7 +139,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
139139

140140
r_mu1 = mlr3::resample(task_mu1, ml_mu1, resampling_mu1, store_models = TRUE)
141141
# mu1_hat_list = lapply(r_mu1$data$prediction, function(x) x$test$response)
142-
mu1_hat_list = lapply(r_mu1$data$predictions(), function(x) x$response)
142+
mu1_hat_list = lapply(r_mu1$predictions(), function(x) x$response)
143143

144144

145145
# nuisance m0: E[D|Z=0, X]
@@ -171,7 +171,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
171171
test_ids_m0 = lapply(1:n_iters, function(x) resampling_m0$test_set(x))
172172
r_m0 = mlr3::resample(task_m0, ml_m0, resampling_m0, store_models = TRUE)
173173
# m0_hat_list = lapply(r_m0$data$prediction, function(x) x$test$prob[, "1"])
174-
m0_hat_list = lapply(r_m0$data$predictions(), function(x) x$prob[, "1"])
174+
m0_hat_list = lapply(r_m0$predictions(), function(x) x$prob[, "1"])
175175
}
176176

177177
if (never_takers == FALSE) {
@@ -194,7 +194,7 @@ dml_irmiv = function(data, y, d, z, k = 2, smpls = NULL, mlmethod, params = list
194194
test_ids_m1 = lapply(1:n_iters, function(x) resampling_m1$test_set(x))
195195
r_m1 = mlr3::resample(task_m1, ml_m1, resampling_m1, store_models = TRUE)
196196
# m1_hat_list = lapply(r_m1$data$prediction, function(x) x$test$prob[, "1"])
197-
m1_hat_list = lapply(r_m1$data$predictions(), function(x) x$prob[, "1"])
197+
m1_hat_list = lapply(r_m1$predictions(), function(x) x$prob[, "1"])
198198
}
199199

200200
if ((resampling_p$iters != resampling_mu0$iters) ||

0 commit comments

Comments
 (0)