@@ -5,7 +5,9 @@ dml_cv_predict = function(learner, X_cols, y_col,
55 return_train_preds = FALSE , task_type = NULL ,
66 fold_specific_params = FALSE ) {
77
8- # TODO: Asserts
8+ valid_task_type = c(" regr" , " classif" )
9+ assertChoice(task_type , valid_task_type )
10+ # TODO: extend asserts
911
1012 if (fold_specific_params ) {
1113 stopifnot(length(smpls $ train_ids ) == length(smpls $ test_ids ))
@@ -122,6 +124,7 @@ dml_cv_predict = function(learner, X_cols, y_col,
122124
123125dml_tune = function (learner , X_cols , y_col , data_tune_list ,
124126 nuisance_id , param_set , tune_settings , measure , task_type ) {
127+
125128 task_tune = lapply(data_tune_list , function (x ) {
126129 initiate_task(
127130 id = nuisance_id ,
@@ -130,6 +133,9 @@ dml_tune = function(learner, X_cols, y_col, data_tune_list,
130133 select_cols = X_cols ,
131134 task_type = task_type )
132135 })
136+ valid_task_type = c(" regr" , " classif" )
137+ assertChoice(task_type , valid_task_type )
138+
133139 ml_learner = initiate_learner(learner , task_type , params = learner $ param_set $ values )
134140 tuning_instance = lapply(task_tune , function (x ) {
135141 TuningInstanceSingleCrit $ new(
@@ -154,6 +160,10 @@ dml_tune = function(learner, X_cols, y_col, data_tune_list,
154160
155161extract_prediction = function (obj_resampling , task_type , n_obs ,
156162 return_train_preds = FALSE ) {
163+
164+ valid_task_type = c(" regr" , " classif" )
165+ assertChoice(task_type , valid_task_type )
166+
157167 if (compareVersion(as.character(packageVersion(" mlr3" )), " 0.11.0" ) < 0 ) {
158168 ind_name = " row_id"
159169 } else {
@@ -204,6 +214,10 @@ extract_prediction = function(obj_resampling, task_type, n_obs,
204214}
205215
206216initiate_learner = function (learner , task_type , params , return_train_preds = FALSE ) {
217+
218+ valid_task_type = c(" regr" , " classif" )
219+ assertChoice(task_type , valid_task_type )
220+
207221 ml_learner = learner $ clone()
208222
209223 if (! is.null(params )) {
@@ -225,6 +239,9 @@ initiate_learner = function(learner, task_type, params, return_train_preds = FAL
225239
226240# Function to initialize task (regression or classification)
227241initiate_task = function (id , data , target , select_cols , task_type ) {
242+ valid_task_type = c(" regr" , " classif" )
243+ assertChoice(task_type , valid_task_type )
244+
228245 if (! is.null(select_cols )) {
229246 indx = (names(data ) %in% c(select_cols , target ))
230247 data = data [, indx , with = FALSE ]
@@ -277,7 +294,10 @@ get_cond_samples = function(smpls, D) {
277294}
278295
279296set_default_measure = function (measure_in = NA , task_type ) {
280- if (is.na(measure_in )) {
297+ valid_task_type = c(" regr" , " classif" )
298+ assertChoice(task_type , valid_task_type )
299+
300+ if (is.null(measure_in )) {
281301 if (task_type == " regr" ) {
282302 measure = msr(" regr.mse" )
283303 } else if (task_type == " classif" ) {
0 commit comments