Skip to content

Commit 1e67dba

Browse files
authored
Merge pull request #108 from DoubleML/m-check-01-binary-for-classif
Extended exception handling for classification
2 parents 15fbc48 + 2b7128d commit 1e67dba

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

R/helper.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ initiate_task = function(id, data, target, select_cols, learner_class) {
229229
task = TaskRegr$new(id = id, backend = data, target = target)
230230
} else if (learner_class == "LearnerClassif") {
231231
data[[target]] = factor(data[[target]])
232+
assert_set_equal(levels(data[[target]]),
233+
c("0", "1"))
232234
task = TaskClassif$new(
233235
id = id, backend = data, target = target,
234236
positive = "1")

tests/testthat/test-double_ml_plr_classifier.R

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,33 @@ patrick::with_parameters_test_that("Unit tests for PLR with classifier for ml_m:
8989
}
9090
}
9191
)
92+
93+
test_that("Unit tests for exception handling of PLR with classifier for ml_m:", {
94+
# Only binary outcome with values 0 and 1 is allowed when ml_m is a classifier
95+
96+
# Test with 0 and 2
97+
df = data_irm$df
98+
df['d'] = df['d']*2
99+
dml_data = double_ml_data_from_data_frame(df, y_col = 'y', d_cols = 'd')
100+
double_mlplr_obj = DoubleMLPLR$new(data = dml_data,
101+
ml_g = mlr3::lrn('regr.rpart'),
102+
ml_m = mlr3::lrn('classif.rpart'))
103+
msg = paste("Assertion on 'levels\\(data\\[\\[target\\]\\])' failed:",
104+
"Must be equal to set \\{'0','1'\\}, but is \\{'0','2'\\}.")
105+
expect_error(double_mlplr_obj$fit(),
106+
regexp = msg)
107+
108+
# Test with 0.5 and 1
109+
df = data_irm$df
110+
df['d'] = (df['d']+2)/2
111+
dml_data = double_ml_data_from_data_frame(df, y_col = 'y', d_cols = 'd')
112+
double_mlplr_obj = DoubleMLPLR$new(data = dml_data,
113+
ml_g = mlr3::lrn('regr.rpart'),
114+
ml_m = mlr3::lrn('classif.rpart'))
115+
msg = paste("Assertion on 'levels\\(data\\[\\[target\\]\\])' failed:",
116+
"Must be equal to set \\{'0','1'\\}, but is \\{'1','1.5'\\}.")
117+
expect_error(double_mlplr_obj$fit(),
118+
regexp = msg)
119+
120+
}
121+
)

0 commit comments

Comments
 (0)