Skip to content

Commit 954e432

Browse files
committed
[looptree] Refactor reads and writes analysis
1 parent f5aa908 commit 954e432

File tree

9 files changed

+216
-180
lines changed

9 files changed

+216
-180
lines changed

pytimeloop/looptree/accesses.py

Lines changed: 180 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,154 +1,217 @@
11
from collections import defaultdict
22
from collections.abc import Mapping
3+
from dataclasses import dataclass
34
from numbers import Number
5+
from typing import Optional, overload
46

57
from bindings.looptree import TemporalTag, SequentialTag, PipelineTemporalTag
68

79
import islpy as isl
810

11+
from pytimeloop.looptree.reuse.isl import IslReuseAnalysisOutput
12+
from pytimeloop.looptree.reuse.summarized import SummarizedAnalysisOutput
13+
914
from pytimeloop.isl.singular import get_sum_of_pw_qpolynomial
1015
from pytimeloop.isl.sum import sum_with_mask
1116
from pytimeloop.looptree.mapping_utilities import *
1217

1318

14-
def get_total_accesses(accesses: Mapping):
19+
@dataclass
20+
class Accesses:
21+
total_reads: Optional[float]
22+
total_writes: Optional[float]
23+
max_per_unit_reads: Optional[float]
24+
max_per_unit_writes: Optional[float]
25+
26+
27+
class BufferAccesses:
28+
def __init__(self):
29+
self.accesses: dict[tuple, Accesses] = {}
30+
31+
def get_accesses(self, buffer, dspace, einsum) -> Accesses:
32+
key = (buffer, dspace, einsum)
33+
if key not in self.accesses:
34+
self.accesses[key] = Accesses(0, 0, 0, 0)
35+
return self.accesses[key]
36+
37+
def items(self):
38+
return self.accesses.items()
39+
40+
def items_with_buffer(self, ref_buffer):
41+
"""Returns iterator similar to `items` but only for `ref_buffer`"""
42+
return (
43+
((buffer, dspace, einsum), value)
44+
for (buffer, dspace, einsum), value in self.accesses.items()
45+
if buffer == ref_buffer
46+
)
47+
48+
49+
@overload
50+
def summarize_total_and_per_unit_actions(
51+
reuse_analysis_result: IslReuseAnalysisOutput
52+
):
53+
pass
54+
@overload
55+
def summarize_total_and_per_unit_actions(
56+
reuse_analysis_result: SummarizedAnalysisOutput
57+
):
58+
pass
59+
60+
def summarize_total_and_per_unit_actions(
61+
reuse_analysis_result
62+
):
1563
result = {}
16-
for k, v in accesses.items():
17-
if isinstance(v, isl.PwQPolynomial):
18-
sum = get_sum_of_pw_qpolynomial(v)
19-
if sum.is_nan():
20-
result[k] = 0
21-
else:
22-
result[k] = sum.to_python()
23-
elif isinstance(v, Number):
24-
result[k] = v
25-
else:
26-
result[k] = v
64+
if isinstance(reuse_analysis_result, IslReuseAnalysisOutput):
65+
for key, (tags, fill) in reuse_analysis_result.fills.items():
66+
read_to_parent = reuse_analysis_result.reads_to_parent[key][1]
67+
read_to_peer = reuse_analysis_result.reads_to_peer[key][1]
68+
69+
total_fill = get_sum_of_pw_qpolynomial(fill)
70+
total_read_to_parent = get_sum_of_pw_qpolynomial(read_to_parent)
71+
total_read_to_peer = get_sum_of_pw_qpolynomial(read_to_peer)
2772

73+
max_per_unit_fill = \
74+
_sum_over_temporal_max_over_spatial(tags, fill)
75+
76+
n_read_to_parent_dim = read_to_parent.dim(isl.dim_type.in_)
77+
max_per_unit_read_to_parent = \
78+
_sum_over_temporal_max_over_spatial(tags[:n_read_to_parent_dim],
79+
read_to_parent)
80+
81+
max_per_unit_read_to_peer = \
82+
_sum_over_temporal_max_over_spatial(tags, read_to_peer)
83+
84+
result[key] = (total_fill,
85+
total_read_to_parent,
86+
total_read_to_peer,
87+
max_per_unit_fill,
88+
max_per_unit_read_to_parent,
89+
max_per_unit_read_to_peer)
90+
elif isinstance(reuse_analysis_result, SummarizedAnalysisOutput):
91+
for key, (tags, fill) in reuse_analysis_result.fills.items():
92+
buffer_id = key[0]
93+
94+
read_to_parent = reuse_analysis_result.reads_to_parent[key][1]
95+
read_to_peer = reuse_analysis_result.reads_to_peer[key][1]
96+
97+
total_fill = fill
98+
total_read_to_parent = read_to_parent
99+
total_read_to_peer = read_to_peer
100+
101+
fanout = reuse_analysis_result.fanout[buffer_id]
102+
103+
max_per_unit_fill = fill / fanout
104+
max_per_unit_read_to_parent = read_to_parent / fanout
105+
max_per_unit_read_to_peer = read_to_peer / fanout
106+
107+
result[key] = (total_fill,
108+
total_read_to_parent,
109+
total_read_to_peer,
110+
max_per_unit_fill,
111+
max_per_unit_read_to_parent,
112+
max_per_unit_read_to_peer)
28113
return result
29114

