diff --git a/parafac2/normalize.py b/parafac2/normalize.py index 8908965..8ac5041 100644 --- a/parafac2/normalize.py +++ b/parafac2/normalize.py @@ -79,6 +79,9 @@ def get_deviance(data: csr_array) -> np.ndarray: where y=0 or n-y=0, and ensures the final deviance values are non-negative before taking the square root. """ + # merge duplicate entries in the sparse matrix by summing their values + data.sum_duplicates() + data.eliminate_zeros() # counts per gene pi_j = data.sum(axis=0) @@ -104,7 +107,7 @@ def get_deviance(data: csr_array) -> np.ndarray: # Term 1: y * log(y / mu) = xlogy(y, y) - xlogy(y, mu) # xlogy handles y=0 case correctly returning 0. row, col = data.nonzero() - mu_ij_nn = n_i_col[row, 0] * pi_j[0, col] + mu_ij_nn = n_i[row] * pi_j[col] term1 = data.data * np.log(data.data / mu_ij_nn) # Term 2: (n-y) * log((n-y) / (n-mu)) = xlogy(n-y, n-y) - xlogy(n-y, n-mu)