22//! for tridiagonal matrix
33
44use super :: * ;
5- use crate :: { error:: * , layout:: MatrixLayout } ;
5+ use crate :: { error:: * , layout:: * } ;
66use cauchy:: * ;
77use num_traits:: Zero ;
88use std:: ops:: { Index , IndexMut } ;
@@ -130,11 +130,11 @@ impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
130130pub trait Tridiagonal_ : Scalar + Sized {
131131 /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
132132 /// partial pivoting with row interchanges.
133- unsafe fn lu_tridiagonal ( a : Tridiagonal < Self > ) -> Result < LUFactorizedTridiagonal < Self > > ;
133+ fn lu_tridiagonal ( a : Tridiagonal < Self > ) -> Result < LUFactorizedTridiagonal < Self > > ;
134134
135- unsafe fn rcond_tridiagonal ( lu : & LUFactorizedTridiagonal < Self > ) -> Result < Self :: Real > ;
135+ fn rcond_tridiagonal ( lu : & LUFactorizedTridiagonal < Self > ) -> Result < Self :: Real > ;
136136
137- unsafe fn solve_tridiagonal (
137+ fn solve_tridiagonal (
138138 lu : & LUFactorizedTridiagonal < Self > ,
139139 bl : MatrixLayout ,
140140 t : Transpose ,
@@ -143,18 +143,23 @@ pub trait Tridiagonal_: Scalar + Sized {
143143}
144144
145145macro_rules! impl_tridiagonal {
146- ( $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
146+ ( @real, $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
147+ impl_tridiagonal!( @body, $scalar, $gttrf, $gtcon, $gttrs, iwork) ;
148+ } ;
149+ ( @complex, $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
150+ impl_tridiagonal!( @body, $scalar, $gttrf, $gtcon, $gttrs, ) ;
151+ } ;
152+ ( @body, $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path, $( $iwork: ident) * ) => {
147153 impl Tridiagonal_ for $scalar {
148- unsafe fn lu_tridiagonal(
149- mut a: Tridiagonal <Self >,
150- ) -> Result <LUFactorizedTridiagonal <Self >> {
154+ fn lu_tridiagonal( mut a: Tridiagonal <Self >) -> Result <LUFactorizedTridiagonal <Self >> {
151155 let ( n, _) = a. l. size( ) ;
152156 let mut du2 = vec![ Zero :: zero( ) ; ( n - 2 ) as usize ] ;
153157 let mut ipiv = vec![ 0 ; n as usize ] ;
154158 // We have to calc one-norm before LU factorization
155159 let a_opnorm_one = a. opnorm_one( ) ;
156- $gttrf( n, & mut a. dl, & mut a. d, & mut a. du, & mut du2, & mut ipiv)
157- . as_lapack_result( ) ?;
160+ let mut info = 0 ;
161+ unsafe { $gttrf( n, & mut a. dl, & mut a. d, & mut a. du, & mut du2, & mut ipiv, & mut info, ) } ;
162+ info. as_lapack_result( ) ?;
158163 Ok ( LUFactorizedTridiagonal {
159164 a,
160165 du2,
@@ -163,56 +168,80 @@ macro_rules! impl_tridiagonal {
163168 } )
164169 }
165170
166- unsafe fn rcond_tridiagonal( lu: & LUFactorizedTridiagonal <Self >) -> Result <Self :: Real > {
171+ fn rcond_tridiagonal( lu: & LUFactorizedTridiagonal <Self >) -> Result <Self :: Real > {
167172 let ( n, _) = lu. a. l. size( ) ;
168173 let ipiv = & lu. ipiv;
174+ let mut work = vec![ Self :: zero( ) ; 2 * n as usize ] ;
175+ $(
176+ let mut $iwork = vec![ 0 ; n as usize ] ;
177+ ) *
169178 let mut rcond = Self :: Real :: zero( ) ;
170- $gtcon(
171- NormType :: One as u8 ,
172- n,
173- & lu. a. dl,
174- & lu. a. d,
175- & lu. a. du,
176- & lu. du2,
177- ipiv,
178- lu. a_opnorm_one,
179- & mut rcond,
180- )
181- . as_lapack_result( ) ?;
179+ let mut info = 0 ;
180+ unsafe {
181+ $gtcon(
182+ NormType :: One as u8 ,
183+ n,
184+ & lu. a. dl,
185+ & lu. a. d,
186+ & lu. a. du,
187+ & lu. du2,
188+ ipiv,
189+ lu. a_opnorm_one,
190+ & mut rcond,
191+ & mut work,
192+ $( & mut $iwork, ) *
193+ & mut info,
194+ ) ;
195+ }
196+ info. as_lapack_result( ) ?;
182197 Ok ( rcond)
183198 }
184199
185- unsafe fn solve_tridiagonal(
200+ fn solve_tridiagonal(
186201 lu: & LUFactorizedTridiagonal <Self >,
187- bl : MatrixLayout ,
202+ b_layout : MatrixLayout ,
188203 t: Transpose ,
189204 b: & mut [ Self ] ,
190205 ) -> Result <( ) > {
191206 let ( n, _) = lu. a. l. size( ) ;
192- let ( _, nrhs) = bl. size( ) ;
193207 let ipiv = & lu. ipiv;
194- let ldb = bl. lda( ) ;
195- $gttrs(
196- lu. a. l. lapacke_layout( ) ,
197- t as u8 ,
198- n,
199- nrhs,
200- & lu. a. dl,
201- & lu. a. d,
202- & lu. a. du,
203- & lu. du2,
204- ipiv,
205- b,
206- ldb,
207- )
208- . as_lapack_result( ) ?;
208+ // Transpose if b is C-continuous
209+ let mut b_t = None ;
210+ let b_layout = match b_layout {
211+ MatrixLayout :: C { .. } => {
212+ b_t = Some ( vec![ Self :: zero( ) ; b. len( ) ] ) ;
213+ transpose( b_layout, b, b_t. as_mut( ) . unwrap( ) )
214+ }
215+ MatrixLayout :: F { .. } => b_layout,
216+ } ;
217+ let ( ldb, nrhs) = b_layout. size( ) ;
218+ let mut info = 0 ;
219+ unsafe {
220+ $gttrs(
221+ t as u8 ,
222+ n,
223+ nrhs,
224+ & lu. a. dl,
225+ & lu. a. d,
226+ & lu. a. du,
227+ & lu. du2,
228+ ipiv,
229+ b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ,
230+ ldb,
231+ & mut info,
232+ ) ;
233+ }
234+ info. as_lapack_result( ) ?;
235+ if let Some ( b_t) = b_t {
236+ transpose( b_layout, & b_t, b) ;
237+ }
209238 Ok ( ( ) )
210239 }
211240 }
212241 } ;
213242} // impl_tridiagonal!
214243
215- impl_tridiagonal ! ( f64 , lapacke :: dgttrf, lapacke :: dgtcon, lapacke :: dgttrs) ;
216- impl_tridiagonal ! ( f32 , lapacke :: sgttrf, lapacke :: sgtcon, lapacke :: sgttrs) ;
217- impl_tridiagonal ! ( c64, lapacke :: zgttrf, lapacke :: zgtcon, lapacke :: zgttrs) ;
218- impl_tridiagonal ! ( c32, lapacke :: cgttrf, lapacke :: cgtcon, lapacke :: cgttrs) ;
244+ impl_tridiagonal ! ( @real , f64 , lapack :: dgttrf, lapack :: dgtcon, lapack :: dgttrs) ;
245+ impl_tridiagonal ! ( @real , f32 , lapack :: sgttrf, lapack :: sgtcon, lapack :: sgttrs) ;
246+ impl_tridiagonal ! ( @complex , c64, lapack :: zgttrf, lapack :: zgtcon, lapack :: zgttrs) ;
247+ impl_tridiagonal ! ( @complex , c32, lapack :: cgttrf, lapack :: cgtcon, lapack :: cgttrs) ;
0 commit comments