Skip to content

Commit 367b518

Browse files
authored
Merge branch 'main' into rename_inames
2 parents f00e72a + 9e69cf4 commit 367b518

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

loopy/preprocess.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,8 @@ def map_scan_local(expr, rec, callables_table, nresults, arg_dtypes,
17511751

17521752
def map_reduction(expr, rec, callables_table,
17531753
guarding_predicates, nresults=1):
1754+
nonlocal insn_changed
1755+
17541756
# Only expand one level of reduction at a time, going from outermost to
17551757
# innermost. Otherwise we get the (iname + insn) dependencies wrong.
17561758

@@ -1827,6 +1829,10 @@ def _error_if_force_scan_on(cls, msg):
18271829
", ".join(str(kernel.iname_tags(iname))
18281830
for iname in bad_inames)))
18291831

1832+
# }}}
1833+
1834+
insn_changed = True
1835+
18301836
if n_local_par == 0 and n_sequential == 0:
18311837
from loopy.diagnostic import warn_with_kernel
18321838
warn_with_kernel(kernel, "empty_reduction",
@@ -1840,8 +1846,6 @@ def _error_if_force_scan_on(cls, msg):
18401846

18411847
return expr.expr, callables_table
18421848

1843-
# }}}
1844-
18451849
if may_be_implemented_as_scan:
18461850
assert force_scan or automagic_scans_ok
18471851

@@ -1916,7 +1920,7 @@ def _error_if_force_scan_on(cls, msg):
19161920
domains = kernel.domains[:]
19171921

19181922
temp_kernel = kernel
1919-
changed = False
1923+
kernel_changed = False
19201924

19211925
import loopy as lp
19221926
while insn_queue:
@@ -1925,6 +1929,7 @@ def _error_if_force_scan_on(cls, msg):
19251929
new_insn_add_within_inames = set()
19261930

19271931
generated_insns = []
1932+
insn_changed = False
19281933

19291934
insn = insn_queue.pop(0)
19301935

@@ -1947,7 +1952,7 @@ def _error_if_force_scan_on(cls, msg):
19471952
callables_table=cb_mapper.callables_table,
19481953
guarding_predicates=insn.predicates),
19491954

1950-
if generated_insns:
1955+
if insn_changed:
19511956
# An expansion happened, so insert the generated stuff plus
19521957
# ourselves back into the queue.
19531958

@@ -2010,14 +2015,14 @@ def _error_if_force_scan_on(cls, msg):
20102015
domains=domains)
20112016
temp_kernel = lp.replace_instruction_ids(
20122017
temp_kernel, insn_id_replacements)
2013-
changed = True
2018+
kernel_changed = True
20142019
else:
20152020
# nothing happened, we're done with insn
20162021
assert not new_insn_add_depends_on
20172022

20182023
new_insns.append(insn)
20192024

2020-
if changed:
2025+
if kernel_changed:
20212026
kernel = kernel.copy(
20222027
instructions=new_insns,
20232028
temporary_variables=new_temporary_variables,

test/test_reduction.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,27 @@ def test_any_all(ctx_factory):
460460
assert not out_dict["out2"].get()
461461

462462

463+
def test_reduction_without_inames(ctx_factory):
464+
"""Ensure that reductions with no inames get rewritten to the element
465+
being reduced over. This was sometimes erroneously eliminated because
466+
reduction realization used the generation of new statements as a criterion
467+
for whether work was done.
468+
"""
469+
ctx = ctx_factory()
470+
cq = cl.CommandQueue(ctx)
471+
472+
knl = lp.make_kernel(
473+
"{:}",
474+
"""
475+
out = reduce(any, [], 5)
476+
""")
477+
knl = lp.set_options(knl, return_dict=True)
478+
479+
_, out_dict = knl(cq)
480+
481+
assert out_dict["out"].get() == 5
482+
483+
463484
if __name__ == "__main__":
464485
if len(sys.argv) > 1:
465486
exec(sys.argv[1])

0 commit comments

Comments
 (0)