Skip to content

Commit 8e34173

Browse files
committed
Fix fetch_idx
1 parent 95132cb commit 8e34173

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

sumpy/expansion/loopy.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,17 @@ def get_domains(v, iorder, with_sync):
276276
domains = [f"{{ [{x0}_outer]: 0<={x0}_outer<={order//c} }}"]
277277
if with_sync:
278278
expr = f"{c//sync_split}*{x0}_sync_outer + {c}*{x0}_outer"
279-
domains += [f"{{ [{x0}_sync_outer]: 0<={expr}<=order "
279+
domains += [f"{{ [{x0}_sync_outer]: 0<={expr}<={order} "
280280
f"and 0<={x0}_sync_outer<{c//sync_split} }}"]
281281
expr += f" + {v[0]}_inner"
282-
domains += [f"{{ [{v[0]}_inner]: 0<={expr}<=order "
282+
domains += [f"{{ [{v[0]}_inner]: 0<={expr}<={order} "
283283
f"and 0<={v[0]}_inner<{sync_split} }}"]
284284
else:
285285
expr = f"{v[0]}_inner + {c}*{x0}_outer"
286-
domains += [f"{{ [{v[0]}_inner]: 0<={expr}<=order "
286+
domains += [f"{{ [{v[0]}_inner]: 0<={expr}<={order} "
287287
f"and 0<={v[0]}_inner<{c} }}"]
288288
domains += [f"{{ [{v[0]}]: {expr}<={v[0]}<={expr} }}"]
289-
domains += [f"{{ [{iorder}]: {v[0]}<={iorder}<=order }}"]
289+
domains += [f"{{ [{iorder}]: {v[0]}<={iorder}<={order} }}"]
290290
upper_bound = f"{iorder}-{v[0]}"
291291
for i in range(dim - 1, 1, -1):
292292
domains += [f"{{ [{v[i]}]: 0<={v[i]}<={upper_bound} }}"]
@@ -313,7 +313,6 @@ def get_idx(v):
313313
# a synchronization step.
314314
prev_copy_idx = (v[0]//c - 1) % 2
315315
curr_copy_idx = (v[0]//c) % 2
316-
fetch_idx = (v[0]//c) % 2
317316
else:
318317
# We need to sync within the c rows.
319318
# Using the biharmonic 2D example:
@@ -330,7 +329,6 @@ def get_idx(v):
330329
# - Read the rows 4, 5, 6, 7 from coeffs_copy[0, :]
331330
prev_copy_idx = 0
332331
curr_copy_idx = 1
333-
fetch_idx = 0
334332

335333
max_mi_sym = [v[i] - max_mi[i] for i in range(dim)]
336334
scale = -1/deriv_id_to_coeff[max_deriv_id]
@@ -377,6 +375,15 @@ def get_idx(v):
377375
idx = get_idx(v)
378376
domains += get_domains(v, iorder, with_sync=False)[1:]
379377

378+
if c == sync_split:
379+
# We did not have to sync within the c rows.
380+
# We last wrote to coeffs_copy[v[0]//c % 2, :] and we read from it.
381+
fetch_idx = (v[0]//c) % 2
382+
else:
383+
# We need to sync within the c rows.
384+
# We last wrote to coeffs_copy[0, :] and we read from it.
385+
fetch_idx = 0
386+
380387
for ikernel, expr_dict in enumerate(sym_expr_dicts):
381388
expr = sum(coeff * prod(powers[i,
382389
v[i] + max_deriv_order - mi[i]] for i in range(dim))

0 commit comments

Comments
 (0)