@@ -36,7 +36,156 @@ def test_all_basic_rvs_are_wrapped():
3636
3737
3838def test_normal ():
39- pass
39+ c_size = tensor ("c_size" , shape = (), dtype = int )
40+ c_size_xr = as_xtensor (c_size , name = "c_size_xr" )
41+
42+ # Vector inputs
43+ mu_vec_in = tensor ("mu" , shape = (2 ,))
44+ sigma_vec_in = tensor ("sigma" , shape = (2 ,))
45+
46+ mu_vec_xr = as_xtensor (mu_vec_in , dims = ("a" ,), name = "mu_xr" )
47+ sigma_vec_xr = as_xtensor (sigma_vec_in , dims = ("a" ,), name = "sigma_xr" )
48+
49+ # Vector inputs: Basic case (no extra_dims)
50+ out_vec = pxr .normal (mu_vec_xr , sigma_vec_xr , rng = rng )
51+ assert out_vec .type .dims == ("a" ,)
52+ assert out_vec .type .shape == (2 ,)
53+ assert equal_computations (
54+ [lower_rewrite (out_vec .values )],
55+ [rewrite_graph (ptr .normal (mu_vec_in , sigma_vec_in , rng = rng ))],
56+ )
57+
58+ mu_val = np .array ([0.0 , 10.0 ])
59+ sigma_val = np .array ([1.0 , 2.0 ])
60+ eval_rng_seed_vec_basic = 12345
61+
62+ actual_val_vec_basic = out_vec .eval (
63+ {
64+ mu_vec_in : mu_val ,
65+ sigma_vec_in : sigma_val ,
66+ rng : np .random .default_rng (eval_rng_seed_vec_basic ),
67+ }
68+ )
69+ expected_val_vec_basic = np .random .default_rng (eval_rng_seed_vec_basic ).normal (
70+ mu_val , sigma_val
71+ )
72+ np .testing .assert_allclose (actual_val_vec_basic , expected_val_vec_basic )
73+
74+ # Vector inputs: With extra_dims
75+ out_vec_extra = pxr .normal (
76+ mu_vec_xr , sigma_vec_xr , extra_dims = dict (c = c_size_xr ), rng = rng
77+ )
78+ assert out_vec_extra .type .dims == ("c" , "a" )
79+ assert equal_computations (
80+ [lower_rewrite (out_vec_extra .values )],
81+ [
82+ rewrite_graph (
83+ ptr .normal (
84+ mu_vec_in , sigma_vec_in , size = (c_size , mu_vec_in .shape [0 ]), rng = rng
85+ )
86+ )
87+ ],
88+ )
89+
90+ c_size_val = 5
91+ eval_rng_seed_vec_extra = 67890
92+ actual_val_vec_extra = out_vec_extra .eval (
93+ {
94+ mu_vec_in : mu_val ,
95+ sigma_vec_in : sigma_val ,
96+ c_size : c_size_val ,
97+ rng : np .random .default_rng (eval_rng_seed_vec_extra ),
98+ }
99+ )
100+ expected_val_vec_extra = np .random .default_rng (eval_rng_seed_vec_extra ).normal (
101+ loc = mu_val , scale = sigma_val , size = (c_size_val , mu_val .shape [0 ])
102+ )
103+ np .testing .assert_allclose (actual_val_vec_extra , expected_val_vec_extra )
104+
105+ # Scalar inputs
106+ mu_scalar_in = tensor ("mu_s" , shape = ())
107+ sigma_scalar_in = tensor ("sigma_s" , shape = ())
108+
109+ mu_scalar_xr = as_xtensor (mu_scalar_in , name = "mu_s_xr" )
110+ sigma_scalar_xr = as_xtensor (sigma_scalar_in , name = "sigma_s_xr" )
111+
112+ # Scalar inputs: Basic case
113+ out_scalar = pxr .normal (mu_scalar_xr , sigma_scalar_xr , rng = rng )
114+ assert out_scalar .type .dims == ()
115+ assert out_scalar .type .shape == ()
116+ assert equal_computations (
117+ [lower_rewrite (out_scalar .values )],
118+ [rewrite_graph (ptr .normal (mu_scalar_in , sigma_scalar_in , rng = rng ))],
119+ )
120+
121+ mu_s_val = 0.0
122+ sigma_s_val = 1.0
123+ eval_rng_seed_scalar_basic = 23456
124+ actual_val_scalar_basic = out_scalar .eval (
125+ {
126+ mu_scalar_in : mu_s_val ,
127+ sigma_scalar_in : sigma_s_val ,
128+ rng : np .random .default_rng (eval_rng_seed_scalar_basic ),
129+ }
130+ )
131+ expected_val_scalar_basic = np .random .default_rng (
132+ eval_rng_seed_scalar_basic
133+ ).normal (mu_s_val , sigma_s_val )
134+ np .testing .assert_allclose (actual_val_scalar_basic , expected_val_scalar_basic )
135+
136+ # Scalar inputs: With extra_dims
137+ out_scalar_extra = pxr .normal (
138+ mu_scalar_xr , sigma_scalar_xr , extra_dims = dict (c = c_size_xr ), rng = rng
139+ )
140+ assert out_scalar_extra .type .dims == ("c" ,)
141+ assert equal_computations (
142+ [lower_rewrite (out_scalar_extra .values )],
143+ [
144+ rewrite_graph (
145+ ptr .normal (mu_scalar_in , sigma_scalar_in , size = (c_size ,), rng = rng )
146+ )
147+ ],
148+ )
149+
150+ eval_rng_seed_scalar_extra = 78901
151+ actual_val_scalar_extra = out_scalar_extra .eval (
152+ {
153+ mu_scalar_in : mu_s_val ,
154+ sigma_scalar_in : sigma_s_val ,
155+ c_size : c_size_val ,
156+ rng : np .random .default_rng (eval_rng_seed_scalar_extra ),
157+ }
158+ )
159+ expected_val_scalar_extra = np .random .default_rng (
160+ eval_rng_seed_scalar_extra
161+ ).normal (loc = mu_s_val , scale = sigma_s_val , size = (c_size_val ,))
162+ np .testing .assert_allclose (actual_val_scalar_extra , expected_val_scalar_extra )
163+
164+ # Error conditions
165+ # Invalid core_dims: normal is element-wise, expects core_dims=() for params.
166+ with pytest .raises (
167+ ValueError ,
168+ match = re .escape (
169+ "Parameter mu_xr has invalid core dimensions ['a']. "
170+ "Expected [] based on RV definition and core_dims argument."
171+ ),
172+ ):
173+ pxr .normal (mu_vec_xr , sigma_vec_xr , core_dims = ("a" ,), rng = rng )
174+
175+ # Invalid extra_dims (conflicting with existing batch dims)
176+ a_size_xr = mu_vec_xr .sizes ["a" ]
177+ with pytest .raises (
178+ ValueError ,
179+ match = re .escape (
180+ "Size dimensions ['a'] conflict with parameter dimensions. They should be unique."
181+ ),
182+ ):
183+ pxr .normal (
184+ mu_vec_xr ,
185+ sigma_vec_xr ,
186+ extra_dims = dict (c = c_size_xr , a = a_size_xr ), # 'a' conflicts
187+ rng = rng ,
188+ )
40189
41190
42191def test_categorical ():
0 commit comments