Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 124 additions & 68 deletions src/squidpy/gr/_sepal.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Literal

import fast_array_utils # noqa: F401
import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit
from numba import get_num_threads, get_thread_id, njit, prange
from scanpy import logging as logg
from scipy.sparse import csr_matrix, isspmatrix_csr, spmatrix
from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix
from sklearn.metrics import pairwise_distances
from spatialdata import SpatialData

from squidpy._constants._pkg_constants import Key
from squidpy._docs import d, inject_docs
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
from squidpy._utils import NDArrayA
from squidpy.gr._utils import (
_assert_connectivity_key,
_assert_non_empty_sequence,
Expand All @@ -40,9 +41,6 @@ def sepal(
layer: str | None = None,
use_raw: bool = False,
copy: bool = False,
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = True,
) -> pd.DataFrame | None:
"""
Identify spatially variable genes with *Sepal*.
Expand Down Expand Up @@ -78,7 +76,6 @@ def sepal(
use_raw
Whether to access :attr:`anndata.AnnData.raw`.
%(copy)s
%(parallelize)s

Returns
-------
Expand Down Expand Up @@ -108,8 +105,6 @@ def sepal(
genes = genes[adata.var["highly_variable"].values]
genes = _assert_non_empty_sequence(genes, name="genes")

n_jobs = _get_n_cores(n_jobs)

g = adata.obsp[connectivity_key]
if not isspmatrix_csr(g):
g = csr_matrix(g)
Expand All @@ -124,27 +119,26 @@ def sepal(

# get counts
vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer)
start = logg.info(f"Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)")

score = parallelize(
_score_helper,
collection=np.arange(len(genes)).tolist(),
extractor=np.hstack,
use_ixs=False,
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)(
vals=vals,
max_neighs=max_neighs,
n_iter=n_iter,
sat=sat,
sat_idx=sat_idx,
unsat=unsat,
unsat_idx=unsat_idx,
dt=dt,
thresh=thresh,
)
start = logg.info(f"Calculating sepal score for `{len(genes)}` genes")

use_hex = max_neighs == 6

if issparse(vals):
vals = csc_matrix(vals)
score = _diffusion_batch_csc(
vals,
use_hex,
n_iter,
sat,
sat_idx,
unsat,
unsat_idx,
dt,
thresh,
)
else:
vals_dense = np.ascontiguousarray(vals, dtype=np.float64)
score = _diffusion_batch_dense(vals_dense, use_hex, n_iter, sat, sat_idx, unsat, unsat_idx, dt, thresh)

key_added = "sepal_score"
sepal_score = pd.DataFrame(score, index=genes, columns=[key_added])
Expand All @@ -160,69 +154,131 @@ def sepal(
_save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start)


def _score_helper(
ixs: Sequence[int],
vals: spmatrix | NDArrayA,
max_neighs: int,
@njit(parallel=True)
def _diffusion_batch_csc(
vals: csc_matrix,
use_hex: bool,
n_iter: int,
sat: NDArrayA,
sat_idx: NDArrayA,
unsat: NDArrayA,
unsat_idx: NDArrayA,
dt: float,
thresh: float,
queue: SigQueue | None = None,
) -> NDArrayA:
if max_neighs == 4:
fun = _laplacian_rect
elif max_neighs == 6:
fun = _laplacian_hex
else:
raise NotImplementedError(f"Laplacian for `{max_neighs}` neighbors is not yet implemented.")

score = []
for i in ixs:
if isinstance(vals, spmatrix):
conc = vals[:, i].toarray().flatten() # Safe to call toarray()
else:
conc = vals[:, i].copy() # vals is assumed to be a NumPy array here

time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh)
score.append(dt * time_iter)

if queue is not None:
queue.put(Signal.UPDATE)

if queue is not None:
queue.put(Signal.FINISH)

return np.array(score)
indptr = vals.indptr
indices = vals.indices
data = vals.data
n_cells, n_genes = vals.shape
sat_shape = sat.shape[0]
n_threads = get_num_threads()

conc_buf = np.empty((n_threads, n_cells))
entropy_buf = np.empty((n_threads, n_iter))
nhood_buf = np.empty((n_threads, sat_shape))
dcdt_buf = np.empty((n_threads, n_cells))

scores = np.empty(n_genes)
for i in prange(n_genes):
tid = get_thread_id()
conc = conc_buf[tid]
conc[:] = 0.0
for j in range(indptr[i], indptr[i + 1]):
conc[indices[j]] = data[j]
time_iter = _diffusion(
conc,
use_hex,
n_iter,
sat,
sat_idx,
unsat,
unsat_idx,
dt,
thresh,
entropy_buf[tid],
nhood_buf[tid],
dcdt_buf[tid],
)
scores[i] = dt * time_iter
return scores


@njit(parallel=True)
def _diffusion_batch_dense(
vals: NDArrayA,
use_hex: bool,
n_iter: int,
sat: NDArrayA,
sat_idx: NDArrayA,
unsat: NDArrayA,
unsat_idx: NDArrayA,
dt: float,
thresh: float,
) -> NDArrayA:
n_genes = vals.shape[1]
n_cells = vals.shape[0]
sat_shape = sat.shape[0]
n_threads = get_num_threads()

# Pre-allocate per-thread workspace to avoid allocator contention
Copy link
Member

@flying-sheep flying-sheep Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there isn’t a way to tell numba to do the pre-allocation instead of doing it manually.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also wondering the same but I couldn't find a way to do it without calling get_num_threads()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Intron7 do you know anything about this?

conc_buf = np.empty((n_threads, n_cells))
entropy_buf = np.empty((n_threads, n_iter))
nhood_buf = np.empty((n_threads, sat_shape))
dcdt_buf = np.empty((n_threads, n_cells))

scores = np.empty(n_genes)
for i in prange(n_genes):
tid = get_thread_id()
conc = conc_buf[tid]
conc[:] = vals[:, i]
time_iter = _diffusion(
conc,
use_hex,
n_iter,
sat,
sat_idx,
unsat,
unsat_idx,
dt,
thresh,
entropy_buf[tid],
nhood_buf[tid],
dcdt_buf[tid],
)
scores[i] = dt * time_iter
return scores


@njit(fastmath=True)
def _diffusion(
conc: NDArrayA,
laplacian: Callable[[NDArrayA, NDArrayA], float],
use_hex: bool,
n_iter: int,
sat: NDArrayA,
sat_idx: NDArrayA,
unsat: NDArrayA,
unsat_idx: NDArrayA,
dt: float = 0.001,
thresh: float = 1e-8,
dt: float,
thresh: float,
entropy_arr: NDArrayA,
nhood: NDArrayA,
dcdt: NDArrayA,
) -> float:
"""Simulate diffusion process on a regular graph."""
sat_shape, conc_shape = sat.shape[0], conc.shape[0]
entropy_arr = np.zeros(n_iter)
sat_shape = sat.shape[0]
entropy_arr[:] = 0.0
nhood[:] = 0.0
prev_ent = 1.0
nhood = np.zeros(sat_shape)

for i in range(n_iter):
for j in range(sat_shape):
nhood[j] = np.sum(conc[sat_idx[j]])
d2 = laplacian(conc[sat], nhood)
if use_hex:
d2 = _laplacian_hex(conc[sat], nhood)
else:
d2 = _laplacian_rect(conc[sat], nhood)

dcdt = np.zeros(conc_shape)
dcdt[:] = 0.0
dcdt[sat] = d2
conc[sat] += dcdt[sat] * dt
conc[unsat] += dcdt[unsat_idx] * dt
Expand Down
20 changes: 16 additions & 4 deletions tests/graph/test_sepal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numba
import numpy as np
from anndata import AnnData
from pandas.testing import assert_frame_equal
Expand All @@ -16,8 +17,14 @@ def test_sepal_seq_par(adata: AnnData):
adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape, p=[0.005, 0.995])

sepal(adata, max_neighs=6)
df = sepal(adata, max_neighs=6, copy=True, n_jobs=1)
df_parallel = sepal(adata, max_neighs=6, copy=True, n_jobs=2)

prev_threads = numba.get_num_threads()
try:
numba.set_num_threads(1)
df = sepal(adata, max_neighs=6, copy=True)
finally:
numba.set_num_threads(prev_threads)
df_parallel = sepal(adata, max_neighs=6, copy=True)

idx_df = df.index.values
idx_adata = adata[:, adata.var.highly_variable.values].var_names.values
Expand All @@ -40,8 +47,13 @@ def test_sepal_square_seq_par(adata_squaregrid: AnnData):
rng = np.random.default_rng(42)
adata.var["highly_variable"] = rng.choice([True, False], size=adata.var_names.shape)

sepal(adata, max_neighs=4)
df_parallel = sepal(adata, copy=True, n_jobs=2, max_neighs=4)
prev_threads = numba.get_num_threads()
try:
numba.set_num_threads(1)
sepal(adata, max_neighs=4)
finally:
numba.set_num_threads(prev_threads)
df_parallel = sepal(adata, copy=True, max_neighs=4)

idx_df = df_parallel.index.values
idx_adata = adata[:, adata.var.highly_variable.values].var_names.values
Expand Down
Loading