@@ -191,6 +191,8 @@ def _infer_shape(
191191
192192 """
193193
194+ from pytensor .tensor .extra_ops import broadcast_shape_iter
195+
194196 size_len = get_vector_length (size )
195197
196198 if size_len > 0 :
@@ -216,57 +218,52 @@ def _infer_shape(
216218
217219 # Broadcast the parameters
218220 param_shapes = params_broadcast_shapes (
219- param_shapes or [shape_tuple (p ) for p in dist_params ], self .ndims_params
221+ param_shapes or [shape_tuple (p ) for p in dist_params ],
222+ self .ndims_params ,
220223 )
221224
222- def slice_ind_dims (p , ps , n ):
225+ def extract_batch_shape (p , ps , n ):
223226 shape = tuple (ps )
224227
225228 if n == 0 :
226- return ( p , shape )
229+ return shape
227230
228- ind_slice = (slice (None ),) * (p .ndim - n ) + (0 ,) * n
229- ind_shape = [
231+ batch_shape = [
230232 s if b is False else constant (1 , "int64" )
231- for s , b in zip (shape [:- n ], p .broadcastable [:- n ])
233+ for s , b in zip (shape [:- n ], p .type . broadcastable [:- n ])
232234 ]
233- return (
234- p [ind_slice ],
235- ind_shape ,
236- )
235+ return batch_shape
237236
238237 # These are versions of our actual parameters with the anticipated
239238 # dimensions (i.e. support dimensions) removed so that only the
240239 # independent variate dimensions are left.
241- params_ind_slice = tuple (
242- slice_ind_dims (p , ps , n )
240+ params_batch_shape = tuple (
241+ extract_batch_shape (p , ps , n )
243242 for p , ps , n in zip (dist_params , param_shapes , self .ndims_params )
244243 )
245244
246- if len (params_ind_slice ) == 1 :
247- _ , shape_ind = params_ind_slice [ 0 ]
248- elif len (params_ind_slice ) > 1 :
245+ if len (params_batch_shape ) == 1 :
246+ [ batch_shape ] = params_batch_shape
247+ elif len (params_batch_shape ) > 1 :
249248 # If there are multiple parameters, the dimensions of their
250249 # independent variates should broadcast together.
251- p_slices , p_shapes = zip (* params_ind_slice )
252-
253- shape_ind = pytensor .tensor .extra_ops .broadcast_shape_iter (
254- p_shapes , arrays_are_shapes = True
250+ batch_shape = broadcast_shape_iter (
251+ params_batch_shape ,
252+ arrays_are_shapes = True ,
255253 )
256-
257254 else :
258255 # Distribution has no parameters
259- shape_ind = ()
256+ batch_shape = ()
260257
261258 if self .ndim_supp == 0 :
262- shape_supp = ()
259+ supp_shape = ()
263260 else :
264- shape_supp = self ._supp_shape_from_params (
261+ supp_shape = self ._supp_shape_from_params (
265262 dist_params ,
266263 param_shapes = param_shapes ,
267264 )
268265
269- shape = tuple (shape_ind ) + tuple (shape_supp )
266+ shape = tuple (batch_shape ) + tuple (supp_shape )
270267 if not shape :
271268 shape = constant ([], dtype = "int64" )
272269
0 commit comments