Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions linear_operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
from __future__ import annotations

from linear_operator import beta_features, operators, settings, utils
from linear_operator.functions import (
add_diagonal,
Expand Down
1 change: 1 addition & 0 deletions linear_operator/beta_features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import warnings

Expand Down
34 changes: 16 additions & 18 deletions linear_operator/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

from __future__ import annotations

from typing import Any, Optional, Tuple, Union
from typing import Any, TypeAlias

import torch

from linear_operator.functions._dsmm import DSMM

LinearOperatorType = Any # Want this to be "LinearOperator" but runtime type checker can't handle
LinearOperatorType: TypeAlias = Any # Want this to be "LinearOperator" but runtime type checker can't handle


Anysor = Union[LinearOperatorType, torch.Tensor]
Anysor: TypeAlias = LinearOperatorType | torch.Tensor


def add_diagonal(input: Anysor, diag: torch.Tensor) -> LinearOperatorType:
Expand Down Expand Up @@ -47,9 +47,7 @@ def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor:
return input + diag


def diagonalization(
input: Anysor, method: Optional[str] = None
) -> Tuple[torch.Tensor, Union[torch.Tensor, LinearOperatorType]]:
def diagonalization(input: Anysor, method: str | None = None) -> tuple[torch.Tensor, torch.Tensor | LinearOperatorType]:
r"""
Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices).
:math:`\mathbf A`.
Expand All @@ -67,7 +65,7 @@ def diagonalization(


def dsmm(
sparse_mat: Union[torch.sparse.HalfTensor, torch.sparse.FloatTensor, torch.sparse.DoubleTensor],
sparse_mat: torch.sparse.HalfTensor | torch.sparse.FloatTensor | torch.sparse.DoubleTensor,
dense_mat: torch.Tensor,
) -> torch.Tensor:
r"""
Expand Down Expand Up @@ -111,8 +109,8 @@ def inv_quad(input: Anysor, inv_quad_rhs: torch.Tensor, reduce_inv_quad: bool =


def inv_quad_logdet(
input: Anysor, inv_quad_rhs: Optional[torch.Tensor] = None, logdet: bool = False, reduce_inv_quad: bool = True
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
input: Anysor, inv_quad_rhs: torch.Tensor | None = None, logdet: bool = False, reduce_inv_quad: bool = True
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
r"""
Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
However, calling this method is far more efficient and stable than calling each method independently.
Expand All @@ -133,8 +131,8 @@ def inv_quad_logdet(


def pivoted_cholesky(
input: Anysor, rank: int, error_tol: Optional[float] = None, return_pivots: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
input: Anysor, rank: int, error_tol: float | None = None, return_pivots: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
r"""
Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
:math:`\mathbf L \mathbf L^\top = \mathbf A`.
Expand All @@ -161,7 +159,7 @@ def pivoted_cholesky(
return to_linear_operator(input).pivoted_cholesky(rank=rank, error_tol=error_tol, return_pivots=return_pivots)


def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOperatorType:
def root_decomposition(input: Anysor, method: str | None = None) -> LinearOperatorType:
r"""
Returns a (usually low-rank) root decomposition linear operator of the
positive definite matrix (or batch of matrices) :math:`\mathbf A`.
Expand All @@ -180,9 +178,9 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe

def root_inv_decomposition(
input: Anysor,
initial_vectors: Optional[torch.Tensor] = None,
test_vectors: Optional[torch.Tensor] = None,
method: Optional[str] = None,
initial_vectors: torch.Tensor | None = None,
test_vectors: torch.Tensor | None = None,
method: str | None = None,
) -> LinearOperatorType:
r"""
Returns a (usually low-rank) inverse root decomposition linear operator
Expand All @@ -206,7 +204,7 @@ def root_inv_decomposition(
)


def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) -> torch.Tensor:
def solve(input: Anysor, rhs: torch.Tensor, lhs: torch.Tensor | None = None) -> torch.Tensor:
r"""
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`,
computes a linear solve with right hand side :math:`\mathbf R`:
Expand Down Expand Up @@ -241,8 +239,8 @@ def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None)


def sqrt_inv_matmul(
input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
input: Anysor, rhs: torch.Tensor, lhs: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
r"""
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`
and a right hand size :math:`\mathbf R`,
Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_diagonalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import torch
from torch.autograd import Function
Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_dsmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

from torch.autograd import Function

Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_inv_quad.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import torch
from torch.autograd import Function
Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_inv_quad_logdet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import warnings

Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

from torch.autograd import Function

Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_pivoted_cholesky.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import torch
from torch.autograd import Function
Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_root_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import torch
from torch.autograd import Function
Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_solve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import torch
from torch.autograd import Function
Expand Down
1 change: 1 addition & 0 deletions linear_operator/functions/_sqrt_inv_matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

import torch
from torch.autograd import Function
Expand Down
1 change: 1 addition & 0 deletions linear_operator/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from __future__ import annotations

from linear_operator.operators._linear_operator import LinearOperator, to_dense
from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator
Expand Down
Loading