@@ -131,14 +131,15 @@ where
131131 {
132132 let observation_axis = Axis ( 1 ) ;
133133 let n_observations = A :: from_usize ( self . len_of ( observation_axis) ) . unwrap ( ) ;
134- let dof =
135- if ddof >= n_observations {
136- panic ! ( "`ddof` needs to be strictly smaller than the \
137- number of observations provided for each \
138- random variable!")
139- } else {
140- n_observations - ddof
141- } ;
134+ let dof = if ddof >= n_observations {
135+ panic ! (
136+ "`ddof` needs to be strictly smaller than the \
137+ number of observations provided for each \
138+ random variable!"
139+ )
140+ } else {
141+ n_observations - ddof
142+ } ;
142143 let mean = self . mean_axis ( observation_axis) ;
143144 let denoised = self - & mean. insert_axis ( observation_axis) ;
144145 let covariance = denoised. dot ( & denoised. t ( ) ) ;
@@ -156,7 +157,9 @@ where
156157 // observation per random variable (or no observations at all)
157158 let ddof = -A :: one ( ) ;
158159 let cov = self . cov ( ddof) ;
159- let std = self . std_axis ( observation_axis, ddof) . insert_axis ( observation_axis) ;
160+ let std = self
161+ . std_axis ( observation_axis, ddof)
162+ . insert_axis ( observation_axis) ;
160163 let std_matrix = std. dot ( & std. t ( ) ) ;
161164 // element-wise division
162165 cov / std_matrix
@@ -167,10 +170,10 @@ where
167170mod cov_tests {
168171 use super :: * ;
169172 use ndarray:: array;
173+ use ndarray_rand:: RandomExt ;
170174 use quickcheck:: quickcheck;
171175 use rand;
172176 use rand:: distributions:: Uniform ;
173- use ndarray_rand:: RandomExt ;
174177
175178 quickcheck ! {
176179 fn constant_random_variables_have_zero_covariance_matrix( value: f64 ) -> bool {
@@ -200,10 +203,7 @@ mod cov_tests {
200203 fn test_invalid_ddof ( ) {
201204 let n_random_variables = 3 ;
202205 let n_observations = 4 ;
203- let a = Array :: random (
204- ( n_random_variables, n_observations) ,
205- Uniform :: new ( 0. , 10. )
206- ) ;
206+ let a = Array :: random ( ( n_random_variables, n_observations) , Uniform :: new ( 0. , 10. ) ) ;
207207 let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
208208 a. cov ( invalid_ddof) ;
209209 }
@@ -235,55 +235,46 @@ mod cov_tests {
235235 #[ test]
236236 fn test_covariance_for_random_array ( ) {
237237 let a = array ! [
238- [ 0.72009497 , 0.12568055 , 0.55705966 , 0.5959984 , 0.69471457 ] ,
239- [ 0.56717131 , 0.47619486 , 0.21526298 , 0.88915366 , 0.91971245 ] ,
240- [ 0.59044195 , 0.10720363 , 0.76573717 , 0.54693675 , 0.95923036 ] ,
241- [ 0.24102952 , 0.131347 , 0.11118028 , 0.21451351 , 0.30515539 ] ,
242- [ 0.26952473 , 0.93079841 , 0.8080893 , 0.42814155 , 0.24642258 ]
238+ [ 0.72009497 , 0.12568055 , 0.55705966 , 0.5959984 , 0.69471457 ] ,
239+ [ 0.56717131 , 0.47619486 , 0.21526298 , 0.88915366 , 0.91971245 ] ,
240+ [ 0.59044195 , 0.10720363 , 0.76573717 , 0.54693675 , 0.95923036 ] ,
241+ [ 0.24102952 , 0.131347 , 0.11118028 , 0.21451351 , 0.30515539 ] ,
242+ [ 0.26952473 , 0.93079841 , 0.8080893 , 0.42814155 , 0.24642258 ]
243243 ] ;
244244 let numpy_covariance = array ! [
245- [ 0.05786248 , 0.02614063 , 0.06446215 , 0.01285105 , -0.06443992 ] ,
246- [ 0.02614063 , 0.08733569 , 0.02436933 , 0.01977437 , -0.06715555 ] ,
247- [ 0.06446215 , 0.02436933 , 0.10052129 , 0.01393589 , -0.06129912 ] ,
248- [ 0.01285105 , 0.01977437 , 0.01393589 , 0.00638795 , -0.02355557 ] ,
249- [ -0.06443992 , -0.06715555 , -0.06129912 , -0.02355557 , 0.09909855 ]
245+ [ 0.05786248 , 0.02614063 , 0.06446215 , 0.01285105 , -0.06443992 ] ,
246+ [ 0.02614063 , 0.08733569 , 0.02436933 , 0.01977437 , -0.06715555 ] ,
247+ [ 0.06446215 , 0.02436933 , 0.10052129 , 0.01393589 , -0.06129912 ] ,
248+ [ 0.01285105 , 0.01977437 , 0.01393589 , 0.00638795 , -0.02355557 ] ,
249+ [
250+ -0.06443992 ,
251+ -0.06715555 ,
252+ -0.06129912 ,
253+ -0.02355557 ,
254+ 0.09909855
255+ ]
250256 ] ;
251257 assert_eq ! ( a. ndim( ) , 2 ) ;
252- assert ! (
253- a. cov( 1. ) . all_close(
254- & numpy_covariance,
255- 1e-8
256- )
257- ) ;
258+ assert ! ( a. cov( 1. ) . all_close( & numpy_covariance, 1e-8 ) ) ;
258259 }
259260
260261 #[ test]
261262 #[ should_panic]
262263 // We lose precision, hence the failing assert
263264 fn test_covariance_for_badly_conditioned_array ( ) {
264- let a: Array2 < f64 > = array ! [
265- [ 1e12 + 1. , 1e12 - 1. ] ,
266- [ 1e-6 + 1e-12 , 1e-6 - 1e-12 ] ,
267- ] ;
268- let expected_covariance = array ! [
269- [ 2. , 2e-12 ] , [ 2e-12 , 2e-24 ]
270- ] ;
271- assert ! (
272- a. cov( 1. ) . all_close(
273- & expected_covariance,
274- 1e-24
275- )
276- ) ;
265+ let a: Array2 < f64 > = array ! [ [ 1e12 + 1. , 1e12 - 1. ] , [ 1e-6 + 1e-12 , 1e-6 - 1e-12 ] , ] ;
266+ let expected_covariance = array ! [ [ 2. , 2e-12 ] , [ 2e-12 , 2e-24 ] ] ;
267+ assert ! ( a. cov( 1. ) . all_close( & expected_covariance, 1e-24 ) ) ;
277268 }
278269}
279270
280271#[ cfg( test) ]
281272mod pearson_correlation_tests {
282273 use super :: * ;
283274 use ndarray:: array;
275+ use ndarray_rand:: RandomExt ;
284276 use quickcheck:: quickcheck;
285277 use rand:: distributions:: Uniform ;
286- use ndarray_rand:: RandomExt ;
287278
288279 quickcheck ! {
289280 fn output_matrix_is_symmetric( bound: f64 ) -> bool {
@@ -337,19 +328,14 @@ mod pearson_correlation_tests {
337328 [ 0.26979716 , 0.20887228 , 0.95454999 , 0.96290785 ]
338329 ] ;
339330 let numpy_corrcoeff = array ! [
340- [ 1. , 0.38089376 , 0.08122504 , -0.59931623 , 0.1365648 ] ,
341- [ 0.38089376 , 1. , 0.80918429 , -0.52615195 , 0.38954398 ] ,
342- [ 0.08122504 , 0.80918429 , 1. , 0.07134906 , -0.17324776 ] ,
343- [ -0.59931623 , -0.52615195 , 0.07134906 , 1. , -0.8743213 ] ,
344- [ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
331+ [ 1. , 0.38089376 , 0.08122504 , -0.59931623 , 0.1365648 ] ,
332+ [ 0.38089376 , 1. , 0.80918429 , -0.52615195 , 0.38954398 ] ,
333+ [ 0.08122504 , 0.80918429 , 1. , 0.07134906 , -0.17324776 ] ,
334+ [ -0.59931623 , -0.52615195 , 0.07134906 , 1. , -0.8743213 ] ,
335+ [ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
345336 ] ;
346337 assert_eq ! ( a. ndim( ) , 2 ) ;
347- assert ! (
348- a. pearson_correlation( ) . all_close(
349- & numpy_corrcoeff,
350- 1e-7
351- )
352- ) ;
338+ assert ! ( a. pearson_correlation( ) . all_close( & numpy_corrcoeff, 1e-7 ) ) ;
353339 }
354340
355341}
0 commit comments