Skip to content

Commit ea9d443

Browse files
committed
EigWorkImpl for c64
1 parent 8d2b5e7 commit ea9d443

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed

lax/src/eig.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,174 @@ pub struct EigWork<T: Scalar> {
5959
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
6060
}
6161

62+
#[derive(Debug, Clone, PartialEq)]
63+
pub struct Eig<T: Scalar> {
64+
pub eigs: Vec<T::Complex>,
65+
pub vr: Option<Vec<T::Complex>>,
66+
pub vl: Option<Vec<T::Complex>>,
67+
}
68+
69+
#[derive(Debug, Clone, PartialEq)]
70+
pub struct EigRef<'work, T: Scalar> {
71+
pub eigs: &'work [T::Complex],
72+
pub vr: Option<&'work [T::Complex]>,
73+
pub vl: Option<&'work [T::Complex]>,
74+
}
75+
76+
pub trait EigWorkImpl: Sized {
77+
type Elem: Scalar;
78+
/// Create new working memory for eigenvalues compution.
79+
fn new(calc_v: bool, l: MatrixLayout) -> Result<Self>;
80+
/// Compute eigenvalues and vectors on this working memory.
81+
fn calc<'work>(&'work mut self, a: &mut [Self::Elem]) -> Result<EigRef<'work, Self::Elem>>;
82+
/// Compute eigenvalues and vectors by consuming this working memory.
83+
fn eval(self, a: &mut [Self::Elem]) -> Result<Eig<Self::Elem>>;
84+
}
85+
86+
impl EigWorkImpl for EigWork<c64> {
87+
type Elem = c64;
88+
89+
fn new(calc_v: bool, l: MatrixLayout) -> Result<Self> {
90+
let (n, _) = l.size();
91+
let (jobvl, jobvr) = if calc_v {
92+
match l {
93+
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
94+
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
95+
}
96+
} else {
97+
(JobEv::None, JobEv::None)
98+
};
99+
let mut eigs: Vec<MaybeUninit<c64>> = vec_uninit(n as usize);
100+
let mut rwork: Vec<MaybeUninit<f64>> = vec_uninit(2 * n as usize);
101+
102+
let mut vc_l: Option<Vec<MaybeUninit<c64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
103+
let mut vc_r: Option<Vec<MaybeUninit<c64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
104+
105+
// calc work size
106+
let mut info = 0;
107+
let mut work_size = [c64::zero()];
108+
unsafe {
109+
lapack_sys::zgeev_(
110+
jobvl.as_ptr(),
111+
jobvr.as_ptr(),
112+
&n,
113+
std::ptr::null_mut(),
114+
&n,
115+
AsPtr::as_mut_ptr(&mut eigs),
116+
AsPtr::as_mut_ptr(vc_l.as_deref_mut().unwrap_or(&mut [])),
117+
&n,
118+
AsPtr::as_mut_ptr(vc_r.as_deref_mut().unwrap_or(&mut [])),
119+
&n,
120+
AsPtr::as_mut_ptr(&mut work_size),
121+
&(-1),
122+
AsPtr::as_mut_ptr(&mut rwork),
123+
&mut info,
124+
)
125+
};
126+
info.as_lapack_result()?;
127+
128+
let lwork = work_size[0].to_usize().unwrap();
129+
let work: Vec<MaybeUninit<c64>> = vec_uninit(lwork);
130+
Ok(Self {
131+
n,
132+
jobvl,
133+
jobvr,
134+
eigs,
135+
eigs_re: None,
136+
eigs_im: None,
137+
rwork: Some(rwork),
138+
vc_l,
139+
vc_r,
140+
vr_l: None,
141+
vr_r: None,
142+
work,
143+
})
144+
}
145+
146+
fn calc<'work>(&'work mut self, a: &mut [c64]) -> Result<EigRef<'work, c64>> {
147+
let lwork = self.work.len().to_i32().unwrap();
148+
let mut info = 0;
149+
unsafe {
150+
lapack_sys::zgeev_(
151+
self.jobvl.as_ptr(),
152+
self.jobvr.as_ptr(),
153+
&self.n,
154+
AsPtr::as_mut_ptr(a),
155+
&self.n,
156+
AsPtr::as_mut_ptr(&mut self.eigs),
157+
AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])),
158+
&self.n,
159+
AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])),
160+
&self.n,
161+
AsPtr::as_mut_ptr(&mut self.work),
162+
&lwork,
163+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
164+
&mut info,
165+
)
166+
};
167+
info.as_lapack_result()?;
168+
169+
let eigs = unsafe { self.eigs.slice_assume_init_ref() };
170+
171+
// Hermite conjugate
172+
if let Some(vl) = self.vc_l.as_mut() {
173+
for value in vl {
174+
let value = unsafe { value.assume_init_mut() };
175+
value.im = -value.im;
176+
}
177+
}
178+
Ok(EigRef {
179+
eigs,
180+
vl: self
181+
.vc_l
182+
.as_ref()
183+
.map(|v| unsafe { v.slice_assume_init_ref() }),
184+
vr: self
185+
.vc_r
186+
.as_ref()
187+
.map(|v| unsafe { v.slice_assume_init_ref() }),
188+
})
189+
}
190+
191+
fn eval(mut self, a: &mut [c64]) -> Result<Eig<c64>> {
192+
let lwork = self.work.len().to_i32().unwrap();
193+
let mut info = 0;
194+
unsafe {
195+
lapack_sys::zgeev_(
196+
self.jobvl.as_ptr(),
197+
self.jobvr.as_ptr(),
198+
&self.n,
199+
AsPtr::as_mut_ptr(a),
200+
&self.n,
201+
AsPtr::as_mut_ptr(&mut self.eigs),
202+
AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])),
203+
&self.n,
204+
AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])),
205+
&self.n,
206+
AsPtr::as_mut_ptr(&mut self.work),
207+
&lwork,
208+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
209+
&mut info,
210+
)
211+
};
212+
info.as_lapack_result()?;
213+
let eigs = unsafe { self.eigs.assume_init() };
214+
215+
// Hermite conjugate
216+
if let Some(vl) = self.vc_l.as_mut() {
217+
for value in vl {
218+
let value = unsafe { value.assume_init_mut() };
219+
value.im = -value.im;
220+
}
221+
}
222+
Ok(Eig {
223+
eigs,
224+
vl: self.vc_l.map(|v| unsafe { v.assume_init() }),
225+
vr: self.vc_r.map(|v| unsafe { v.assume_init() }),
226+
})
227+
}
228+
}
229+
62230
macro_rules! impl_eig_complex {
63231
($scalar:ty, $ev:path) => {
64232
impl Eig_ for $scalar {

0 commit comments

Comments
 (0)