Skip to content

Commit 1524bc9

Browse files
committed
fix mypy errors
1 parent 8809993 commit 1524bc9

File tree

2 files changed

+24
-31
lines changed

2 files changed

+24
-31
lines changed

pytential/linalg/hmatrix.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
"""
2222

2323
from dataclasses import dataclass
24-
from typing import Any, Callable, Dict, Iterable, Optional, Union
24+
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union
2525

2626
import numpy as np
2727
import numpy.linalg as la
28+
from scipy.sparse.linalg import LinearOperator
2829

2930
from arraycontext import PyOpenCLArrayContext, ArrayOrContainerT, flatten, unflatten
3031
from meshmode.dof_array import DOFArray
@@ -35,13 +36,6 @@
3536
from pytential.linalg.skeletonization import (
3637
SkeletonizationWrangler, SkeletonizationResult)
3738

38-
try:
39-
from scipy.sparse.linalg import LinearOperator
40-
except ImportError:
41-
# NOTE: scipy should be available (for interp_decomp), but just in case
42-
class LinearOperator:
43-
pass
44-
4539
import logging
4640
logger = logging.getLogger(__name__)
4741

@@ -124,7 +118,7 @@ def _update_skeleton_diagonal(
124118
targets, sources = parent.skel_tgt_src_index
125119

126120
# FIXME: nicer way to do this?
127-
mat = np.empty(skeleton.nclusters, dtype=object)
121+
mat: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
128122
for k in range(skeleton.nclusters):
129123
D = skeleton.D[k].copy()
130124

@@ -146,9 +140,9 @@ def _update_skeleton_diagonal(
146140

147141
def _update_skeletons_diagonal(
148142
wrangler: "ProxyHierarchicalMatrixWrangler",
149-
func: Callable[[SkeletonizationResult], np.ndarray],
143+
func: Callable[[SkeletonizationResult], Optional[np.ndarray]],
150144
) -> np.ndarray:
151-
skeletons = np.empty(wrangler.skeletons.shape, dtype=object)
145+
skeletons: np.ndarray = np.empty(wrangler.skeletons.shape, dtype=object)
152146
skeletons[0] = wrangler.skeletons[0]
153147

154148
for i in range(1, wrangler.ctree.nlevels):
@@ -263,11 +257,14 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
263257
else:
264258
raise TypeError(f"unsupported input type: {type(x)}")
265259

260+
assert actx is None or isinstance(actx, PyOpenCLArrayContext)
266261
result = apply_skeleton_forward_matvec(self, ary)
262+
267263
if isinstance(x, DOFArray):
264+
assert actx is not None
268265
result = unflatten(x, actx.from_numpy(result), actx)
269266

270-
return result
267+
return result # type: ignore[return-value]
271268

272269

273270
def apply_skeleton_forward_matvec(
@@ -276,7 +273,7 @@ def apply_skeleton_forward_matvec(
276273
) -> ArrayOrContainerT:
277274
from pytential.linalg.cluster import split_array
278275
targets, sources = hmat.skeletons[0].tgt_src_index
279-
x = split_array(ary, sources)
276+
x = split_array(ary, sources) # type: ignore[arg-type]
280277

281278
# NOTE: this computes a telescoping product of the form
282279
#
@@ -297,7 +294,7 @@ def apply_skeleton_forward_matvec(
297294
#
298295
# which gives back the desired product when we reach the leaf level again.
299296

300-
d_dot_x = np.empty(hmat.nlevels, dtype=object)
297+
d_dot_x: np.ndarray = np.empty(hmat.nlevels, dtype=object)
301298

302299
# {{{ recurse down
303300

@@ -307,8 +304,8 @@ def apply_skeleton_forward_matvec(
307304
assert x.shape == (skeleton.nclusters,)
308305
assert skeleton.tgt_src_index.shape[1] == sum([xi.size for xi in x])
309306

310-
d_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
311-
r_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
307+
d_dot_x_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
308+
r_dot_x_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
312309

313310
for i in range(skeleton.nclusters):
314311
r_dot_x_k[i] = skeleton.R[i] @ x[i]
@@ -366,23 +363,26 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
366363
else:
367364
raise TypeError(f"unsupported input type: {type(x)}")
368365

366+
assert actx is None or isinstance(actx, PyOpenCLArrayContext)
369367
result = apply_skeleton_backward_matvec(actx, self, ary)
368+
370369
if isinstance(x, DOFArray):
370+
assert actx is not None
371371
result = unflatten(x, actx.from_numpy(result), actx)
372372

373-
return result
373+
return result # type: ignore[return-value]
374374

375375

376376
def apply_skeleton_backward_matvec(
377-
actx: PyOpenCLArrayContext,
377+
actx: Optional[PyOpenCLArrayContext],
378378
hmat: ProxyHierarchicalMatrix,
379379
ary: ArrayOrContainerT,
380380
) -> ArrayOrContainerT:
381381
from pytential.linalg.cluster import split_array
382382
targets, sources = hmat.skeletons[0].tgt_src_index
383383

384-
b = split_array(ary, targets)
385-
r_dot_b = np.empty(hmat.nlevels, dtype=object)
384+
b = split_array(ary, targets) # type: ignore[arg-type]
385+
r_dot_b: np.ndarray = np.empty(hmat.nlevels, dtype=object)
386386

387387
# {{{ recurse down
388388

@@ -412,7 +412,7 @@ def apply_skeleton_backward_matvec(
412412
assert b.shape == (skeleton.nclusters,)
413413
assert skeleton.tgt_src_index.shape[0] == sum([bi.size for bi in b])
414414

415-
dhat_dot_b_k = np.empty(skeleton.nclusters, dtype=object)
415+
dhat_dot_b_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
416416
for i in range(skeleton.nclusters):
417417
dhat_dot_b_k[i] = (
418418
skeleton.Dhat[i] @ (skeleton.R[i] @ (skeleton.invD[i] @ b[i]))
@@ -467,7 +467,7 @@ def build_hmatrix_by_proxy(
467467
exprs: Union[sym.Expression, Iterable[sym.Expression]],
468468
input_exprs: Union[sym.Expression, Iterable[sym.Expression]], *,
469469
auto_where: Optional[sym.DOFDescriptorLike] = None,
470-
domains: Optional[Iterable[sym.DOFDescriptorLike]] = None,
470+
domains: Optional[Sequence[sym.DOFDescriptorLike]] = None,
471471
context: Optional[Dict[str, Any]] = None,
472472
id_eps: float = 1.0e-8,
473473

@@ -483,13 +483,6 @@ def build_hmatrix_by_proxy(
483483
_approx_nproxy: Optional[int] = None,
484484
_proxy_radius_factor: Optional[float] = None,
485485
) -> ProxyHierarchicalMatrixWrangler:
486-
try:
487-
import scipy # noqa: F401
488-
except ImportError:
489-
raise ImportError(
490-
"The direct solver requires 'scipy' for the interpolative "
491-
"decomposition used in skeletonization")
492-
493486
from pytential.symbolic.matrix import P2PClusterMatrixBuilder
494487
from pytential.linalg.skeletonization import make_skeletonization_wrangler
495488

pytential/linalg/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,8 @@ def mnorm(x: np.ndarray, y: np.ndarray) -> "np.floating[Any]":
443443
def skeletonization_matrix(
444444
mat: np.ndarray, skeleton: "SkeletonizationResult",
445445
) -> Tuple[np.ndarray, np.ndarray]:
446-
D = np.empty(skeleton.nclusters, dtype=object)
447-
S = np.empty((skeleton.nclusters, skeleton.nclusters), dtype=object)
446+
D: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
447+
S: np.ndarray = np.empty((skeleton.nclusters, skeleton.nclusters), dtype=object)
448448

449449
from itertools import product
450450
for i, j in product(range(skeleton.nclusters), repeat=2):

0 commit comments

Comments
 (0)