@@ -399,6 +399,83 @@ class HaloComms(Queue):
399399 def process (self , clusters ):
400400 return self ._process_fatd (clusters , 1 , seen = set ())
401401
402+ def _derive_halo_schemes (self , c ):
403+ hs = HaloScheme (c .exprs , c .ispace )
404+
405+ # 95% of the times we will just return `hs` as is as there are no guards
406+ if not c .guards :
407+ yield hs , c
408+ return
409+
410+ # This is a more contrived situation in which we might need halo exchanges
411+ # from multiple so called loc-indices -- let's check this out
412+ candidates = []
413+ for f , hse in hs .fmapper .items ():
414+ reads = c .scope .reads [f ]
415+
416+ for d in hse .loc_indices :
417+ if not d ._defines & set (c .guards ):
418+ continue
419+
420+ candidates .append (as_mapper (reads , key = lambda i : i [d ]).values ())
421+
422+ # 4% of the times we will just return `hs` as is
423+ # E.g., we end up here when taking space derivatives of one or more saved
424+ # TimeFunctions in equations evaluating gradients that are controlled by
425+ # a ConditionalDimension (otherwise we would have exited earlier)
426+ if any (len (g ) <= 1 for g in candidates ):
427+ yield hs , c
428+ return
429+
430+ # 1% of the times, finally, we end up here...
431+ # At this point we have to create a mock Cluster for each loc-index,
432+ # containing all and only the accesses to `f` at a given loc-index
433+ # E.g., a mock Cluster at `loc_index=t0` containing the accesses
434+ # `[u(t0, x + 8, ...), u(t0, x + 7, ...)], another mock Cluster at
435+ # `loc_index=t1` containing the accesses `[u(t1, x + 5, ...),
436+ # u(t1, x + 6, ...)]`, and so on
437+ for unordered_groups in candidates :
438+ # Sort for deterministic code generation
439+ groups = sorted (unordered_groups , key = str )
440+ for group in groups :
441+ pointset = sympy .Function ('pointset' )
442+ v = pointset (* [i .access for i in group ])
443+ exprs = [e .func (rhs = v ) for e in c .exprs ]
444+
445+ c1 = c .rebuild (exprs = exprs )
446+
447+ hs = HaloScheme (c1 .exprs , c .ispace )
448+
449+ yield hs , c1
450+
451+ def _make_halo_touch (self , hs , c , prefix ):
452+ points = set ()
453+ for f in hs .fmapper :
454+ for a in c .scope .getreads (f ):
455+ points .add (a .access )
456+
457+ # We also add all written symbols to ultimately create mock WARs
458+ # with `c`, which will prevent the newly created HaloTouch from
459+ # ever being rescheduled
460+ points .update (a .access for a in c .scope .accesses if a .is_write )
461+
462+ # Sort for determinism
463+ # NOTE: not sorting might impact code generation. The order of
464+ # the args is important because that's what search functions honor!
465+ points = sorted (points , key = str )
466+
467+ # Construct the HaloTouch Cluster
468+ expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
469+
470+ key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
471+ ispace = c .ispace .project (key )
472+ # HaloTouches are not parallel
473+ properties = c .properties .sequentialize ()
474+
475+ halo_touch = c .rebuild (exprs = expr , ispace = ispace , properties = properties )
476+
477+ return halo_touch
478+
402479 def callback (self , clusters , prefix , seen = None ):
403480 if not prefix :
404481 return clusters
@@ -412,38 +489,18 @@ def callback(self, clusters, prefix, seen=None):
412489 c in seen :
413490 continue
414491
415- hs = HaloScheme (c .exprs , c .ispace )
416- if hs .is_void or \
417- not d ._defines & hs .distributed_aindices :
418- continue
419-
420- points = set ()
421- for f in hs .fmapper :
422- for a in c .scope .getreads (f ):
423- points .add (a .access )
424-
425- # We also add all written symbols to ultimately create mock WARs
426- # with `c`, which will prevent the newly created HaloTouch to ever
427- # be rescheduled after `c` upon topological sorting
428- points .update (a .access for a in c .scope .accesses if a .is_write )
492+ seen .add (c )
429493
430- # Sort for determinism
431- # NOTE: not sorting might impact code generation. The order of
432- # the args is important because that's what search functions honor!
433- points = sorted (points , key = str )
434-
435- # Construct the HaloTouch Cluster
436- expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
494+ for hs , c1 in self ._derive_halo_schemes (c ):
495+ if hs .is_void or \
496+ not d ._defines & hs .distributed_aindices :
497+ continue
437498
438- key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
439- ispace = c .ispace .project (key )
440- # HaloTouches are not parallel
441- properties = c .properties .sequentialize ()
499+ halo_touch = self ._make_halo_touch (hs , c1 , prefix )
442500
443- halo_touch = c . rebuild ( exprs = expr , ispace = ispace , properties = properties )
501+ processed . append ( halo_touch )
444502
445- processed .append (halo_touch )
446- seen .update ({halo_touch , c })
503+ seen .add (halo_touch )
447504
448505 processed .extend (clusters )
449506
0 commit comments