From 15bfbafa579cc5fd7de37eb211b3403ff7ea0189 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Wed, 4 Feb 2026 22:30:04 -0800 Subject: [PATCH] Upgrade to use py3.10 features --- linear_operator/__init__.py | 2 + linear_operator/beta_features.py | 1 + linear_operator/functions/__init__.py | 34 ++-- linear_operator/functions/_diagonalization.py | 1 + linear_operator/functions/_dsmm.py | 1 + linear_operator/functions/_inv_quad.py | 1 + linear_operator/functions/_inv_quad_logdet.py | 1 + linear_operator/functions/_matmul.py | 1 + .../functions/_pivoted_cholesky.py | 1 + .../functions/_root_decomposition.py | 1 + linear_operator/functions/_solve.py | 1 + linear_operator/functions/_sqrt_inv_matmul.py | 1 + linear_operator/operators/__init__.py | 1 + linear_operator/operators/_linear_operator.py | 192 +++++++++--------- .../operators/added_diag_linear_operator.py | 22 +- .../operators/batch_repeat_linear_operator.py | 43 ++-- .../operators/block_diag_linear_operator.py | 57 +++--- .../block_interleaved_linear_operator.py | 46 +++-- .../operators/block_linear_operator.py | 8 +- .../operators/cat_linear_operator.py | 22 +- .../operators/chol_linear_operator.py | 40 ++-- .../operators/constant_mul_linear_operator.py | 12 +- .../operators/dense_linear_operator.py | 35 ++-- .../operators/diag_linear_operator.py | 117 +++++------ .../operators/identity_linear_operator.py | 68 +++---- .../operators/interpolated_linear_operator.py | 16 +- .../operators/keops_linear_operator.py | 4 +- .../operators/kernel_linear_operator.py | 14 +- ...cker_product_added_diag_linear_operator.py | 47 ++--- .../kronecker_product_linear_operator.py | 91 +++++---- .../linear_operator_representation_tree.py | 2 + ...ow_rank_root_added_diag_linear_operator.py | 42 ++-- .../low_rank_root_linear_operator.py | 6 +- .../operators/masked_linear_operator.py | 10 +- .../operators/matmul_linear_operator.py | 24 +-- .../operators/mul_linear_operator.py | 10 +- .../operators/permutation_linear_operator.py | 32 +-- .../operators/psd_sum_linear_operator.py | 2 + .../operators/root_linear_operator.py | 22 +- .../operators/sum_batch_linear_operator.py | 1 + .../sum_kronecker_linear_operator.py | 38 ++-- .../operators/sum_linear_operator.py | 59 +++--- .../operators/toeplitz_linear_operator.py | 10 +- .../operators/triangular_linear_operator.py | 57 +++--- .../operators/zero_linear_operator.py | 52 +++-- linear_operator/settings.py | 1 + linear_operator/test/__init__.py | 1 + linear_operator/test/base_test_case.py | 3 +- .../test/linear_operator_test_case.py | 1 + linear_operator/test/utils.py | 1 + linear_operator/utils/__init__.py | 1 + linear_operator/utils/broadcasting.py | 1 + linear_operator/utils/cholesky.py | 7 +- .../utils/contour_integral_quad.py | 2 + linear_operator/utils/deprecation.py | 1 + linear_operator/utils/errors.py | 1 + linear_operator/utils/generic.py | 8 +- linear_operator/utils/getitem.py | 25 +-- linear_operator/utils/interpolation.py | 2 +- linear_operator/utils/lanczos.py | 1 + linear_operator/utils/linear_cg.py | 1 + linear_operator/utils/memoize.py | 1 + linear_operator/utils/minres.py | 1 + linear_operator/utils/permutation.py | 9 +- linear_operator/utils/pinverse.py | 1 + linear_operator/utils/qr.py | 1 + linear_operator/utils/sparse.py | 96 ++++----- linear_operator/utils/stochastic_lq.py | 1 + linear_operator/utils/toeplitz.py | 1 + linear_operator/utils/warnings.py | 1 + 70 files changed, 727 insertions(+), 690 deletions(-) diff --git a/linear_operator/__init__.py b/linear_operator/__init__.py index 1b2a5e1c..4be9fd28 100644 --- a/linear_operator/__init__.py +++ b/linear_operator/__init__.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + from linear_operator import beta_features, operators, settings, utils from linear_operator.functions import ( add_diagonal, diff --git a/linear_operator/beta_features.py b/linear_operator/beta_features.py index 1bb84d6b..f4c2299d 100644 --- a/linear_operator/beta_features.py +++ b/linear_operator/beta_features.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import warnings diff --git a/linear_operator/functions/__init__.py b/linear_operator/functions/__init__.py index 8a2940a9..00abf0aa 100644 --- a/linear_operator/functions/__init__.py +++ b/linear_operator/functions/__init__.py @@ -2,16 +2,16 @@ from __future__ import annotations -from typing import Any, Optional, Tuple, Union +from typing import Any, TypeAlias import torch from linear_operator.functions._dsmm import DSMM -LinearOperatorType = Any # Want this to be "LinearOperator" but runtime type checker can't handle +LinearOperatorType: TypeAlias = Any # Want this to be "LinearOperator" but runtime type checker can't handle -Anysor = Union[LinearOperatorType, torch.Tensor] +Anysor: TypeAlias = LinearOperatorType | torch.Tensor def add_diagonal(input: Anysor, diag: torch.Tensor) -> LinearOperatorType: @@ -47,9 +47,7 @@ def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor: return input + diag -def diagonalization( - input: Anysor, method: Optional[str] = None -) -> Tuple[torch.Tensor, Union[torch.Tensor, LinearOperatorType]]: +def diagonalization(input: Anysor, method: str | None = None) -> tuple[torch.Tensor, torch.Tensor | LinearOperatorType]: r""" Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices). :math:`\mathbf A`. @@ -67,7 +65,7 @@ def diagonalization( def dsmm( - sparse_mat: Union[torch.sparse.HalfTensor, torch.sparse.FloatTensor, torch.sparse.DoubleTensor], + sparse_mat: torch.sparse.HalfTensor | torch.sparse.FloatTensor | torch.sparse.DoubleTensor, dense_mat: torch.Tensor, ) -> torch.Tensor: r""" @@ -111,8 +109,8 @@ def inv_quad(input: Anysor, inv_quad_rhs: torch.Tensor, reduce_inv_quad: bool = def inv_quad_logdet( - input: Anysor, inv_quad_rhs: Optional[torch.Tensor] = None, logdet: bool = False, reduce_inv_quad: bool = True -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + input: Anysor, inv_quad_rhs: torch.Tensor | None = None, logdet: bool = False, reduce_inv_quad: bool = True +) -> tuple[torch.Tensor | None, torch.Tensor | None]: r""" Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`. However, calling this method is far more efficient and stable than calling each method independently. @@ -133,8 +131,8 @@ def inv_quad_logdet( def pivoted_cholesky( - input: Anysor, rank: int, error_tol: Optional[float] = None, return_pivots: bool = False -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + input: Anysor, rank: int, error_tol: float | None = None, return_pivots: bool = False +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: r""" Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices). :math:`\mathbf L \mathbf L^\top = \mathbf A`. @@ -161,7 +159,7 @@ def pivoted_cholesky( return to_linear_operator(input).pivoted_cholesky(rank=rank, error_tol=error_tol, return_pivots=return_pivots) -def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOperatorType: +def root_decomposition(input: Anysor, method: str | None = None) -> LinearOperatorType: r""" Returns a (usually low-rank) root decomposition linear operator of the positive definite matrix (or batch of matrices) :math:`\mathbf A`. @@ -180,9 +178,9 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe def root_inv_decomposition( input: Anysor, - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - method: Optional[str] = None, + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + method: str | None = None, ) -> LinearOperatorType: r""" Returns a (usually low-rank) inverse root decomposition linear operator @@ -206,7 +204,7 @@ def root_inv_decomposition( ) -def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) -> torch.Tensor: +def solve(input: Anysor, rhs: torch.Tensor, lhs: torch.Tensor | None = None) -> torch.Tensor: r""" Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`, computes a linear solve with right hand side :math:`\mathbf R`: @@ -241,8 +239,8 @@ def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) def sqrt_inv_matmul( - input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + input: Anysor, rhs: torch.Tensor, lhs: torch.Tensor | None = None +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: r""" Given a positive definite matrix (or batch of matrices) :math:`\mathbf A` and a right hand size :math:`\mathbf R`, diff --git a/linear_operator/functions/_diagonalization.py b/linear_operator/functions/_diagonalization.py index 8b7241ab..0180bced 100644 --- a/linear_operator/functions/_diagonalization.py +++ b/linear_operator/functions/_diagonalization.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch.autograd import Function diff --git a/linear_operator/functions/_dsmm.py b/linear_operator/functions/_dsmm.py index 8c316e07..8fd2e80f 100644 --- a/linear_operator/functions/_dsmm.py +++ b/linear_operator/functions/_dsmm.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations from torch.autograd import Function diff --git a/linear_operator/functions/_inv_quad.py b/linear_operator/functions/_inv_quad.py index 6d7a611b..57c05349 100644 --- a/linear_operator/functions/_inv_quad.py +++ b/linear_operator/functions/_inv_quad.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch.autograd import Function diff --git a/linear_operator/functions/_inv_quad_logdet.py b/linear_operator/functions/_inv_quad_logdet.py index ec3e7896..c128fcfc 100644 --- a/linear_operator/functions/_inv_quad_logdet.py +++ b/linear_operator/functions/_inv_quad_logdet.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import warnings diff --git a/linear_operator/functions/_matmul.py b/linear_operator/functions/_matmul.py index 13cf96ee..82653e92 100644 --- a/linear_operator/functions/_matmul.py +++ b/linear_operator/functions/_matmul.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations from torch.autograd import Function diff --git a/linear_operator/functions/_pivoted_cholesky.py b/linear_operator/functions/_pivoted_cholesky.py index 023eb2ae..01235c05 100644 --- a/linear_operator/functions/_pivoted_cholesky.py +++ b/linear_operator/functions/_pivoted_cholesky.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch.autograd import Function diff --git a/linear_operator/functions/_root_decomposition.py b/linear_operator/functions/_root_decomposition.py index 5b3fbe58..3c2cc594 100644 --- a/linear_operator/functions/_root_decomposition.py +++ b/linear_operator/functions/_root_decomposition.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch.autograd import Function diff --git a/linear_operator/functions/_solve.py b/linear_operator/functions/_solve.py index 8783b68b..ed515ef7 100644 --- a/linear_operator/functions/_solve.py +++ b/linear_operator/functions/_solve.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch.autograd import Function diff --git a/linear_operator/functions/_sqrt_inv_matmul.py b/linear_operator/functions/_sqrt_inv_matmul.py index c137c762..6da628b8 100644 --- a/linear_operator/functions/_sqrt_inv_matmul.py +++ b/linear_operator/functions/_sqrt_inv_matmul.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch.autograd import Function diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index 41d093cf..e278956a 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations from linear_operator.operators._linear_operator import LinearOperator, to_dense from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index 9e814637..d8ced9ea 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -10,7 +10,7 @@ from abc import abstractmethod from collections import OrderedDict from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable import numpy as np import torch @@ -138,7 +138,7 @@ class LinearOperator(object): :attr:`matrix_shape`. """ - def _check_args(self, *args, **kwargs) -> Union[str, None]: + def _check_args(self, *args, **kwargs) -> str | None: """ (Optional) run checks to see that input arguments and kwargs are valid @@ -333,7 +333,7 @@ def _unsqueeze_batch(self, dim: int) -> LinearOperator: #### # The following methods PROBABLY should be over-written by LinearOperator subclasses for efficiency #### - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: r""" Given :math:`\mathbf U` (left_vecs) and :math:`\mathbf V` (right_vecs), Computes the derivatives of (:math:`\mathbf u^\top \mathbf K \mathbf v`) w.r.t. :math:`\mathbf K`. @@ -393,7 +393,7 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O return tuple(grads) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) """ Expands along batch dimensions. Return size will be *batch_shape x *matrix_shape. @@ -469,15 +469,15 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice # Standard LinearOperator methods #### @property - def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator", int], ...]: + def _args(self) -> tuple[torch.Tensor | LinearOperator | int, ...]: return self._args_memo @_args.setter - def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> None: + def _args(self, args: tuple[torch.Tensor | LinearOperator | int, ...]) -> None: self._args_memo = args @property - def _kwargs(self) -> Dict[str, Any]: + def _kwargs(self) -> dict[str, Any]: return {**self._differentiable_kwargs, **self._nondifferentiable_kwargs} def _approx_diagonal( @@ -498,7 +498,7 @@ def _approx_diagonal( @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) """ (Optional) Cholesky-factorizes the LinearOperator @@ -529,9 +529,9 @@ def _cholesky( def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) """ (Optional) Assuming that `self` is a Cholesky factor, computes the cholesky solve. @@ -576,7 +576,7 @@ def _diagonal( return self[..., row_col_iter, row_col_iter] def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) """ Multiplies the LinearOperator by a constant. @@ -594,7 +594,7 @@ def _mul_constant( def _mul_matrix( self: LinearOperator, # shape: (..., #M, #N) - other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + other: torch.Tensor | LinearOperator, # shape: (..., #M, #N) ) -> LinearOperator: # shape: (..., M, N) r""" Multiplies the LinearOperator by a (batch of) matrices. @@ -615,7 +615,7 @@ def _mul_matrix( else: return MulLinearOperator(self, other) - def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], Optional[torch.Tensor]]: + def _preconditioner(self) -> tuple[Callable | None, LinearOperator | None, torch.Tensor | None]: r""" (Optional) define a preconditioner (:math:`\mathbf P`) for linear conjugate gradients @@ -688,7 +688,7 @@ def _prod_batch(self, dim: int) -> LinearOperator: def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) """ Returns the (usually low-rank) root of a LinearOperator of a PSD matrix. @@ -722,9 +722,9 @@ def _root_decomposition_size(self) -> int: def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) r""" Returns the (usually low-rank) inverse root of a LinearOperator of a PSD matrix. @@ -781,15 +781,15 @@ def _set_requires_grad(self, val: bool) -> None: def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): r""" TODO """ @@ -802,7 +802,7 @@ def _solve( preconditioner=preconditioner, ) - def _solve_preconditioner(self) -> Optional[Callable]: + def _solve_preconditioner(self) -> Callable | None: r""" (Optional) define a preconditioner :math:`\mathbf P` that can be used for linear systems, but not necessarily for log determinants. By default, this can call @@ -864,7 +864,7 @@ def _sum_batch(self, dim: int) -> LinearOperator: @cached(name="svd") def _svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) """Method that allows implementing special-cased SVD computation. Should not be called directly""" # Using symeig is preferable here for psd LinearOperators. # Will need to overwrite this function for non-psd LinearOperators. @@ -878,8 +878,8 @@ def _svd( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) r""" Method that allows implementing special-cased symeig computation. Should not be called directly """ @@ -902,8 +902,8 @@ def _symeig( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) r""" Performs a transpose matrix multiplication :math:`\mathbf K^\top \mathbf M` with the (... x M x N) matrix :math:`\mathbf K` that this LinearOperator represents. @@ -929,7 +929,7 @@ def abs(self) -> LinearOperator: @_implements_symmetric(torch.add) def add( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch, M, N) + other: Tensor | LinearOperator, # shape: (*batch, M, N) alpha: float = None, ) -> LinearOperator: # shape: (*batch, M, N) r""" @@ -1018,10 +1018,10 @@ def add_jitter( def add_low_rank( self: LinearOperator, # shape: (*batch, N, N) - low_rank_mat: Union[Tensor, LinearOperator], # shape: (..., N, _) - root_decomp_method: Optional[str] = None, - root_inv_decomp_method: Optional[str] = None, - generate_roots: Optional[bool] = True, + low_rank_mat: Tensor | LinearOperator, # shape: (..., N, _) + root_decomp_method: str | None = None, + root_inv_decomp_method: str | None = None, + generate_roots: bool | None = True, **root_decomp_kwargs, ) -> LinearOperator: # returns SumLinearOperator # shape: (*batch, N, N) r""" @@ -1346,7 +1346,7 @@ def cpu( return self.__class__(*new_args, **new_kwargs) def cuda( - self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + self: LinearOperator, device_id: str | None = None # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.nn.Module.cuda`. @@ -1368,7 +1368,7 @@ def cuda( return self.__class__(*new_args, **new_kwargs) @property - def device(self) -> Optional[torch.device]: + def device(self) -> torch.device | None: return self._args[0].device def detach( @@ -1429,8 +1429,8 @@ def diagonal( @cached(name="diagonalization") def diagonalization( - self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + self: LinearOperator, method: str | None = None # shape: (*batch, N, N) + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) """ Returns a (usually partial) diagonalization of a symmetric PSD matrix. Options are either "lanczos" or "symeig". "lanczos" runs Lanczos while @@ -1480,7 +1480,7 @@ def dim(self) -> int: return self.ndimension() @_implements(torch.div) - def div(self, other: Union[float, torch.Tensor]) -> LinearOperator: + def div(self, other: float | torch.Tensor) -> LinearOperator: """ Returns the product of this LinearOperator the elementwise reciprocal of another matrix. @@ -1496,7 +1496,7 @@ def div(self, other: Union[float, torch.Tensor]) -> LinearOperator: return self.mul(1.0 / other) def double( - self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + self: LinearOperator, device_id: str | None = None # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.Tensor.double`. @@ -1506,13 +1506,13 @@ def double( return self.type(torch.double) @property - def dtype(self) -> Optional[torch.dtype]: + def dtype(self) -> torch.dtype | None: return self._args[0].dtype @_implements(torch.linalg.eigh) def eigh( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, N), (*batch, N, N) + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, N), (*batch, N, N) """ Compute the symmetric eigendecomposition of the linear operator. This can be very slow for large tensors. @@ -1535,7 +1535,7 @@ def eigh( @_implements(torch.linalg.eigvalsh) def eigvalsh( self: LinearOperator, # shape: (*batch, N, N) - ) -> Union[Tensor, Tuple[Tensor, Optional[LinearOperator]]]: # shape: (*batch, N) or (*batch, N, N) + ) -> Tensor | tuple[Tensor, LinearOperator | None]: # shape: (*batch, N) or (*batch, N, N) """ Compute the eigenvalues of symmetric linear operator. This can be very slow for large tensors. @@ -1569,7 +1569,7 @@ def exp( # We define it here so that we can map the torch function torch.exp to the LinearOperator method raise NotImplementedError(f"torch.exp({self.__class__.__name__}) is not implemented.") - def expand(self, *sizes: Union[torch.Size, int]) -> LinearOperator: + def expand(self, *sizes: torch.Size | int) -> LinearOperator: r""" Returns a new view of the self :obj:`~linear_operator.operators.LinearOperator` with singleton @@ -1607,7 +1607,7 @@ def expand(self, *sizes: Union[torch.Size, int]) -> LinearOperator: return res def float( - self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + self: LinearOperator, device_id: str | None = None # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.Tensor.float`. @@ -1617,7 +1617,7 @@ def float( return self.type(torch.float) def half( - self: LinearOperator, device_id: Optional[str] = None # shape: (*batch, M, N) + self: LinearOperator, device_id: str | None = None # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) """ This method operates identically to :func:`torch.Tensor.half`. @@ -1679,12 +1679,12 @@ def inv_quad( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on r""" Calls both :func:`inv_quad` and :func:`logdet` on a positive @@ -1832,8 +1832,8 @@ def logdet( @_implements(torch.matmul) def matmul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) + other: Tensor | LinearOperator, # shape: (*batch2, N, P) or (*batch2, N) + ) -> Tensor | LinearOperator: # shape: (..., M, P) or (..., M) r""" Performs :math:`\mathbf A \mathbf B`, where :math:`\mathbf A \in \mathbb R^{M \times N}` is the LinearOperator and :math:`\mathbf B` @@ -1869,7 +1869,7 @@ def mT( @_implements_symmetric(torch.mul) def mul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[float, Tensor, LinearOperator], # shape: (*batch2, M, N) + other: float | Tensor | LinearOperator, # shape: (*batch2, M, N) ) -> LinearOperator: # shape: (..., M, N) """ Multiplies the matrix by a constant, or elementwise the matrix by another matrix. @@ -1928,7 +1928,7 @@ def numpy(self) -> np.ndarray: return self.to_dense().detach().cpu().numpy() @_implements(torch.permute) - def permute(self, *dims: Union[int, Tuple[int, ...]]) -> LinearOperator: + def permute(self, *dims: int | tuple[int, ...]) -> LinearOperator: """ Returns a view of the original tensor with its dimensions permuted. @@ -1963,9 +1963,9 @@ def permute(self, *dims: Union[int, Tuple[int, ...]]) -> LinearOperator: def pivoted_cholesky( self: LinearOperator, # shape: (*batch, N, N) rank: int, - error_tol: Optional[float] = None, + error_tol: float | None = None, return_pivots: bool = False, - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, R), (*batch, N, R), (*batch, N) + ) -> Tensor | tuple[Tensor, Tensor]: # shape: (*batch, N, R), (*batch, N, R), (*batch, N) r""" Performs a partial pivoted Cholesky factorization of the (positive definite) LinearOperator. :math:`\mathbf L \mathbf L^\top = \mathbf K`. @@ -1996,7 +1996,7 @@ def pivoted_cholesky( # TODO: implement keepdim @_implements(torch.prod) - def prod(self, dim: int) -> Union[LinearOperator, torch.Tensor]: + def prod(self, dim: int) -> LinearOperator | torch.Tensor: r""" Returns the product of each row of :math:`\mathbf A` along the batch dimension :attr:`dim`. @@ -2027,7 +2027,7 @@ def prod(self, dim: int) -> Union[LinearOperator, torch.Tensor]: return self._prod_batch(dim) - def repeat(self, *sizes: Union[int, Tuple[int, ...]]) -> LinearOperator: + def repeat(self, *sizes: int | tuple[int, ...]) -> LinearOperator: """ Repeats this tensor along the specified dimensions. @@ -2061,7 +2061,7 @@ def repeat(self, *sizes: Union[int, Tuple[int, ...]]) -> LinearOperator: return BatchRepeatLinearOperator(self, batch_repeat=torch.Size(sizes[:-2])) # TODO: make this method private - def representation(self) -> Tuple[torch.Tensor, ...]: + def representation(self) -> tuple[torch.Tensor, ...]: """ Returns the Tensors that are used to define the LinearOperator """ @@ -2113,7 +2113,7 @@ def requires_grad_(self, val: bool) -> LinearOperator: self._set_requires_grad(val) return self - def reshape(self, *sizes: Union[torch.Size, int, Tuple[int, ...]]) -> LinearOperator: + def reshape(self, *sizes: torch.Size | int | tuple[int, ...]) -> LinearOperator: """ Alias for expand """ @@ -2126,8 +2126,8 @@ def reshape(self, *sizes: Union[torch.Size, int, Tuple[int, ...]]) -> LinearOper @_implements_second_arg(torch.matmul) def rmatmul( self: LinearOperator, # shape: (..., M, N) - other: Union[Tensor, LinearOperator], # shape: (..., P, M) or (..., M) - ) -> Union[Tensor, LinearOperator]: # shape: (..., P, N) or (N) + other: Tensor | LinearOperator, # shape: (..., P, M) or (..., M) + ) -> Tensor | LinearOperator: # shape: (..., P, N) or (N) r""" Performs :math:`\mathbf B \mathbf A`, where :math:`\mathbf A \in \mathbb R^{M \times N}` is the LinearOperator and :math:`\mathbf B` @@ -2144,7 +2144,7 @@ def rmatmul( @cached(name="root_decomposition") def root_decomposition( - self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + self: LinearOperator, method: str | None = None # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) r""" Returns a (usually low-rank) root decomposition linear operator of the PSD LinearOperator :math:`\mathbf A`. @@ -2208,10 +2208,10 @@ def root_decomposition( @cached(name="root_inv_decomposition") def root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - method: Optional[str] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + method: str | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) r""" Returns a (usually low-rank) inverse root decomposition linear operator of the PSD LinearOperator :math:`\mathbf A`. @@ -2294,7 +2294,7 @@ def root_inv_decomposition( return RootLinearOperator(inv_root) - def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: + def size(self, dim: int | None = None) -> torch.Size | int: """ Returns he size of the LinearOperator (or the specified dimension). @@ -2313,7 +2313,7 @@ def shape(self) -> torch.Size: def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) r""" Computes a linear solve (w.r.t self = :math:`\mathbf A`) with right hand side :math:`\mathbf R`. @@ -2410,8 +2410,8 @@ def sqrt( def sqrt_inv_matmul( self: LinearOperator, # shape: (*batch, N, N) rhs: Tensor, # shape: (*batch, N, P) - lhs: Optional[Tensor] = None, # shape: (*batch, O, N) - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) + lhs: Tensor | None = None, # shape: (*batch, O, N) + ) -> Tensor | tuple[Tensor, Tensor]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) r""" If the LinearOperator :math:`\mathbf A` is positive definite, computes @@ -2454,7 +2454,7 @@ def sqrt_inv_matmul( return sqrt_inv_matmul_res, inv_quad_res @_implements(torch.squeeze) - def squeeze(self, dim: int) -> Union[LinearOperator, torch.Tensor]: + def squeeze(self, dim: int) -> LinearOperator | torch.Tensor: """ Removes the singleton dimension of a LinearOperator specifed by :attr:`dim`. @@ -2473,7 +2473,7 @@ def squeeze(self, dim: int) -> Union[LinearOperator, torch.Tensor]: @_implements(torch.sub) def sub( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch, M, N) + other: Tensor | LinearOperator, # shape: (*batch, M, N) alpha: float = None, ) -> LinearOperator: # shape: (*batch, M, N) r""" @@ -2495,7 +2495,7 @@ def sub( return self + (alpha * -1) * other @_implements(torch.sum) - def sum(self, dim: Optional[int] = None) -> Union[LinearOperator, torch.Tensor]: + def sum(self, dim: int | None = None) -> LinearOperator | torch.Tensor: """ Sum the LinearOperator across a dimension. The `dim` controls which batch dimension is summed over. @@ -2539,7 +2539,7 @@ def sum(self, dim: Optional[int] = None) -> Union[LinearOperator, torch.Tensor]: def svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) r""" Compute the SVD of the linear operator :math:`\mathbf A \in \mathbb R^{M \times N}` s.t. :math:`\mathbf A = \mathbf{U S V^\top}`. @@ -2559,7 +2559,7 @@ def svd( @_implements(torch.linalg.svd) def _torch_linalg_svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) r""" A version of self.svd() that matches the torch.linalg.svd API. @@ -2782,14 +2782,14 @@ def zero_mean_mvn_samples( def __sub__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) return self + other.mul(-1) def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator from linear_operator.operators.dense_linear_operator import to_linear_operator from linear_operator.operators.diag_linear_operator import DiagLinearOperator @@ -2814,7 +2814,7 @@ def __add__( else: return SumLinearOperator(self, other) - def __getitem__(self, index: Union[IndexType, Tuple[IndexType, ...]]) -> Union[LinearOperator, torch.Tensor]: + def __getitem__(self, index: IndexType | tuple[IndexType, ...]) -> LinearOperator | torch.Tensor: ndimension = self.ndimension() # Process the index @@ -2927,34 +2927,34 @@ def _isclose(self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: b def __matmul__( self: LinearOperator, # shape: (*batch, M, N) - other: Union[torch.Tensor, LinearOperator], # shape: (*batch2, N, D) or (N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., M, D) or (..., M) + other: torch.Tensor | LinearOperator, # shape: (*batch2, N, D) or (N) + ) -> torch.Tensor | LinearOperator: # shape: (..., M, D) or (..., M) return self.matmul(other) @_implements_second_arg(torch.Tensor.matmul) def __rmatmul__( self: LinearOperator, # shape: (..., M, N) - other: Union[Tensor, LinearOperator], # shape: (..., P, M) or (..., M) - ) -> Union[Tensor, LinearOperator]: # shape: (..., P, N) or (..., N) + other: Tensor | LinearOperator, # shape: (..., P, M) or (..., M) + ) -> Tensor | LinearOperator: # shape: (..., P, N) or (..., N) return self.rmatmul(other) @_implements_second_arg(torch.Tensor.mul) def __mul__( self: LinearOperator, # shape: (*batch, #M, #N) - other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + other: torch.Tensor | LinearOperator | float, # shape: (*batch2, #M, #N) ) -> LinearOperator: # shape: (..., M, N) return self.mul(other) @_implements_second_arg(torch.Tensor.add) def __radd__( self: LinearOperator, # shape: (*batch, #M, #N) - other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + other: torch.Tensor | LinearOperator | float, # shape: (*batch2, #M, #N) ) -> LinearOperator: # shape: (..., M, N) return self + other def __rmul__( self: LinearOperator, # shape: (*batch, #M, #N) - other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + other: torch.Tensor | LinearOperator | float, # shape: (*batch2, #M, #N) ) -> LinearOperator: # shape: (..., M, N) return self.mul(other) @@ -2962,13 +2962,13 @@ def __rmul__( @_implements_second_arg(torch.Tensor.sub) def __rsub__( self: LinearOperator, # shape: (*batch, #M, #N) - other: Union[torch.Tensor, LinearOperator, float], # shape: (*batch2, #M, #N) + other: torch.Tensor | LinearOperator | float, # shape: (*batch2, #M, #N) ) -> LinearOperator: # shape: (..., M, N) return self.mul(-1) + other @classmethod def __torch_function__( - cls, func: Callable, types: Tuple[type, ...], args: Tuple[Any, ...] = (), kwargs: Dict[str, Any] = None + cls, func: Callable, types: tuple[type, ...], args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None ) -> Any: if kwargs is None: kwargs = {} @@ -2996,7 +2996,7 @@ def __torch_function__( func = getattr(cls, _HANDLED_FUNCTIONS[func]) return func(*args, **kwargs) - def __truediv__(self, other: Union[torch.Tensor, float]) -> LinearOperator: + def __truediv__(self, other: torch.Tensor | float) -> LinearOperator: return self.div(other) @@ -3008,7 +3008,7 @@ def _import_dotted_name(name: str): return obj -def to_dense(obj: Union[LinearOperator, Tensor]) -> Tensor: +def to_dense(obj: LinearOperator | Tensor) -> Tensor: r""" A function which ensures that `obj` is a (normal) Tensor. - If `obj` is a Tensor, this function does nothing. diff --git a/linear_operator/operators/added_diag_linear_operator.py b/linear_operator/operators/added_diag_linear_operator.py index a75342da..1fb4728a 100644 --- a/linear_operator/operators/added_diag_linear_operator.py +++ b/linear_operator/operators/added_diag_linear_operator.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable import torch from torch import Tensor @@ -35,8 +35,8 @@ class AddedDiagLinearOperator(SumLinearOperator): def __init__( self, - *linear_ops: Union[Tuple[LinearOperator, DiagLinearOperator], Tuple[DiagLinearOperator, LinearOperator]], - preconditioner_override: Optional[Callable] = None, + *linear_ops: tuple[LinearOperator, DiagLinearOperator] | tuple[DiagLinearOperator, LinearOperator], + preconditioner_override: Callable | None = None, ): linear_ops = list(linear_ops) super(AddedDiagLinearOperator, self).__init__(*linear_ops, preconditioner_override=preconditioner_override) @@ -83,8 +83,8 @@ def add_diagonal( def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): @@ -92,7 +92,7 @@ def __add__( else: return self.__class__(self._linear_op + other, self._diag_tensor) - def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], Optional[torch.Tensor]]: + def _preconditioner(self) -> tuple[Callable | None, LinearOperator | None, torch.Tensor | None]: r""" Here we use a partial pivoted Cholesky preconditioner: @@ -158,7 +158,7 @@ def _init_cache(self): self._precond_lt = PsdSumLinearOperator(RootLinearOperator(self._piv_chol_self), self._diag_tensor) - def _init_cache_for_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int, k: int): + def _init_cache_for_constant_diag(self, eye: Tensor, batch_shape: torch.Size | list[int], n: int, k: int): # We can factor out the noise for for both QR and solves. self._noise = self._noise.narrow(-2, 0, 1) self._q_cache, self._r_cache = torch.linalg.qr( @@ -171,7 +171,7 @@ def _init_cache_for_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Si logdet = logdet + (n - k) * self._noise.squeeze(-2).squeeze(-1).log() self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() - def _init_cache_for_non_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int): + def _init_cache_for_non_constant_diag(self, eye: Tensor, batch_shape: torch.Size | list[int], n: int): # With non-constant diagonals, we cant factor out the noise as easily self._q_cache, self._r_cache = torch.linalg.qr( torch.cat((self._piv_chol_self / self._noise.sqrt(), eye), dim=-2) @@ -186,7 +186,7 @@ def _init_cache_for_non_constant_diag(self, eye: Tensor, batch_shape: Union[torc @cached(name="svd") def _svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) if isinstance(self._diag_tensor, ConstantDiagLinearOperator): U, S_, V = self._linear_op.svd() S = S_ + self._diag_tensor._diagonal() @@ -196,8 +196,8 @@ def _svd( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) if isinstance(self._diag_tensor, ConstantDiagLinearOperator): evals_, evecs = self._linear_op._symeig(eigenvectors=eigenvectors) evals = evals_ + self._diag_tensor._diagonal() diff --git a/linear_operator/operators/batch_repeat_linear_operator.py b/linear_operator/operators/batch_repeat_linear_operator.py index 333e8f5c..4dc40f54 100644 --- a/linear_operator/operators/batch_repeat_linear_operator.py +++ b/linear_operator/operators/batch_repeat_linear_operator.py @@ -2,7 +2,6 @@ from __future__ import annotations import itertools -from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -39,7 +38,7 @@ def __init__(self, base_linear_op, batch_repeat=torch.Size((1,))): @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator @@ -49,9 +48,9 @@ def _cholesky( def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) # TODO: Figure out how to deal with this with TriangularLinearOperator if returned by _cholesky output_shape = _matmul_broadcast_shape(self.shape, rhs.shape) if rhs.shape != output_shape: @@ -63,7 +62,7 @@ def _cholesky_solve( return res def _compute_batch_repeat_size( - self, current_batch_shape: Union[torch.Size, List[int]], desired_batch_shape: Union[torch.Size, List[int]] + self, current_batch_shape: torch.Size | list[int], desired_batch_shape: torch.Size | list[int] ) -> torch.Size: batch_repeat = torch.Size( desired_batch_size // current_batch_size @@ -72,7 +71,7 @@ def _compute_batch_repeat_size( return batch_repeat def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) padding_dims = torch.Size(tuple(1 for _ in range(max(len(batch_shape) + 2 - self.base_linear_op.dim(), 0)))) current_batch_shape = padding_dims + self.base_linear_op.batch_shape @@ -197,7 +196,7 @@ def _permute_batch(self, *dims: int) -> LinearOperator: res = self.__class__(self.base_linear_op._permute_batch(*dims), batch_repeat=new_batch_repeat) return res - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: if self.is_square: left_output_shape = _matmul_broadcast_shape(self.shape, left_vecs.shape) if left_output_shape != left_vecs.shape: @@ -216,14 +215,14 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) return self.base_linear_op._root_decomposition().repeat(*self.batch_repeat, 1, 1) def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) return self.base_linear_op._root_inv_decomposition().repeat(*self.batch_repeat, 1, 1) def _size(self) -> torch.Size: @@ -257,12 +256,12 @@ def add_jitter( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if not self.is_square: raise RuntimeError( @@ -302,7 +301,7 @@ def inv_quad_logdet( return inv_quad_term, logdet_term - def repeat(self, *sizes: Union[int, Tuple[int, ...]]) -> LinearOperator: + def repeat(self, *sizes: int | tuple[int, ...]) -> LinearOperator: if len(sizes) < 3 or tuple(sizes[-2:]) != (1, 1): raise RuntimeError( "Invalid repeat arguments {}. Currently, repeat only works to create repeated " @@ -321,7 +320,7 @@ def repeat(self, *sizes: Union[int, Tuple[int, ...]]) -> LinearOperator: @cached(name="svd") def _svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) U_, S_, V_ = self.base_linear_op.svd() U = U_.repeat(*self.batch_repeat, 1, 1) S = S_.repeat(*self.batch_repeat, 1) @@ -331,8 +330,8 @@ def _svd( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) evals, evecs = self.base_linear_op._symeig(eigenvectors=eigenvectors) evals = evals.repeat(*self.batch_repeat, 1) if eigenvectors: diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index a523abc8..05a654fe 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +from __future__ import annotations from abc import ABCMeta -from typing import Callable, Optional, Tuple, Union +from typing import Callable import torch from torch import Tensor @@ -17,7 +18,7 @@ # _MetaBlockDiagLinearOperator(base_linear_op, block_dim=-3) to return a DiagLinearOperator # if base_linear_op is a DiagLinearOperator itself class _MetaBlockDiagLinearOperator(ABCMeta): - def __call__(cls, base_linear_op: Union[LinearOperator, Tensor], block_dim: int = -3): + def __call__(cls, base_linear_op: LinearOperator | Tensor, block_dim: int = -3): from linear_operator.operators.diag_linear_operator import DiagLinearOperator if cls is BlockDiagLinearOperator and isinstance(base_linear_op, DiagLinearOperator): @@ -76,7 +77,7 @@ def _add_batch_dim( @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator @@ -85,9 +86,9 @@ def _cholesky( def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) rhs = self._add_batch_dim(rhs) res = self.base_linear_op._cholesky_solve(rhs, upper=upper) res = self._remove_batch_dim(res) @@ -123,14 +124,14 @@ def _remove_batch_dim(self, other): def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_decomposition()) def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_inv_decomposition(initial_vectors)) def _size(self) -> torch.Size: @@ -143,15 +144,15 @@ def _size(self) -> torch.Size: def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): if num_tridiag: return super()._solve(rhs, preconditioner, num_tridiag=num_tridiag) else: @@ -162,12 +163,12 @@ def _solve( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if inv_quad_rhs is not None: inv_quad_rhs = self._add_batch_dim(inv_quad_rhs) @@ -187,8 +188,8 @@ def inv_quad_logdet( def matmul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) + other: Tensor | LinearOperator, # shape: (*batch2, N, P) or (*batch2, N) + ) -> Tensor | LinearOperator: # shape: (..., M, P) or (..., M) from linear_operator.operators.diag_linear_operator import DiagLinearOperator # this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes @@ -205,7 +206,7 @@ def matmul( @cached(name="svd") def _svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) U, S, V = self.base_linear_op.svd() # Doesn't make much sense to sort here, o/w we lose the structure S = S.reshape(*S.shape[:-2], S.shape[-2:].numel()) @@ -217,8 +218,8 @@ def _svd( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) evals, evecs = self.base_linear_op._symeig(eigenvectors=eigenvectors) # Doesn't make much sense to sort here, o/w we lose the structure evals = evals.reshape(*evals.shape[:-2], evals.shape[-2:].numel()) diff --git a/linear_operator/operators/block_interleaved_linear_operator.py b/linear_operator/operators/block_interleaved_linear_operator.py index 6b077936..69f0429a 100644 --- a/linear_operator/operators/block_interleaved_linear_operator.py +++ b/linear_operator/operators/block_interleaved_linear_operator.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 -from typing import Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable import torch from torch import Tensor @@ -40,7 +42,7 @@ def _add_batch_dim(self, other): @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator @@ -49,9 +51,9 @@ def _cholesky( def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) rhs = self._add_batch_dim(rhs) res = self.base_linear_op._cholesky_solve(rhs, upper=upper) res = self._remove_batch_dim(res) @@ -88,14 +90,14 @@ def _remove_batch_dim(self, other): def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_decomposition()) def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) return self.__class__(self.base_linear_op._root_inv_decomposition(initial_vectors)) def _size(self) -> torch.Size: @@ -108,15 +110,15 @@ def _size(self) -> torch.Size: def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): if num_tridiag: return super()._solve(rhs, preconditioner, num_tridiag=num_tridiag) else: @@ -127,12 +129,12 @@ def _solve( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if inv_quad_rhs is not None: inv_quad_rhs = self._add_batch_dim(inv_quad_rhs) diff --git a/linear_operator/operators/block_linear_operator.py b/linear_operator/operators/block_linear_operator.py index 93903ef3..4a8e9e51 100644 --- a/linear_operator/operators/block_linear_operator.py +++ b/linear_operator/operators/block_linear_operator.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations from abc import abstractmethod -from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -59,7 +59,7 @@ def _add_batch_dim(self, other): raise NotImplementedError def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) batch_shape = torch.Size((*batch_shape, self.base_linear_op.size(-3))) res = self.__class__(self.base_linear_op._expand_batch(batch_shape)) @@ -117,7 +117,7 @@ def _matmul( res = res.squeeze(-1) return res - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: if left_vecs.ndim == 1: left_vecs = left_vecs.unsqueeze(-1) right_vecs = right_vecs.unsqueeze(-1) @@ -150,7 +150,7 @@ def _remove_batch_dim(self, other): raise NotImplementedError def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) # We're using a custom method here - the constant mul is applied to the base_lazy tensor # This preserves the block structure diff --git a/linear_operator/operators/cat_linear_operator.py b/linear_operator/operators/cat_linear_operator.py index b16574a6..7a44209e 100644 --- a/linear_operator/operators/cat_linear_operator.py +++ b/linear_operator/operators/cat_linear_operator.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple, Union +from typing import Sequence import torch from torch import Tensor @@ -103,7 +103,7 @@ def __init__(self, *linear_ops, dim=0, output_device=None): (*rep_tensor.shape[:positive_dim], cat_dim_cum_sizes[-1].item(), *rep_tensor.shape[positive_dim + 1 :]) ) - def _split_slice(self, slice_idx: slice) -> Tuple[Sequence[int], List[slice]]: + def _split_slice(self, slice_idx: slice) -> tuple[Sequence[int], list[slice]]: """ Splits a slice(a, b, None) in to a list of slices [slice(a1, b1, None), slice(a2, b2, None), ...] so that each slice in the list slices in to a single tensor that we have concatenated with this LinearOperator. @@ -161,7 +161,7 @@ def _diagonal( return res def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) batch_dim = self.cat_dim + 2 if batch_dim < 0: @@ -390,22 +390,22 @@ def to_dense( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on res = super().inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad) return tuple(r.to(self.device) for r in res) @property - def device(self) -> Optional[torch.device]: + def device(self) -> torch.device | None: return self.output_device @property - def devices(self) -> List[torch.device]: + def devices(self) -> list[torch.device]: return [t.device for t in self.linear_ops] @property diff --git a/linear_operator/operators/chol_linear_operator.py b/linear_operator/operators/chol_linear_operator.py index 2983b9b0..697e784f 100644 --- a/linear_operator/operators/chol_linear_operator.py +++ b/linear_operator/operators/chol_linear_operator.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings -from typing import Callable, Optional, Tuple, Union +from typing import Callable import torch from torch import Tensor @@ -53,7 +53,7 @@ def _chol_diag( @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) if upper == self.upper: return self.root @@ -70,15 +70,15 @@ def _diagonal( def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): if num_tridiag: return super()._solve(rhs, preconditioner, num_tridiag=num_tridiag) return self.root._cholesky_solve(rhs, upper=self.upper) @@ -120,12 +120,12 @@ def inv_quad( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if not self.is_square: raise RuntimeError( @@ -166,17 +166,17 @@ def inv_quad_logdet( def root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - method: Optional[str] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + method: str | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) inv_root = self.root.inverse() return RootLinearOperator(inv_root._transpose_nonbatch()) def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) is_vector = right_tensor.ndim == 1 if is_vector: diff --git a/linear_operator/operators/constant_mul_linear_operator.py b/linear_operator/operators/constant_mul_linear_operator.py index 994e60e5..2caf74a2 100644 --- a/linear_operator/operators/constant_mul_linear_operator.py +++ b/linear_operator/operators/constant_mul_linear_operator.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import List, Optional, Tuple, Union - import torch from torch import Tensor @@ -83,7 +81,7 @@ def _diagonal( return res * self._constant.unsqueeze(-1) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.base_linear_op._expand_batch(batch_shape), @@ -123,7 +121,7 @@ def _permute_batch(self, *dims: int) -> LinearOperator: self.base_linear_op._permute_batch(*dims), self._constant.expand(self.batch_shape).permute(*dims) ) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: # Gradient with respect to the constant constant_deriv = left_vecs * self.base_linear_op._matmul(right_vecs) constant_deriv = constant_deriv.sum(-2).sum(-1) @@ -144,8 +142,8 @@ def _size(self) -> torch.Size: def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) res = self.base_linear_op._t_matmul(rhs) res = res * self.expanded_constant return res @@ -184,7 +182,7 @@ def to_dense( @cached(name="root_decomposition") def root_decomposition( - self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + self: LinearOperator, method: str | None = None # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) if torch.all(self._constant >= 0): base_root = self.base_linear_op.root_decomposition(method=method).root diff --git a/linear_operator/operators/dense_linear_operator.py b/linear_operator/operators/dense_linear_operator.py index 739a5baf..204d55af 100644 --- a/linear_operator/operators/dense_linear_operator.py +++ b/linear_operator/operators/dense_linear_operator.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import List, Optional, Tuple, Union - import torch from torch import Tensor @@ -31,9 +29,9 @@ def __init__(self, tsr): def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) return torch.cholesky_solve(rhs, self.to_dense(), upper=upper) def _diagonal( @@ -42,7 +40,7 @@ def _diagonal( return self.tensor.diagonal(dim1=-1, dim2=-2) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.tensor.expand(*batch_shape, *self.matrix_shape)) @@ -68,7 +66,7 @@ def _matmul( def _prod_batch(self, dim: int) -> LinearOperator: return self.__class__(self.tensor.prod(dim)) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: res = left_vecs.matmul(right_vecs.mT) return (res,) @@ -85,8 +83,8 @@ def _transpose_nonbatch( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) return torch.matmul(self.tensor.mT, rhs) def to_dense( @@ -96,8 +94,8 @@ def to_dense( def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) if isinstance(other, DenseLinearOperator): return DenseLinearOperator(self.tensor + other.tensor) elif isinstance(other, torch.Tensor): @@ -106,19 +104,20 @@ def __add__( return super().__add__(other) -def to_linear_operator(obj: Union[torch.Tensor, LinearOperator]) -> LinearOperator: +def to_linear_operator(obj: torch.Tensor | LinearOperator) -> LinearOperator: """ A function which ensures that `obj` is a LinearOperator. - If `obj` is a LinearOperator, this function does nothing. - If `obj` is a (normal) Tensor, this function wraps it with a `DenseLinearOperator`. """ - if torch.is_tensor(obj): - return DenseLinearOperator(obj) - elif isinstance(obj, LinearOperator): - return obj - else: - raise TypeError("object of class {} cannot be made into a LinearOperator".format(obj.__class__.__name__)) + match obj: + case _ if torch.is_tensor(obj): + return DenseLinearOperator(obj) + case LinearOperator(): + return obj + case _: + raise TypeError("object of class {} cannot be made into a LinearOperator".format(obj.__class__.__name__)) __all__ = ["DenseLinearOperator", "to_linear_operator"] diff --git a/linear_operator/operators/diag_linear_operator.py b/linear_operator/operators/diag_linear_operator.py index b4d56789..3fcc576f 100644 --- a/linear_operator/operators/diag_linear_operator.py +++ b/linear_operator/operators/diag_linear_operator.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import List, Optional, Tuple, Union - import torch from torch import Tensor @@ -28,15 +26,15 @@ def __init__(self, diag: Tensor): def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) if isinstance(other, DiagLinearOperator): return self.add_diagonal(other._diag) from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator return AddedDiagLinearOperator(other, self) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: # TODO: Use proper batching for input vectors (prepend to shape rather than append) if not self._diag.requires_grad: return (None,) @@ -48,19 +46,19 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O @cached(name="cholesky", ignore_args=True) def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) return self.sqrt() def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) return rhs / self._diag.unsqueeze(-1).pow(2) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self._diag.expand(*batch_shape, self._diag.size(-1))) @@ -80,13 +78,13 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice return res def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) return self.__class__(self._diag * other.unsqueeze(-1)) def _mul_matrix( self: LinearOperator, # shape: (..., #M, #N) - other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + other: torch.Tensor | LinearOperator, # shape: (..., #M, #N) ) -> LinearOperator: # shape: (..., M, N) return DiagLinearOperator(self._diag * other._diagonal()) @@ -95,14 +93,14 @@ def _prod_batch(self, dim: int) -> LinearOperator: def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) return self.sqrt() def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) return self.inverse().sqrt() def _size(self) -> torch.Size: @@ -113,8 +111,8 @@ def _sum_batch(self, dim: int) -> LinearOperator: def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) # Diagonal matrices always commute return self._matmul(rhs) @@ -162,12 +160,12 @@ def inverse( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on # TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append) if inv_quad_rhs is None: @@ -204,29 +202,26 @@ def log( # a MatmulLinearOperator is created. def matmul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) - if isinstance(other, Tensor): - diag = self._diag if other.ndim == 1 else self._diag.unsqueeze(-1) - return diag * other - - if isinstance(other, DenseLinearOperator): - return DenseLinearOperator(self @ other.tensor) - - if isinstance(other, DiagLinearOperator): - return DiagLinearOperator(self._diag * other._diag) - - if isinstance(other, TriangularLinearOperator): - return TriangularLinearOperator(self @ other._tensor, upper=other.upper) - - if isinstance(other, BlockDiagLinearOperator): - diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1]) - diag = DiagLinearOperator(diag_reshape) - # using matmul here avoids having to implement special case of elementwise multiplication - # with block diagonal operator, which itself has special cases for vectors and matrices - return BlockDiagLinearOperator(diag @ other.base_linear_op) - - return super().matmul(other) # happens with other structured linear operators + other: Tensor | LinearOperator, # shape: (*batch2, N, P) or (*batch2, N) + ) -> Tensor | LinearOperator: # shape: (..., M, P) or (..., M) + match other: + case Tensor(): + diag = self._diag if other.ndim == 1 else self._diag.unsqueeze(-1) + return diag * other + case DenseLinearOperator(): + return DenseLinearOperator(self @ other.tensor) + case DiagLinearOperator(): + return DiagLinearOperator(self._diag * other._diag) + case TriangularLinearOperator(): + return TriangularLinearOperator(self @ other._tensor, upper=other.upper) + case BlockDiagLinearOperator(): + diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1]) + diag = DiagLinearOperator(diag_reshape) + # using matmul here avoids having to implement special case of elementwise multiplication + # with block diagonal operator, which itself has special cases for vectors and matrices + return BlockDiagLinearOperator(diag @ other.base_linear_op) + case _: + return super().matmul(other) # happens with other structured linear operators def _matmul( self: LinearOperator, # shape: (*batch, M, N) @@ -237,7 +232,7 @@ def _matmul( def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) res = self.inverse()._matmul(right_tensor) if left_tensor is not None: @@ -265,8 +260,8 @@ def sqrt( def sqrt_inv_matmul( self: LinearOperator, # shape: (*batch, N, N) rhs: Tensor, # shape: (*batch, N, P) - lhs: Optional[Tensor] = None, # shape: (*batch, O, N) - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) + lhs: Tensor | None = None, # shape: (*batch, O, N) + ) -> Tensor | tuple[Tensor, Tensor]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) matrix_inv_root = self._root_inv_decomposition() if lhs is None: return matrix_inv_root.matmul(rhs) @@ -284,7 +279,7 @@ def zero_mean_mvn_samples( @cached(name="svd") def _svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) evals, evecs = self._symeig(eigenvectors=True) S = torch.abs(evals) U = evecs @@ -294,8 +289,8 @@ def _svd( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) evals = self._diag if eigenvectors: diag_values = torch.ones(evals.shape[:-1], device=evals.device, dtype=evals.dtype).unsqueeze(-1) @@ -328,8 +323,8 @@ def __init__(self, diag_values: torch.Tensor, diag_shape: int): def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) if isinstance(other, ConstantDiagLinearOperator): if other.shape[-1] == self.shape[-1]: return ConstantDiagLinearOperator(self.diag_values + other.diag_values, self.diag_shape) @@ -339,7 +334,7 @@ def __add__( ) return super().__add__(other) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: # TODO: Use proper batching for input vectors (prepand to shape rathern than append) if not self.diag_values.requires_grad: return (None,) @@ -355,18 +350,18 @@ def _diag( return self.diag_values.expand(*self.diag_values.shape[:-1], self.diag_shape) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.diag_values.expand(*batch_shape, 1), diag_shape=self.diag_shape) def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) return self.__class__(self.diag_values * other, diag_shape=self.diag_shape) def _mul_matrix( self: LinearOperator, # shape: (..., #M, #N) - other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + other: torch.Tensor | LinearOperator, # shape: (..., #M, #N) ) -> LinearOperator: # shape: (..., M, N) if isinstance(other, ConstantDiagLinearOperator): if not self.diag_shape == other.diag_shape: @@ -419,8 +414,8 @@ def log( def matmul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) + other: Tensor | LinearOperator, # shape: (*batch2, N, P) or (*batch2, N) + ) -> Tensor | LinearOperator: # shape: (..., M, P) or (..., M) if isinstance(other, ConstantDiagLinearOperator): return self._mul_matrix(other) return super().matmul(other) diff --git a/linear_operator/operators/identity_linear_operator.py b/linear_operator/operators/identity_linear_operator.py index caee46cf..54bf413e 100644 --- a/linear_operator/operators/identity_linear_operator.py +++ b/linear_operator/operators/identity_linear_operator.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import List, Optional, Tuple, Union - import torch from torch import Tensor @@ -29,9 +27,9 @@ class IdentityLinearOperator(ConstantDiagLinearOperator): def __init__( self, diag_shape: int, - batch_shape: Optional[torch.Size] = torch.Size([]), - dtype: Optional[torch.dtype] = torch.float, - device: Optional[torch.device] = None, + batch_shape: torch.Size | None = torch.Size([]), + dtype: torch.dtype | None = torch.float, + device: torch.device | None = None, ): one = torch.tensor(1.0, dtype=dtype, device=device) LinearOperator.__init__(self, diag_shape=diag_shape, batch_shape=batch_shape, dtype=dtype, device=device) @@ -46,14 +44,14 @@ def batch_shape(self) -> torch.Size: return self._batch_shape @property - def dtype(self) -> Optional[torch.dtype]: + def dtype(self) -> torch.dtype | None: return self._dtype @property - def device(self) -> Optional[torch.device]: + def device(self) -> torch.device | None: return self._device - def _maybe_reshape_rhs(self, rhs: Union[torch.Tensor, LinearOperator]) -> Union[torch.Tensor, LinearOperator]: + def _maybe_reshape_rhs(self, rhs: torch.Tensor | LinearOperator) -> torch.Tensor | LinearOperator: if self._batch_shape != rhs.shape[:-2]: batch_shape = torch.broadcast_shapes(rhs.shape[:-2], self._batch_shape) return rhs.expand(*batch_shape, *rhs.shape[-2:]) @@ -62,19 +60,19 @@ def _maybe_reshape_rhs(self, rhs: Union[torch.Tensor, LinearOperator]) -> Union[ @cached(name="cholesky", ignore_args=True) def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) return self def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) return self._maybe_reshape_rhs(rhs) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return IdentityLinearOperator( diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device @@ -101,13 +99,13 @@ def _matmul( return self._maybe_reshape_rhs(rhs) def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) return ConstantDiagLinearOperator(self.diag_values * other, diag_shape=self.diag_shape) def _mul_matrix( self: LinearOperator, # shape: (..., #M, #N) - other: Union[torch.Tensor, LinearOperator], # shape: (..., #M, #N) + other: torch.Tensor | LinearOperator, # shape: (..., #M, #N) ) -> LinearOperator: # shape: (..., M, N) return other @@ -126,14 +124,14 @@ def _prod_batch(self, dim: int) -> LinearOperator: def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) return self.sqrt() def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) return self.inverse().sqrt() def _size(self) -> torch.Size: @@ -142,20 +140,20 @@ def _size(self) -> torch.Size: @cached(name="svd") def _svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) return self, self._diag, self def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) return self._diag, self def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) return self._maybe_reshape_rhs(rhs) def _transpose_nonbatch( @@ -186,12 +184,12 @@ def inverse( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on # TODO: Use proper batching for inv_quad_rhs (prepand to shape rather than append) if inv_quad_rhs is None: @@ -218,8 +216,8 @@ def log( def matmul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) + other: Tensor | LinearOperator, # shape: (*batch2, N, P) or (*batch2, N) + ) -> Tensor | LinearOperator: # shape: (..., M, P) or (..., M) is_vec = False if other.dim() == 1: is_vec = True @@ -232,7 +230,7 @@ def matmul( def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) res = self._maybe_reshape_rhs(right_tensor) if left_tensor is not None: @@ -247,8 +245,8 @@ def sqrt( def sqrt_inv_matmul( self: LinearOperator, # shape: (*batch, N, N) rhs: Tensor, # shape: (*batch, N, P) - lhs: Optional[Tensor] = None, # shape: (*batch, O, N) - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) + lhs: Tensor | None = None, # shape: (*batch, O, N) + ) -> Tensor | tuple[Tensor, Tensor]: # shape: (*batch, N, P), (*batch, O, P), (*batch, O) if lhs is None: return self._maybe_reshape_rhs(rhs) else: diff --git a/linear_operator/operators/interpolated_linear_operator.py b/linear_operator/operators/interpolated_linear_operator.py index 89455dde..5a60ecb1 100644 --- a/linear_operator/operators/interpolated_linear_operator.py +++ b/linear_operator/operators/interpolated_linear_operator.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import List, Optional, Tuple, Union - import torch from torch import Tensor @@ -119,7 +117,7 @@ def _diagonal( return super(InterpolatedLinearOperator, self)._diagonal() def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.base_linear_op._expand_batch(batch_shape), @@ -221,7 +219,7 @@ def _matmul( return res def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) # We're using a custom method here - the constant mul is applied to the base_lazy tensor # This preserves the interpolated structure @@ -235,8 +233,8 @@ def _mul_constant( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) # Get sparse tensor representations of left/right interp matrices left_interp_t = self._sparse_left_interp_t(self.left_interp_indices, self.left_interp_values) right_interp_t = self._sparse_right_interp_t(self.right_interp_indices, self.right_interp_values) @@ -262,7 +260,7 @@ def _t_matmul( res = res.squeeze(-1) return res - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: # Get sparse tensor representations of left/right interp matrices left_interp_t = self._sparse_left_interp_t(self.left_interp_indices, self.left_interp_values) right_interp_t = self._sparse_right_interp_t(self.right_interp_indices, self.right_interp_values) @@ -414,8 +412,8 @@ def _sum_batch(self, dim: int) -> LinearOperator: def matmul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) + other: Tensor | LinearOperator, # shape: (*batch2, N, P) or (*batch2, N) + ) -> Tensor | LinearOperator: # shape: (..., M, P) or (..., M) # We're using a custom matmul here, because it is significantly faster than # what we get from the function factory. # The _matmul_closure is optimized for repeated calls, such as for _solve diff --git a/linear_operator/operators/keops_linear_operator.py b/linear_operator/operators/keops_linear_operator.py index 0f1981cf..daa77673 100644 --- a/linear_operator/operators/keops_linear_operator.py +++ b/linear_operator/operators/keops_linear_operator.py @@ -2,8 +2,6 @@ import warnings -from typing import Optional, Tuple - import torch from torch import Tensor @@ -102,7 +100,7 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I # Now construct a kernel with those indices return self.__class__(x1, x2, covar_func=self.covar_func, **self.params) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: """ Use default behavior, but KeOps does not automatically make args contiguous like torch.matmul. diff --git a/linear_operator/operators/kernel_linear_operator.py b/linear_operator/operators/kernel_linear_operator.py index 947acc6d..43fd6dfa 100644 --- a/linear_operator/operators/kernel_linear_operator.py +++ b/linear_operator/operators/kernel_linear_operator.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from collections import defaultdict -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable import torch @@ -133,10 +135,10 @@ def __init__( self, x1: Tensor, # shape: (..., M, D) x2: Tensor, # shape: (..., N, D) - covar_func: Callable[..., Union[Tensor, LinearOperator]], # shape: (..., M, N) - num_outputs_per_input: Tuple[int, int] = (1, 1), - num_nonbatch_dimensions: Optional[Dict[str, int]] = None, - **params: Union[Tensor, Any], + covar_func: Callable[..., Tensor | LinearOperator], # shape: (..., M, N) + num_outputs_per_input: tuple[int, int] = (1, 1), + num_nonbatch_dimensions: dict[str, int] | None = None, + **params: Tensor | Any, ): # Change num_nonbatch_dimensions into a default dict if num_nonbatch_dimensions is None: @@ -251,7 +253,7 @@ def _diagonal( @cached(name="covar_mat") def covar_mat( self: LinearOperator, # shape: (..., M, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, N) + ) -> Tensor | LinearOperator: # shape: (..., M, N) return self.covar_func(self.x1, self.x2, **self.tensor_params, **self.nontensor_params) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: diff --git a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py index d362967c..7c3cb40a 100644 --- a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py +++ b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations -from typing import Callable, Optional, Tuple, Union +from typing import Callable import torch from torch import Tensor @@ -64,12 +65,12 @@ def __init__(self, *linear_ops, preconditioner_override=None): def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if inv_quad_rhs is not None: inv_quad_term, _ = super().inv_quad_logdet( @@ -126,22 +127,22 @@ def _logdet( return super().inv_quad_logdet(logdet=True)[1] - def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], Optional[torch.Tensor]]: + def _preconditioner(self) -> tuple[Callable | None, LinearOperator | None, torch.Tensor | None]: # solves don't use CG so don't waste time computing it return None, None, None def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): rhs_dtype = rhs.dtype @@ -224,7 +225,7 @@ def _solve( def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) if self._diag_is_constant: evals, q_matrix = self.linear_op.diagonalization() updated_evals = DiagLinearOperator((evals + self.diag_tensor._diagonal()).pow(0.5)) @@ -257,9 +258,9 @@ def _root_decomposition( def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) if self._diag_is_constant: evals, q_matrix = self.linear_op.diagonalization() inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor._diagonal()).pow(-0.5)) @@ -293,8 +294,8 @@ def _root_inv_decomposition( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) # return_evals_as_lazy is a flag to return the eigenvalues as a lazy tensor # which is useful for root decompositions here (see the root_decomposition # method above) @@ -307,8 +308,8 @@ def _symeig( def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) if isinstance(other, ConstantDiagLinearOperator) and self._diag_is_constant: # the other cases have only partial implementations return KroneckerProductAddedDiagLinearOperator(self.linear_op, self.diag_tensor + other) diff --git a/linear_operator/operators/kronecker_product_linear_operator.py b/linear_operator/operators/kronecker_product_linear_operator.py index f76fe37e..32c6e055 100644 --- a/linear_operator/operators/kronecker_product_linear_operator.py +++ b/linear_operator/operators/kronecker_product_linear_operator.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 +from __future__ import annotations import operator from functools import reduce -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable import torch from torch import Tensor @@ -67,7 +68,7 @@ class KroneckerProductLinearOperator(LinearOperator): :param linear_ops: :math:`\boldsymbol K_1, \ldots, \boldsymbol K_P`: the LinearOperators in the Kronecker product. """ - def __init__(self, *linear_ops: Union[Tensor, LinearOperator]): + def __init__(self, *linear_ops: Tensor | LinearOperator): try: linear_ops = tuple(to_linear_operator(linear_op) for linear_op in linear_ops) except TypeError: @@ -96,8 +97,8 @@ def __init__(self, *linear_ops: Union[Tensor, LinearOperator]): def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) if isinstance(other, (KroneckerProductDiagLinearOperator, ConstantDiagLinearOperator)): from linear_operator.operators.kronecker_product_added_diag_linear_operator import ( KroneckerProductAddedDiagLinearOperator, @@ -144,8 +145,8 @@ def add_diagonal( return KroneckerProductAddedDiagLinearOperator(self, diag_tensor) def diagonalization( - self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + self: LinearOperator, method: str | None = None # shape: (*batch, N, N) + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) if method is None: method = "symeig" return super().diagonalization(method=method) @@ -161,12 +162,12 @@ def inverse( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if inv_quad_rhs is not None: inv_quad_term, _ = super().inv_quad_logdet( @@ -179,7 +180,7 @@ def inv_quad_logdet( @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) chol_factors = [lt.cholesky(upper=upper) for lt in self.linear_ops] return KroneckerProductTriangularLinearOperator(*chol_factors, upper=upper) @@ -190,7 +191,7 @@ def _diagonal( return _kron_diag(*self.linear_ops) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(*[linear_op._expand_batch(batch_shape) for linear_op in self.linear_ops]) @@ -217,15 +218,15 @@ def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indice def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): # Computes solve by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1 # we perform the solve first before worrying about any tridiagonal matrices @@ -284,7 +285,7 @@ def _matmul( @cached(name="root_decomposition") def root_decomposition( - self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + self: LinearOperator, method: str | None = None # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) from linear_operator.operators import RootLinearOperator @@ -299,10 +300,10 @@ def root_decomposition( @cached(name="root_inv_decomposition") def root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - method: Optional[str] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + method: str | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) from linear_operator.operators import RootLinearOperator # return a dense root decomposition if the matrix is small @@ -322,7 +323,7 @@ def _size(self) -> torch.Size: @cached(name="svd") def _svd( self: LinearOperator, # shape: (*batch, N, N) - ) -> Tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) + ) -> tuple[LinearOperator, Tensor, LinearOperator]: # shape: (*batch, N, N), (..., N), (*batch, N, N) U, S, V = [], [], [] for lt in self.linear_ops: U_, S_, V_ = lt.svd() @@ -337,8 +338,8 @@ def _svd( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) # return_evals_as_lazy is a flag to return the eigenvalues as a lazy tensor # which is useful for root decompositions here (see the root_decomposition # method above) @@ -360,8 +361,8 @@ def _symeig( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) is_vec = rhs.ndimension() == 1 if is_vec: rhs = rhs.unsqueeze(-1) @@ -397,15 +398,15 @@ def inverse( @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) raise NotImplementedError("_cholesky not applicable to triangular lazy tensors") def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) if upper: # res = (U.T @ U)^-1 @ v = U^-1 @ U^-T @ v w = self._transpose_nonbatch().solve(rhs) @@ -419,14 +420,14 @@ def _cholesky_solve( def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) raise NotImplementedError("_symeig not applicable to triangular lazy tensors") def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) # For triangular components, using triangular-triangular substition should generally be good return self._inv_matmul(right_tensor=right_tensor, left_tensor=left_tensor) @@ -441,18 +442,18 @@ class KroneckerProductDiagLinearOperator(DiagLinearOperator, KroneckerProductTri :param linear_ops: Diagonal linear operators (:math:`\mathbf D_1, \mathbf D_2, \ldots \mathbf D_\ell`). """ - def __init__(self, *linear_ops: Tuple[DiagLinearOperator, ...]): + def __init__(self, *linear_ops: tuple[DiagLinearOperator, ...]): if not all(isinstance(lt, DiagLinearOperator) for lt in linear_ops): raise RuntimeError("Components of KroneckerProductDiagLinearOperator must be DiagLinearOperator.") super(KroneckerProductTriangularLinearOperator, self).__init__(*linear_ops) self.upper = False - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: return KroneckerProductTriangularLinearOperator._bilinear_derivative(self, left_vecs, right_vecs) @cached(name="cholesky") def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) chol_factors = [lt.cholesky(upper=upper) for lt in self.linear_ops] return KroneckerProductDiagLinearOperator(*chol_factors) @@ -464,12 +465,12 @@ def _diag( return _kron_diag(*self.linear_ops) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return KroneckerProductTriangularLinearOperator._expand_batch(self, batch_shape) def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) return DiagLinearOperator(self._diag * other.unsqueeze(-1)) @@ -485,8 +486,8 @@ def _size(self) -> torch.Size: def _symeig( self: LinearOperator, # shape: (*batch, N, N) eigenvectors: bool = False, - return_evals_as_lazy: Optional[bool] = False, - ) -> Tuple[Tensor, Optional[LinearOperator]]: # shape: (*batch, M), (*batch, N, M) + return_evals_as_lazy: bool | None = False, + ) -> tuple[Tensor, LinearOperator | None]: # shape: (*batch, M), (*batch, N, M) # return_evals_as_lazy is a flag to return the eigenvalues as a lazy tensor # which is useful for root decompositions here (see the root_decomposition # method above) diff --git a/linear_operator/operators/linear_operator_representation_tree.py b/linear_operator/operators/linear_operator_representation_tree.py index 838d0d63..126a2cc9 100644 --- a/linear_operator/operators/linear_operator_representation_tree.py +++ b/linear_operator/operators/linear_operator_representation_tree.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import itertools diff --git a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py index 541d9493..b81d2eff 100644 --- a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py +++ b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 -from typing import Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable import torch from torch import Tensor @@ -45,7 +47,7 @@ def chol_cap_mat(self): return chol_cap_mat def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) # We have to over-ride this here for the case where the constant is negative if other > 0: @@ -54,21 +56,21 @@ def _mul_constant( res = AddedDiagLinearOperator(self._linear_op._mul_constant(other), self._diag_tensor._mul_constant(other)) return res - def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], Optional[torch.Tensor]]: + def _preconditioner(self) -> tuple[Callable | None, LinearOperator | None, torch.Tensor | None]: return None, None, None def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): A_inv = self._diag_tensor.inverse() # This is fine since it's a DiagLinearOperator U = self._linear_op.root V = self._linear_op.root.mT @@ -82,7 +84,7 @@ def _solve( return solve - def _solve_preconditioner(self) -> Optional[Callable]: + def _solve_preconditioner(self) -> Callable | None: return None def _sum_batch(self, dim: int) -> LinearOperator: @@ -98,8 +100,8 @@ def _logdet(self): def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): @@ -109,12 +111,12 @@ def __add__( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if not self.is_square: raise RuntimeError( @@ -158,7 +160,7 @@ def inv_quad_logdet( def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) if not self.is_square: raise RuntimeError( diff --git a/linear_operator/operators/low_rank_root_linear_operator.py b/linear_operator/operators/low_rank_root_linear_operator.py index 8f5f0c84..b6d80f60 100644 --- a/linear_operator/operators/low_rank_root_linear_operator.py +++ b/linear_operator/operators/low_rank_root_linear_operator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import Union +from __future__ import annotations import torch from torch import Tensor @@ -51,8 +51,8 @@ def add_diagonal( def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator from linear_operator.operators.low_rank_root_added_diag_linear_operator import ( LowRankRootAddedDiagLinearOperator, diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index 9b213152..782a4313 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from __future__ import annotations import torch from torch import Tensor @@ -61,8 +61,8 @@ def _matmul( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) rhs_expanded = self._expand(rhs, self.row_mask) res_expanded = self.base._t_matmul(rhs_expanded) res = res_expanded[..., self.col_mask, :] @@ -92,13 +92,13 @@ def to_dense( full_dense = self.base.to_dense() return full_dense[..., self.row_mask, :][..., :, self.col_mask] - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: left_vecs = self._expand(left_vecs, self.row_mask) right_vecs = self._expand(right_vecs, self.col_mask) return self.base._bilinear_derivative(left_vecs, right_vecs) + (None, None) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.base._expand_batch(batch_shape), self.row_mask, self.col_mask) diff --git a/linear_operator/operators/matmul_linear_operator.py b/linear_operator/operators/matmul_linear_operator.py index 5f8ffe51..8204bc38 100644 --- a/linear_operator/operators/matmul_linear_operator.py +++ b/linear_operator/operators/matmul_linear_operator.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 - -from typing import List, Optional, Tuple, Union +from __future__ import annotations import torch from torch import Tensor @@ -46,7 +45,7 @@ def __init__(self, left_linear_op, right_linear_op): self.right_linear_op = right_linear_op def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.left_linear_op._expand_batch(batch_shape), self.right_linear_op._expand_batch(batch_shape) @@ -103,11 +102,11 @@ def _matmul( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) return self.right_linear_op._t_matmul(self.left_linear_op._t_matmul(rhs)) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: if left_vecs.ndimension() == 1: left_vecs = left_vecs.unsqueeze(1) right_vecs = right_vecs.unsqueeze(1) @@ -136,9 +135,10 @@ def to_dense( self: LinearOperator, # shape: (*batch, M, N) ) -> Tensor: # shape: (*batch, M, N) # Use element-wise multiplication for DiagLinearOperators - if isinstance(self.left_linear_op, DiagLinearOperator): - return self.left_linear_op._diag.unsqueeze(-1) * self.right_linear_op.to_dense() - if isinstance(self.right_linear_op, DiagLinearOperator): - return self.left_linear_op.to_dense() * self.right_linear_op._diag.unsqueeze(-2) - - return torch.matmul(self.left_linear_op.to_dense(), self.right_linear_op.to_dense()) + match (self.left_linear_op, self.right_linear_op): + case (DiagLinearOperator() as left, _): + return left._diag.unsqueeze(-1) * self.right_linear_op.to_dense() + case (_, DiagLinearOperator() as right): + return self.left_linear_op.to_dense() * right._diag.unsqueeze(-2) + case _: + return torch.matmul(self.left_linear_op.to_dense(), self.right_linear_op.to_dense()) diff --git a/linear_operator/operators/mul_linear_operator.py b/linear_operator/operators/mul_linear_operator.py index 5371dbf7..e7095575 100644 --- a/linear_operator/operators/mul_linear_operator.py +++ b/linear_operator/operators/mul_linear_operator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import List, Optional, Tuple, Union +from __future__ import annotations import torch from torch import Tensor @@ -80,7 +80,7 @@ def _matmul( return res def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) if other > 0: res = self.__class__(self.left_linear_op._mul_constant(other), self.right_linear_op) @@ -90,7 +90,7 @@ def _mul_constant( res = super()._mul_constant(other) return res - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: if left_vecs.ndimension() == 1: left_vecs = left_vecs.unsqueeze(1) right_vecs = right_vecs.unsqueeze(1) @@ -130,7 +130,7 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O return tuple(list(left_deriv_args) + list(right_deriv_args)) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__( self.left_linear_op._expand_batch(batch_shape), self.right_linear_op._expand_batch(batch_shape) @@ -151,7 +151,7 @@ def _transpose_nonbatch( # mul.linear_op only works with symmetric matrices return self - def representation(self) -> Tuple[torch.Tensor, ...]: + def representation(self) -> tuple[torch.Tensor, ...]: """ Returns the Tensors that are used to define the LinearOperator """ diff --git a/linear_operator/operators/permutation_linear_operator.py b/linear_operator/operators/permutation_linear_operator.py index c59e3765..f21ff6e6 100644 --- a/linear_operator/operators/permutation_linear_operator.py +++ b/linear_operator/operators/permutation_linear_operator.py @@ -1,4 +1,6 @@ -from typing import Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable import torch from torch import Tensor @@ -20,22 +22,22 @@ def inverse( def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): self._matmul_check_shape(rhs) return self.inverse() @ rhs def _matmul_check_shape(self, rhs: Tensor) -> None: if rhs.shape[-2] != self.shape[-1]: raise ValueError( - f"{rhs.shape[0] = } incompatible with first dimensions of" + f"rhs.shape[0]={rhs.shape[0]} incompatible with first dimensions of " f"permutation operator with shape {self.shape}." ) @@ -43,7 +45,7 @@ def _matmul_batch_shape(self, rhs: Tensor) -> torch.Size: return torch.broadcast_shapes(self.batch_shape, rhs.shape[:-2]) @property - def dtype(self) -> Optional[torch.dtype]: + def dtype(self) -> torch.dtype | None: return self._dtype @@ -65,7 +67,7 @@ class PermutationLinearOperator(AbstractPermutationLinearOperator): def __init__( self, perm: Tensor, - inv_perm: Optional[Tensor] = None, + inv_perm: Tensor | None = None, validate_args: bool = True, ): if not isinstance(perm, Tensor): @@ -88,7 +90,7 @@ def __init__( if (sorted_perm[..., i] != i).any(): raise ValueError( f"Invalid perm-inv_perm input, index {i} missing or not at " - f"correct index for permutation with {perm.shape = }." + f"correct index for permutation with perm.shape={perm.shape}." ) self.perm = perm @@ -115,7 +117,7 @@ def _matmul( indices = batch_indices + (perm_indices, final_indices) return expanded_rhs[indices] - def _batch_indexing_helper(self, batch_shape: torch.Size) -> Tuple: + def _batch_indexing_helper(self, batch_shape: torch.Size) -> tuple: """Creates a tuple of indices with broadcastable shapes to preserve the batch dimensions when indexing into the non-batch dimensions with `perm`. @@ -185,7 +187,7 @@ def _transpose_nonbatch( return self @property - def dtype(self) -> Optional[torch.dtype]: + def dtype(self) -> torch.dtype | None: return self._dtype def type(self: LinearOperator, dtype: torch.dtype) -> LinearOperator: @@ -193,5 +195,5 @@ def type(self: LinearOperator, dtype: torch.dtype) -> LinearOperator: return self @property - def device(self) -> Optional[torch.device]: + def device(self) -> torch.device | None: return None diff --git a/linear_operator/operators/psd_sum_linear_operator.py b/linear_operator/operators/psd_sum_linear_operator.py index 4535e75d..c983ce50 100644 --- a/linear_operator/operators/psd_sum_linear_operator.py +++ b/linear_operator/operators/psd_sum_linear_operator.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + from torch import Tensor from linear_operator.operators._linear_operator import LinearOperator diff --git a/linear_operator/operators/root_linear_operator.py b/linear_operator/operators/root_linear_operator.py index bc029224..15af4e71 100644 --- a/linear_operator/operators/root_linear_operator.py +++ b/linear_operator/operators/root_linear_operator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import List, Optional, Union +from __future__ import annotations import torch from torch import Tensor @@ -28,7 +28,7 @@ def _diagonal( return super()._diagonal() def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) if len(batch_shape) == 0: return self @@ -72,7 +72,7 @@ def _matmul( return self.root._matmul(self.root._t_matmul(rhs)) def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) if (other > 0).all(): res = self.__class__(self.root._mul_constant(other.sqrt())) @@ -82,29 +82,29 @@ def _mul_constant( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) # Matrix is symmetric return self._matmul(rhs) def add_low_rank( self: LinearOperator, # shape: (*batch, N, N) - low_rank_mat: Union[Tensor, LinearOperator], # shape: (..., N, _) - root_decomp_method: Optional[str] = None, - root_inv_decomp_method: Optional[str] = None, - generate_roots: Optional[bool] = True, + low_rank_mat: Tensor | LinearOperator, # shape: (..., N, _) + root_decomp_method: str | None = None, + root_inv_decomp_method: str | None = None, + generate_roots: bool | None = True, **root_decomp_kwargs, ) -> LinearOperator: # shape: (*batch, N, N) return super().add_low_rank(low_rank_mat, root_inv_decomp_method=root_inv_decomp_method) def root_decomposition( - self: LinearOperator, method: Optional[str] = None # shape: (*batch, N, N) + self: LinearOperator, method: str | None = None # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) return self def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) return self.root def _root_decomposition_size(self) -> int: diff --git a/linear_operator/operators/sum_batch_linear_operator.py b/linear_operator/operators/sum_batch_linear_operator.py index feaf85d4..b15c665d 100644 --- a/linear_operator/operators/sum_batch_linear_operator.py +++ b/linear_operator/operators/sum_batch_linear_operator.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch import Tensor diff --git a/linear_operator/operators/sum_kronecker_linear_operator.py b/linear_operator/operators/sum_kronecker_linear_operator.py index 760a08aa..22a88b3a 100644 --- a/linear_operator/operators/sum_kronecker_linear_operator.py +++ b/linear_operator/operators/sum_kronecker_linear_operator.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 -from typing import Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable import torch from torch import Tensor @@ -40,15 +42,15 @@ def _sum_formulation(self): def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): inner_mat = self._sum_formulation # root decomposition may not be trustworthy if it uses a different method than # root_inv_decomposition. so ensure that we call this locally @@ -72,7 +74,7 @@ def _logdet( def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) inner_mat = self._sum_formulation lt2_root = KroneckerProductLinearOperator( *[lt.root_decomposition().root for lt in self.linear_ops[1].linear_ops] @@ -83,9 +85,9 @@ def _root_decomposition( def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) inner_mat = self._sum_formulation lt2_root_inv = self.linear_ops[1].root_inv_decomposition().root inner_mat_root_inv = inner_mat.root_inv_decomposition().root @@ -94,12 +96,12 @@ def _root_inv_decomposition( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on inv_quad_term = None logdet_term = None diff --git a/linear_operator/operators/sum_linear_operator.py b/linear_operator/operators/sum_linear_operator.py index 03fa196b..99052cdf 100644 --- a/linear_operator/operators/sum_linear_operator.py +++ b/linear_operator/operators/sum_linear_operator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import List, Optional, Tuple, Union +from __future__ import annotations import torch from torch import Tensor @@ -31,7 +31,7 @@ def _diagonal( return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) expanded_tensors = [linear_op._expand_batch(batch_shape) for linear_op in self.linear_ops] return self.__class__(*expanded_tensors) @@ -51,12 +51,12 @@ def _matmul( return sum(linear_op._matmul(rhs) for linear_op in self.linear_ops) def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) # We're using a custom method here - the constant mul is applied to the base_linear_ops return self.__class__(*[lt._mul_constant(other) for lt in self.linear_ops]) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: return tuple( var for linear_op in self.linear_ops for var in linear_op._bilinear_derivative(left_vecs, right_vecs) ) @@ -69,8 +69,8 @@ def _sum_batch(self, dim: int) -> LinearOperator: def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) return sum(linear_op._t_matmul(rhs) for linear_op in self.linear_ops) def _transpose_nonbatch( @@ -87,29 +87,30 @@ def to_dense( def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator from linear_operator.operators.diag_linear_operator import DiagLinearOperator - if isinstance(other, ZeroLinearOperator): - return self - elif isinstance(other, DiagLinearOperator): - return AddedDiagLinearOperator(self, other) - elif isinstance(other, SumLinearOperator): - return SumLinearOperator(*(list(self.linear_ops) + list(other.linear_ops))) - elif isinstance(other, LinearOperator): - return SumLinearOperator(*(list(self.linear_ops) + [other])) - elif isinstance(other, Tensor): - # get broadcast shape, assuming mul broadcasting the same as add broadcasting - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - - # to_linear_operator + broadcast other - broadcasted_other = to_linear_operator(other.expand(broadcasted_shape)) - - # update the lazy tensors' shape as well - new_self = self if broadcasted_shape == self.shape else self._expand_batch(broadcasted_shape[:-2]) - - return SumLinearOperator(*(list(new_self.linear_ops) + [broadcasted_other])) - else: - raise AttributeError("other must be a LinearOperator") + match other: + case ZeroLinearOperator(): + return self + case DiagLinearOperator(): + return AddedDiagLinearOperator(self, other) + case SumLinearOperator(): + return SumLinearOperator(*(list(self.linear_ops) + list(other.linear_ops))) + case LinearOperator(): + return SumLinearOperator(*(list(self.linear_ops) + [other])) + case Tensor(): + # get broadcast shape, assuming mul broadcasting the same as add broadcasting + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + + # to_linear_operator + broadcast other + broadcasted_other = to_linear_operator(other.expand(broadcasted_shape)) + + # update the lazy tensors' shape as well + new_self = self if broadcasted_shape == self.shape else self._expand_batch(broadcasted_shape[:-2]) + + return SumLinearOperator(*(list(new_self.linear_ops) + [broadcasted_other])) + case _: + raise AttributeError("other must be a LinearOperator") diff --git a/linear_operator/operators/toeplitz_linear_operator.py b/linear_operator/operators/toeplitz_linear_operator.py index 128873e8..7519b232 100644 --- a/linear_operator/operators/toeplitz_linear_operator.py +++ b/linear_operator/operators/toeplitz_linear_operator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import List, Optional, Tuple, Union +from __future__ import annotations import torch from torch import Tensor @@ -31,7 +31,7 @@ def _diagonal( return diag_term.expand(*self.column.size()) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.column.expand(*batch_shape, self.column.size(-1))) @@ -47,12 +47,12 @@ def _matmul( def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) # Matrix is symmetric return self._matmul(rhs) - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: if left_vecs.ndimension() == 1: left_vecs = left_vecs.unsqueeze(1) right_vecs = right_vecs.unsqueeze(1) diff --git a/linear_operator/operators/triangular_linear_operator.py b/linear_operator/operators/triangular_linear_operator.py index 201b2700..0f5d3aa7 100644 --- a/linear_operator/operators/triangular_linear_operator.py +++ b/linear_operator/operators/triangular_linear_operator.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, TypeAlias import torch from torch import Tensor @@ -12,7 +13,7 @@ from linear_operator.utils.errors import NotPSDError from linear_operator.utils.memoize import cached -Allsor = Union[Tensor, LinearOperator] +Allsor: TypeAlias = Tensor | LinearOperator class _TriangularLinearOperatorBase: @@ -51,8 +52,8 @@ def __init__(self, tensor: Allsor, upper: bool = False) -> None: def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): @@ -64,15 +65,15 @@ def __add__( return self._tensor + other def _cholesky( - self: LinearOperator, upper: Optional[bool] = False # shape: (*batch, N, N) + self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) raise NotPSDError("TriangularLinearOperator does not allow a Cholesky decomposition") def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) - rhs: Union[LinearOperator, Tensor], # shape: (*batch2, N, M) - upper: Optional[bool] = False, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, M) + rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) + upper: bool | None = False, + ) -> LinearOperator | Tensor: # shape: (..., N, M) # use custom method if implemented try: res = self._tensor._cholesky_solve(rhs=rhs, upper=upper) @@ -93,7 +94,7 @@ def _diagonal( return self._tensor._diagonal() def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) if len(batch_shape) == 0: return self @@ -109,20 +110,20 @@ def _matmul( return self._tensor.matmul(rhs) def _mul_constant( - self: LinearOperator, other: Union[float, torch.Tensor] # shape: (*batch, M, N) + self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) return self.__class__(self._tensor * other.unsqueeze(-1), upper=self.upper) def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) raise NotPSDError("TriangularLinearOperator does not allow a root decomposition") def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) raise NotPSDError("TriangularLinearOperator does not allow an inverse root decomposition") def _size(self) -> torch.Size: @@ -131,15 +132,15 @@ def _size(self) -> torch.Size: def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) - preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, # shape: (..., N, C) - num_tridiag: Optional[int] = 0, - ) -> Union[ - torch.Tensor, # shape: (..., N, C) - Tuple[ + preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) + num_tridiag: int | None = 0, + ) -> ( + torch.Tensor # shape: (..., N, C) + | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) - ], - ]: + ] + ): # already triangular, can just call solve for the solve return self.solve(rhs) @@ -179,12 +180,12 @@ def exp( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on if inv_quad_rhs is None: inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device) @@ -216,7 +217,7 @@ def inverse( def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) squeeze = False if right_tensor.dim() == 1: diff --git a/linear_operator/operators/zero_linear_operator.py b/linear_operator/operators/zero_linear_operator.py index f6665a67..cae129ee 100644 --- a/linear_operator/operators/zero_linear_operator.py +++ b/linear_operator/operators/zero_linear_operator.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import List, Optional, Tuple, Union - import torch from torch import Tensor @@ -22,9 +20,7 @@ class ZeroLinearOperator(LinearOperator): :param device: Device that the LinearOperator will be operating on. (Default: CPU). """ - def __init__( - self, *sizes: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None - ): + def __init__(self, *sizes: tuple[int, ...], dtype: torch.dtype | None = None, device: torch.device | None = None): super(ZeroLinearOperator, self).__init__(*sizes) self.sizes = list(sizes) @@ -32,14 +28,14 @@ def __init__( self._device = device or torch.device("cpu") @property - def dtype(self) -> Optional[torch.dtype]: + def dtype(self) -> torch.dtype | None: return self._dtype @property - def device(self) -> Optional[torch.device]: + def device(self) -> torch.device | None: return self._device - def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: raise RuntimeError("Backwards through a ZeroLinearOperator is not possible") def _diagonal( @@ -49,7 +45,7 @@ def _diagonal( return torch.zeros(shape[:-1], dtype=self.dtype, device=self.device) def _expand_batch( - self: LinearOperator, batch_shape: Union[torch.Size, List[int]] # shape: (..., M, N) + self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(*batch_shape, *self.sizes[-2:], dtype=self._dtype, device=self._device) @@ -84,7 +80,7 @@ def _prod_batch(self, dim: int) -> LinearOperator: def _root_decomposition( self: LinearOperator, # shape: (..., N, N) - ) -> Union[torch.Tensor, LinearOperator]: # shape: (..., N, N) + ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) raise RuntimeError("ZeroLinearOperators are not positive definite!") def _root_decomposition_size(self) -> int: @@ -92,9 +88,9 @@ def _root_decomposition_size(self) -> int: def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) - initial_vectors: Optional[torch.Tensor] = None, - test_vectors: Optional[torch.Tensor] = None, - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, N) + initial_vectors: torch.Tensor | None = None, + test_vectors: torch.Tensor | None = None, + ) -> LinearOperator | Tensor: # shape: (..., N, N) raise RuntimeError("ZeroLinearOperators are not positive definite!") def _size(self) -> torch.Size: @@ -107,8 +103,8 @@ def _sum_batch(self, dim: int) -> LinearOperator: def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) - rhs: Union[Tensor, LinearOperator], # shape: (*batch2, M, P) - ) -> Union[LinearOperator, Tensor]: # shape: (..., N, P) + rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) + ) -> LinearOperator | Tensor: # shape: (..., N, P) rhs_size_ind = -2 if rhs.ndimension() > 1 else -1 if self.size(-2) != rhs.size(rhs_size_ind): raise RuntimeError("Size mismatch, self: {}, rhs: {}".format(self.size(), rhs.size())) @@ -171,7 +167,7 @@ def add_diagonal( ) return res - def div(self, other: Union[float, torch.Tensor]) -> LinearOperator: + def div(self, other: float | torch.Tensor) -> LinearOperator: return self def inv_quad( @@ -183,12 +179,12 @@ def inv_quad( def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) - inv_quad_rhs: Optional[Tensor] = None, # shape: (*batch, N, M) or (*batch, N) - logdet: Optional[bool] = False, - reduce_inv_quad: Optional[bool] = True, - ) -> Tuple[ # fmt: off - Optional[Tensor], # shape: (*batch, M) or (*batch) or (0) - Optional[Tensor], # shape: (...) + inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) + logdet: bool | None = False, + reduce_inv_quad: bool | None = True, + ) -> tuple[ # fmt: off + Tensor | None, # shape: (*batch, M) or (*batch) or (0) + Tensor | None, # shape: (...) ]: # fmt: on raise RuntimeError("ZeroLinearOperators are not invertible!") @@ -199,8 +195,8 @@ def logdet( def matmul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[Tensor, LinearOperator], # shape: (*batch2, N, P) or (*batch2, N) - ) -> Union[Tensor, LinearOperator]: # shape: (..., M, P) or (..., M) + other: Tensor | LinearOperator, # shape: (*batch2, N, P) or (*batch2, N) + ) -> Tensor | LinearOperator: # shape: (..., M, P) or (..., M) tensor_size_ind = -2 if other.ndimension() > 1 else -1 if self.size(-1) != other.size(tensor_size_ind): raise RuntimeError("Size mismatch, self: {}, other: {}".format(self.size(), other.size())) @@ -215,7 +211,7 @@ def matmul( def mul( self: LinearOperator, # shape: (*batch, M, N) - other: Union[float, Tensor, LinearOperator], # shape: (*batch2, M, N) + other: float | Tensor | LinearOperator, # shape: (*batch2, M, N) ) -> LinearOperator: # shape: (..., M, N) shape = torch.broadcast_shapes(self.shape, other.shape) return self.__class__(*shape, dtype=self._dtype, device=self._device) @@ -223,7 +219,7 @@ def mul( def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) - left_tensor: Optional[Tensor] = None, # shape: (..., O, N) + left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) raise RuntimeError("ZeroLinearOperators are not invertible!") @@ -243,6 +239,6 @@ def transpose(self, dim1: int, dim2: int) -> LinearOperator: def __add__( self: LinearOperator, # shape: (..., #M, #N) - other: Union[Tensor, LinearOperator, float], # shape: (..., #M, #N) - ) -> Union[LinearOperator, Tensor]: # shape: (..., M, N) + other: Tensor | LinearOperator | float, # shape: (..., #M, #N) + ) -> LinearOperator | Tensor: # shape: (..., M, N) return other diff --git a/linear_operator/settings.py b/linear_operator/settings.py index 98332e2e..40076e15 100644 --- a/linear_operator/settings.py +++ b/linear_operator/settings.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import logging diff --git a/linear_operator/test/__init__.py b/linear_operator/test/__init__.py index e5a0d9b4..9e270c94 100644 --- a/linear_operator/test/__init__.py +++ b/linear_operator/test/__init__.py @@ -1 +1,2 @@ #!/usr/bin/env python3 +from __future__ import annotations diff --git a/linear_operator/test/base_test_case.py b/linear_operator/test/base_test_case.py index a045a870..ea86734e 100644 --- a/linear_operator/test/base_test_case.py +++ b/linear_operator/test/base_test_case.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import os import random @@ -45,7 +46,7 @@ def assertAllClose(self, tensor1, tensor2, rtol=1e-4, atol=1e-5, equal_nan=False raise AssertionError( f"tensor1 ({tensor1.shape}) and tensor2 ({tensor2.shape}) are not close enough. \n" - f"max rtol: {rtol_max:0.8f}\t\tmax atol: {atol_max:0.8f}" + f"max rtol: {rtol_max:0.8f}\t\tmax atol: {atol_max:0.8f}" # noqa: E231 ) def assertEqual(self, item1, item2): diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index 435588d2..58cc6950 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import logging import math diff --git a/linear_operator/test/utils.py b/linear_operator/test/utils.py index 10d6b62d..5d04f2ef 100644 --- a/linear_operator/test/utils.py +++ b/linear_operator/test/utils.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import random from contextlib import contextmanager diff --git a/linear_operator/utils/__init__.py b/linear_operator/utils/__init__.py index fcdcfb44..6c99e26b 100644 --- a/linear_operator/utils/__init__.py +++ b/linear_operator/utils/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations from linear_operator.utils import ( broadcasting, diff --git a/linear_operator/utils/broadcasting.py b/linear_operator/utils/broadcasting.py index 2fbb645a..39bdd82b 100644 --- a/linear_operator/utils/broadcasting.py +++ b/linear_operator/utils/broadcasting.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch diff --git a/linear_operator/utils/cholesky.py b/linear_operator/utils/cholesky.py index 4d779371..dafc80ef 100644 --- a/linear_operator/utils/cholesky.py +++ b/linear_operator/utils/cholesky.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import warnings @@ -38,13 +39,15 @@ def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=None): Aprime.diagonal(dim1=-1, dim2=-2).add_(diag_add) jitter_prev = jitter_new warnings.warn( - f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", + f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", # noqa: E231 NumericalWarning, ) L, info = torch.linalg.cholesky_ex(Aprime, out=out) if not torch.any(info): return L - raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.") + raise NotPSDError( + f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}." # noqa: E231 + ) def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=None): diff --git a/linear_operator/utils/contour_integral_quad.py b/linear_operator/utils/contour_integral_quad.py index 4141a352..122c3bb2 100644 --- a/linear_operator/utils/contour_integral_quad.py +++ b/linear_operator/utils/contour_integral_quad.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math import warnings diff --git a/linear_operator/utils/deprecation.py b/linear_operator/utils/deprecation.py index 136517a6..a6bcd1e6 100644 --- a/linear_operator/utils/deprecation.py +++ b/linear_operator/utils/deprecation.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import functools import warnings diff --git a/linear_operator/utils/errors.py b/linear_operator/utils/errors.py index f2b82191..a9930ba4 100644 --- a/linear_operator/utils/errors.py +++ b/linear_operator/utils/errors.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations class CachingError(RuntimeError): diff --git a/linear_operator/utils/generic.py b/linear_operator/utils/generic.py index 33b4bcd0..a7af9774 100644 --- a/linear_operator/utils/generic.py +++ b/linear_operator/utils/generic.py @@ -2,12 +2,10 @@ from __future__ import annotations -from typing import Optional, Set, Tuple - import torch -def _to_helper(*args, **kwargs) -> Tuple[Optional[torch.device], Optional[torch.dtype]]: +def _to_helper(*args, **kwargs) -> tuple[torch.device | None, torch.dtype | None]: """ Silently plucks out dtype and devices from a list of arguments. Can contain `torch.device`, `torch.dtype` and `torch.Tensor` objects as positional arguments @@ -20,8 +18,8 @@ def _to_helper(*args, **kwargs) -> Tuple[Optional[torch.device], Optional[torch. >>> dtype, device = _to_helper(torch.float, torch.device("cpu")) >>> dtype, device = _to_helper(torch.rand(2, dtype=torch.double) """ - dtype_args: Set[torch.dtype] = set() - device_args: Set[torch.device] = set() + dtype_args: set[torch.dtype] = set() + device_args: set[torch.device] = set() for arg in args: if type(arg) is torch.dtype: diff --git a/linear_operator/utils/getitem.py b/linear_operator/utils/getitem.py index 17bf94ee..85783b3f 100644 --- a/linear_operator/utils/getitem.py +++ b/linear_operator/utils/getitem.py @@ -2,22 +2,22 @@ from __future__ import annotations -from typing import Any, Iterable, Tuple, Union +import types +from typing import Any, Iterable import torch from linear_operator import settings from linear_operator.utils.broadcasting import _pad_with_singletons -# EllipsisType is only available in Python 3.10+ -IndexType = Union[type(Ellipsis), slice, Iterable[int], torch.LongTensor, int] +IndexType = types.EllipsisType | slice | Iterable[int] | torch.LongTensor | int # A slice that does nothing to a dimension _noop_index = slice(None, None, None) def _compute_getitem_size( - obj: Union[torch.Tensor, Any], + obj: torch.Tensor | Any, indices: IndexType, # Forward references not supported - obj: Union[torch.Tensor, "LinearOperator"] ) -> torch.Size: """ @@ -96,9 +96,9 @@ def _compute_getitem_size( def _convert_indices_to_tensors( - obj: Union[torch.Tensor, Any], + obj: torch.Tensor | Any, indices: IndexType, # Forward references not supported - obj: Union[torch.Tensor, "LinearOperator"] -) -> Tuple[torch.LongTensor, ...]: +) -> tuple[torch.LongTensor, ...]: """ Given an index made up of tensors/slices/ints, returns a tensor-only index that has the same outcome as the original index (when applied to the obj) @@ -167,12 +167,13 @@ def _equal_indices(a, b): """ Helper which checks whether two index components (int, slice, tensor) are equal """ - if torch.is_tensor(a) and torch.is_tensor(b): - return torch.equal(a, b) - elif not torch.is_tensor(a) and not torch.is_tensor(b): - return a == b - else: - return False + match (a, b): + case (a, b) if torch.is_tensor(a) and torch.is_tensor(b): + return torch.equal(a, b) + case (a, b) if not torch.is_tensor(a) and not torch.is_tensor(b): + return a == b + case _: + return False def _is_noop_index(index): diff --git a/linear_operator/utils/interpolation.py b/linear_operator/utils/interpolation.py index 92f8cce9..c32527db 100644 --- a/linear_operator/utils/interpolation.py +++ b/linear_operator/utils/interpolation.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 - +from __future__ import annotations import torch diff --git a/linear_operator/utils/lanczos.py b/linear_operator/utils/lanczos.py index dc34a1f1..e7adaf9e 100644 --- a/linear_operator/utils/lanczos.py +++ b/linear_operator/utils/lanczos.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch diff --git a/linear_operator/utils/linear_cg.py b/linear_operator/utils/linear_cg.py index d9b97a2b..25bc69ec 100644 --- a/linear_operator/utils/linear_cg.py +++ b/linear_operator/utils/linear_cg.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import warnings diff --git a/linear_operator/utils/memoize.py b/linear_operator/utils/memoize.py index 4adbe303..9d606e29 100644 --- a/linear_operator/utils/memoize.py +++ b/linear_operator/utils/memoize.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import functools import pickle diff --git a/linear_operator/utils/minres.py b/linear_operator/utils/minres.py index 18c755a8..732e4c33 100644 --- a/linear_operator/utils/minres.py +++ b/linear_operator/utils/minres.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch diff --git a/linear_operator/utils/permutation.py b/linear_operator/utils/permutation.py index 8905dc2a..675fc297 100644 --- a/linear_operator/utils/permutation.py +++ b/linear_operator/utils/permutation.py @@ -1,14 +1,15 @@ #!/usr/bin/env python3 +from __future__ import annotations -from typing import Any, Optional, Union +from typing import Any import torch def apply_permutation( - matrix: Union[Any, torch.Tensor], # Union["LinearOperator", torch.Tensor] - left_permutation: Optional[torch.Tensor] = None, - right_permutation: Optional[torch.Tensor] = None, + matrix: Any | torch.Tensor, # Union["LinearOperator", torch.Tensor] + left_permutation: torch.Tensor | None = None, + right_permutation: torch.Tensor | None = None, ): r""" Applies a left and/or right (partial) permutation to a given matrix :math:`\mathbf K`: diff --git a/linear_operator/utils/pinverse.py b/linear_operator/utils/pinverse.py index 4b34e9c2..0cb0055f 100644 --- a/linear_operator/utils/pinverse.py +++ b/linear_operator/utils/pinverse.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch import Tensor diff --git a/linear_operator/utils/qr.py b/linear_operator/utils/qr.py index 254e1956..fe5d1c3c 100644 --- a/linear_operator/utils/qr.py +++ b/linear_operator/utils/qr.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch diff --git a/linear_operator/utils/sparse.py b/linear_operator/utils/sparse.py index b55ebb4b..e216e384 100644 --- a/linear_operator/utils/sparse.py +++ b/linear_operator/utils/sparse.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch @@ -156,54 +157,55 @@ def sparse_getitem(sparse, idxs): size = list(sparse.size()) for i, idx in list(enumerate(idxs))[::-1]: - if isinstance(idx, int): - del size[i] - mask = indices[i].eq(idx) - if torch.any(mask): - new_indices = torch.zeros( - indices.size(0) - 1, - torch.sum(mask), - dtype=indices.dtype, - device=indices.device, - ) - for j in range(indices.size(0)): - if i > j: + match idx: + case int(): + del size[i] + mask = indices[i].eq(idx) + if torch.any(mask): + new_indices = torch.zeros( + indices.size(0) - 1, + torch.sum(mask), + dtype=indices.dtype, + device=indices.device, + ) + for j in range(indices.size(0)): + if i > j: + new_indices[j].copy_(indices[j][mask]) + elif i < j: + new_indices[j - 1].copy_(indices[j][mask]) + indices = new_indices + values = values[mask] + else: + indices.resize_(indices.size(0) - 1, 1).zero_() + values.resize_(1).zero_() + + if not len(size): + return sum(values) + + case slice(): + start, stop, step = idx.indices(size[i]) + size = list(size[:i]) + [stop - start] + list(size[i + 1 :]) + if step != 1: + raise RuntimeError("Slicing with step is not supported") + mask = indices[i].lt(stop) & indices[i].ge(start) + if torch.any(mask): + new_indices = torch.zeros( + indices.size(0), + torch.sum(mask), + dtype=indices.dtype, + device=indices.device, + ) + for j in range(indices.size(0)): new_indices[j].copy_(indices[j][mask]) - elif i < j: - new_indices[j - 1].copy_(indices[j][mask]) - indices = new_indices - values = values[mask] - else: - indices.resize_(indices.size(0) - 1, 1).zero_() - values.resize_(1).zero_() - - if not len(size): - return sum(values) - - elif isinstance(idx, slice): - start, stop, step = idx.indices(size[i]) - size = list(size[:i]) + [stop - start] + list(size[i + 1 :]) - if step != 1: - raise RuntimeError("Slicing with step is not supported") - mask = indices[i].lt(stop) & indices[i].ge(start) - if torch.any(mask): - new_indices = torch.zeros( - indices.size(0), - torch.sum(mask), - dtype=indices.dtype, - device=indices.device, - ) - for j in range(indices.size(0)): - new_indices[j].copy_(indices[j][mask]) - new_indices[i].sub_(start) - indices = new_indices - values = values[mask] - else: - indices.resize_(indices.size(0), 1).zero_() - values.resize_(1).zero_() - - else: - raise RuntimeError("Unknown index type") + new_indices[i].sub_(start) + indices = new_indices + values = values[mask] + else: + indices.resize_(indices.size(0), 1).zero_() + values.resize_(1).zero_() + + case _: + raise RuntimeError("Unknown index type") return torch.sparse_coo_tensor(indices, values, torch.Size(size), dtype=values.dtype, device=values.device) diff --git a/linear_operator/utils/stochastic_lq.py b/linear_operator/utils/stochastic_lq.py index 762bec63..16478aae 100644 --- a/linear_operator/utils/stochastic_lq.py +++ b/linear_operator/utils/stochastic_lq.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch diff --git a/linear_operator/utils/toeplitz.py b/linear_operator/utils/toeplitz.py index 7064307c..a4cdf1d3 100644 --- a/linear_operator/utils/toeplitz.py +++ b/linear_operator/utils/toeplitz.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import torch from torch.fft import fft, ifft diff --git a/linear_operator/utils/warnings.py b/linear_operator/utils/warnings.py index 5c73d2fa..68a6708a 100644 --- a/linear_operator/utils/warnings.py +++ b/linear_operator/utils/warnings.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations class NumericalWarning(RuntimeWarning):