diff --git a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py index 6e7f1aba..0622209e 100644 --- a/linear_operator/operators/low_rank_root_added_diag_linear_operator.py +++ b/linear_operator/operators/low_rank_root_added_diag_linear_operator.py @@ -40,7 +40,11 @@ def chol_cap_mat(self): V = self._linear_op.root.mT C = ConstantDiagLinearOperator(torch.ones(*V.batch_shape, 1, device=V.device, dtype=V.dtype), V.shape[-2]) - cap_mat = to_dense(C + V.matmul(A_inv.matmul(U))) + if isinstance(self._diag_tensor, ConstantDiagLinearOperator): + sigma_inv = A_inv.diag_values[0] + cap_mat = to_dense(C + sigma_inv * V.matmul(U)) + else: + cap_mat = to_dense(C + V.matmul(A_inv.matmul(U))) chol_cap_mat = psd_safe_cholesky(cap_mat) return chol_cap_mat