Skip to content

Commit d6980b7

Browse files
Merge pull request #2589 from devitocodes/nvc-red
compiler: Avoid nvc++ array reductions and add MPI fallback
2 parents 6d494db + 248f95f commit d6980b7

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

devito/passes/iet/languages/openmp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sympy import And, Ne, Not
66

77
from devito.arch import AMDGPUX, NVIDIAX, INTELGPUX, PVC
8-
from devito.arch.compiler import GNUCompiler
8+
from devito.arch.compiler import GNUCompiler, NvidiaCompiler
99
from devito.ir import (Call, Conditional, DeviceCall, List, Pragma, Prodder,
1010
ParallelBlock, PointerCast, While, FindSymbols)
1111
from devito.passes.iet.definitions import DataManager, DeviceAwareDataManager
@@ -224,7 +224,11 @@ def _support_array_reduction(cls, compiler):
224224
if isinstance(compiler, GNUCompiler) and \
225225
compiler.version < Version("6.0"):
226226
return False
227-
return True
227+
elif isinstance(compiler, NvidiaCompiler):
228+
# NVC++ does not support array reduction and leads to segfault
229+
return False
230+
else:
231+
return True
228232

229233

230234
class DeviceOmpizer(PragmaDeviceAwareTransformer):

devito/passes/iet/mpi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,14 @@ def _mark_overlappable(iet):
264264

265265

266266
@iet_pass
267-
def make_halo_exchanges(iet, mpimode=None, **kwargs):
267+
def make_halo_exchanges(iet, mpimode=None, fallback='basic', **kwargs):
268268
"""
269269
Lower HaloSpots into halo exchanges for distributed-memory parallelism.
270270
"""
271271
# To produce unique object names
272272
generators = {'msg': generator(), 'comm': generator(), 'comp': generator()}
273273

274-
sync_heb = HaloExchangeBuilder('basic', generators, **kwargs)
274+
sync_heb = HaloExchangeBuilder(fallback, generators, **kwargs)
275275
user_heb = HaloExchangeBuilder(mpimode, generators, **kwargs)
276276
mapper = {}
277277
for hs in FindNodes(HaloSpot).visit(iet):

0 commit comments

Comments
 (0)