Skip to content

Commit 6c40611

Browse files
committed
Introduce AppendResult and move tolerance into orthogonalizer
1 parent 8028feb commit 6c40611

File tree

3 files changed

+84
-34
lines changed

3 files changed

+84
-34
lines changed

src/krylov/householder.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,19 @@ pub struct Householder<A: Scalar> {
5050
///
5151
/// The coefficient is copied into another array, and this does not contain
5252
v: Vec<Array1<A>>,
53+
54+
/// Tolerance
55+
tol: A::Real,
5356
}
5457

5558
impl<A: Scalar + Lapack> Householder<A> {
5659
/// Create a new orthogonalizer
57-
pub fn new(dim: usize) -> Self {
58-
Householder { dim, v: Vec::new() }
60+
pub fn new(dim: usize, tol: A::Real) -> Self {
61+
Householder {
62+
dim,
63+
v: Vec::new(),
64+
tol,
65+
}
5966
}
6067

6168
/// Take a Reflection `P = I - 2ww^T`
@@ -127,6 +134,10 @@ impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
127134
self.v.len()
128135
}
129136

137+
fn tolerance(&self) -> A::Real {
138+
self.tol
139+
}
140+
130141
fn decompose<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
131142
where
132143
S: DataMut<Elem = A>,
@@ -146,7 +157,7 @@ impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
146157
self.compose_coefficients(&a)
147158
}
148159

