Skip to content

Commit 2fd5f12

Browse files
authored
Merge pull request #60 from termoshtt/op_multi
OperatorMulti
2 parents 21298db + 3c53ad3 commit 2fd5f12

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

src/operator.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,67 @@ where
3939
self.dot(rhs)
4040
}
4141
}
42+
43+
pub trait OperatorMulti<A, S, D>
44+
where
45+
S: Data<Elem = A>,
46+
D: Dimension,
47+
{
48+
fn op_multi(&self, &ArrayBase<S, D>) -> Array<A, D>;
49+
}
50+
51+
impl<T, A, S, D> OperatorMulti<A, S, D> for T
52+
where
53+
A: Scalar,
54+
S: DataMut<Elem = A>,
55+
D: Dimension + RemoveAxis,
56+
for<'a> T: OperatorMut<ViewRepr<&'a mut A>, D::Smaller>,
57+
{
58+
fn op_multi(&self, a: &ArrayBase<S, D>) -> Array<A, D> {
59+
let a = a.to_owned();
60+
self.op_multi_into(a)
61+
}
62+
}
63+
64+
pub trait OperatorMultiInto<S, D>
65+
where
66+
S: DataMut,
67+
D: Dimension,
68+
{
69+
fn op_multi_into(&self, ArrayBase<S, D>) -> ArrayBase<S, D>;
70+
}
71+
72+
impl<T, A, S, D> OperatorMultiInto<S, D> for T
73+
where
74+
S: DataMut<Elem = A>,
75+
D: Dimension + RemoveAxis,
76+
for<'a> T: OperatorMut<ViewRepr<&'a mut A>, D::Smaller>,
77+
{
78+
fn op_multi_into(&self, mut a: ArrayBase<S, D>) -> ArrayBase<S, D> {
79+
self.op_multi_mut(&mut a);
80+
a
81+
}
82+
}
83+
84+
pub trait OperatorMultiMut<S, D>
85+
where
86+
S: DataMut,
87+
D: Dimension,
88+
{
89+
fn op_multi_mut<'a>(&self, &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
90+
}
91+
92+
impl<T, A, S, D> OperatorMultiMut<S, D> for T
93+
where
94+
S: DataMut<Elem = A>,
95+
D: Dimension + RemoveAxis,
96+
for<'a> T: OperatorMut<ViewRepr<&'a mut A>, D::Smaller>,
97+
{
98+
fn op_multi_mut<'a>(&self, mut a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D> {
99+
let n = a.ndim();
100+
for mut col in a.axis_iter_mut(Axis(n - 1)) {
101+
self.op_mut(&mut col);
102+
}
103+
a
104+
}
105+
}

tests/diag.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,12 @@ fn diag_2d() {
2222
println!("dm = {:?}", dm);
2323
assert_close_l2!(&dm, &arr2(&[[1.0, 1.0], [2.0, 2.0]]), 1e-7);
2424
}
25+
26+
#[test]
27+
fn diag_2d_multi() {
28+
let d = arr1(&[1.0, 2.0]);
29+
let m = arr2(&[[1.0, 1.0], [1.0, 1.0]]);
30+
let dm = d.into_diagonal().op_multi_into(m);
31+
println!("dm = {:?}", dm);
32+
assert_close_l2!(&dm, &arr2(&[[1.0, 1.0], [2.0, 2.0]]), 1e-7);
33+
}

0 commit comments

Comments
 (0)