30115

31-
def reads_and_writes_from_fill_by_parent(fills: Mapping,
32-
reads_to_parent,
33-
mapping,
34-
workload,
35-
is_path=False,
36-
per_unit=False):
116+
117+
@overload
118+
def buffer_accesses_from_buffet_actions(
119+
reuse_analysis_result: IslReuseAnalysisOutput,
120+
mapping,
121+
workload,
122+
is_path=False
123+
) -> BufferAccesses:
124+
pass
125+
@overload
126+
def buffer_accesses_from_buffet_actions(
127+
reuse_analysis_result: SummarizedAnalysisOutput,
128+
mapping,
129+
workload,
130+
is_path=False
131+
) -> BufferAccesses:
132+
pass
133+
# TODO: is_path should be removed and we should accept only regular mappings
134+
def buffer_accesses_from_buffet_actions(
135+
reuse_analysis_result,
136+
mapping,
137+
workload,
138+
is_path=False
139+
) -> BufferAccesses:
37140
mapping = mapping['nodes']
38141
dspace_id_to_name = workload.data_space_id_to_name()
39142
einsum_id_to_name = workload.einsum_id_to_name()
40143

41-
reads = defaultdict(lambda: 0)
42-
writes = defaultdict(lambda: 0)
43144

44145
parent_buffers = get_parent_buffers(mapping, workload, is_path)
45146

46-
einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping, workload, is_path)
147+
einsums_with_complete_mappings = \
148+
get_einsums_with_complete_mappings(mapping, workload, is_path)
47149

48150
compute_targets = set()
49151
for compute_node in get_leaves(mapping, is_path):
50152
assert compute_node["type"] == "compute"
51153
compute_targets.add(compute_node["target"])
52154

53-
for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items():
54-
read_to_parent = reads_to_parent[(buffer_id, dspace_id, einsum_id)][1]
155+
summarized_actions = \
156+
summarize_total_and_per_unit_actions(reuse_analysis_result)
55157

56-
if not per_unit:
57-
read_to_parent = get_sum_of_pw_qpolynomial(read_to_parent)
58-
fill = get_sum_of_pw_qpolynomial(fill)
59-
else:
60-
fill = sum_with_mask(
61-
[
62-
(
63-
isinstance(t, TemporalTag) or
64-
isinstance(t, PipelineTemporalTag) or
65-
isinstance(t, SequentialTag)
66-
)
67-
for t in tags
68-
],
69-
fill
70-
).max().to_python()
71-
n_read_to_parent_dim = read_to_parent.dim(isl.dim_type.in_)
72-
read_to_parent = sum_with_mask(
73-
[
74-
(
75-
isinstance(t, TemporalTag) or
76-
isinstance(t, PipelineTemporalTag) or
77-
isinstance(t, SequentialTag)
78-
)
79-
for t in tags[:n_read_to_parent_dim]
80-
],
81-
read_to_parent
82-
).max().to_python()
158+
accesses_results = BufferAccesses()
159+
for (buffer_id, dspace_id, einsum_id), value in summarized_actions.items():
160+
(
161+
fill,
162+
read_to_parent,
163+
read_to_peer,
164+
max_per_unit_fill,
165+
max_per_unit_read_to_parent,
166+
max_per_unit_read_to_peer
167+
) = value
83168

