Skip to content

Commit aed4fa6

Browse files
authored
Merge pull request #2671 from devitocodes/muli-allred-v3
compiler: fix conditional reductions
2 parents 7783a27 + cb1a5c6 commit aed4fa6

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,8 @@ def reduction_comms(clusters):
457457
# if `c`'s IterationSpace is such that the reduction can be carried out
458458
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
459459
for ispace, reds in groupby(found, key=lambda r: r.ispace):
460-
exprs = [Eq(dr.var, dr) for dr in reds]
461-
processed.append(Cluster(exprs=exprs, ispace=ispace))
460+
exprs = flatten([dr.exprs for dr in reds])
461+
processed.append(c.rebuild(exprs=exprs, ispace=ispace))
462462

463463
# Detect the global distributed reductions in `c`
464464
for e in c.exprs:
@@ -487,15 +487,16 @@ def reduction_comms(clusters):
487487
# The IterationSpace within which the global distributed reduction
488488
# must be carried out
489489
ispace = c.ispace.prefix(lambda d: d in var.free_symbols)
490-
491-
fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace))
490+
expr = [Eq(var, DistReduce(var, op=op, grid=grid, ispace=ispace))]
491+
fifo.append(c.rebuild(exprs=expr, ispace=ispace))
492492

493493
processed.append(c)
494494

495495
# Leftover reductions are placed at the very end
496496
for ispace, reds in groupby(fifo, key=lambda r: r.ispace):
497-
exprs = [Eq(dr.var, dr) for dr in reds]
498-
processed.append(Cluster(exprs=exprs, ispace=ispace))
497+
reds = list(reds)
498+
exprs = flatten([dr.exprs for dr in reds])
499+
processed.append(reds[0].rebuild(exprs=exprs, ispace=ispace))
499500

500501
return processed
501502

tests/test_mpi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,39 @@ def test_multi_allreduce_time(self, mode):
20682068
assert np.isclose(np.max(g.data), 4356.0)
20692069
assert np.isclose(np.max(h.data), 4356.0)
20702070

2071+
@pytest.mark.parallel(mode=1)
2072+
def test_multi_allreduce_time_cond(self, mode):
2073+
space_order = 8
2074+
nx, ny = 11, 11
2075+
2076+
grid = Grid(shape=(nx, ny))
2077+
tt = grid.time_dim
2078+
nt = 20
2079+
ct = ConditionalDimension(name="ct", parent=tt, factor=2)
2080+
2081+
ux = TimeFunction(name="ux", grid=grid, time_order=1, space_order=space_order)
2082+
g = TimeFunction(name="g", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2083+
time_dim=ct)
2084+
h = TimeFunction(name="h", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2085+
time_dim=ct)
2086+
2087+
op = Operator([Eq(g, 0), Eq(ux.forward, tt), Inc(g, ux), Inc(h, ux)], name="Op")
2088+
assert_structure(op, ['t', 't,x,y', 't,x,y'], 'txyxy')
2089+
2090+
# Make sure the two allreduce calls are in the time the loop
2091+
iters = FindNodes(Iteration).visit(op)
2092+
for i in iters:
2093+
if i.dim.is_Time:
2094+
assert len(FindNodes(Call).visit(i)) == 2 # Two allreduce
2095+
else:
2096+
assert len(FindNodes(Call).visit(i)) == 0
2097+
2098+
op.apply(time_m=0, time_M=nt-1)
2099+
2100+
expected = [nx * ny * max(t-1, 0) for t in range(0, nt, 2)]
2101+
assert np.allclose(g.data, expected)
2102+
assert np.allclose(h.data, expected)
2103+
20712104

20722105
class TestOperatorAdvanced:
20732106

0 commit comments

Comments
 (0)