Skip to content

Commit 7ef1be1

Browse files
committed
Added overloading of +, *, - to LinearOperator
1 parent 800a4e9 commit 7ef1be1

File tree

1 file changed

+113
-10
lines changed

1 file changed

+113
-10
lines changed

pylops_distributed/LinearOperator.py

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import copy
12
import numpy as np
23
import dask.array as da
4+
5+
from dask.array.linalg import solve, lstsq
36
from pylops import LinearOperator as pLinearOperator
7+
from pylops_distributed.optimization.cg import cgls
48

59

610
class LinearOperator(pLinearOperator):
@@ -17,11 +21,12 @@ class LinearOperator(pLinearOperator):
1721
overwritten here to simply call their private methods
1822
``_matvec`` and ``_rmatvec`` without any prior check on the input vectors.
1923
20-
.. note:: End users of PyLops should not use this class directly but simply
21-
use operators that are already implemented. This class is meant for
22-
developers and it has to be used as the parent class of any new operator
23-
developed within PyLops-distibuted. Find more details regarding
24-
implementation of new operators at :ref:`addingoperator`.
24+
.. note:: End users of PyLops-distributed should not use this class
25+
directly but simply use operators that are already implemented.
26+
This class is meant for developers and it has to be used as the
27+
parent class of any new operator developed within PyLops-distibuted.
28+
Find more details regarding implementation of new operators at
29+
https://pylops.readthedocs.io/en/latest/adding.html.
2530
2631
Parameters
2732
----------
@@ -159,10 +164,29 @@ def __rmatmul__(self, other):
159164

160165
def __rmul__(self, x):
161166
if np.isscalar(x):
162-
return _ScaledLinearOperator(self, x)
167+
return aslinearoperator(_ScaledLinearOperator(self, x))
168+
else:
169+
return NotImplemented
170+
171+
def __pow__(self, p):
172+
if np.isscalar(p):
173+
return aslinearoperator(_PowerLinearOperator(self, p))
163174
else:
164175
return NotImplemented
165176

177+
def __add__(self, x):
178+
if isinstance(x, LinearOperator):
179+
return aslinearoperator(_SumLinearOperator(self, x))
180+
else:
181+
return NotImplemented
182+
183+
def __neg__(self):
184+
return aslinearoperator(_ScaledLinearOperator(self, -1))
185+
186+
def __sub__(self, x):
187+
return self.__add__(-x)
188+
189+
166190
def adjoint(self):
167191
"""Hermitian adjoint.
168192
@@ -181,6 +205,38 @@ def _adjoint(self):
181205
rmatvec=self.matvec,
182206
dtype=self.dtype)
183207

208+
def div1(self, y, niter=100):
209+
r"""Solve the linear problem :math:`\mathbf{y}=\mathbf{A}\mathbf{x}`.
210+
211+
Overloading of operator ``/`` to improve expressivity of
212+
`Pylops-distributed` when solving inverse problems.
213+
214+
Parameters
215+
----------
216+
y : :obj:`dask.array`
217+
Data
218+
niter : :obj:`int`, optional
219+
Number of iterations (to be used only when ``explicit=False``)
220+
221+
Returns
222+
-------
223+
xest : :obj:`dask.array`
224+
Estimated model
225+
226+
"""
227+
xest = self.__truediv__(y, niter=niter)
228+
return xest
229+
230+
def __truediv__(self, y, niter=100):
231+
if self.explicit is True:
232+
if self.A.shape[0] == self.A.shape[1]:
233+
xest = solve(self.A, y)
234+
else:
235+
xest = lstsq(self.A, y)[0]
236+
else:
237+
xest = cgls(self, y, niter=niter)[0]
238+
return xest
239+
184240
def conj(self):
185241
"""Complex conjugate operator
186242
@@ -244,12 +300,23 @@ def __init__(self, A, B):
244300
% (A, B))
245301
if A.compute[0] != B.compute[0] or A.compute[1] != B.compute[1]:
246302
raise ValueError('compute must be the same for A and B')
303+
if A.todask[0] != B.todask[0] or A.todask[1] != B.todask[1]:
304+
raise ValueError('todask must be the same for A and B')
247305
self.args = (A, B)
248306
super(_SumLinearOperator, self).__init__(shape=A.shape,
249307
dtype=A.dtype, Op=None,
250-
explicit=A.explicit,
308+
explicit=A.explicit and
309+
B.explicit,
251310
compute=A.compute,
252311
todask=A.todask)
312+
# Force compute and todask not to be applied to individual operators
313+
Ac = copy.deepcopy(A)
314+
Bc = copy.deepcopy(B)
315+
Ac.compute = (False, False)
316+
Bc.compute = (False, False)
317+
Ac.todask = (False, False)
318+
Bc.todask = (False, False)
319+
self.args = (Ac, Bc)
253320

254321
def _matvec(self, x):
255322
return self.args[0].matvec(x) + self.args[1].matvec(x)
@@ -276,12 +343,20 @@ def __init__(self, A, B):
276343
super(_ProductLinearOperator, self).__init__(shape=(A.shape[0],
277344
B.shape[1]),
278345
dtype=A.dtype, Op=None,
279-
explicit=A.explicit,
346+
explicit=A.explicit and
347+
B.explicit,
280348
compute=(B.compute[0],
281349
A.compute[1]),
282350
todask=(B.todask[0],
283351
A.todask[1]))
284-
self.args = (A, B)
352+
# Force compute and todask not to be applied to individual operators
353+
Ac = copy.deepcopy(A)
354+
Bc = copy.deepcopy(B)
355+
Ac.compute = (False, False)
356+
Bc.compute = (False, False)
357+
Ac.todask = (False, False)
358+
Bc.todask = (False, False)
359+
self.args = (Ac, Bc)
285360

286361
def _matvec(self, x):
287362
return self.args[0].matvec(self.args[1].matvec(x))
@@ -308,7 +383,11 @@ def __init__(self, A, alpha):
308383
explicit=A.explicit,
309384
compute=A.compute,
310385
todask=A.todask)
311-
self.args = (A, alpha)
386+
# Force compute and todask not to be applied to individual operators
387+
Ac = copy.deepcopy(A)
388+
Ac.compute = (False, False)
389+
Ac.todask = (False, False)
390+
self.args = (Ac, alpha)
312391

313392
def _matvec(self, x):
314393
return self.args[1] * self.args[0].matvec(x)
@@ -387,3 +466,27 @@ def _rmatvec(self, x):
387466
def _adjoint(self):
388467
return _ConjLinearOperator(self.oOp.H)
389468

469+
470+
def aslinearoperator(Op):
471+
"""Return Op as a LinearOperator.
472+
473+
Converts any operator into a LinearOperator. This can be used when `Op`
474+
is a private operator to ensure that the return operator has all properties
475+
and methods of the parent class.
476+
477+
Parameters
478+
----------
479+
Op : :obj:`pylops_distributed.LinearOperator` or any other Operator
480+
Operator of any type
481+
482+
Returns
483+
-------
484+
Op : :obj:`pylops_distributed.LinearOperator`
485+
Operator of type :obj:`pylops.LinearOperator`
486+
487+
"""
488+
if isinstance(Op, LinearOperator):
489+
return Op
490+
else:
491+
return LinearOperator(Op.shape, Op.dtype, Op, explicit=Op.explicit,
492+
compute=Op.compute, todask=Op.todask)

0 commit comments

Comments
 (0)