Skip to content

Commit 33087e1

Browse files
committed
fix problems after resolving merge conflicts
1 parent d81a2d4 commit 33087e1

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

R/double_ml_pliv.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
542542
fold_specific_params = private$fold_specific_params)
543543
}
544544

545-
res = private$score_elements(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls)
545+
res = private$score_elements(y, z, d, l_hat$preds, m_hat$preds,
546+
r_hat$preds, g_hat$preds, smpls)
546547
res$preds = list(
547548
"ml_l" = l_hat$preds,
548549
"ml_m" = m_hat$preds,

R/double_ml_plr.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
409409
g_hat = list(preds = NULL, models = NULL)
410410
if (exists("ml_g", where = private$learner_)) {
411411
# get an initial estimate for theta using the partialling out score
412-
psi_a = -(d - m_hat) * (d - m_hat)
413-
psi_b = (d - m_hat) * (y - l_hat)
412+
psi_a = -(d - m_hat$preds) * (d - m_hat$preds)
413+
psi_b = (d - m_hat$preds) * (y - l_hat$preds)
414414
theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
415415

416416
data_aux = data.table(self$data$data_model,
@@ -428,7 +428,8 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
428428
fold_specific_params = private$fold_specific_params)
429429
}
430430

431-
res = private$score_elements(y, d, l_hat, m_hat, g_hat, smpls)
431+
res = private$score_elements(y, d, l_hat$preds, m_hat$preds, g_hat$preds,
432+
smpls)
432433
res$preds = list(
433434
"ml_l" = l_hat$preds,
434435
"ml_m" = m_hat$preds,

0 commit comments

Comments
 (0)