diff --git a/codes/IAE_JAX_v2_devl.py b/codes/IAE_JAX_v2_devl.py index bbae376..3673b79 100644 --- a/codes/IAE_JAX_v2_devl.py +++ b/codes/IAE_JAX_v2_devl.py @@ -734,8 +734,8 @@ def get_cost(params): if not self.simplex: B = params["Lambda"] @ self.PhiE else: - Lambda = params["Lambda"] / (np.sum(np.abs(params["Lambda"] ), axis=1)[:, np.newaxis] + 1e-3) - B = Lambda @ self.PhiE + params["Lambda"] = params["Lambda"] / (np.sum(params["Lambda"], axis=1)[:, np.newaxis] + 1e-3) + B = params["Lambda"] @ self.PhiE XRec = self.decoder(B)