Skip to content

Commit 547f074

Browse files
authored
Merge pull request #105 from DoubleML/m-fix96
initialize predictions with NA
2 parents 1a1904b + 81f047c commit 547f074

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ importFrom(R6,R6Class)
1919
importFrom(clusterGeneration,genPositiveDefMat)
2020
importFrom(data.table,as.data.table)
2121
importFrom(data.table,data.table)
22+
importFrom(data.table,setnafill)
2223
importFrom(mlr3,Task)
2324
importFrom(mlr3,TaskClassif)
2425
importFrom(mlr3,TaskRegr)

R/double_ml_pliv.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
380380
fold_specific_params = private$fold_specific_params)
381381
m_hat = m_hat_list$preds
382382
data_aux_list = lapply(m_hat_list$train_preds, function(x) {
383-
data.table(self$data$data_model, "m_hat_on_train" = x)
383+
setnafill(data.table(self$data$data_model, "m_hat_on_train" = x),
384+
fill = -9999.99) # mlr3 does not allow NA's (values are not used)
384385
})
385386

386387
m_hat_tilde = dml_cv_predict(self$learner$ml_r,

R/helper.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ extract_prediction = function(obj_resampling, learner_class, n_obs,
173173
1:n_iters,
174174
function(x) as.data.table(obj_resampling$predictions("train")[[x]]))
175175
for (i_iter in 1:n_iters) {
176-
preds_vec = vector("numeric", length = n_obs)
176+
preds_vec = as.numeric(rep(NA, n_obs))
177177
f_hat = f_hat_list[[i_iter]]
178178
preds_vec[f_hat[[ind_name]]] = f_hat[[resp_name]]
179179
preds[[i_iter]] = preds_vec
@@ -189,7 +189,7 @@ extract_prediction = function(obj_resampling, learner_class, n_obs,
189189
}
190190
}
191191
} else {
192-
preds = vector("numeric", length = n_obs)
192+
preds = as.numeric(rep(NA, n_obs))
193193
if (testR6(obj_resampling, classes = "ResampleResult")) obj_resampling = list(obj_resampling)
194194
n_obj_rsmp = length(obj_resampling)
195195
for (i_obj_rsmp in 1:n_obj_rsmp) {

R/zzz.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#' @importFrom mlr3tuning TuningInstanceSingleCrit tnr trm
66
#' @importFrom mlr3learners LearnerRegrLM
77
#' @importFrom mlr3misc insert_named
8-
#' @importFrom data.table data.table as.data.table
8+
#' @importFrom data.table data.table as.data.table setnafill
99
#' @importFrom readstata13 read.dta13
1010
#' @importFrom stats formula model.matrix rnorm runif rexp toeplitz pnorm qnorm
1111
#' printCoefmat quantile p.adjust.methods p.adjust median

0 commit comments

Comments
 (0)