Skip to content

Commit bedab84

Browse files
committed
pass through the trimming threshold to the functional implementation
1 parent 3d7a514 commit bedab84

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

tests/testthat/test-double_ml_iivm.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ if (on_cran) {
1010
learner = "rpart",
1111
dml_procedure = "dml2",
1212
score = "LATE",
13-
trimming_threshold = c(0.0001),
13+
trimming_threshold = c(1e-5),
1414
stringsAsFactors = FALSE)
1515
} else {
1616
test_cases = expand.grid(
1717
learner = "cv_glmnet",
1818
dml_procedure = c("dml1", "dml2"),
1919
score = "LATE",
20-
trimming_threshold = c(0.0001),
20+
trimming_threshold = c(1e-5),
2121
stringsAsFactors = FALSE)
2222
}
2323

@@ -35,7 +35,8 @@ patrick::with_parameters_test_that("Unit tests for IIVM:",
3535
ml_g = learner_pars$ml_g$clone(),
3636
ml_m = learner_pars$ml_m$clone(),
3737
ml_r = learner_pars$ml_r$clone(),
38-
dml_procedure = dml_procedure, score = score)
38+
dml_procedure = dml_procedure, score = score,
39+
trimming_threshold = trimming_threshold)
3940
theta = iivm_hat$coef
4041
se = iivm_hat$se
4142

@@ -45,7 +46,8 @@ patrick::with_parameters_test_that("Unit tests for IIVM:",
4546
n_folds = 5, smpls = iivm_hat$smpls,
4647
all_preds = iivm_hat$all_preds,
4748
score = score,
48-
bootstrap = "normal", n_rep_boot = n_rep_boot)$boot_coef
49+
bootstrap = "normal", n_rep_boot = n_rep_boot,
50+
trimming_threshold = trimming_threshold)$boot_coef
4951

5052
set.seed(3141)
5153
double_mliivm_obj = DoubleMLIIVM$new(

tests/testthat/test-double_ml_irm_user_score.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ if (on_cran) {
2525
learner = "regr.rpart",
2626
learner_m = "classif.rpart",
2727
dml_procedure = "dml2",
28-
trimming_threshold = 0.0001,
28+
trimming_threshold = 1e-5,
2929
stringsAsFactors = FALSE)
3030
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
3131
} else {
3232
test_cases = expand.grid(
3333
learner = "regr.glmnet",
3434
learner_m = "classif.glmnet",
3535
dml_procedure = c("dml1", "dml2"),
36-
trimming_threshold = c(0.0001, 0.01),
36+
trimming_threshold = c(1e-5, 0.01),
3737
stringsAsFactors = FALSE)
3838
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
3939
}

0 commit comments

Comments
 (0)