Skip to content

Commit 866c359

Browse files
committed
Add checks in Solve for compatible shapes
Before, mismatched shapes would lead to segfaults or other memory errors.
1 parent 178b302 commit 866c359

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

ndarray-linalg/src/solve.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,59 +77,103 @@ pub use lax::{Pivot, Transpose};
7777
pub trait Solve<A: Scalar> {
7878
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
7979
/// is the argument, and `x` is the successful result.
80+
///
81+
/// # Panics
82+
///
83+
/// Panics if the length of `b` is not the equal to the number of columns
84+
/// of `A`.
8085
fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
8186
let mut b = replicate(b);
8287
self.solve_inplace(&mut b)?;
8388
Ok(b)
8489
}
90+
8591
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
8692
/// is the argument, and `x` is the successful result.
93+
///
94+
/// # Panics
95+
///
96+
/// Panics if the length of `b` is not the equal to the number of columns
97+
/// of `A`.
8798
fn solve_into<S: DataMut<Elem = A>>(
8899
&self,
89100
mut b: ArrayBase<S, Ix1>,
90101
) -> Result<ArrayBase<S, Ix1>> {
91102
self.solve_inplace(&mut b)?;
92103
Ok(b)
93104
}
105+
94106
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
95107
/// is the argument, and `x` is the successful result.
108+
///
109+
/// # Panics
110+
///
111+
/// Panics if the length of `b` is not the equal to the number of columns
112+
/// of `A`.
96113
fn solve_inplace<'a, S: DataMut<Elem = A>>(
97114
&self,
98115
b: &'a mut ArrayBase<S, Ix1>,
99116
) -> Result<&'a mut ArrayBase<S, Ix1>>;
100117

101118
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
102119
/// is the argument, and `x` is the successful result.
120+
///
121+
/// # Panics
122+
///
123+
/// Panics if the length of `b` is not the equal to the number of rows of
124+
/// `A`.
103125
fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
104126
let mut b = replicate(b);
105127
self.solve_t_inplace(&mut b)?;
106128
Ok(b)
107129
}
130+
108131
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
109132
/// is the argument, and `x` is the successful result.
133+
///
134+
/// # Panics
135+
///
136+
/// Panics if the length of `b` is not the equal to the number of rows of
137+
/// `A`.
110138
fn solve_t_into<S: DataMut<Elem = A>>(
111139
&self,
112140
mut b: ArrayBase<S, Ix1>,
113141
) -> Result<ArrayBase<S, Ix1>> {
114142
self.solve_t_inplace(&mut b)?;
115143
Ok(b)
116144
}
145+
117146
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
118147
/// is the argument, and `x` is the successful result.
148+
///
149+
/// # Panics
150+
///
151+
/// Panics if the length of `b` is not the equal to the number of rows of
152+
/// `A`.
119153
fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
120154
&self,
121155
b: &'a mut ArrayBase<S, Ix1>,
122156
) -> Result<&'a mut ArrayBase<S, Ix1>>;
123157

124158
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
125159
/// is the argument, and `x` is the successful result.
160+
///
161+
/// # Panics
162+
///
163+
/// Panics if the length of `b` is not the equal to the number of rows of
164+
/// `A`.
126165
fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
127166
let mut b = replicate(b);
128167
self.solve_h_inplace(&mut b)?;
129168
Ok(b)
130169
}
131170
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
132171
/// is the argument, and `x` is the successful result.
172+
///
173+
/// # Panics
174+
///
175+
/// Panics if the length of `b` is not the equal to the number of rows of
176+
/// `A`.
133177
fn solve_h_into<S: DataMut<Elem = A>>(
134178
&self,
135179
mut b: ArrayBase<S, Ix1>,
@@ -139,6 +183,11 @@ pub trait Solve<A: Scalar> {
139183
}
140184
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
141185
/// is the argument, and `x` is the successful result.
186+
///
187+
/// # Panics
188+
///
189+
/// Panics if the length of `b` is not the equal to the number of rows of
190+
/// `A`.
142191
fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
143192
&self,
144193
b: &'a mut ArrayBase<S, Ix1>,
@@ -167,6 +216,11 @@ where
167216
where
168217
Sb: DataMut<Elem = A>,
169218
{
219+
assert_eq!(
220+
rhs.len(),
221+
self.a.len_of(Axis(1)),
222+
"The length of `rhs` must be compatible with the shape of the factored matrix.",
223+
);
170224
A::solve(
171225
self.a.square_layout()?,
172226
Transpose::No,
@@ -183,6 +237,11 @@ where
183237
where
184238
Sb: DataMut<Elem = A>,
185239
{
240+
assert_eq!(
241+
rhs.len(),
242+
self.a.len_of(Axis(0)),
243+
"The length of `rhs` must be compatible with the shape of the factored matrix.",
244+
);
186245
A::solve(
187246
self.a.square_layout()?,
188247
Transpose::Transpose,
@@ -199,6 +258,11 @@ where
199258
where
200259
Sb: DataMut<Elem = A>,
201260
{
261+
assert_eq!(
262+
rhs.len(),
263+
self.a.len_of(Axis(0)),
264+
"The length of `rhs` must be compatible with the shape of the factored matrix.",
265+
);
202266
A::solve(
203267
self.a.square_layout()?,
204268
Transpose::Hermite,

0 commit comments

Comments
 (0)