149-
fn append<S>(&mut self, mut a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
160+
fn append<S>(&mut self, mut a: ArrayBase<S, Ix1>) -> AppendResult<A>
150161
where
151162
S: DataMut<Elem = A>,
152163
{
@@ -159,18 +170,18 @@ impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
159170
coef[i] = a[i];
160171
}
161172
if self.is_full() {
162-
return Err(coef); // coef[k] must be zero in this case
173+
return AppendResult::Dependent(coef); // coef[k] must be zero in this case
163174
}
164175

165176
let alpha = calc_reflector(&mut a.slice_mut(s![k..]));
166177
coef[k] = alpha;
167178

168-
if alpha.abs() < rtol {
179+
if alpha.abs() < self.tol {
169180
// linearly dependent
170-
return Err(coef);
181+
return AppendResult::Dependent(coef);
171182
}
172183
self.v.push(a.into_owned());
173-
Ok(coef)
184+
AppendResult::Added(coef)
174185
}
175186

176187
fn get_q(&self) -> Q<A> {
@@ -195,8 +206,8 @@ where
195206
A: Scalar + Lapack,
196207
S: Data<Elem = A>,
197208
{
198-
let h = Householder::new(dim);
199-
qr(iter, h, rtol, strategy)
209+
let h = Householder::new(dim, rtol);
210+
qr(iter, h, strategy)
200211
}
201212

202213
#[cfg(test)]

src/krylov/mgs.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,24 @@ use crate::{generate::*, inner::*, norm::Norm};
55

66
/// Iterative orthogonalizer using modified Gram-Schmit procedure
77
#[derive(Debug, Clone)]
8-
pub struct MGS<A> {
8+
pub struct MGS<A: Scalar> {
99
/// Dimension of base space
10-
dimension: usize,
10+
dim: usize,
11+
1112
/// Basis of spanned space
1213
q: Vec<Array1<A>>,
14+
15+
/// Tolerance
16+
tol: A::Real,
1317
}
1418

1519
impl<A: Scalar + Lapack> MGS<A> {
1620
/// Create an empty orthogonalizer
17-
pub fn new(dimension: usize) -> Self {
21+
pub fn new(dim: usize, tol: A::Real) -> Self {
1822
Self {
19-
dimension,
23+
dim,
2024
q: Vec::new(),
25+
tol,
2126
}
2227
}
2328
}
@@ -26,13 +31,17 @@ impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
2631
type Elem = A;
2732

2833
fn dim(&self) -> usize {
29-
self.dimension
34+
self.dim
3035
}
3136

3237
fn len(&self) -> usize {
3338
self.q.len()
3439
}
3540

41+
fn tolerance(&self) -> A::Real {
42+
self.tol
43+
}
44+
3645
fn decompose<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
3746
where
3847
S: DataMut<Elem = A>,
@@ -59,21 +68,21 @@ impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
5968
self.decompose(&mut a)
6069
}
6170

62-
fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
71+
fn append<S>(&mut self, a: ArrayBase<S, Ix1>) -> AppendResult<A>
6372
where
6473
A: Lapack,
6574
S: Data<Elem = A>,
6675
{
6776
let mut a = a.into_owned();
6877
let coef = self.decompose(&mut a);
6978
let nrm = coef[coef.len() - 1].re();
70-
if nrm < rtol {
79+
if nrm < self.tol {
7180
// Linearly dependent
72-
return Err(coef);
81+
return AppendResult::Dependent(coef);
7382
}
7483
azip!(mut a in { *a = *a / A::from_real(nrm) });
7584
self.q.push(a);
76-
Ok(coef)
85+
AppendResult::Added(coef)
7786
}
7887

7988
fn get_q(&self) -> Q<A> {
@@ -92,6 +101,6 @@ where
92101
A: Scalar + Lapack,
93102
S: Data<Elem = A>,
94103
{
95-
let mgs = MGS::new(dim);
96-
qr(iter, mgs, rtol, strategy)
104+
let mgs = MGS::new(dim, rtol);
105+
qr(iter, mgs, strategy)
97106
}

src/krylov/mod.rs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ pub type Coefficients<A> = Array1<A>;
4343
/// ```rust
4444
/// # use ndarray::*;
4545
/// # use ndarray_linalg::{krylov::*, *};
46-
/// let mut mgs = MGS::new(3);
47-
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
46+
/// let mut mgs = MGS::new(3, 1e-9);
47+
/// let coef = mgs.append(array![0.0, 1.0, 0.0]).into_coeff();
4848
/// close_l2(&coef, &array![1.0], 1e-9);
4949
///
50-
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
50+
/// let coef = mgs.append(array![1.0, 1.0, 0.0]).into_coeff();
5151
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
5252
///
5353
/// // Fail if the vector is linearly dependent
54-
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
54+
/// assert!(mgs.append(array![1.0, 2.0, 0.0]).is_dependent());
5555
///
5656
/// // You can get coefficients of dependent vector
57-
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
57+
/// if let AppendResult::Dependent(coef) = mgs.append(array![1.0, 2.0, 0.0]) {
5858
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
5959
/// }
6060
/// ```
@@ -76,6 +76,8 @@ pub trait Orthogonalizer {
7676
self.len() == 0
7777
}
7878

79+
fn tolerance(&self) -> <Self::Elem as Scalar>::Real;
80+
7981
/// Decompose given vector into the span of current basis and
8082
/// its tangent space
8183
///
@@ -96,18 +98,47 @@ pub trait Orthogonalizer {
9698
S: Data<Elem = Self::Elem>;
9799

98100
/// Add new vector if the residual is larger than relative tolerance
99-
fn append<S>(
100-
&mut self,
101-
a: ArrayBase<S, Ix1>,
102-
rtol: <Self::Elem as Scalar>::Real,
103-
) -> Result<Coefficients<Self::Elem>, Coefficients<Self::Elem>>
101+
fn append<S>(&mut self, a: ArrayBase<S, Ix1>) -> AppendResult<Self::Elem>
104102
where
105103
S: DataMut<Elem = Self::Elem>;
106104

107105
/// Get Q-matrix of generated basis
108106
fn get_q(&self) -> Q<Self::Elem>;
109107
}
110108

109+
pub enum AppendResult<A> {
110+
Added(Coefficients<A>),
111+
Dependent(Coefficients<A>),
112+
}
113+
114+
impl<A: Scalar> AppendResult<A> {
115+
pub fn into_coeff(self) -> Coefficients<A> {
116+
match self {
117+
AppendResult::Added(c) => c,
118+
AppendResult::Dependent(c) => c,
119+
}
120+
}
121+
122+
pub fn is_dependent(&self) -> bool {
123+
match self {
124+
AppendResult::Added(_) => false,
125+
AppendResult::Dependent(_) => true,
126+
}
127+
}
128+
129+
pub fn coeff(&self) -> &Coefficients<A> {
130+
match self {
131+
AppendResult::Added(c) => c,
132+
AppendResult::Dependent(c) => c,
133+
}
134+
}
135+
136+
pub fn residual_norm(&self) -> A::Real {
137+
let c = self.coeff();
138+
c[c.len() - 1].abs()
139+
}
140+
}
141+
111142
/// Strategy for linearly dependent vectors appearing in iterative QR decomposition
112143
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
113144
pub enum Strategy {
@@ -133,7 +164,6 @@ pub enum Strategy {
133164
pub fn qr<A, S>(
134165
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
135166
mut ortho: impl Orthogonalizer<Elem = A>,
136-
rtol: A::Real,
137167
strategy: Strategy,
138168
) -> (Q<A>, R<A>)
139169
where
@@ -144,9 +174,9 @@ where
144174

145175
let mut coefs = Vec::new();
146176
for a in iter {
147-
match ortho.append(a.into_owned(), rtol) {
148-
Ok(coef) => coefs.push(coef),
149-
Err(coef) => match strategy {
177+
match ortho.append(a.into_owned()) {
178+
AppendResult::Added(coef) => coefs.push(coef),
179+
AppendResult::Dependent(coef) => match strategy {
150180
Strategy::Terminate => break,
151181
Strategy::Skip => continue,
152182
Strategy::Full => coefs.push(coef),

0 commit comments

Comments
 (0)