@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
6666 if index != "x" :
6767 # Initialize the variables associated to the jth loop
6868 # jump = stride - adjust
69- # If the variable has size 1 in that dim, we set the stride to zero to
70- # emulate broadcasting
7169 jump = f"({ var } _stride{ index } ) - ({ adjust } )"
7270 init += f"""
7371 { var } _n{ index } = PyArray_DIMS({ var } )[{ index } ];
74- { var } _stride{ index } = ( { var } _n { index } == 1)? 0 : PyArray_STRIDES({ var } )[{ index } ] / sizeof({ dtype } );
72+ { var } _stride{ index } = PyArray_STRIDES({ var } )[{ index } ] / sizeof({ dtype } );
7573 { var } _jump{ index } _{ j } = { jump } ;
7674 """
7775 adjust = f"{ var } _n{ index } *{ var } _stride{ index } "
@@ -86,88 +84,73 @@ def make_checks(loop_orders, dtypes, sub):
8684 # This loop builds multiple if conditions to verify that the
8785 # dimensions of the inputs match, and the first one that is true
8886 # raises an informative error message
87+
88+ runtime_broadcast_error_msg = (
89+ "Runtime broadcasting not allowed. "
90+ "One input had a distinct dimension length of 1, but was not marked as broadcastable: "
91+ "(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
92+ "If broadcasting was intended, use `specify_broadcastable` on the relevant input."
93+ )
94+
8995 for matches in zip (* loop_orders ):
9096 to_compare = [(j , x ) for j , x in enumerate (matches ) if x != "x" ]
9197
9298 # elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx )
9399 if len (to_compare ) < 2 :
94100 continue
95101
96- # Find first dimension size that is != 1
97- jl , xl = to_compare [- 1 ]
98- non1size_dim_check = f"""
99- npy_intp non1size_dim{ xl } ;
100- non1size_dim{ xl } = """
101- for j , x in to_compare [:- 1 ]:
102- non1size_dim_check += f"(%(lv{ j } )s_n{ x } != 1) ? %(lv{ j } )s_n{ x } : "
103- non1size_dim_check += f"%(lv{ jl } )s_n{ xl } ;"
104- check += non1size_dim_check
105-
106- # Check the nonsize1 dims match
107- # TODO: This is a bit inefficient because we are comparing one dimension against itself
108- check += f"""
109- if (non1size_dim{ xl } != 1)
110- {{
111- """
112- for j , x in to_compare :
102+ j0 , x0 = to_compare [0 ]
103+ for j , x in to_compare [1 :]:
113104 check += f"""
114- if ((%(lv{ j } )s_n{ x } != non1size_dim{ x } ) && (%(lv{ j } )s_n{ x } != 1))
105+ if (%(lv{ j0 } )s_n{ x0 } != %(lv{ j } )s_n{ x } )
106+ {{
107+ if (%(lv{ j0 } )s_n{ x0 } == 1 || %(lv{ j } )s_n{ x } == 1)
115108 {{
116- PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.",
117- { x } ,
118- (long long int) non1size_dim{ x } ,
109+ PyErr_Format(PyExc_ValueError, "{ runtime_broadcast_error_msg } ",
110+ { j0 } ,
111+ { x0 } ,
112+ (long long int) %(lv{ j0 } )s_n{ x0 } ,
113+ { j } ,
114+ { x } ,
115+ (long long int) %(lv{ j } )s_n{ x }
116+ );
117+ }} else {{
118+ PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
119+ { j0 } ,
120+ { x0 } ,
121+ (long long int) %(lv{ j0 } )s_n{ x0 } ,
119122 { j } ,
120123 { x } ,
121124 (long long int) %(lv{ j } )s_n{ x }
122125 );
123- %(fail)s
124126 }}
125- """
126- check += """
127- }
127+ %(fail)s
128+ }}
128129 """
129130
130131 return init % sub + check % sub
131132
132133
133- def compute_broadcast_dimensions (array_name : str , loop_orders , sub ) -> str :
134- """Create c_code to compute broadcasted dimensions of multiple arrays, arising from
135- Elemwise operations.
134+ def compute_output_dims_lengths (array_name : str , loop_orders , sub ) -> str :
135+ """Create c_code to compute the output dimensions of an Elemwise operation.
136136
137137 The code returned by this function populates the array `array_name`, but does not
138138 initialize it.
139139
140- TODO: We can decide to either specialize C code even further given the input types
141- or make it general, regardless of whether static broadcastable information is given
140+ Note: We could specialize C code even further with the known static output shapes
142141 """
143142 dims_c_code = ""
144143 for i , candidates in enumerate (zip (* loop_orders )):
145- # TODO: Are candidates always either "x" or "i"? If that's the case we can
146- # simplify some logic here (e.g., we don't need to track the `idx`).
147- nonx_candidates = tuple (
148- ( idx , c ) for idx , c in enumerate ( candidates ) if c != "x"
149- )
150-
151- # All inputs are known to be broadcastable
152- if not nonx_candidates :
144+ # Borrow the length of the first non-broadcastable input dimension
145+ for j , candidate in enumerate ( candidates ):
146+ if candidate != "x" :
147+ var = sub [ f"lv { int ( j ) } " ]
148+ dims_c_code += f" { array_name } [ { i } ] = { var } _n { candidate } ; \n "
149+ break
150+ # If none is non-broadcastable, the output dimension has a length of 1
151+ else : # no-break
153152 dims_c_code += f"{ array_name } [{ i } ] = 1;\n "
154- continue
155-
156- # There is only one informative source of size
157- if len (nonx_candidates ) == 1 :
158- idx , candidate = nonx_candidates [0 ]
159- var = sub [f"lv{ int (idx )} " ]
160- dims_c_code += f"{ array_name } [{ i } ] = { var } _n{ candidate } ;\n "
161- continue
162153
163- # In this case any non-size 1 variable will define the right size
164- dims_c_code += f"{ array_name } [{ i } ] = "
165- for idx , candidate in nonx_candidates [:- 1 ]:
166- var = sub [f"lv{ int (idx )} " ]
167- dims_c_code += f"({ var } _n{ candidate } != 1)? { var } _n{ candidate } : "
168- idx , candidate = nonx_candidates [- 1 ]
169- var = sub [f"lv{ idx } " ]
170- dims_c_code += f"{ var } _n{ candidate } ;\n "
171154 return dims_c_code
172155
173156
@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
186169 if type .startswith ("PYTENSOR_COMPLEX" ):
187170 type = type .replace ("PYTENSOR_COMPLEX" , "NPY_COMPLEX" )
188171 nd = len (loop_orders [0 ])
189- init_dims = compute_broadcast_dimensions ("dims" , loop_orders , sub )
172+ init_dims = compute_output_dims_lengths ("dims" , loop_orders , sub )
190173
191174 # TODO: it would be interesting to allocate the output in such a
192175 # way that its contiguous dimensions match one of the input's
@@ -359,7 +342,7 @@ def make_reordered_loop(
359342
360343 # Get the (sorted) total number of iterations of each loop
361344 declare_totals = f"int init_totals[{ nnested } ];\n "
362- declare_totals += compute_broadcast_dimensions ("init_totals" , init_loop_orders , sub )
345+ declare_totals += compute_output_dims_lengths ("init_totals" , init_loop_orders , sub )
363346
364347 # Sort totals to match the new order that was computed by sorting
365348 # the loop vector. One integer variable per loop is declared.
0 commit comments