Skip to content

Commit 636ae56

Browse files
committed
Add .sign_ln_det*() methods to Determinant*
1 parent 70a1855 commit 636ae56

File tree

2 files changed

+121
-32
lines changed

2 files changed

+121
-32
lines changed

src/solve.rs

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
//! ```
4848
4949
use ndarray::*;
50+
use num_traits::{Float, Zero};
5051

5152
use super::convert::*;
5253
use super::error::*;
@@ -336,16 +337,54 @@ where
336337
/// An interface for calculating determinants of matrix refs.
337338
pub trait Determinant<A: Scalar> {
338339
/// Computes the determinant of the matrix.
339-
fn det(&self) -> Result<A>;
340+
fn det(&self) -> Result<A> {
341+
let (sign, ln_det) = self.sln_det()?;
342+
Ok(sign.mul_real(ln_det.exp()))
343+
}
344+
345+
/// Computes the `(sign, natural_log)` of the determinant of the matrix.
346+
///
347+
/// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
348+
/// `sign` is `0` or a complex number with absolute value 1. The
349+
/// `natural_log` is the natural logarithm of the absolute value of the
350+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
351+
/// is negative infinity.
352+
///
353+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
354+
/// or just call `.det()` instead.
355+
///
356+
/// This method is more robust than `.det()` to very small or very large
357+
/// determinants since it returns the natural logarithm of the determinant
358+
/// rather than the determinant itself.
359+
fn sln_det(&self) -> Result<(A, A::Real)>;
340360
}
341361

342362
/// An interface for calculating determinants of matrices.
343-
pub trait DeterminantInto<A: Scalar> {
363+
pub trait DeterminantInto<A: Scalar>: Sized {
344364
/// Computes the determinant of the matrix.
345-
fn det_into(self) -> Result<A>;
365+
fn det_into(self) -> Result<A> {
366+
let (sign, ln_det) = self.sln_det_into()?;
367+
Ok(sign.mul_real(ln_det.exp()))
368+
}
369+
370+
/// Computes the `(sign, natural_log)` of the determinant of the matrix.
371+
///
372+
/// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
373+
/// `sign` is `0` or a complex number with absolute value 1. The
374+
/// `natural_log` is the natural logarithm of the absolute value of the
375+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
376+
/// is negative infinity.
377+
///
378+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
379+
/// or just call `.det_into()` instead.
380+
///
381+
/// This method is more robust than `.det()` to very small or very large
382+
/// determinants since it returns the natural logarithm of the determinant
383+
/// rather than the determinant itself.
384+
fn sln_det_into(self) -> Result<(A, A::Real)>;
346385
}
347386

348-
fn lu_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> A
387+
fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
349388
where
350389
A: Scalar,
351390
P: Iterator<Item = i32>,
@@ -360,24 +399,27 @@ where
360399
} else {
361400
-A::one()
362401
};
363-
let (upper_sign, ln_det) = u_diag_iter.fold((A::one(), A::zero()), |(upper_sign, ln_det), &elem| {
364-
let abs_elem = elem.abs();
365-
(
366-
upper_sign * elem.div_real(abs_elem),
367-
ln_det.add_real(abs_elem.ln()),
368-
)
369-
});
370-
pivot_sign * upper_sign * ln_det.exp()
402+
let (upper_sign, ln_det) = u_diag_iter.fold(
403+
(A::one(), A::Real::zero()),
404+
|(upper_sign, ln_det), &elem| {
405+
let abs_elem: A::Real = elem.abs();
406+
(upper_sign * elem.div_real(abs_elem), ln_det + abs_elem.ln())
407+
},
408+
);
409+
(pivot_sign * upper_sign, ln_det)
371410
}
372411

