@@ -52,9 +52,7 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
5252 params_g = list (
5353 s = " lambda.min" ,
5454 family = " gaussian" ))
55-
5655 }
57-
5856 }
5957
6058 else if (default == TRUE ) {
@@ -64,12 +62,28 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
6462 params = list (
6563 params_g = list (),
6664 params_m = list ())
65+ }
6766
67+ if (learner == " graph_learner" ) {
68+ # pipeline learner
69+ pipe_learner = mlr3pipelines :: po(" learner" ,
70+ lrn(" regr.glmnet" ),
71+ lambda = 0.01 ,
72+ family = " gaussian" )
73+ mlmethod = list (
74+ mlmethod_m = " graph_learner" ,
75+ mlmethod_g = " graph_learner" )
76+ params = list (
77+ params_g = list (),
78+ params_m = list ())
79+ ml_g = mlr3 :: as_learner(pipe_learner )
80+ ml_m = mlr3 :: as_learner(pipe_learner )
81+ } else {
82+ ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
83+ ml_g $ param_set $ values = params $ params_g
84+ ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m )
85+ ml_m $ param_set $ values = params $ params_m
6886 }
69- ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
70- ml_g $ param_set $ values = params $ params_g
71- ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m )
72- ml_m $ param_set $ values = params $ params_m
7387
7488 return (list (
7589 mlmethod = mlmethod , params = params ,
@@ -147,12 +161,31 @@ get_default_mlmethod_pliv = function(learner) {
147161 family = " gaussian" ))
148162
149163 }
150- ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
151- ml_g $ param_set $ values = params $ params_g
152- ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m )
153- ml_m $ param_set $ values = params $ params_m
154- ml_r = mlr3 :: lrn(mlmethod $ mlmethod_r )
155- ml_r $ param_set $ values = params $ params_r
164+
165+ if (learner == " graph_learner" ) {
166+ # pipeline learner
167+ pipe_learner = mlr3pipelines :: po(" learner" ,
168+ lrn(" regr.glmnet" ),
169+ lambda = 0.01 ,
170+ family = " gaussian" )
171+ mlmethod = list (
172+ mlmethod_m = " graph_learner" ,
173+ mlmethod_g = " graph_learner" ,
174+ mlmethod_r = " graph_learner" )
175+ params = list (
176+ params_g = list (),
177+ params_m = list ())
178+ ml_g = mlr3 :: as_learner(pipe_learner )
179+ ml_m = mlr3 :: as_learner(pipe_learner )
180+ ml_r = mlr3 :: as_learner(pipe_learner )
181+ } else {
182+ ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
183+ ml_g $ param_set $ values = params $ params_g
184+ ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m )
185+ ml_m $ param_set $ values = params $ params_m
186+ ml_r = mlr3 :: lrn(mlmethod $ mlmethod_r )
187+ ml_r $ param_set $ values = params $ params_r
188+ }
156189
157190 return (list (
158191 mlmethod = mlmethod , params = params ,
@@ -182,11 +215,30 @@ get_default_mlmethod_irm = function(learner) {
182215 params_m = list (cp = 0.01 , minsplit = 20 ))
183216
184217 }
185- ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
186- ml_g $ param_set $ values = params $ params_g
187- ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m , predict_type = " prob" )
188- ml_m $ param_set $ values = params $ params_m
189218
219+ if (learner == " graph_learner" ) {
220+ # pipeline learner
221+ pipe_learner = mlr3pipelines :: po(" learner" ,
222+ lrn(" regr.rpart" ),
223+ cp = 0.01 , minsplit = 20 )
224+ pipe_learner_classif = mlr3pipelines :: po(" learner" ,
225+ lrn(" classif.rpart" ,
226+ predict_type = " prob" ),
227+ cp = 0.01 , minsplit = 20 )
228+ mlmethod = list (
229+ mlmethod_m = " graph_learner" ,
230+ mlmethod_g = " graph_learner" )
231+ params = list (
232+ params_g = list (),
233+ params_m = list ())
234+ ml_g = mlr3 :: as_learner(pipe_learner )
235+ ml_m = mlr3 :: as_learner(pipe_learner_classif )
236+ } else {
237+ ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
238+ ml_g $ param_set $ values = params $ params_g
239+ ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m , predict_type = " prob" )
240+ ml_m $ param_set $ values = params $ params_m
241+ }
190242 return (list (
191243 mlmethod = mlmethod , params = params ,
192244 ml_g = ml_g , ml_m = ml_m ))
@@ -219,12 +271,35 @@ get_default_mlmethod_iivm = function(learner) {
219271 params_r = list (cp = 0.01 , minsplit = 20 ))
220272
221273 }
222- ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
223- ml_g $ param_set $ values = params $ params_g
224- ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m , predict_type = " prob" )
225- ml_m $ param_set $ values = params $ params_m
226- ml_r = mlr3 :: lrn(mlmethod $ mlmethod_r , predict_type = " prob" )
227- ml_r $ param_set $ values = params $ params_r
274+
275+ if (learner == " graph_learner" ) {
276+ # pipeline learner
277+ pipe_learner = mlr3pipelines :: po(" learner" ,
278+ lrn(" regr.rpart" ),
279+ cp = 0.01 , minsplit = 20 )
280+ pipe_learner_classif = mlr3pipelines :: po(" learner" ,
281+ lrn(" classif.rpart" ,
282+ predict_type = " prob" ),
283+ cp = 0.01 , minsplit = 20 )
284+ mlmethod = list (
285+ mlmethod_m = " graph_learner" ,
286+ mlmethod_g = " graph_learner" ,
287+ mlmethod_r = " graph_learner" )
288+ params = list (
289+ params_g = list (),
290+ params_m = list (),
291+ params_r = list ())
292+ ml_g = mlr3 :: as_learner(pipe_learner )
293+ ml_m = mlr3 :: as_learner(pipe_learner_classif )
294+ ml_r = mlr3 :: as_learner(pipe_learner_classif )
295+ } else {
296+ ml_g = mlr3 :: lrn(mlmethod $ mlmethod_g )
297+ ml_g $ param_set $ values = params $ params_g
298+ ml_m = mlr3 :: lrn(mlmethod $ mlmethod_m , predict_type = " prob" )
299+ ml_m $ param_set $ values = params $ params_m
300+ ml_r = mlr3 :: lrn(mlmethod $ mlmethod_r , predict_type = " prob" )
301+ ml_r $ param_set $ values = params $ params_r
302+ }
228303
229304 return (list (
230305 mlmethod = mlmethod , params = params ,
0 commit comments