diff --git a/src/squidpy/gr/_sepal.py b/src/squidpy/gr/_sepal.py index 7e60085d1..1197a6a6a 100644 --- a/src/squidpy/gr/_sepal.py +++ b/src/squidpy/gr/_sepal.py @@ -1,20 +1,21 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Sequence from typing import Literal +import fast_array_utils # noqa: F401 import numpy as np import pandas as pd from anndata import AnnData -from numba import njit +from numba import get_num_threads, get_thread_id, njit, prange from scanpy import logging as logg -from scipy.sparse import csr_matrix, isspmatrix_csr, spmatrix +from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix from sklearn.metrics import pairwise_distances from spatialdata import SpatialData from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA from squidpy.gr._utils import ( _assert_connectivity_key, _assert_non_empty_sequence, @@ -40,9 +41,6 @@ def sepal( layer: str | None = None, use_raw: bool = False, copy: bool = False, - n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, ) -> pd.DataFrame | None: """ Identify spatially variable genes with *Sepal*. @@ -78,7 +76,6 @@ def sepal( use_raw Whether to access :attr:`anndata.AnnData.raw`. %(copy)s - %(parallelize)s Returns ------- @@ -108,8 +105,6 @@ def sepal( genes = genes[adata.var["highly_variable"].values] genes = _assert_non_empty_sequence(genes, name="genes") - n_jobs = _get_n_cores(n_jobs) - g = adata.obsp[connectivity_key] if not isspmatrix_csr(g): g = csr_matrix(g) @@ -124,27 +119,26 @@ def sepal( # get counts vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer) - start = logg.info(f"Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)") - - score = parallelize( - _score_helper, - collection=np.arange(len(genes)).tolist(), - extractor=np.hstack, - use_ixs=False, - n_jobs=n_jobs, - backend=backend, - show_progress_bar=show_progress_bar, - )( - vals=vals, - max_neighs=max_neighs, - n_iter=n_iter, - sat=sat, - sat_idx=sat_idx, - unsat=unsat, - unsat_idx=unsat_idx, - dt=dt, - thresh=thresh, - ) + start = logg.info(f"Calculating sepal score for `{len(genes)}` genes") + + use_hex = max_neighs == 6 + + if issparse(vals): + vals = csc_matrix(vals) + score = _diffusion_batch_csc( + vals, + use_hex, + n_iter, + sat, + sat_idx, + unsat, + unsat_idx, + dt, + thresh, + ) + else: + vals_dense = np.ascontiguousarray(vals, dtype=np.float64) + score = _diffusion_batch_dense(vals_dense, use_hex, n_iter, sat, sat_idx, unsat, unsat_idx, dt, thresh) key_added = "sepal_score" sepal_score = pd.DataFrame(score, index=genes, columns=[key_added]) @@ -160,10 +154,10 @@ def sepal( _save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start) -def _score_helper( - ixs: Sequence[int], - vals: spmatrix | NDArrayA, - max_neighs: int, +@njit(parallel=True) +def _diffusion_batch_csc( + vals: csc_matrix, + use_hex: bool, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, @@ -171,58 +165,120 @@ def _score_helper( unsat_idx: NDArrayA, dt: float, thresh: float, - queue: SigQueue | None = None, ) -> NDArrayA: - if max_neighs == 4: - fun = _laplacian_rect - elif max_neighs == 6: - fun = _laplacian_hex - else: - raise NotImplementedError(f"Laplacian for `{max_neighs}` neighbors is not yet implemented.") - - score = [] - for i in ixs: - if isinstance(vals, spmatrix): - conc = vals[:, i].toarray().flatten() # Safe to call toarray() - else: - conc = vals[:, i].copy() # vals is assumed to be a NumPy array here - - time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh) - score.append(dt * time_iter) - - if queue is not None: - queue.put(Signal.UPDATE) - - if queue is not None: - queue.put(Signal.FINISH) - - return np.array(score) + indptr = vals.indptr + indices = vals.indices + data = vals.data + n_cells, n_genes = vals.shape + sat_shape = sat.shape[0] + n_threads = get_num_threads() + + conc_buf = np.empty((n_threads, n_cells)) + entropy_buf = np.empty((n_threads, n_iter)) + nhood_buf = np.empty((n_threads, sat_shape)) + dcdt_buf = np.empty((n_threads, n_cells)) + + scores = np.empty(n_genes) + for i in prange(n_genes): + tid = get_thread_id() + conc = conc_buf[tid] + conc[:] = 0.0 + for j in range(indptr[i], indptr[i + 1]): + conc[indices[j]] = data[j] + time_iter = _diffusion( + conc, + use_hex, + n_iter, + sat, + sat_idx, + unsat, + unsat_idx, + dt, + thresh, + entropy_buf[tid], + nhood_buf[tid], + dcdt_buf[tid], + ) + scores[i] = dt * time_iter + return scores + + +@njit(parallel=True) +def _diffusion_batch_dense( + vals: NDArrayA, + use_hex: bool, + n_iter: int, + sat: NDArrayA, + sat_idx: NDArrayA, + unsat: NDArrayA, + unsat_idx: NDArrayA, + dt: float, + thresh: float, +) -> NDArrayA: + n_genes = vals.shape[1] + n_cells = vals.shape[0] + sat_shape = sat.shape[0] + n_threads = get_num_threads() + + # Pre-allocate per-thread workspace to avoid allocator contention + conc_buf = np.empty((n_threads, n_cells)) + entropy_buf = np.empty((n_threads, n_iter)) + nhood_buf = np.empty((n_threads, sat_shape)) + dcdt_buf = np.empty((n_threads, n_cells)) + + scores = np.empty(n_genes) + for i in prange(n_genes): + tid = get_thread_id() + conc = conc_buf[tid] + conc[:] = vals[:, i] + time_iter = _diffusion( + conc, + use_hex, + n_iter, + sat, + sat_idx, + unsat, + unsat_idx, + dt, + thresh, + entropy_buf[tid], + nhood_buf[tid], + dcdt_buf[tid], + ) + scores[i] = dt * time_iter + return scores @njit(fastmath=True) def _diffusion( conc: NDArrayA, - laplacian: Callable[[NDArrayA, NDArrayA], float], + use_hex: bool, n_iter: int, sat: NDArrayA, sat_idx: NDArrayA, unsat: NDArrayA, unsat_idx: NDArrayA, - dt: float = 0.001, - thresh: float = 1e-8, + dt: float, + thresh: float, + entropy_arr: NDArrayA, + nhood: NDArrayA, + dcdt: NDArrayA, ) -> float: """Simulate diffusion process on a regular graph.""" - sat_shape, conc_shape = sat.shape[0], conc.shape[0] - entropy_arr = np.zeros(n_iter) + sat_shape = sat.shape[0] + entropy_arr[:] = 0.0 + nhood[:] = 0.0 prev_ent = 1.0 - nhood = np.zeros(sat_shape) for i in range(n_iter): for j in range(sat_shape): nhood[j] = np.sum(conc[sat_idx[j]]) - d2 = laplacian(conc[sat], nhood) + if use_hex: + d2 = _laplacian_hex(conc[sat], nhood) + else: + d2 = _laplacian_rect(conc[sat], nhood) - dcdt = np.zeros(conc_shape) + dcdt[:] = 0.0 dcdt[sat] = d2 conc[sat] += dcdt[sat] * dt conc[unsat] += dcdt[unsat_idx] * dt diff --git a/tests/graph/test_sepal.py b/tests/graph/test_sepal.py index 8fb711f5e..a54f8d7ac 100644 --- a/tests/graph/test_sepal.py +++ b/tests/graph/test_sepal.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numba import numpy as np from anndata import AnnData from pandas.testing import assert_frame_equal @@ -16,8 +17,14 @@ def test_sepal_seq_par(adata: AnnData): adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape, p=[0.005, 0.995]) sepal(adata, max_neighs=6) - df = sepal(adata, max_neighs=6, copy=True, n_jobs=1) - df_parallel = sepal(adata, max_neighs=6, copy=True, n_jobs=2) + + prev_threads = numba.get_num_threads() + try: + numba.set_num_threads(1) + df = sepal(adata, max_neighs=6, copy=True) + finally: + numba.set_num_threads(prev_threads) + df_parallel = sepal(adata, max_neighs=6, copy=True) idx_df = df.index.values idx_adata = adata[:, adata.var.highly_variable.values].var_names.values @@ -40,8 +47,13 @@ def test_sepal_square_seq_par(adata_squaregrid: AnnData): rng = np.random.default_rng(42) adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape) - sepal(adata, max_neighs=4) - df_parallel = sepal(adata, copy=True, n_jobs=2, max_neighs=4) + prev_threads = numba.get_num_threads() + try: + numba.set_num_threads(1) + sepal(adata, max_neighs=4) + finally: + numba.set_num_threads(prev_threads) + df_parallel = sepal(adata, copy=True, max_neighs=4) idx_df = df_parallel.index.values idx_adata = adata[:, adata.var.highly_variable.values].var_names.values