11from collections import defaultdict
22from collections .abc import Mapping
3+ from dataclasses import dataclass
34from numbers import Number
5+ from typing import Optional , overload
46
57from bindings .looptree import TemporalTag , SequentialTag , PipelineTemporalTag
68
79import islpy as isl
810
11+ from pytimeloop .looptree .reuse .isl import IslReuseAnalysisOutput
12+ from pytimeloop .looptree .reuse .summarized import SummarizedAnalysisOutput
13+
914from pytimeloop .isl .singular import get_sum_of_pw_qpolynomial
1015from pytimeloop .isl .sum import sum_with_mask
1116from pytimeloop .looptree .mapping_utilities import *
1217
1318
14- def get_total_accesses (accesses : Mapping ):
19+ @dataclass
20+ class Accesses :
21+ total_reads : Optional [float ]
22+ total_writes : Optional [float ]
23+ max_per_unit_reads : Optional [float ]
24+ max_per_unit_writes : Optional [float ]
25+
26+
27+ class BufferAccesses :
28+ def __init__ (self ):
29+ self .accesses : dict [tuple , Accesses ] = {}
30+
31+ def get_accesses (self , buffer , dspace , einsum ) -> Accesses :
32+ key = (buffer , dspace , einsum )
33+ if key not in self .accesses :
34+ self .accesses [key ] = Accesses (0 , 0 , 0 , 0 )
35+ return self .accesses [key ]
36+
37+ def items (self ):
38+ return self .accesses .items ()
39+
40+ def items_with_buffer (self , ref_buffer ):
41+ """Returns iterator similar to `items` but only for `ref_buffer`"""
42+ return (
43+ ((buffer , dspace , einsum ), value )
44+ for (buffer , dspace , einsum ), value in self .accesses .items ()
45+ if buffer == ref_buffer
46+ )
47+
48+
49+ @overload
50+ def summarize_total_and_per_unit_actions (
51+ reuse_analysis_result : IslReuseAnalysisOutput
52+ ):
53+ pass
54+ @overload
55+ def summarize_total_and_per_unit_actions (
56+ reuse_analysis_result : SummarizedAnalysisOutput
57+ ):
58+ pass
59+
60+ def summarize_total_and_per_unit_actions (
61+ reuse_analysis_result
62+ ):
1563 result = {}
16- for k , v in accesses .items ():
17- if isinstance (v , isl .PwQPolynomial ):
18- sum = get_sum_of_pw_qpolynomial (v )
19- if sum .is_nan ():
20- result [k ] = 0
21- else :
22- result [k ] = sum .to_python ()
23- elif isinstance (v , Number ):
24- result [k ] = v
25- else :
26- result [k ] = v
64+ if isinstance (reuse_analysis_result , IslReuseAnalysisOutput ):
65+ for key , (tags , fill ) in reuse_analysis_result .fills .items ():
66+ read_to_parent = reuse_analysis_result .reads_to_parent [key ][1 ]
67+ read_to_peer = reuse_analysis_result .reads_to_peer [key ][1 ]
68+
69+ total_fill = get_sum_of_pw_qpolynomial (fill )
70+ total_read_to_parent = get_sum_of_pw_qpolynomial (read_to_parent )
71+ total_read_to_peer = get_sum_of_pw_qpolynomial (read_to_peer )
2772
73+ max_per_unit_fill = \
74+ _sum_over_temporal_max_over_spatial (tags , fill )
75+
76+ n_read_to_parent_dim = read_to_parent .dim (isl .dim_type .in_ )
77+ max_per_unit_read_to_parent = \
78+ _sum_over_temporal_max_over_spatial (tags [:n_read_to_parent_dim ],
79+ read_to_parent )
80+
81+ max_per_unit_read_to_peer = \
82+ _sum_over_temporal_max_over_spatial (tags , read_to_peer )
83+
84+ result [key ] = (total_fill ,
85+ total_read_to_parent ,
86+ total_read_to_peer ,
87+ max_per_unit_fill ,
88+ max_per_unit_read_to_parent ,
89+ max_per_unit_read_to_peer )
90+ elif isinstance (reuse_analysis_result , SummarizedAnalysisOutput ):
91+ for key , (tags , fill ) in reuse_analysis_result .fills .items ():
92+ buffer_id = key [0 ]
93+
94+ read_to_parent = reuse_analysis_result .reads_to_parent [key ][1 ]
95+ read_to_peer = reuse_analysis_result .reads_to_peer [key ][1 ]
96+
97+ total_fill = fill
98+ total_read_to_parent = read_to_parent
99+ total_read_to_peer = read_to_peer
100+
101+ fanout = reuse_analysis_result .fanout [buffer_id ]
102+
103+ max_per_unit_fill = fill / fanout
104+ max_per_unit_read_to_parent = read_to_parent / fanout
105+ max_per_unit_read_to_peer = read_to_peer / fanout
106+
107+ result [key ] = (total_fill ,
108+ total_read_to_parent ,
109+ total_read_to_peer ,
110+ max_per_unit_fill ,
111+ max_per_unit_read_to_parent ,
112+ max_per_unit_read_to_peer )
28113 return result
29114
30115
31- def reads_and_writes_from_fill_by_parent (fills : Mapping ,
32- reads_to_parent ,
33- mapping ,
34- workload ,
35- is_path = False ,
36- per_unit = False ):
116+
117+ @overload
118+ def buffer_accesses_from_buffet_actions (
119+ reuse_analysis_result : IslReuseAnalysisOutput ,
120+ mapping ,
121+ workload ,
122+ is_path = False
123+ ) -> BufferAccesses :
124+ pass
125+ @overload
126+ def buffer_accesses_from_buffet_actions (
127+ reuse_analysis_result : SummarizedAnalysisOutput ,
128+ mapping ,
129+ workload ,
130+ is_path = False
131+ ) -> BufferAccesses :
132+ pass
133+ # TODO: is_path should be removed and we should accept only regular mappings
134+ def buffer_accesses_from_buffet_actions (
135+ reuse_analysis_result ,
136+ mapping ,
137+ workload ,
138+ is_path = False
139+ ) -> BufferAccesses :
37140 mapping = mapping ['nodes' ]
38141 dspace_id_to_name = workload .data_space_id_to_name ()
39142 einsum_id_to_name = workload .einsum_id_to_name ()
40143
41- reads = defaultdict (lambda : 0 )
42- writes = defaultdict (lambda : 0 )
43144
44145 parent_buffers = get_parent_buffers (mapping , workload , is_path )
45146
46- einsums_with_complete_mappings = get_einsums_with_complete_mappings (mapping , workload , is_path )
147+ einsums_with_complete_mappings = \
148+ get_einsums_with_complete_mappings (mapping , workload , is_path )
47149
48150 compute_targets = set ()
49151 for compute_node in get_leaves (mapping , is_path ):
50152 assert compute_node ["type" ] == "compute"
51153 compute_targets .add (compute_node ["target" ])
52154
53- for ( buffer_id , dspace_id , einsum_id ), ( tags , fill ) in fills . items ():
54- read_to_parent = reads_to_parent [( buffer_id , dspace_id , einsum_id )][ 1 ]
155+ summarized_actions = \
156+ summarize_total_and_per_unit_actions ( reuse_analysis_result )
55157
56- if not per_unit :
57- read_to_parent = get_sum_of_pw_qpolynomial (read_to_parent )
58- fill = get_sum_of_pw_qpolynomial (fill )
59- else :
60- fill = sum_with_mask (
61- [
62- (
63- isinstance (t , TemporalTag ) or
64- isinstance (t , PipelineTemporalTag ) or
65- isinstance (t , SequentialTag )
66- )
67- for t in tags
68- ],
69- fill
70- ).max ().to_python ()
71- n_read_to_parent_dim = read_to_parent .dim (isl .dim_type .in_ )
72- read_to_parent = sum_with_mask (
73- [
74- (
75- isinstance (t , TemporalTag ) or
76- isinstance (t , PipelineTemporalTag ) or
77- isinstance (t , SequentialTag )
78- )
79- for t in tags [:n_read_to_parent_dim ]
80- ],
81- read_to_parent
82- ).max ().to_python ()
158+ accesses_results = BufferAccesses ()
159+ for (buffer_id , dspace_id , einsum_id ), value in summarized_actions .items ():
160+ (
161+ fill ,
162+ read_to_parent ,
163+ read_to_peer ,
164+ max_per_unit_fill ,
165+ max_per_unit_read_to_parent ,
166+ max_per_unit_read_to_peer
167+ ) = value
83168
84169 dspace_name = dspace_id_to_name [dspace_id ]
85170 einsum_name = einsum_id_to_name [einsum_id ]
86171 if einsum_id not in einsums_with_complete_mappings :
87172 continue
173+
88174 parent_buffer = parent_buffers [(buffer_id , dspace_id , einsum_id )]
89175 if parent_buffer is not None :
90- key = (parent_buffer , dspace_name , einsum_name )
176+ accesses = accesses_results .get_accesses (parent_buffer ,
177+ dspace_name ,
178+ einsum_name )
91179 if dspace_id in workload .tensors_written_by_einsum (einsum_id ):
92- writes [ key ] += read_to_parent
93- reads [ key ] += read_to_parent
94- # Subtracted term: elided first read of a read-write tensor
180+ accesses . total_writes += read_to_parent
181+ accesses . total_reads += read_to_parent
182+
95183 # TODO: figure out how to do this per unit
96- if not per_unit :
97- reads [key ] -= workload .get_tensor_volume (dspace_id )
184+ total_elided_reads = workload .get_tensor_volume (dspace_id )
185+ accesses .total_reads -= total_elided_reads
186+
187+ accesses .max_per_unit_reads += max_per_unit_read_to_parent
188+ accesses .max_per_unit_writes += max_per_unit_read_to_parent
98189 elif dspace_id in workload .tensors_read_by_einsum (einsum_id ):
99- reads [key ] += read_to_parent
190+ accesses .total_reads += read_to_parent
191+
192+ accesses .max_per_unit_reads += read_to_parent
193+
100194 # Fills will write into current buffer except for compute (which does
101195 # not have write action) and top-level buffer
196+ accesses = accesses_results .get_accesses (buffer_id ,
197+ dspace_name ,
198+ einsum_name )
102199 if buffer_id not in compute_targets and parent_buffer is not None :
103200 if dspace_id in workload .tensors_written_by_einsum (einsum_id ):
104- writes [(buffer_id , dspace_name , einsum_name )] += fill
105- if not per_unit :
106- writes [(buffer_id , dspace_name , einsum_name )] -= \
107- workload .get_tensor_volume (dspace_id )
108- else :
109- writes [(buffer_id , dspace_name , einsum_name )] += fill
110-
111- return reads , writes
112-
113-
114- def reads_and_writes_from_fill_by_peer (fills : Mapping ,
115- mapping ,
116- workload ,
117- is_path = False ,
118- per_unit = False ):
119- mapping = mapping ['nodes' ]
120- dspace_id_to_name = workload .data_space_id_to_name ()
121- einsum_id_to_name = workload .einsum_id_to_name ()
201+ accesses .total_writes += fill
202+ accesses .max_per_unit_writes += max_per_unit_fill
122203
123- reads = {}
124- writes = {}
125-
126- einsums_with_complete_mappings = get_einsums_with_complete_mappings (mapping , workload , is_path )
127-
128- for (buffer_id , dspace_id , einsum_id ), (tags , fill ) in fills .items ():
129- if not per_unit :
130- fill = get_sum_of_pw_qpolynomial (fill )
131- else :
132- fill = sum_with_mask (
133- [
134- (
135- isinstance (t , TemporalTag ) or
136- isinstance (t , PipelineTemporalTag ) or
137- isinstance (t , SequentialTag )
138- )
139- for t in tags
140- ],
141- fill
142- ).max ().to_python ()
143- einsum_name = einsum_id_to_name [einsum_id ]
144- dspace_name = dspace_id_to_name [dspace_id ]
145- if einsum_id not in einsums_with_complete_mappings :
146- continue
204+ # TODO: figure out how to do this per unit
205+ total_elided_writes = workload .get_tensor_volume (dspace_id )
206+ accesses .total_writes -= total_elided_writes
207+ else :
208+ accesses .total_writes += fill
209+ accesses .max_per_unit_writes += max_per_unit_fill
147210
148- reads [( buffer_id , dspace_name , einsum_name )] = fill
149- writes [( buffer_id , dspace_name , einsum_name )] = 0 # already accounted for in fill_by_parent
211+ accesses . total_reads += read_to_peer
212+ accesses . max_per_unit_reads += max_per_unit_read_to_peer
150213
151- return reads , writes
214+ return accesses_results
152215
153216
154217def get_parent_buffers (mapping , workload , is_path ):
@@ -190,4 +253,18 @@ def get_parent_buffers(mapping, workload, is_path):
190253 if dspace_id in dspace_to_top_buffer :
191254 parent_buffers [key ] = dspace_to_top_buffer [dspace_id ]
192255
193- return parent_buffers
256+ return parent_buffers
257+
258+
259+ def _sum_over_temporal_max_over_spatial (tags , actions ):
260+ return sum_with_mask (
261+ [
262+ (
263+ isinstance (t , TemporalTag ) or
264+ isinstance (t , PipelineTemporalTag ) or
265+ isinstance (t , SequentialTag )
266+ )
267+ for t in tags
268+ ],
269+ actions
270+ ).max ().to_python ()
0 commit comments