2121"""
2222
2323from dataclasses import dataclass
24- from typing import Any , Callable , Dict , Iterable , Optional , Union
24+ from typing import Any , Callable , Dict , Iterable , Optional , Sequence , Union
2525
2626import numpy as np
2727import numpy .linalg as la
28+ from scipy .sparse .linalg import LinearOperator
2829
2930from arraycontext import PyOpenCLArrayContext , ArrayOrContainerT , flatten , unflatten
3031from meshmode .dof_array import DOFArray
3536from pytential .linalg .skeletonization import (
3637 SkeletonizationWrangler , SkeletonizationResult )
3738
38- try :
39- from scipy .sparse .linalg import LinearOperator
40- except ImportError :
41- # NOTE: scipy should be available (for interp_decomp), but just in case
42- class LinearOperator :
43- pass
44-
4539import logging
4640logger = logging .getLogger (__name__ )
4741
@@ -124,7 +118,7 @@ def _update_skeleton_diagonal(
124118 targets , sources = parent .skel_tgt_src_index
125119
126120 # FIXME: nicer way to do this?
127- mat = np .empty (skeleton .nclusters , dtype = object )
121+ mat : np . ndarray = np .empty (skeleton .nclusters , dtype = object )
128122 for k in range (skeleton .nclusters ):
129123 D = skeleton .D [k ].copy ()
130124
@@ -146,9 +140,9 @@ def _update_skeleton_diagonal(
146140
147141def _update_skeletons_diagonal (
148142 wrangler : "ProxyHierarchicalMatrixWrangler" ,
149- func : Callable [[SkeletonizationResult ], np .ndarray ],
143+ func : Callable [[SkeletonizationResult ], Optional [ np .ndarray ] ],
150144 ) -> np .ndarray :
151- skeletons = np .empty (wrangler .skeletons .shape , dtype = object )
145+ skeletons : np . ndarray = np .empty (wrangler .skeletons .shape , dtype = object )
152146 skeletons [0 ] = wrangler .skeletons [0 ]
153147
154148 for i in range (1 , wrangler .ctree .nlevels ):
@@ -263,11 +257,14 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
263257 else :
264258 raise TypeError (f"unsupported input type: { type (x )} " )
265259
260+ assert actx is None or isinstance (actx , PyOpenCLArrayContext )
266261 result = apply_skeleton_forward_matvec (self , ary )
262+
267263 if isinstance (x , DOFArray ):
264+ assert actx is not None
268265 result = unflatten (x , actx .from_numpy (result ), actx )
269266
270- return result
267+ return result # type: ignore[return-value]
271268
272269
273270def apply_skeleton_forward_matvec (
@@ -276,7 +273,7 @@ def apply_skeleton_forward_matvec(
276273 ) -> ArrayOrContainerT :
277274 from pytential .linalg .cluster import split_array
278275 targets , sources = hmat .skeletons [0 ].tgt_src_index
279- x = split_array (ary , sources )
276+ x = split_array (ary , sources ) # type: ignore[arg-type]
280277
281278 # NOTE: this computes a telescoping product of the form
282279 #
@@ -297,7 +294,7 @@ def apply_skeleton_forward_matvec(
297294 #
298295 # which gives back the desired product when we reach the leaf level again.
299296
300- d_dot_x = np .empty (hmat .nlevels , dtype = object )
297+ d_dot_x : np . ndarray = np .empty (hmat .nlevels , dtype = object )
301298
302299 # {{{ recurse down
303300
@@ -307,8 +304,8 @@ def apply_skeleton_forward_matvec(
307304 assert x .shape == (skeleton .nclusters ,)
308305 assert skeleton .tgt_src_index .shape [1 ] == sum ([xi .size for xi in x ])
309306
310- d_dot_x_k = np .empty (skeleton .nclusters , dtype = object )
311- r_dot_x_k = np .empty (skeleton .nclusters , dtype = object )
307+ d_dot_x_k : np . ndarray = np .empty (skeleton .nclusters , dtype = object )
308+ r_dot_x_k : np . ndarray = np .empty (skeleton .nclusters , dtype = object )
312309
313310 for i in range (skeleton .nclusters ):
314311 r_dot_x_k [i ] = skeleton .R [i ] @ x [i ]
@@ -366,23 +363,26 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
366363 else :
367364 raise TypeError (f"unsupported input type: { type (x )} " )
368365
366+ assert actx is None or isinstance (actx , PyOpenCLArrayContext )
369367 result = apply_skeleton_backward_matvec (actx , self , ary )
368+
370369 if isinstance (x , DOFArray ):
370+ assert actx is not None
371371 result = unflatten (x , actx .from_numpy (result ), actx )
372372
373- return result
373+ return result # type: ignore[return-value]
374374
375375
376376def apply_skeleton_backward_matvec (
377- actx : PyOpenCLArrayContext ,
377+ actx : Optional [ PyOpenCLArrayContext ] ,
378378 hmat : ProxyHierarchicalMatrix ,
379379 ary : ArrayOrContainerT ,
380380 ) -> ArrayOrContainerT :
381381 from pytential .linalg .cluster import split_array
382382 targets , sources = hmat .skeletons [0 ].tgt_src_index
383383
384- b = split_array (ary , targets )
385- r_dot_b = np .empty (hmat .nlevels , dtype = object )
384+ b = split_array (ary , targets ) # type: ignore[arg-type]
385+ r_dot_b : np . ndarray = np .empty (hmat .nlevels , dtype = object )
386386
387387 # {{{ recurse down
388388
@@ -412,7 +412,7 @@ def apply_skeleton_backward_matvec(
412412 assert b .shape == (skeleton .nclusters ,)
413413 assert skeleton .tgt_src_index .shape [0 ] == sum ([bi .size for bi in b ])
414414
415- dhat_dot_b_k = np .empty (skeleton .nclusters , dtype = object )
415+ dhat_dot_b_k : np . ndarray = np .empty (skeleton .nclusters , dtype = object )
416416 for i in range (skeleton .nclusters ):
417417 dhat_dot_b_k [i ] = (
418418 skeleton .Dhat [i ] @ (skeleton .R [i ] @ (skeleton .invD [i ] @ b [i ]))
@@ -467,7 +467,7 @@ def build_hmatrix_by_proxy(
467467 exprs : Union [sym .Expression , Iterable [sym .Expression ]],
468468 input_exprs : Union [sym .Expression , Iterable [sym .Expression ]], * ,
469469 auto_where : Optional [sym .DOFDescriptorLike ] = None ,
470- domains : Optional [Iterable [sym .DOFDescriptorLike ]] = None ,
470+ domains : Optional [Sequence [sym .DOFDescriptorLike ]] = None ,
471471 context : Optional [Dict [str , Any ]] = None ,
472472 id_eps : float = 1.0e-8 ,
473473
@@ -483,13 +483,6 @@ def build_hmatrix_by_proxy(
483483 _approx_nproxy : Optional [int ] = None ,
484484 _proxy_radius_factor : Optional [float ] = None ,
485485 ) -> ProxyHierarchicalMatrixWrangler :
486- try :
487- import scipy # noqa: F401
488- except ImportError :
489- raise ImportError (
490- "The direct solver requires 'scipy' for the interpolative "
491- "decomposition used in skeletonization" )
492-
493486 from pytential .symbolic .matrix import P2PClusterMatrixBuilder
494487 from pytential .linalg .skeletonization import make_skeletonization_wrangler
495488
0 commit comments