Skip to content

Commit 31950f0

Browse files
committed
simplify if clauses in score_elements of pliv
1 parent 1fc52ce commit 31950f0

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

R/double_ml_pliv.R

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,19 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
332332
w_hat = d - r_hat
333333
v_hat = z - m_hat
334334

335-
if (self$data$n_instr > 1) {
336-
v_hat = z - m_hat
337-
335+
if (self$data$n_instr == 1) {
336+
if (is.character(self$score)) {
337+
if (self$score == "partialling out") {
338+
psi_a = -w_hat * v_hat
339+
psi_b = v_hat * u_hat
340+
}
341+
psis = list(
342+
psi_a = psi_a,
343+
psi_b = psi_b)
344+
} else if (is.function(self$score)) {
345+
psis = self$score(y, z, d, g_hat, m_hat, r_hat, smpls)
346+
}
347+
} else {
338348
stopifnot(self$apply_cross_fitting)
339349

340350
# Projection: r_hat from projection on m_hat
@@ -348,29 +358,20 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
348358
r_r_tilde = resample(task_r_tilde, ml_r_tilde, resampling_r_tilde,
349359
store_models = TRUE)
350360
r_hat_tilde = as.data.table(r_r_tilde$prediction())$response
351-
}
361+
352362
if (is.character(self$score)) {
353-
if (self$data$n_instr == 1) {
354-
if (self$score == "partialling out") {
355-
psi_a = -w_hat * v_hat
356-
psi_b = v_hat * u_hat
357-
}
358-
} else {
359363
if (self$score == "partialling out") {
360364
psi_a = -w_hat * r_hat_tilde
361365
psi_b = r_hat_tilde * u_hat
362366
}
363-
}
364367
psis = list(
365368
psi_a = psi_a,
366369
psi_b = psi_b)
367370
} else if (is.function(self$score)) {
368-
if (self$data$n_instr > 1) {
369-
stop(paste(
371+
stop(paste(
370372
"Callable score not implemented for DoubleMLPLIV with",
371373
"partialX=TRUE and partialZ=FALSE with several instruments."))
372-
}
373-
psis = self$score(y, z, d, g_hat, m_hat, r_hat, smpls)
374+
}
374375
}
375376
return(psis)
376377
},

0 commit comments

Comments
 (0)