Skip to content

Commit 1c2e5db

Browse files
authored
Merge pull request #99 from DoubleML/m-boot-algo
adaption of the bootstrap algorithm
2 parents 547f074 + 9eec119 commit 1c2e5db

File tree

2 files changed

+19
-68
lines changed

2 files changed

+19
-68
lines changed

R/double_ml.R

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,33 +1177,10 @@ DoubleML = R6Class("DoubleML",
11771177
}
11781178

11791179
if (self$apply_cross_fitting) {
1180-
if (dml_procedure == "dml1") {
1181-
boot_coefs = boot_t_stat = matrix(NA,
1182-
nrow = n_rep_boot,
1183-
ncol = self$n_folds)
1184-
ii = 0
1185-
for (i_fold in 1:self$n_folds) {
1186-
test_index = test_ids[[i_fold]]
1187-
n_obs_in_fold = length(test_index)
1188-
1189-
J = mean(private$get__psi_a()[test_index])
1190-
boot_coefs[, i_fold] = weights[, (ii + 1):(ii + n_obs_in_fold)] %*%
1191-
private$get__psi()[test_index] / (n_obs_in_fold * J)
1192-
boot_t_stat[, i_fold] = weights[, (ii + 1):(ii + n_obs_in_fold)] %*%
1193-
private$get__psi()[test_index] /
1194-
(n_obs_in_fold * private$get__all_se() * J)
1195-
ii = ii + n_obs_in_fold
1196-
}
1197-
boot_coef = rowMeans(boot_coefs)
1198-
boot_t_stat = rowMeans(boot_t_stat)
1199-
}
1200-
else if (dml_procedure == "dml2") {
12011180
J = mean(private$get__psi_a())
12021181
boot_coef = weights %*% private$get__psi() / (n_obs * J)
12031182
boot_t_stat = weights %*% private$get__psi() /
12041183
(n_obs * private$get__all_se() * J)
1205-
}
1206-
12071184
} else {
12081185
J = mean(private$get__psi_a()[test_index])
12091186
boot_coef = weights %*% private$get__psi()[test_index] /

tests/testthat/helper-08-dml_plr.R

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -222,52 +222,26 @@ dml_plr_boot = function(data, y, d, theta, se, all_preds, dml_procedure = "dml2"
222222
D = data[, d]
223223
Y = data[, y]
224224

225-
# DML 1
226-
if (dml_procedure == "dml1") {
227-
v_hat = u_hat = v_hatd = d_k = matrix(NA, nrow = max(n_k), ncol = n_iters)
228-
229-
for (i in 1:n_iters) {
230-
test_index = test_ids[[i]]
231-
232-
m_hat = m_hat_list[[i]]
233-
g_hat = g_hat_list[[i]]
234-
235-
d_k[, i] = D[test_index]
236-
v_hat[, i] = D[test_index] - m_hat
237-
u_hat[, i] = Y[test_index] - g_hat
238-
v_hatd[, i] = v_hat[, i] * D[test_index]
239-
}
240-
241-
boot = bootstrap_plr(
242-
theta = theta, d = d_k, u_hat = u_hat, v_hat = v_hat,
243-
v_hatd = v_hatd, score = score, se = se,
244-
weights = weights, nRep = nRep)
245-
boot_theta = boot$boot_theta
225+
v_hat = u_hat = v_hatd = matrix(NA, nrow = n, ncol = 1)
226+
227+
for (i in 1:n_iters) {
228+
test_index = test_ids[[i]]
229+
230+
m_hat = m_hat_list[[i]]
231+
g_hat = g_hat_list[[i]]
232+
233+
v_hat[test_index, 1] = D[test_index] - m_hat
234+
u_hat[test_index, 1] = Y[test_index] - g_hat
235+
v_hatd[test_index, 1] = v_hat[test_index] * D[test_index]
236+
246237
}
247-
248-
if (dml_procedure == "dml2") {
249-
250-
v_hat = u_hat = v_hatd = matrix(NA, nrow = n, ncol = 1)
251-
252-
for (i in 1:n_iters) {
253-
test_index = test_ids[[i]]
254-
255-
m_hat = m_hat_list[[i]]
256-
g_hat = g_hat_list[[i]]
257-
258-
v_hat[test_index, 1] = D[test_index] - m_hat
259-
u_hat[test_index, 1] = Y[test_index] - g_hat
260-
v_hatd[test_index, 1] = v_hat[test_index] * D[test_index]
261-
262-
}
263-
264-
boot = bootstrap_plr(
265-
theta = theta, d = D, u_hat = u_hat, v_hat = v_hat,
266-
v_hatd = v_hatd, score = score, se = se,
267-
weights = weights, nRep = nRep)
268-
boot_theta = boot$boot_theta
269-
}
270-
238+
239+
boot = bootstrap_plr(
240+
theta = theta, d = D, u_hat = u_hat, v_hat = v_hat,
241+
v_hatd = v_hatd, score = score, se = se,
242+
weights = weights, nRep = nRep)
243+
boot_theta = boot$boot_theta
244+
271245
return(boot_theta)
272246
}
273247

0 commit comments

Comments
 (0)