@@ -332,9 +332,19 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
332332 w_hat = d - r_hat
333333 v_hat = z - m_hat
334334
335- if (self $ data $ n_instr > 1 ) {
336- v_hat = z - m_hat
337-
335+ if (self $ data $ n_instr == 1 ) {
336+ if (is.character(self $ score )) {
337+ if (self $ score == " partialling out" ) {
338+ psi_a = - w_hat * v_hat
339+ psi_b = v_hat * u_hat
340+ }
341+ psis = list (
342+ psi_a = psi_a ,
343+ psi_b = psi_b )
344+ } else if (is.function(self $ score )) {
345+ psis = self $ score(y , z , d , g_hat , m_hat , r_hat , smpls )
346+ }
347+ } else {
338348 stopifnot(self $ apply_cross_fitting )
339349
340350 # Projection: r_hat from projection on m_hat
@@ -348,29 +358,20 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
348358 r_r_tilde = resample(task_r_tilde , ml_r_tilde , resampling_r_tilde ,
349359 store_models = TRUE )
350360 r_hat_tilde = as.data.table(r_r_tilde $ prediction())$ response
351- }
361+
352362 if (is.character(self $ score )) {
353- if (self $ data $ n_instr == 1 ) {
354- if (self $ score == " partialling out" ) {
355- psi_a = - w_hat * v_hat
356- psi_b = v_hat * u_hat
357- }
358- } else {
359363 if (self $ score == " partialling out" ) {
360364 psi_a = - w_hat * r_hat_tilde
361365 psi_b = r_hat_tilde * u_hat
362366 }
363- }
364367 psis = list (
365368 psi_a = psi_a ,
366369 psi_b = psi_b )
367370 } else if (is.function(self $ score )) {
368- if (self $ data $ n_instr > 1 ) {
369- stop(paste(
371+ stop(paste(
370372 " Callable score not implemented for DoubleMLPLIV with" ,
371373 " partialX=TRUE and partialZ=FALSE with several instruments." ))
372- }
373- psis = self $ score(y , z , d , g_hat , m_hat , r_hat , smpls )
374+ }
374375 }
375376 return (psis )
376377 },
0 commit comments