@@ -227,7 +227,7 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):
227227
228228 return Apply (
229229 self ,
230- [condition ] + new_inputs_true_branch + new_inputs_false_branch ,
230+ [condition , * new_inputs_true_branch , * new_inputs_false_branch ] ,
231231 output_vars ,
232232 )
233233
@@ -275,11 +275,11 @@ def grad(self, ins, grads):
275275 # condition + epsilon always triggers the same branch as condition
276276 condition_grad = condition .zeros_like ().astype (config .floatX )
277277
278- return (
279- [ condition_grad ]
280- + if_true_op (* inputs_true_grad , return_list = True )
281- + if_false_op (* inputs_false_grad , return_list = True )
282- )
278+ return [
279+ condition_grad ,
280+ * if_true_op (* inputs_true_grad , return_list = True ),
281+ * if_false_op (* inputs_false_grad , return_list = True ),
282+ ]
283283
284284 def make_thunk (self , node , storage_map , compute_map , no_recycling , impl = None ):
285285 cond = node .inputs [0 ]
@@ -397,7 +397,7 @@ def ifelse(
397397
398398 new_ifelse = IfElse (n_outs = len (then_branch ), as_view = False , name = name )
399399
400- ins = [condition ] + list (then_branch ) + list (else_branch )
400+ ins = [condition , * list (then_branch ), * list (else_branch )]
401401 rval = new_ifelse (* ins , return_list = True )
402402
403403 if rval_type is None :
@@ -611,7 +611,7 @@ def apply(self, fgraph):
611611 mn_fs = merging_node .inputs [1 :][merging_node .op .n_outs :]
612612 pl_ts = proposal .inputs [1 :][: proposal .op .n_outs ]
613613 pl_fs = proposal .inputs [1 :][proposal .op .n_outs :]
614- new_ins = [merging_node .inputs [0 ]] + mn_ts + pl_ts + mn_fs + pl_fs
614+ new_ins = [merging_node .inputs [0 ], * mn_ts , * pl_ts , * mn_fs , * pl_fs ]
615615 mn_name = "?"
616616 if merging_node .op .name :
617617 mn_name = merging_node .op .name
@@ -673,7 +673,7 @@ def cond_remove_identical(fgraph, node):
673673
674674 new_ifelse = IfElse (n_outs = len (nw_ts ), as_view = op .as_view , name = op .name )
675675
676- new_ins = [node .inputs [0 ]] + nw_ts + nw_fs
676+ new_ins = [node .inputs [0 ], * nw_ts , * nw_fs ]
677677 new_outs = new_ifelse (* new_ins , return_list = True )
678678
679679 rval = []
@@ -711,7 +711,7 @@ def cond_merge_random_op(fgraph, main_node):
711711 mn_fs = merging_node .inputs [1 :][merging_node .op .n_outs :]
712712 pl_ts = proposal .inputs [1 :][: proposal .op .n_outs ]
713713 pl_fs = proposal .inputs [1 :][proposal .op .n_outs :]
714- new_ins = [merging_node .inputs [0 ]] + mn_ts + pl_ts + mn_fs + pl_fs
714+ new_ins = [merging_node .inputs [0 ], * mn_ts , * pl_ts , * mn_fs , * pl_fs ]
715715 mn_name = "?"
716716 if merging_node .op .name :
717717 mn_name = merging_node .op .name
0 commit comments