From bd1542afa781b8771b6bed1d65bb6918e516a6aa Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Sun, 14 Dec 2025 16:27:01 -0500 Subject: [PATCH 1/2] add tests for `BlockDiagLinearOperator` --- .../operators/test_block_diag_linear_operator.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/test/operators/test_block_diag_linear_operator.py b/test/operators/test_block_diag_linear_operator.py index a5871239..fb6b8f1a 100644 --- a/test/operators/test_block_diag_linear_operator.py +++ b/test/operators/test_block_diag_linear_operator.py @@ -12,11 +12,19 @@ class TestBlockDiagLinearOperator(LinearOperatorTestCase, unittest.TestCase): seed = 0 should_test_sample = True + # Whether to initialize `BlockDiagLinearOperator` from a tensor or a linear operator. + _initialize_from_tensor = False + def create_linear_op(self): blocks = torch.randn(8, 4, 4) blocks = blocks.matmul(blocks.mT) blocks.add_(torch.eye(4, 4).unsqueeze_(0)) - return BlockDiagLinearOperator(DenseLinearOperator(blocks)) + + return ( + BlockDiagLinearOperator(blocks) + if self._initialize_from_tensor + else BlockDiagLinearOperator(DenseLinearOperator(blocks)) + ) def evaluate_linear_op(self, linear_op): blocks = linear_op.base_linear_op.tensor @@ -26,6 +34,10 @@ def evaluate_linear_op(self, linear_op): return actual +class TestBlockDiagLinearOperatorFromTensor(TestBlockDiagLinearOperator): + _initialize_from_tensor = True + + class TestBlockDiagLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase): seed = 0 should_test_sample = True @@ -75,7 +87,7 @@ def test_metaclass_constructor(self): base_operators = [torch.randn(k, n), torch.randn(b1, b2, k, n)] subtest_names = ["non-batched input", "batched input"] # repeats tests for both batched and non-batched tensors - for (base_op, test_name) in zip(base_operators, subtest_names): + for base_op, test_name in zip(base_operators, subtest_names): with self.subTest(test_name): base_diag = DiagLinearOperator(base_op) linear_op = BlockDiagLinearOperator(base_diag) From 86092eb50699fd29e85dc98338c857b33e64e39c Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Sun, 14 Dec 2025 16:09:13 -0500 Subject: [PATCH 2/2] fix `BlockDiagLinearOperator.diagonal` --- linear_operator/operators/block_diag_linear_operator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index 971f5de2..8abf3bd6 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -9,6 +9,7 @@ from linear_operator.operators._linear_operator import IndexType, LinearOperator from linear_operator.operators.block_linear_operator import BlockLinearOperator +from linear_operator.operators.dense_linear_operator import DenseLinearOperator from linear_operator.utils.memoize import cached @@ -49,6 +50,9 @@ class BlockDiagLinearOperator(BlockLinearOperator, metaclass=_MetaBlockDiagLinea """ def __init__(self, base_linear_op, block_dim=-3): + if isinstance(base_linear_op, Tensor): + base_linear_op = DenseLinearOperator(base_linear_op) + super().__init__(base_linear_op, block_dim) # block diagonal is restricted to have square diagonal blocks if self.base_linear_op.shape[-1] != self.base_linear_op.shape[-2]: