Skip to content

Commit 8197d0a

Browse files
committed
Simplify SVD
1 parent 018803e commit 8197d0a

File tree

1 file changed

+47
-36
lines changed

1 file changed

+47
-36
lines changed

src/svd.rs

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,69 @@ use ndarray::*;
44

55
use super::convert::*;
66
use super::error::*;
7-
use super::lapack_traits::LapackScalar;
87
use super::layout::*;
8+
use super::types::*;
99

10-
pub trait SVD<U, S, VT> {
11-
fn svd(self, calc_u: bool, calc_vt: bool) -> Result<(Option<U>, S, Option<VT>)>;
10+
pub trait SVD {
11+
type U;
12+
type VT;
13+
type Sigma;
14+
fn svd(&self, calc_u: bool, calc_vt: bool) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
1215
}
1316

14-
impl<A, S, Su, Svt, Ss> SVD<ArrayBase<Su, Ix2>, ArrayBase<Ss, Ix1>, ArrayBase<Svt, Ix2>> for ArrayBase<S, Ix2>
15-
where A: LapackScalar,
16-
S: DataMut<Elem = A>,
17-
Su: DataOwned<Elem = A>,
18-
Svt: DataOwned<Elem = A>,
19-
Ss: DataOwned<Elem = A::Real>
17+
pub trait SVDInto {
18+
type U;
19+
type VT;
20+
type Sigma;
21+
fn svd_into(self, calc_u: bool, calc_vt: bool) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
22+
}
23+
24+
pub trait SVDMut {
25+
type U;
26+
type VT;
27+
type Sigma;
28+
fn svd_mut(&mut self, calc_u: bool, calc_vt: bool) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
29+
}
30+
31+
impl<A, S> SVDInto for ArrayBase<S, Ix2>
32+
where
33+
A: Scalar,
34+
S: DataMut<Elem = A>,
2035
{
21-
fn svd(mut self,
22-
calc_u: bool,
23-
calc_vt: bool)
24-
-> Result<(Option<ArrayBase<Su, Ix2>>, ArrayBase<Ss, Ix1>, Option<ArrayBase<Svt, Ix2>>)> {
25-
(&mut self).svd(calc_u, calc_vt)
36+
type U = Array2<A>;
37+
type VT = Array2<A>;
38+
type Sigma = Array1<A::Real>;
39+
40+
fn svd_into(mut self, calc_u: bool, calc_vt: bool) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
41+
self.svd_mut(calc_u, calc_vt)
2642
}
2743
}
2844

29-
impl<'a, A, S, Su, Svt, Ss> SVD<ArrayBase<Su, Ix2>, ArrayBase<Ss, Ix1>, ArrayBase<Svt, Ix2>> for &'a ArrayBase<S, Ix2>
30-
where A: LapackScalar + Clone,
31-
S: Data<Elem = A>,
32-
Su: DataOwned<Elem = A>,
33-
Svt: DataOwned<Elem = A>,
34-
Ss: DataOwned<Elem = A::Real>
45+
impl<A, S> SVD for ArrayBase<S, Ix2>
46+
where
47+
A: Scalar,
48+
S: Data<Elem = A>,
3549
{
36-
fn svd(self,
37-
calc_u: bool,
38-
calc_vt: bool)
39-
-> Result<(Option<ArrayBase<Su, Ix2>>, ArrayBase<Ss, Ix1>, Option<ArrayBase<Svt, Ix2>>)> {
50+
type U = Array2<A>;
51+
type VT = Array2<A>;
52+
type Sigma = Array1<A::Real>;
53+
54+
fn svd(&self, calc_u: bool, calc_vt: bool) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
4055
let a = self.to_owned();
41-
a.svd(calc_u, calc_vt)
56+
a.svd_into(calc_u, calc_vt)
4257
}
4358
}
4459

45-
impl<'a, A, S, Su, Svt, Ss> SVD<ArrayBase<Su, Ix2>, ArrayBase<Ss, Ix1>, ArrayBase<Svt, Ix2>>
46-
for &'a mut ArrayBase<S, Ix2>
60+
impl<A, S> SVDMut for ArrayBase<S, Ix2>
4761
where
48-
A: LapackScalar,
62+
A: Scalar,
4963
S: DataMut<Elem = A>,
50-
Su: DataOwned<Elem = A>,
51-
Svt: DataOwned<Elem = A>,
52-
Ss: DataOwned<Elem = A::Real>,
5364
{
54-
fn svd(
55-
mut self,
56-
calc_u: bool,
57-
calc_vt: bool,
58-
) -> Result<(Option<ArrayBase<Su, Ix2>>, ArrayBase<Ss, Ix1>, Option<ArrayBase<Svt, Ix2>>)> {
65+
type U = Array2<A>;
66+
type VT = Array2<A>;
67+
type Sigma = Array1<A::Real>;
68+
69+
fn svd_mut(&mut self, calc_u: bool, calc_vt: bool) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
5970
let l = self.layout()?;
6071
let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?;
6172
let (n, m) = l.size();

0 commit comments

Comments
 (0)