@@ -2068,6 +2068,39 @@ def test_multi_allreduce_time(self, mode):
20682068 assert np .isclose (np .max (g .data ), 4356.0 )
20692069 assert np .isclose (np .max (h .data ), 4356.0 )
20702070
2071+ @pytest .mark .parallel (mode = 1 )
2072+ def test_multi_allreduce_time_cond (self , mode ):
2073+ space_order = 8
2074+ nx , ny = 11 , 11
2075+
2076+ grid = Grid (shape = (nx , ny ))
2077+ tt = grid .time_dim
2078+ nt = 20
2079+ ct = ConditionalDimension (name = "ct" , parent = tt , factor = 2 )
2080+
2081+ ux = TimeFunction (name = "ux" , grid = grid , time_order = 1 , space_order = space_order )
2082+ g = TimeFunction (name = "g" , grid = grid , dimensions = (ct , ), shape = (int (nt / 2 ),),
2083+ time_dim = ct )
2084+ h = TimeFunction (name = "h" , grid = grid , dimensions = (ct , ), shape = (int (nt / 2 ),),
2085+ time_dim = ct )
2086+
2087+ op = Operator ([Eq (g , 0 ), Eq (ux .forward , tt ), Inc (g , ux ), Inc (h , ux )], name = "Op" )
2088+ assert_structure (op , ['t' , 't,x,y' , 't,x,y' ], 'txyxy' )
2089+
2090+ # Make sure the two allreduce calls are in the time the loop
2091+ iters = FindNodes (Iteration ).visit (op )
2092+ for i in iters :
2093+ if i .dim .is_Time :
2094+ assert len (FindNodes (Call ).visit (i )) == 2 # Two allreduce
2095+ else :
2096+ assert len (FindNodes (Call ).visit (i )) == 0
2097+
2098+ op .apply (time_m = 0 , time_M = nt - 1 )
2099+
2100+ expected = [nx * ny * max (t - 1 , 0 ) for t in range (0 , nt , 2 )]
2101+ assert np .allclose (g .data , expected )
2102+ assert np .allclose (h .data , expected )
2103+
20712104
20722105class TestOperatorAdvanced :
20732106
0 commit comments