@@ -276,17 +276,17 @@ def get_domains(v, iorder, with_sync):
276276 domains = [f"{{ [{ x0 } _outer]: 0<={ x0 } _outer<={ order // c } }}" ]
277277 if with_sync :
278278 expr = f"{ c // sync_split } *{ x0 } _sync_outer + { c } *{ x0 } _outer"
279- domains += [f"{{ [{ x0 } _sync_outer]: 0<={ expr } <=order "
279+ domains += [f"{{ [{ x0 } _sync_outer]: 0<={ expr } <={ order } "
280280 f"and 0<={ x0 } _sync_outer<{ c // sync_split } }}" ]
281281 expr += f" + { v [0 ]} _inner"
282- domains += [f"{{ [{ v [0 ]} _inner]: 0<={ expr } <=order "
282+ domains += [f"{{ [{ v [0 ]} _inner]: 0<={ expr } <={ order } "
283283 f"and 0<={ v [0 ]} _inner<{ sync_split } }}" ]
284284 else :
285285 expr = f"{ v [0 ]} _inner + { c } *{ x0 } _outer"
286- domains += [f"{{ [{ v [0 ]} _inner]: 0<={ expr } <=order "
286+ domains += [f"{{ [{ v [0 ]} _inner]: 0<={ expr } <={ order } "
287287 f"and 0<={ v [0 ]} _inner<{ c } }}" ]
288288 domains += [f"{{ [{ v [0 ]} ]: { expr } <={ v [0 ]} <={ expr } }}" ]
289- domains += [f"{{ [{ iorder } ]: { v [0 ]} <={ iorder } <=order }}" ]
289+ domains += [f"{{ [{ iorder } ]: { v [0 ]} <={ iorder } <={ order } }}" ]
290290 upper_bound = f"{ iorder } -{ v [0 ]} "
291291 for i in range (dim - 1 , 1 , - 1 ):
292292 domains += [f"{{ [{ v [i ]} ]: 0<={ v [i ]} <={ upper_bound } }}" ]
@@ -313,7 +313,6 @@ def get_idx(v):
313313 # a synchronization step.
314314 prev_copy_idx = (v [0 ]// c - 1 ) % 2
315315 curr_copy_idx = (v [0 ]// c ) % 2
316- fetch_idx = (v [0 ]// c ) % 2
317316 else :
318317 # We need to sync within the c rows.
319318 # Using the biharmonic 2D example:
@@ -330,7 +329,6 @@ def get_idx(v):
330329 # - Read the rows 4, 5, 6, 7 from coeffs_copy[0, :]
331330 prev_copy_idx = 0
332331 curr_copy_idx = 1
333- fetch_idx = 0
334332
335333 max_mi_sym = [v [i ] - max_mi [i ] for i in range (dim )]
336334 scale = - 1 / deriv_id_to_coeff [max_deriv_id ]
@@ -377,6 +375,15 @@ def get_idx(v):
377375 idx = get_idx (v )
378376 domains += get_domains (v , iorder , with_sync = False )[1 :]
379377
378+ if c == sync_split :
379+ # We did not have to sync within the c rows.
380+ # We last wrote to coeffs_copy[v[0]//c % 2, :] and we read from it.
381+ fetch_idx = (v [0 ]// c ) % 2
382+ else :
383+ # We need to sync within the c rows.
384+ # We last wrote to coeffs_copy[0, :] and we read from it.
385+ fetch_idx = 0
386+
380387 for ikernel , expr_dict in enumerate (sym_expr_dicts ):
381388 expr = sum (coeff * prod (powers [i ,
382389 v [i ] + max_deriv_order - mi [i ]] for i in range (dim ))
0 commit comments