Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,610 changes: 121 additions & 4,489 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions doc/ref_transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,10 @@ TODO: Matching instruction tags

.. automodule:: loopy.match


Fusing Loops
------------

.. automodule:: loopy.transform.loop_fusion

.. vim: tw=75:spell
6 changes: 6 additions & 0 deletions loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@
simplify_indices,
tag_instructions,
)
from loopy.transform.loop_fusion import (
get_kennedy_unweighted_fusion_candidates,
rename_inames_in_batch,
)
from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call
from loopy.transform.padding import (
add_padding,
Expand Down Expand Up @@ -336,6 +340,7 @@
"get_dot_dependency_graph",
"get_global_barrier_order",
"get_iname_duplication_options",
"get_kennedy_unweighted_fusion_candidates",
"get_mem_access_map",
"get_one_linearized_kernel",
"get_one_scheduled_kernel",
Expand Down Expand Up @@ -382,6 +387,7 @@
"rename_callable",
"rename_iname",
"rename_inames",
"rename_inames_in_batch",
"replace_instruction_ids",
"save_and_reload_temporaries",
"set_argument_order",
Expand Down
6 changes: 3 additions & 3 deletions loopy/auto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def auto_test_vs_ref(

need_check = False

events = []
events: list[cl.Event] = []
queue.finish()

logger.info("%s: warmup done" % (test_entrypoint))
Expand Down Expand Up @@ -645,8 +645,8 @@ def auto_test_vs_ref(
evt_start.wait()
evt_end.wait()

elapsed_event = (1e-9*events[-1].profile.END
- 1e-9*events[0].profile.START) \
elapsed_event = (1e-9*events[-1].profile.end
- 1e-9*events[0].profile.start) \
/ timing_rounds
try:
elapsed_event_marker = ((1e-9*evt_end.profile.start
Expand Down
12 changes: 6 additions & 6 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
check_each_kernel,
)
from loopy.type_inference import TypeReader
from loopy.typing import auto, not_none
from loopy.typing import auto, not_none, set_union


if TYPE_CHECKING:
Expand Down Expand Up @@ -1110,10 +1110,10 @@ def _check_variable_access_ordered_inner(kernel: LoopKernel) -> None:
address_space = _get_address_space(kernel, var)
eq_class = aliasing_equiv_classes[var]

readers = set.union(
*[rmap.get(eq_name, set()) for eq_name in eq_class])
writers = set.union(
*[wmap.get(eq_name, set()) for eq_name in eq_class])
readers = set_union(
rmap.get(eq_name, set()) for eq_name in eq_class)
writers = set_union(
wmap.get(eq_name, set()) for eq_name in eq_class)

for writer in writers:
required_deps = (readers | writers) - {writer}
Expand Down Expand Up @@ -1679,7 +1679,7 @@ def _get_sub_array_ref_swept_range(
return get_access_map(
domain.to_set(),
sar.swept_inames,
kernel.assumptions.to_set()).range()
kernel.assumptions).range()


def _are_sub_array_refs_equivalent(
Expand Down
10 changes: 5 additions & 5 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)
from loopy.tools import update_persistent_hash
from loopy.types import LoopyType, NumpyType
from loopy.typing import PreambleGenerator, SymbolMangler, fset_union, not_none
from loopy.typing import InsnId, PreambleGenerator, SymbolMangler, fset_union, not_none


if TYPE_CHECKING:
Expand Down Expand Up @@ -612,8 +612,8 @@ def insn_inames(self, insn: str | InstructionBase) -> frozenset[InameStr]:
return insn.within_inames

@memoize_method
def iname_to_insns(self):
result = {
def iname_to_insns(self) -> Mapping[InameStr, Set[InsnId]]:
result: dict[InameStr, set[InsnId]] = {
iname: set() for iname in self.all_inames()}
for insn in self.instructions:
for iname in insn.within_inames:
Expand Down Expand Up @@ -692,7 +692,7 @@ def compute_deps(insn_id):
# {{{ read and written variables

@memoize_method
def reader_map(self):
def reader_map(self) -> Mapping[str, Set[InsnId]]:
"""
:return: a dict that maps variable names to ids of insns that read that
variable.
Expand All @@ -710,7 +710,7 @@ def reader_map(self):
return result

@memoize_method
def writer_map(self):
def writer_map(self) -> Mapping[str, Set[InsnId]]:
"""
:return: a dict that maps variable names to ids of insns that write
to that variable.
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
floord mod ceil floor""".split())


def _gather_isl_identifiers(s):
def _gather_isl_identifiers(s: str):
return set(_IDENTIFIER_RE.findall(s)) - _ISL_KEYWORDS


Expand Down Expand Up @@ -2470,7 +2470,7 @@ def make_function(
# does something.
knl = add_inferred_inames(knl)
from loopy.transform.parameter import fix_parameters
knl = fix_parameters(knl, **fixed_parameters)
knl = fix_parameters(knl, within=None, **fixed_parameters)

# -------------------------------------------------------------------------
# Ordering dependency:
Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def stringify(cls, val: AddressSpace | type[auto]) -> str:

# {{{ arguments

class KernelArgument(ImmutableRecord):
class KernelArgument(ImmutableRecord, Taggable):
"""Base class for all argument types.

.. attribute:: name
Expand Down
7 changes: 4 additions & 3 deletions loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,15 +924,16 @@ def with_descrs(self,
for arg in subkernel.args:
kw = arg.name
if isinstance(arg, ArrayBase):
new_arg_id_to_descr[kw] = (
new_arg_descriptor = (
ArrayArgDescriptor(shape=arg.shape,
dim_tags=arg.dim_tags,
address_space=arg.address_space))
else:
assert isinstance(arg, ValueArg)
new_arg_id_to_descr[kw] = ValueArgDescriptor()
new_arg_descriptor = ValueArgDescriptor()

new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_id_to_descr[kw]
# FIXME: Should decide what the canonical arg identifiers are
new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_descriptor

# }}}

Expand Down
Loading
Loading