Skip to content

Commit 9db48dd

Browse files
committed
assert choice regr or classif for task_type; see #157
1 parent 92279ef commit 9db48dd

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

R/helper.R

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ dml_cv_predict = function(learner, X_cols, y_col,
55
return_train_preds = FALSE, task_type = NULL,
66
fold_specific_params = FALSE) {
77

8-
# TODO: Asserts
8+
valid_task_type = c("regr", "classif")
9+
assertChoice(task_type, valid_task_type)
10+
# TODO: extend asserts
911

1012
if (fold_specific_params) {
1113
stopifnot(length(smpls$train_ids) == length(smpls$test_ids))
@@ -122,6 +124,7 @@ dml_cv_predict = function(learner, X_cols, y_col,
122124

123125
dml_tune = function(learner, X_cols, y_col, data_tune_list,
124126
nuisance_id, param_set, tune_settings, measure, task_type) {
127+
125128
task_tune = lapply(data_tune_list, function(x) {
126129
initiate_task(
127130
id = nuisance_id,
@@ -130,6 +133,9 @@ dml_tune = function(learner, X_cols, y_col, data_tune_list,
130133
select_cols = X_cols,
131134
task_type = task_type)
132135
})
136+
valid_task_type = c("regr", "classif")
137+
assertChoice(task_type, valid_task_type)
138+
133139
ml_learner = initiate_learner(learner, task_type, params = learner$param_set$values)
134140
tuning_instance = lapply(task_tune, function(x) {
135141
TuningInstanceSingleCrit$new(
@@ -154,6 +160,10 @@ dml_tune = function(learner, X_cols, y_col, data_tune_list,
154160

155161
extract_prediction = function(obj_resampling, task_type, n_obs,
156162
return_train_preds = FALSE) {
163+
164+
valid_task_type = c("regr", "classif")
165+
assertChoice(task_type, valid_task_type)
166+
157167
if (compareVersion(as.character(packageVersion("mlr3")), "0.11.0") < 0) {
158168
ind_name = "row_id"
159169
} else {
@@ -204,6 +214,10 @@ extract_prediction = function(obj_resampling, task_type, n_obs,
204214
}
205215

206216
initiate_learner = function(learner, task_type, params, return_train_preds = FALSE) {
217+
218+
valid_task_type = c("regr", "classif")
219+
assertChoice(task_type, valid_task_type)
220+
207221
ml_learner = learner$clone()
208222

209223
if (!is.null(params)) {
@@ -225,6 +239,9 @@ initiate_learner = function(learner, task_type, params, return_train_preds = FAL
225239

226240
# Function to initialize task (regression or classification)
227241
initiate_task = function(id, data, target, select_cols, task_type) {
242+
valid_task_type = c("regr", "classif")
243+
assertChoice(task_type, valid_task_type)
244+
228245
if (!is.null(select_cols)) {
229246
indx = (names(data) %in% c(select_cols, target))
230247
data = data[, indx, with = FALSE]
@@ -277,6 +294,9 @@ get_cond_samples = function(smpls, D) {
277294
}
278295

279296
set_default_measure = function(measure_in = NA, task_type) {
297+
valid_task_type = c("regr", "classif")
298+
assertChoice(task_type, valid_task_type)
299+
280300
if (is.na(measure_in)) {
281301
if (task_type == "regr") {
282302
measure = msr("regr.mse")

0 commit comments

Comments
 (0)