From dccc6b961ccf671848d58426f50b0050dfcdf3c1 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 2 Feb 2026 23:08:01 -0800 Subject: [PATCH] Speed up MatmulLinearOperator.to_dense() for DiagLazyTensors This can help substantially if some of the operators involved are `DiagLinearOperator`s. This fixes the issue reported in https://github.com/meta-pytorch/botorch/issues/3159 --- .../operators/matmul_linear_operator.py | 6 ++ test/operators/test_matmul_linear_operator.py | 57 ++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/linear_operator/operators/matmul_linear_operator.py b/linear_operator/operators/matmul_linear_operator.py index 40c389ca..5f8ffe51 100644 --- a/linear_operator/operators/matmul_linear_operator.py +++ b/linear_operator/operators/matmul_linear_operator.py @@ -135,4 +135,10 @@ def _transpose_nonbatch( def to_dense( self: LinearOperator, # shape: (*batch, M, N) ) -> Tensor: # shape: (*batch, M, N) + # Use element-wise multiplication for DiagLinearOperators + if isinstance(self.left_linear_op, DiagLinearOperator): + return self.left_linear_op._diag.unsqueeze(-1) * self.right_linear_op.to_dense() + if isinstance(self.right_linear_op, DiagLinearOperator): + return self.left_linear_op.to_dense() * self.right_linear_op._diag.unsqueeze(-2) + return torch.matmul(self.left_linear_op.to_dense(), self.right_linear_op.to_dense()) diff --git a/test/operators/test_matmul_linear_operator.py b/test/operators/test_matmul_linear_operator.py index f9613354..00bd74b3 100644 --- a/test/operators/test_matmul_linear_operator.py +++ b/test/operators/test_matmul_linear_operator.py @@ -4,7 +4,7 @@ import torch -from linear_operator.operators import MatmulLinearOperator +from linear_operator.operators import DenseLinearOperator, DiagLinearOperator, MatmulLinearOperator from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase @@ -56,5 +56,60 @@ def evaluate_linear_op(self, linear_op): return linear_op.left_linear_op.tensor.matmul(linear_op.right_linear_op.tensor) +class TestMatmulLinearOperatorDiagOptimization(unittest.TestCase): + """Tests for efficient diagonal matrix multiplication in to_dense().""" + + def test_diag_left_matmul_to_dense(self): + """Test D @ A uses element-wise multiplication.""" + diag = torch.tensor([1.0, 2.0, 3.0, 4.0]) + A = torch.randn(4, 5) + + D = DiagLinearOperator(diag) + result = MatmulLinearOperator(D, DenseLinearOperator(A)) + + expected = torch.diag(diag) @ A + self.assertTrue(torch.allclose(result.to_dense(), expected)) + + def test_diag_right_matmul_to_dense(self): + """Test A @ D uses element-wise multiplication.""" + diag = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + A = torch.randn(4, 5) + + D = DiagLinearOperator(diag) + result = MatmulLinearOperator(DenseLinearOperator(A), D) + + expected = A @ torch.diag(diag) + self.assertTrue(torch.allclose(result.to_dense(), expected)) + + def test_diag_sandwich_to_dense(self): + """Test D1 @ A @ D2 uses element-wise multiplication (the main bug fix).""" + diag1 = torch.tensor([1.0, 2.0, 3.0, 4.0]) + diag2 = torch.tensor([0.5, 1.5, 2.5, 3.5]) + A = torch.randn(4, 4) + + D1 = DiagLinearOperator(diag1) + D2 = DiagLinearOperator(diag2) + + result = D1 @ DenseLinearOperator(A) @ D2 + expected = torch.diag(diag1) @ A @ torch.diag(diag2) + self.assertTrue(torch.allclose(result.to_dense(), expected)) + + def test_diag_sandwich_batch(self): + """Test D1 @ A @ D2 with batch dimensions.""" + batch_size = 3 + n = 4 + + diag1 = torch.randn(batch_size, n).abs() + diag2 = torch.randn(batch_size, n).abs() + A = torch.randn(batch_size, n, n) + + D1 = DiagLinearOperator(diag1) + D2 = DiagLinearOperator(diag2) + + result = D1 @ DenseLinearOperator(A) @ D2 + expected = torch.diag_embed(diag1) @ A @ torch.diag_embed(diag2) + self.assertTrue(torch.allclose(result.to_dense(), expected)) + + if __name__ == "__main__": unittest.main()