84169
dspace_name = dspace_id_to_name[dspace_id]
85170
einsum_name = einsum_id_to_name[einsum_id]
86171
if einsum_id not in einsums_with_complete_mappings:
87172
continue
173+
88174
parent_buffer = parent_buffers[(buffer_id, dspace_id, einsum_id)]
89175
if parent_buffer is not None:
90-
key = (parent_buffer, dspace_name, einsum_name)
176+
accesses = accesses_results.get_accesses(parent_buffer,
177+
dspace_name,
178+
einsum_name)
91179
if dspace_id in workload.tensors_written_by_einsum(einsum_id):
92-
writes[key] += read_to_parent
93-
reads[key] += read_to_parent
94-
# Subtracted term: elided first read of a read-write tensor
180+
accesses.total_writes += read_to_parent
181+
accesses.total_reads += read_to_parent
182+
95183
# TODO: figure out how to do this per unit
96-
if not per_unit:
97-
reads[key] -= workload.get_tensor_volume(dspace_id)
184+
total_elided_reads = workload.get_tensor_volume(dspace_id)
185+
accesses.total_reads -= total_elided_reads
186+
187+
accesses.max_per_unit_reads += max_per_unit_read_to_parent
188+
accesses.max_per_unit_writes += max_per_unit_read_to_parent
98189
elif dspace_id in workload.tensors_read_by_einsum(einsum_id):
99-
reads[key] += read_to_parent
190+
accesses.total_reads += read_to_parent
191+
192+
accesses.max_per_unit_reads += read_to_parent
193+
100194
# Fills will write into current buffer except for compute (which does
101195
# not have write action) and top-level buffer
196+
accesses = accesses_results.get_accesses(buffer_id,
197+
dspace_name,
198+
einsum_name)
102199
if buffer_id not in compute_targets and parent_buffer is not None:
103200
if dspace_id in workload.tensors_written_by_einsum(einsum_id):
104-
writes[(buffer_id, dspace_name, einsum_name)] += fill
105-
if not per_unit:
106-
writes[(buffer_id, dspace_name, einsum_name)] -= \
107-
workload.get_tensor_volume(dspace_id)
108-
else:
109-
writes[(buffer_id, dspace_name, einsum_name)] += fill
110-
111-
return reads, writes
112-
113-
114-
def reads_and_writes_from_fill_by_peer(fills: Mapping,
115-
mapping,
116-
workload,
117-
is_path=False,
118-
per_unit=False):
119-
mapping = mapping['nodes']
120-
dspace_id_to_name = workload.data_space_id_to_name()
121-
einsum_id_to_name = workload.einsum_id_to_name()
201+
accesses.total_writes += fill
202+
accesses.max_per_unit_writes += max_per_unit_fill
122203

123-
reads = {}
124-
writes = {}
125-
126-
einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping, workload, is_path)
127-
128-
for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items():
129-
if not per_unit:
130-
fill = get_sum_of_pw_qpolynomial(fill)
131-
else:
132-
fill = sum_with_mask(
133-
[
134-
(
135-
isinstance(t, TemporalTag) or
136-
isinstance(t, PipelineTemporalTag) or
137-
isinstance(t, SequentialTag)
138-
)
139-
for t in tags
140-
],
141-
fill
142-
).max().to_python()
143-
einsum_name = einsum_id_to_name[einsum_id]
144-
dspace_name = dspace_id_to_name[dspace_id]
145-
if einsum_id not in einsums_with_complete_mappings:
146-
continue
204+
# TODO: figure out how to do this per unit
205+
total_elided_writes = workload.get_tensor_volume(dspace_id)
206+
accesses.total_writes -= total_elided_writes
207+
else:
208+
accesses.total_writes += fill
209+
accesses.max_per_unit_writes += max_per_unit_fill
147210

148-
reads[(buffer_id, dspace_name, einsum_name)] = fill
149-
writes[(buffer_id, dspace_name, einsum_name)] = 0 # already accounted for in fill_by_parent
211+
accesses.total_reads += read_to_peer
212+
accesses.max_per_unit_reads += max_per_unit_read_to_peer
150213

151-
return reads, writes
214+
return accesses_results
152215

153216

154217
def get_parent_buffers(mapping, workload, is_path):
@@ -190,4 +253,18 @@ def get_parent_buffers(mapping, workload, is_path):
190253
if dspace_id in dspace_to_top_buffer:
191254
parent_buffers[key] = dspace_to_top_buffer[dspace_id]
192255

193-
return parent_buffers
256+
return parent_buffers
257+
258+
259+
def _sum_over_temporal_max_over_spatial(tags, actions):
260+
return sum_with_mask(
261+
[
262+
(
263+
isinstance(t, TemporalTag) or
264+
isinstance(t, PipelineTemporalTag) or
265+
isinstance(t, SequentialTag)
266+
)
267+
for t in tags
268+
],
269+
actions
270+
).max().to_python()

0 commit comments

Comments
 (0)