1+ from functools import partial
2+
13import numpy as np
24import numpy .linalg
35import pytest
3436 lscalar ,
3537 matrix ,
3638 scalar ,
39+ tensor ,
3740 tensor3 ,
3841 tensor4 ,
3942 vector ,
@@ -150,29 +153,52 @@ def test_qr_modes():
150153
151154class TestSvd (utt .InferShapeTester ):
152155 op_class = SVD
153- dtype = "float32"
154156
155157 def setup_method (self ):
156158 super ().setup_method ()
157159 self .rng = np .random .default_rng (utt .fetch_seed ())
158- self .A = matrix (dtype = self . dtype )
160+ self .A = matrix (dtype = config . floatX )
159161 self .op = svd
160162
161- def test_svd (self ):
162- A = matrix ("A" , dtype = self .dtype )
163- U , S , VT = svd (A )
164- fn = function ([A ], [U , S , VT ])
165- a = self .rng .random ((4 , 4 )).astype (self .dtype )
166- n_u , n_s , n_vt = np .linalg .svd (a )
167- t_u , t_s , t_vt = fn (a )
163+ @pytest .mark .parametrize (
164+ "core_shape" , [(3 , 3 ), (4 , 3 ), (3 , 4 )], ids = ["square" , "tall" , "wide" ]
165+ )
166+ @pytest .mark .parametrize (
167+ "full_matrix" , [True , False ], ids = ["full=True" , "full=False" ]
168+ )
169+ @pytest .mark .parametrize (
170+ "compute_uv" , [True , False ], ids = ["compute_uv=True" , "compute_uv=False" ]
171+ )
172+ @pytest .mark .parametrize (
173+ "batched" , [True , False ], ids = ["batched=True" , "batched=False" ]
174+ )
175+ @pytest .mark .parametrize (
176+ "test_imag" , [True , False ], ids = ["test_imag=True" , "test_imag=False" ]
177+ )
178+ def test_svd (self , core_shape , full_matrix , compute_uv , batched , test_imag ):
179+ dtype = config .floatX
180+ if test_imag :
181+ dtype = "complex128" if dtype .endswith ("64" ) else "complex64"
182+ shape = core_shape if not batched else (10 , * core_shape )
183+ A = tensor ("A" , shape = shape , dtype = dtype )
184+ a = self .rng .random (shape ).astype (dtype )
185+
186+ outputs = svd (A , compute_uv = compute_uv , full_matrices = full_matrix )
187+ outputs = outputs if isinstance (outputs , list ) else [outputs ]
188+ fn = function (inputs = [A ], outputs = outputs )
189+
190+ np_fn = np .vectorize (
191+ partial (np .linalg .svd , compute_uv = compute_uv , full_matrices = full_matrix ),
192+ signature = outputs [0 ].owner .op .core_op .gufunc_signature ,
193+ )
194+
195+ np_outputs = np_fn (a )
196+ pt_outputs = fn (a )
168197
169- assert _allclose (n_u , t_u )
170- assert _allclose (n_s , t_s )
171- assert _allclose (n_vt , t_vt )
198+ np_outputs = np_outputs if isinstance (np_outputs , tuple ) else [np_outputs ]
172199
173- fn = function ([A ], svd (A , compute_uv = False ))
174- t_s = fn (a )
175- assert _allclose (n_s , t_s )
200+ for np_val , pt_val in zip (np_outputs , pt_outputs ):
201+ assert _allclose (np_val , pt_val )
176202
177203 def test_svd_infer_shape (self ):
178204 self .validate_shape ((4 , 4 ), full_matrices = True , compute_uv = True )
@@ -183,7 +209,7 @@ def test_svd_infer_shape(self):
183209
184210 def validate_shape (self , shape , compute_uv = True , full_matrices = True ):
185211 A = self .A
186- A_v = self .rng .random (shape ).astype (self . dtype )
212+ A_v = self .rng .random (shape ).astype (config . floatX )
187213 outputs = self .op (A , full_matrices = full_matrices , compute_uv = compute_uv )
188214 if not compute_uv :
189215 outputs = [outputs ]
@@ -451,8 +477,8 @@ def test_non_tensorial_input(self):
451477 norm (3 , None )
452478
453479 def test_tensor_input (self ):
454- with pytest . raises ( NotImplementedError ):
455- norm ( np . random . random ( (3 , 4 , 5 )), None )
480+ res = norm ( np . random . random (( 3 , 4 , 5 )), None )
481+ assert res . shape . eval () == (3 ,)
456482
457483 def test_numpy_compare (self ):
458484 rng = np .random .default_rng (utt .fetch_seed ())
0 commit comments