373412
impl<A, S> Determinant<A> for LUFactorized<S>
374413
where
375414
A: Scalar,
376415
S: Data<Elem = A>,
377416
{
378-
fn det(&self) -> Result<A> {
417+
fn sln_det(&self) -> Result<(A, A::Real)> {
379418
self.a.ensure_square()?;
380-
Ok(lu_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
419+
Ok(lu_sln_det(
420+
self.ipiv.iter().cloned(),
421+
self.a.diag().iter(),
422+
))
381423
}
382424
}
383425

@@ -386,9 +428,12 @@ where
386428
A: Scalar,
387429
S: Data<Elem = A>,
388430
{
389-
fn det_into(self) -> Result<A> {
431+
fn sln_det_into(self) -> Result<(A, A::Real)> {
390432
self.a.ensure_square()?;
391-
Ok(lu_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
433+
Ok(lu_sln_det(
434+
self.ipiv.into_iter(),
435+
self.a.into_diag().iter(),
436+
))
392437
}
393438
}
394439

@@ -397,11 +442,14 @@ where
397442
A: Scalar,
398443
S: Data<Elem = A>,
399444
{
400-
fn det(&self) -> Result<A> {
445+
fn sln_det(&self) -> Result<(A, A::Real)> {
401446
self.ensure_square()?;
402447
match self.factorize() {
403-
Ok(fac) => fac.det(),
404-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
448+
Ok(fac) => fac.sln_det(),
449+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
450+
// The determinant is zero.
451+
Ok((A::zero(), A::Real::neg_infinity()))
452+
}
405453
Err(err) => Err(err),
406454
}
407455
}
@@ -412,11 +460,14 @@ where
412460
A: Scalar,
413461
S: DataMut<Elem = A>,
414462
{
415-
fn det_into(self) -> Result<A> {
463+
fn sln_det_into(self) -> Result<(A, A::Real)> {
416464
self.ensure_square()?;
417465
match self.factorize_into() {
418-
Ok(fac) => fac.det_into(),
419-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
466+
Ok(fac) => fac.sln_det_into(),
467+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
468+
// The determinant is zero.
469+
Ok((A::zero(), A::Real::neg_infinity()))
470+
}
420471
Err(err) => Err(err),
421472
}
422473
}

tests/det.rs

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ extern crate num_traits;
55

66
use ndarray::*;
77
use ndarray_linalg::*;
8-
use num_traits::{One, Zero};
8+
use num_traits::{Float, One, Zero};
99

1010
/// Returns the matrix with the specified `row` and `col` removed.
1111
fn matrix_minor<A, S>(a: &ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A>
@@ -52,10 +52,16 @@ fn det_empty() {
5252
macro_rules! det_empty {
5353
($elem:ty) => {
5454
let a: Array2<$elem> = Array2::zeros((0, 0));
55-
assert_eq!(a.factorize().unwrap().det().unwrap(), One::one());
56-
assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one());
57-
assert_eq!(a.det().unwrap(), One::one());
58-
assert_eq!(a.det_into().unwrap(), One::one());
55+
let det = One::one();
56+
let (sign, ln_det) = (One::one(), Zero::zero());
57+
assert_eq!(a.factorize().unwrap().det().unwrap(), det);
58+
assert_eq!(a.factorize().unwrap().sln_det().unwrap(), (sign, ln_det));
59+
assert_eq!(a.factorize().unwrap().det_into().unwrap(), det);
60+
assert_eq!(a.factorize().unwrap().sln_det_into().unwrap(), (sign, ln_det));
61+
assert_eq!(a.det().unwrap(), det);
62+
assert_eq!(a.sln_det().unwrap(), (sign, ln_det));
63+
assert_eq!(a.clone().det_into().unwrap(), det);
64+
assert_eq!(a.sln_det_into().unwrap(), (sign, ln_det));
5965
}
6066
}
6167
det_empty!(f64);
@@ -69,8 +75,12 @@ fn det_zero() {
6975
macro_rules! det_zero {
7076
($elem:ty) => {
7177
let a: Array2<$elem> = Array2::zeros((1, 1));
72-
assert_eq!(a.det().unwrap(), Zero::zero());
73-
assert_eq!(a.det_into().unwrap(), Zero::zero());
78+
let det = Zero::zero();
79+
let (sign, ln_det) = (Zero::zero(), Float::neg_infinity());
80+
assert_eq!(a.det().unwrap(), det);
81+
assert_eq!(a.sln_det().unwrap(), (sign, ln_det));
82+
assert_eq!(a.clone().det_into().unwrap(), det);
83+
assert_eq!(a.sln_det_into().unwrap(), (sign, ln_det));
7484
}
7585
}
7686
det_zero!(f64);
@@ -85,7 +95,9 @@ fn det_zero_nonsquare() {
8595
($elem:ty, $shape:expr) => {
8696
let a: Array2<$elem> = Array2::zeros($shape);
8797
assert!(a.det().is_err());
88-
assert!(a.det_into().is_err());
98+
assert!(a.sln_det().is_err());
99+
assert!(a.clone().det_into().is_err());
100+
assert!(a.sln_det_into().is_err());
89101
}
90102
}
91103
for &shape in &[(1, 2).into_shape(), (1, 2).f()] {
@@ -103,10 +115,32 @@ fn det() {
103115
let a: Array2<$elem> = random($shape);
104116
println!("a = \n{:?}", a);
105117
let det = det_naive(&a);
118+
let sign = det.div_real(det.abs());
119+
let ln_det = det.abs().ln();
106120
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
121+
{
122+
let result = a.factorize().unwrap().sln_det().unwrap();
123+
assert_rclose!(result.0, sign, $rtol);
124+
assert_rclose!(result.1, ln_det, $rtol);
125+
}
107126
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
127+
{
128+
let result = a.factorize().unwrap().sln_det_into().unwrap();
129+
assert_rclose!(result.0, sign, $rtol);
130+
assert_rclose!(result.1, ln_det, $rtol);
131+
}
108132
assert_rclose!(a.det().unwrap(), det, $rtol);
109-
assert_rclose!(a.det_into().unwrap(), det, $rtol);
133+
{
134+
let result = a.sln_det().unwrap();
135+
assert_rclose!(result.0, sign, $rtol);
136+
assert_rclose!(result.1, ln_det, $rtol);
137+
}
138+
assert_rclose!(a.clone().det_into().unwrap(), det, $rtol);
139+
{
140+
let result = a.sln_det_into().unwrap();
141+
assert_rclose!(result.0, sign, $rtol);
142+
assert_rclose!(result.1, ln_det, $rtol);
143+
}
110144
}
111145
}
112146
for rows in 1..5 {
@@ -125,9 +159,13 @@ fn det_nonsquare() {
125159
($elem:ty, $shape:expr) => {
126160
let a: Array2<$elem> = random($shape);
127161
assert!(a.factorize().unwrap().det().is_err());
162+
assert!(a.factorize().unwrap().sln_det().is_err());
128163
assert!(a.factorize().unwrap().det_into().is_err());
164+
assert!(a.factorize().unwrap().sln_det_into().is_err());
129165
assert!(a.det().is_err());
130-
assert!(a.det_into().is_err());
166+
assert!(a.sln_det().is_err());
167+
assert!(a.clone().det_into().is_err());
168+
assert!(a.sln_det_into().is_err());
131169
}
132170
}
133171
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {

0 commit comments

Comments
 (0)