@@ -270,7 +270,7 @@ def lambda_fn(h, W1, W2):
270270
271271 f = function ([h0 , W1 , W2 ], o , mode = self .mode )
272272
273- scan_node = [ x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan )][ 0 ]
273+ scan_node = next ( x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan ))
274274 assert (
275275 len (
276276 [
@@ -444,9 +444,9 @@ def test_dot_not_output(self):
444444 # Ensure that the optimization was performed correctly in f_opt
445445 # The inner function of scan should have only one output and it should
446446 # not be the result of a Dot
447- scan_node = [
447+ scan_node = next (
448448 node for node in f_opt .maker .fgraph .toposort () if isinstance (node .op , Scan )
449- ][ 0 ]
449+ )
450450 assert len (scan_node .op .inner_outputs ) == 1
451451 assert not isinstance (scan_node .op .inner_outputs [0 ], Dot )
452452
@@ -488,9 +488,9 @@ def inner_fct(vect, mat):
488488 # Ensure that the optimization was performed correctly in f_opt
489489 # The inner function of scan should have only one output and it should
490490 # not be the result of a Dot
491- scan_node = [
491+ scan_node = next (
492492 node for node in f_opt .maker .fgraph .toposort () if isinstance (node .op , Scan )
493- ][ 0 ]
493+ )
494494 # NOTE: WHEN INFER_SHAPE IS RE-ENABLED, BELOW THE SCAN MUST
495495 # HAVE ONLY 1 OUTPUT.
496496 assert len (scan_node .op .inner_outputs ) == 2
@@ -536,9 +536,9 @@ def inner_fct(seq1, previous_output1, nonseq1):
536536 # Ensure that the optimization was performed correctly in f_opt
537537 # The inner function of scan should have only one output and it should
538538 # not be the result of a Dot
539- scan_node = [
539+ scan_node = next (
540540 node for node in f_opt .maker .fgraph .toposort () if isinstance (node .op , Scan )
541- ][ 0 ]
541+ )
542542 assert len (scan_node .op .inner_outputs ) == 2
543543 assert not isinstance (scan_node .op .inner_outputs [0 ], Dot )
544544
@@ -1639,7 +1639,7 @@ def lambda_fn(h, W1, W2):
16391639 )
16401640
16411641 f = function ([h0 , W1 , W2 ], o , mode = get_default_mode ().including ("scan" ))
1642- scan_node = [ x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan )][ 0 ]
1642+ scan_node = next ( x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan ))
16431643 assert (
16441644 len (
16451645 [
@@ -1673,7 +1673,7 @@ def lambda_fn(W1, h, W2):
16731673 )
16741674
16751675 f = function ([h0 , W1 , W2 ], o , mode = get_default_mode ().including ("scan" ))
1676- scan_node = [ x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan )][ 0 ]
1676+ scan_node = next ( x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan ))
16771677
16781678 assert (
16791679 len (
@@ -1709,7 +1709,7 @@ def lambda_fn(W1, h, W2):
17091709
17101710 # TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
17111711 f = function ([_h0 , _W1 , _W2 ], o , mode = "FAST_RUN" )
1712- scan_node = [ x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan )][ 0 ]
1712+ scan_node = next ( x for x in f .maker .fgraph .toposort () if isinstance (x .op , Scan ))
17131713
17141714 assert len (scan_node .op .inner_inputs ) == 1
17151715
0 commit comments