@@ -269,34 +269,23 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate):
269269
270270
271271@pytest .mark .parametrize (
272- "size1, supp_size1, size2, supp_size2, axis, concatenate" ,
272+ "size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis " ,
273273 [
274- (None , 2 , None , 2 , 0 , True ),
275- (None , 2 , None , 2 , - 1 , True ),
276- ((5 ,), 2 , (3 ,), 2 , 0 , True ),
277- ((5 ,), 2 , (3 ,), 2 , - 2 , True ),
278- ((2 ,), 5 , (2 ,), 3 , 1 , True ),
279- pytest .param (
280- (2 ,),
281- 5 ,
282- (2 ,),
283- 5 ,
284- 0 ,
285- False ,
286- marks = pytest .mark .xfail (reason = "cannot measure dimshuffled multivariate RVs" ),
287- ),
288- pytest .param (
289- (2 ,),
290- 5 ,
291- (2 ,),
292- 5 ,
293- 1 ,
294- False ,
295- marks = pytest .mark .xfail (reason = "cannot measure dimshuffled multivariate RVs" ),
296- ),
274+ (None , 2 , None , 2 , 0 , True , 0 ),
275+ (None , 2 , None , 2 , - 1 , True , 0 ),
276+ ((5 ,), 2 , (3 ,), 2 , 0 , True , 0 ),
277+ ((5 ,), 2 , (3 ,), 2 , - 2 , True , 0 ),
278+ ((2 ,), 5 , (2 ,), 3 , 1 , True , 0 ),
279+ ((5 , 6 ), 10 , (5 , 1 ), 10 , 1 , True , 1 ),
280+ ((5 , 6 ), 10 , (5 , 1 ), 10 , - 2 , True , 1 ),
281+ ((2 ,), 5 , (2 ,), 5 , 0 , False , 0 ),
282+ ((2 ,), 5 , (2 ,), 5 , 1 , False , 1 ),
283+ ((5 , 6 ), 10 , (5 , 6 ), 10 , 2 , False , 2 ),
297284 ],
298285)
299- def test_measurable_join_multivariate (size1 , supp_size1 , size2 , supp_size2 , axis , concatenate ):
286+ def test_measurable_join_multivariate (
287+ size1 , supp_size1 , size2 , supp_size2 , axis , concatenate , logp_axis
288+ ):
300289 base1_rv = pt .random .multivariate_normal (
301290 np .zeros (supp_size1 ), np .eye (supp_size1 ), size = size1 , name = "base1"
302291 )
@@ -310,19 +299,18 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
310299 base1_vv = base1_rv .clone ()
311300 base2_vv = base2_rv .clone ()
312301 y_vv = y_rv .clone ()
302+
303+ y_logp = logp (y_rv , y_vv )
304+ assert_no_rvs (y_logp )
305+
313306 base_logps = [
314307 pt .atleast_1d (logp )
315308 for logp in conditional_logp ({base1_rv : base1_vv , base2_rv : base2_vv }).values ()
316309 ]
317-
318310 if concatenate :
319- axis_norm = np .core .numeric .normalize_axis_index (axis , base1_rv .ndim )
320- base_logps = pt .concatenate (base_logps , axis = axis_norm - 1 )
311+ expected_logp = pt .concatenate (base_logps , axis = logp_axis )
321312 else :
322- axis_norm = np .core .numeric .normalize_axis_index (axis , base1_rv .ndim + 1 )
323- base_logps = pt .stack (base_logps , axis = axis_norm - 1 )
324- y_logp = y_logp = logp (y_rv , y_vv )
325- assert_no_rvs (y_logp )
313+ expected_logp = pt .stack (base_logps , axis = logp_axis )
326314
327315 base1_testval = base1_rv .eval ()
328316 base2_testval = base2_rv .eval ()
@@ -331,7 +319,7 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
331319 else :
332320 y_testval = np .stack ((base1_testval , base2_testval ), axis = axis )
333321 np .testing .assert_allclose (
334- base_logps .eval ({base1_vv : base1_testval , base2_vv : base2_testval }),
322+ expected_logp .eval ({base1_vv : base1_testval , base2_vv : base2_testval }),
335323 y_logp .eval ({y_vv : y_testval }),
336324 )
337325
0 commit comments