@@ -5,22 +5,70 @@ extern crate num_traits;
55
66use ndarray:: * ;
77use ndarray_linalg:: * ;
8- use num_traits:: Zero ;
8+ use num_traits:: { One , Zero } ;
99
10- fn det_3x3 < A , S > ( a : ArrayBase < S , Ix2 > ) -> A
10+ /// Returns the matrix with the specified `row` and `col` removed.
11+ fn matrix_minor < A , S > ( a : ArrayBase < S , Ix2 > , ( row, col) : ( usize , usize ) ) -> Array2 < A >
1112where
1213 A : Scalar ,
1314 S : Data < Elem = A > ,
1415{
15- a[ ( 0 , 0 ) ] * a[ ( 1 , 1 ) ] * a[ ( 2 , 2 ) ] + a[ ( 0 , 1 ) ] * a[ ( 1 , 2 ) ] * a[ ( 2 , 0 ) ] + a[ ( 0 , 2 ) ] * a[ ( 1 , 0 ) ] * a[ ( 2 , 1 ) ] -
16- a[ ( 0 , 2 ) ] * a[ ( 1 , 1 ) ] * a[ ( 2 , 0 ) ] - a[ ( 0 , 1 ) ] * a[ ( 1 , 0 ) ] * a[ ( 2 , 2 ) ] - a[ ( 0 , 0 ) ] * a[ ( 1 , 2 ) ] * a[ ( 2 , 1 ) ]
16+ let mut select_rows = ( 0 ..a. rows ( ) ) . collect :: < Vec < _ > > ( ) ;
17+ select_rows. remove ( row) ;
18+ let mut select_cols = ( 0 ..a. cols ( ) ) . collect :: < Vec < _ > > ( ) ;
19+ select_cols. remove ( col) ;
20+ a. select ( Axis ( 0 ) , & select_rows) . select (
21+ Axis ( 1 ) ,
22+ & select_cols,
23+ )
24+ }
25+
26+ /// Computes the determinant of matrix `a`.
27+ ///
28+ /// Note: This implementation is written to be clearly correct so that it's
29+ /// useful for verification, but it's very inefficient.
30+ fn det_naive < A , S > ( a : ArrayBase < S , Ix2 > ) -> A
31+ where
32+ A : Scalar ,
33+ S : Data < Elem = A > ,
34+ {
35+ assert_eq ! ( a. rows( ) , a. cols( ) ) ;
36+ match a. cols ( ) {
37+ 0 => A :: one ( ) ,
38+ 1 => a[ ( 0 , 0 ) ] ,
39+ cols => {
40+ ( 0 ..cols)
41+ . map ( |col| {
42+ let sign = if col % 2 == 0 { A :: one ( ) } else { -A :: one ( ) } ;
43+ sign * a[ ( 0 , col) ] * det_naive ( matrix_minor ( a. view ( ) , ( 0 , col) ) )
44+ } )
45+ . fold ( A :: zero ( ) , |sum, subdet| sum + subdet)
46+ }
47+ }
48+ }
49+
50+ #[ test]
51+ fn det_empty ( ) {
52+ macro_rules! det_empty {
53+ ( $elem: ty) => {
54+ let a: Array2 <$elem> = Array2 :: zeros( ( 0 , 0 ) ) ;
55+ assert_eq!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , One :: one( ) ) ;
56+ assert_eq!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , One :: one( ) ) ;
57+ assert_eq!( a. det( ) . unwrap( ) , One :: one( ) ) ;
58+ assert_eq!( a. det_into( ) . unwrap( ) , One :: one( ) ) ;
59+ }
60+ }
61+ det_empty ! ( f64 ) ;
62+ det_empty ! ( f32 ) ;
63+ det_empty ! ( c64) ;
64+ det_empty ! ( c32) ;
1765}
1866
1967#[ test]
2068fn det_zero ( ) {
2169 macro_rules! det_zero {
2270 ( $elem: ty) => {
23- let a: Array2 <$elem> = array! [ [ Zero :: zero ( ) ] ] ;
71+ let a: Array2 <$elem> = Array2 :: zeros ( ( 1 , 1 ) ) ;
2472 assert_eq!( a. det( ) . unwrap( ) , Zero :: zero( ) ) ;
2573 assert_eq!( a. det_into( ) . unwrap( ) , Zero :: zero( ) ) ;
2674 }
@@ -54,18 +102,20 @@ fn det() {
54102 ( $elem: ty, $shape: expr, $rtol: expr) => {
55103 let a: Array2 <$elem> = random( $shape) ;
56104 println!( "a = \n {:?}" , a) ;
57- let det = det_3x3 ( a. view( ) ) ;
105+ let det = det_naive ( a. view( ) ) ;
58106 assert_rclose!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , det, $rtol) ;
59107 assert_rclose!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , det, $rtol) ;
60108 assert_rclose!( a. det( ) . unwrap( ) , det, $rtol) ;
61109 assert_rclose!( a. det_into( ) . unwrap( ) , det, $rtol) ;
62110 }
63111 }
64- for & shape in & [ ( 3 , 3 ) . into_shape ( ) , ( 3 , 3 ) . f ( ) ] {
65- det ! ( f64 , shape, 1e-9 ) ;
66- det ! ( f32 , shape, 1e-4 ) ;
67- det ! ( c64, shape, 1e-9 ) ;
68- det ! ( c32, shape, 1e-4 ) ;
112+ for rows in 1 ..5 {
113+ for & shape in & [ ( rows, rows) . into_shape ( ) , ( rows, rows) . f ( ) ] {
114+ det ! ( f64 , shape, 1e-9 ) ;
115+ det ! ( f32 , shape, 1e-4 ) ;
116+ det ! ( c64, shape, 1e-9 ) ;
117+ det ! ( c32, shape, 1e-4 ) ;
118+ }
69119 }
70120}
71121
@@ -80,10 +130,18 @@ fn det_nonsquare() {
80130 assert!( a. det_into( ) . is_err( ) ) ;
81131 }
82132 }
83- for & shape in & [ ( 1 , 2 ) . into_shape ( ) , ( 1 , 2 ) . f ( ) , ( 2 , 1 ) . into_shape ( ) , ( 2 , 1 ) . f ( ) ] {
84- det_nonsquare ! ( f64 , shape) ;
85- det_nonsquare ! ( f32 , shape) ;
86- det_nonsquare ! ( c64, shape) ;
87- det_nonsquare ! ( c32, shape) ;
133+ for & dims in & [ ( 1 , 0 ) , ( 1 , 2 ) , ( 2 , 1 ) , ( 2 , 3 ) ] {
134+ // Work around bug in ndarray: https://github.com/bluss/rust-ndarray/issues/361
135+ let shapes = if dims == ( 1 , 0 ) {
136+ vec ! [ dims. clone( ) . into_shape( ) ]
137+ } else {
138+ vec ! [ dims. clone( ) . into_shape( ) , dims. clone( ) . f( ) ]
139+ } ;
140+ for & shape in & shapes {
141+ det_nonsquare ! ( f64 , shape) ;
142+ det_nonsquare ! ( f32 , shape) ;
143+ det_nonsquare ! ( c64, shape) ;
144+ det_nonsquare ! ( c32, shape) ;
145+ }
88146 }
89147}
0 commit comments