diff --git a/scripts/bench_ligrec.py b/scripts/bench_ligrec.py new file mode 100644 index 000000000..0b4e22a98 --- /dev/null +++ b/scripts/bench_ligrec.py @@ -0,0 +1,203 @@ +""" +Benchmark script for ligrec() -- compare main vs refactored branch. + +Usage: + python scripts/bench_ligrec.py # default config + python scripts/bench_ligrec.py --n-perms 500 # fewer perms (faster) + python scripts/bench_ligrec.py --n-cells 50000 # more cells (slower) + python scripts/bench_ligrec.py --n-runs 5 # average over 5 runs + python scripts/bench_ligrec.py --n-jobs 4 # 4 workers (main only) + python scripts/bench_ligrec.py --no-cache # rebuild data from scratch + +Defaults are calibrated to ~30s per run on Apple M-series (1 core): + 30 000 cells, 2 000 genes, 25 clusters, 6 400 interactions, 1 000 perms. + +The prepared AnnData + interactions are cached under .pytest_cache/ +so repeated runs skip the (slow) data-generation step. +""" + +from __future__ import annotations + +import argparse +import hashlib +import pickle +import time +import warnings +from itertools import product +from pathlib import Path + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", message=".*ImplicitModificationWarning.*") +warnings.filterwarnings("ignore", message=".*Transforming to str index.*") + +import numpy as np +import pandas as pd +from anndata import AnnData + +from squidpy.gr import ligrec + +CACHE_DIR = Path(".pytest_cache") / "bench_ligrec" + + +def _cache_key( + n_cells: int, + n_genes: int, + n_clusters: int, + n_interaction_genes: int, +) -> str: + tag = f"{n_cells}_{n_genes}_{n_clusters}_{n_interaction_genes}" + return hashlib.sha256(tag.encode()).hexdigest()[:16] + + +def _build_adata( + n_cells: int, + n_genes: int, + n_clusters: int, + n_interaction_genes: int, + use_cache: bool, +) -> tuple[AnnData, list[tuple[str, str]]]: + key = _cache_key(n_cells, n_genes, n_clusters, n_interaction_genes) + cache_path = CACHE_DIR / f"{key}.pkl" + + if use_cache and cache_path.exists(): + print(f"Loading cached data from {cache_path}", flush=True) + with open(cache_path, "rb") as f: + adata, interactions = pickle.load(f) + print( + f" cells={adata.n_obs}, genes={adata.n_vars}, " + f"clusters={len(adata.obs['cluster'].cat.categories)}, " + f"interactions={len(interactions)}", + flush=True, + ) + return adata, interactions + + print("Building synthetic AnnData...", flush=True) + rng = np.random.default_rng(42) + X = rng.random((n_cells, n_genes)) + cluster_labels = rng.choice([f"c{i}" for i in range(n_clusters)], size=n_cells) + obs = pd.DataFrame({"cluster": pd.Categorical(cluster_labels)}) + var = pd.DataFrame(index=[f"G{i}" for i in range(n_genes)]) + adata = AnnData(X, obs=obs, var=var) + adata.raw = adata.copy() + + igenes = list(adata.var_names[:n_interaction_genes]) + interactions = list(product(igenes, igenes)) + + print( + f" cells={n_cells}, genes={n_genes}, clusters={n_clusters}, " + f"interaction_genes={n_interaction_genes}, " + f"interactions={len(interactions)}", + flush=True, + ) + + CACHE_DIR.mkdir(parents=True, exist_ok=True) + with open(cache_path, "wb") as f: + pickle.dump((adata, interactions), f, protocol=pickle.HIGHEST_PROTOCOL) + print(f" cached to {cache_path}", flush=True) + + return adata, interactions + + +def _run_once( + adata: AnnData, + interactions: list[tuple[str, str]], + n_perms: int, +) -> float: + t0 = time.perf_counter() + ligrec( + adata, + cluster_key="cluster", + interactions=interactions, + n_perms=n_perms, + copy=True, + seed=0, + use_raw=True, + ) + return time.perf_counter() - t0 + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark ligrec()") + parser.add_argument( + "--n-cells", + type=int, + default=30000, + help="Number of cells (default 30000)", + ) + parser.add_argument( + "--n-genes", + type=int, + default=2000, + help="Total genes in AnnData (default 2000)", + ) + parser.add_argument( + "--n-clusters", + type=int, + default=25, + help="Number of clusters (default 25)", + ) + parser.add_argument( + "--n-interaction-genes", + type=int, + default=80, + help="Genes used in interactions; n^2 pairs (default 80 -> 6400)", + ) + parser.add_argument( + "--n-perms", + type=int, + default=1000, + help="Number of permutations (default 1000)", + ) + parser.add_argument( + "--n-runs", + type=int, + default=3, + help="Number of timed runs (default 3)", + ) + parser.add_argument( + "--no-cache", + action="store_true", + help="Rebuild data even if cache exists", + ) + args = parser.parse_args() + + adata, interactions = _build_adata( + n_cells=args.n_cells, + n_genes=args.n_genes, + n_clusters=args.n_clusters, + n_interaction_genes=args.n_interaction_genes, + use_cache=not args.no_cache, + ) + + print("\nWarmup (JIT compile)...", flush=True) + small = adata[:50, :].copy() + small.raw = small.copy() + ligrec(small, cluster_key="cluster", interactions=interactions[:4], n_perms=5, copy=True, seed=0, use_raw=True) + print(" done.\n", flush=True) + + n_inter = len(interactions) + n_cls_pairs = len(adata.obs["cluster"].cat.categories) ** 2 + print( + f"Config: {args.n_cells} cells, {args.n_genes} genes, " + f"{args.n_clusters} clusters, {n_inter} interactions, " + f"{n_cls_pairs} cluster pairs, {args.n_perms} perms", + flush=True, + ) + print(f"Running ligrec() {args.n_runs} time(s)...\n", flush=True) + + times = [] + for i in range(args.n_runs): + t = _run_once(adata, interactions, args.n_perms) + times.append(t) + print(f" run {i + 1}: {t:.3f}s", flush=True) + + times_arr = np.array(times) + print(f"\nResults ({args.n_runs} runs):") + print(f" mean: {times_arr.mean():.3f}s") + print(f" median: {np.median(times_arr):.3f}s") + print(f" min: {times_arr.min():.3f}s") + print(f" max: {times_arr.max():.3f}s") + + +if __name__ == "__main__": + main() diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index a4beecd8f..ed7dd5d83 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -5,14 +5,15 @@ from abc import ABC from collections import namedtuple from collections.abc import Iterable, Mapping, Sequence -from functools import partial from itertools import product from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, TypeAlias +import numba import numpy as np import pandas as pd from anndata import AnnData +from numba import njit, prange from scanpy import logging as logg from scipy.sparse import csc_matrix from spatialdata import SpatialData @@ -20,7 +21,7 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA from squidpy.gr._utils import ( _assert_categorical_obs, _assert_positive, @@ -42,102 +43,6 @@ TempResult = namedtuple("TempResult", ["means", "pvalues"]) -_template = """ -from __future__ import annotations - -from numba import njit, prange -import numpy as np - -@njit(parallel={parallel}, cache=False, fastmath=False) -def _test_{n_cls}_{ret_means}_{parallel}( - interactions: NDArrayA[np.uint32], - interaction_clusters: NDArrayA[np.uint32], - data: NDArrayA[np.float64], - clustering: NDArrayA[np.uint32], - mean: NDArrayA[np.float64], - mask: NDArrayA[np.bool_], - res: NDArrayA[np.float64], - {args} -) -> None: - - {init} - {loop} - {finalize} - - for i in prange(len(interactions)): - rec, lig = interactions[i] - for j in prange(len(interaction_clusters)): - c1, c2 = interaction_clusters[j] - m1, m2 = mean[rec, c1], mean[lig, c2] - - if np.isnan(res[i, j]): - continue - - if m1 > 0 and m2 > 0: - {set_means} - if mask[rec, c1] and mask[lig, c2]: - # both rec, lig are sufficiently expressed in c1, c2 - res[i, j] += (groups[c1, rec] + groups[c2, lig]) > (m1 + m2) - else: - res[i, j] = np.nan - else: - # res_means is initialized with 0s - res[i, j] = np.nan -""" - - -def _create_template(n_cls: int, return_means: bool = False, parallel: bool = True) -> str: - if n_cls <= 0: - raise ValueError(f"Expected number of clusters to be positive, found `{n_cls}`.") - - rng = range(n_cls) - init = "".join( - f""" - g{i} = np.zeros((data.shape[1],), dtype=np.float64); s{i} = 0""" - for i in rng - ) - init += """ - error = False - """ - - loop_body = """ - if cl == 0: - g0 += data[row] - s0 += 1""" - loop_body = loop_body + "".join( - f""" - elif cl == {i}: - g{i} += data[row] - s{i} += 1""" - for i in range(1, n_cls) - ) - loop = f""" - for row in prange(data.shape[0]): - cl = clustering[row] - {loop_body} - else: - error = True - """ - finalize = ", ".join(f"g{i} / s{i}" for i in rng) - finalize = f"groups = np.stack(({finalize}))" - - if return_means: - args = "res_means: NDArrayA, # [np.float64]" - set_means = "res_means[i, j] = (m1 + m2) / 2.0" - else: - args = set_means = "" - - return _template.format( - n_cls=n_cls, - parallel=bool(parallel), - ret_means=int(return_means), - args=args, - init=init, - loop=loop, - finalize=finalize, - set_means=set_means, - ) - def _fdr_correct( pvals: pd.DataFrame, @@ -326,8 +231,6 @@ def test( alpha: float = 0.05, copy: bool = False, key_added: str | None = None, - numba_parallel: bool | None = None, - **kwargs: Any, ) -> Mapping[str, pd.DataFrame] | None: """ Perform the permutation test as described in :cite:`cellphonedb`. @@ -355,9 +258,6 @@ def test( key_added Key in :attr:`anndata.AnnData.uns` where the result is stored if ``copy = False``. If `None`, ``'{{cluster_key}}_ligrec'`` will be used. - %(numba_parallel)s - %(parallelize)s - Returns ------- %(ligrec_test_returns)s @@ -409,10 +309,9 @@ def test( # much faster than applymap (tested on 1M interactions) interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values) - n_jobs = _get_n_cores(kwargs.pop("n_jobs", None)) start = logg.info( f"Running `{n_perms}` permutations on `{len(interactions)}` interactions " - f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)" + f"and `{len(clusters)}` cluster combinations" ) res = _analysis( data, @@ -421,9 +320,6 @@ def test( threshold=threshold, n_perms=n_perms, seed=seed, - n_jobs=n_jobs, - numba_parallel=numba_parallel, - **kwargs, ) index = pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET]) columns = pd.MultiIndex.from_tuples(clusters, names=["cluster_1", "cluster_2"]) @@ -454,6 +350,7 @@ def test( return res _save_data(self._adata, attr="uns", key=Key.uns.ligrec(cluster_key, key_added), data=res, time=start) + return None def _trim_data(self) -> None: """Subset genes :attr:`_data` to those present in interactions.""" @@ -642,7 +539,13 @@ def ligrec( copy: bool = False, key_added: str | None = None, gene_symbols: str | None = None, - **kwargs: Any, + n_perms: int = 1000, + seed: int | None = None, + clusters: Cluster_t | None = None, + alpha: float = 0.05, + interactions_params: Mapping[str, Any] = MappingProxyType({}), + transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), + receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s @@ -664,19 +567,118 @@ def ligrec( with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) - .prepare(interactions, complex_policy=complex_policy, **kwargs) + .prepare( + interactions, + complex_policy=complex_policy, + interactions_params=interactions_params, + transmitter_params=transmitter_params, + receiver_params=receiver_params, + ) .test( cluster_key=cluster_key, + clusters=clusters, + n_perms=n_perms, threshold=threshold, + seed=seed, corr_method=corr_method, corr_axis=corr_axis, + alpha=alpha, copy=copy, key_added=key_added, - **kwargs, ) ) +@njit(cache=True) +def _permutation_test_chunk( + data: NDArrayA, + clustering: NDArrayA, + inv_counts: NDArrayA, + mean_obs: NDArrayA, + interactions: NDArrayA, + interaction_clusters: NDArrayA, + valid: NDArrayA, + n_perms: int, + chunk_seed: int, + local_counts: NDArrayA, +) -> None: + """Run a chunk of permutations sequentially, accumulating into local_counts.""" + n_cells = data.shape[0] + n_genes = data.shape[1] + n_cls = mean_obs.shape[0] + n_inter = interactions.shape[0] + n_cpairs = interaction_clusters.shape[0] + + np.random.seed(chunk_seed) + for _perm_idx in range(n_perms): + perm = clustering.copy() + np.random.shuffle(perm) + + groups = np.zeros((n_cls, n_genes), dtype=np.float64) + for cell in range(n_cells): + cl = perm[cell] + for g in range(n_genes): + groups[cl, g] += data[cell, g] + for k in range(n_cls): + inv_c = inv_counts[k] + for g in range(n_genes): + groups[k, g] *= inv_c + + for i in range(n_inter): + r = interactions[i, 0] + l = interactions[i, 1] + for j in range(n_cpairs): + if valid[i, j]: + a = interaction_clusters[j, 0] + b = interaction_clusters[j, 1] + shuf = groups[a, r] + groups[b, l] + obs = mean_obs[a, r] + mean_obs[b, l] + if shuf > obs: + local_counts[i, j] += 1 + + +@njit(parallel=True, cache=True) +def _permutation_test( + data: NDArrayA, + clustering: NDArrayA, + inv_counts: NDArrayA, + mean_obs: NDArrayA, + mask: NDArrayA, + interactions: NDArrayA, + interaction_clusters: NDArrayA, + valid: NDArrayA, + n_perms: int, + chunk_seeds: NDArrayA, + chunk_sizes: NDArrayA, + pval_counts: NDArrayA, +) -> None: + """Distribute permutations across threads, each with a local accumulator.""" + n_inter = interactions.shape[0] + n_cpairs = interaction_clusters.shape[0] + n_threads = len(chunk_seeds) + + thread_counts = np.zeros((n_threads, n_inter, n_cpairs), dtype=np.int64) + + for t in prange(n_threads): + _permutation_test_chunk( + data, + clustering, + inv_counts, + mean_obs, + interactions, + interaction_clusters, + valid, + chunk_sizes[t], + chunk_seeds[t], + thread_counts[t], + ) + + for t in range(n_threads): + for i in range(n_inter): + for j in range(n_cpairs): + pval_counts[i, j] += thread_counts[t, i, j] + + @d.dedent def _analysis( data: pd.DataFrame, @@ -685,15 +687,10 @@ def _analysis( threshold: float = 0.1, n_perms: int = 1000, seed: int | None = None, - n_jobs: int = 1, - numba_parallel: bool | None = None, - **kwargs: Any, ) -> TempResult: """ Run the analysis as described in :cite:`cellphonedb`. - This function runs the mean, percent and shuffled analysis. - Parameters ---------- data @@ -706,12 +703,6 @@ def _analysis( Percentage threshold for removing lowly expressed genes in clusters. %(n_perms)s %(seed)s - n_jobs - Number of parallel jobs to launch. - numba_parallel - Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - kwargs - Keyword arguments for :func:`squidpy._utils.parallelize`, such as ``n_jobs`` or ``backend``. Returns ------- @@ -720,145 +711,72 @@ def _analysis( - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the means. - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing the p-values. """ - - def extractor(res: Sequence[TempResult]) -> TempResult: - assert len(res) == n_jobs, f"Expected to find `{n_jobs}` results, found `{len(res)}`." - - meanss: list[NDArrayA] = [r.means for r in res if r.means is not None] - assert len(meanss) == 1, f"Only `1` job should've calculated the means, but found `{len(meanss)}`." - means = meanss[0] - if TYPE_CHECKING: - assert isinstance(means, np.ndarray) - - pvalues = np.sum([r.pvalues for r in res if r.pvalues is not None], axis=0) / float(n_perms) - assert means.shape == pvalues.shape, f"Means and p-values differ in shape: `{means.shape}`, `{pvalues.shape}`." - - return TempResult(means=means, pvalues=pvalues) - clustering = np.array(data["clusters"].values, dtype=np.int32) # densify the data earlier to avoid concatenating sparse arrays # with multiple fill values: '[0.0, nan]' (which leads to PerformanceWarning) data = data.astype({c: np.float64 for c in data.columns if c != "clusters"}) groups = data.groupby("clusters", observed=True) - mean = groups.mean().values.T # (n_genes, n_clusters) + mean_obs = groups.mean().values # (n_clusters, n_genes) # see https://github.com/scverse/squidpy/pull/991#issuecomment-2888506296 # for why we need to cast to int64 here mask = groups.apply( lambda c: ((c > 0).astype(np.int64).sum() / len(c)) >= threshold - ).values.T # (n_genes, n_clusters) - - # (n_cells, n_genes) - data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") - # all 3 should be C contiguous - return parallelize( # type: ignore[no-any-return] - _analysis_helper, - np.arange(n_perms, dtype=np.int32).tolist(), - n_jobs=n_jobs, - unit="permutation", - extractor=extractor, - **kwargs, - )( - data, - mean, - mask, - interactions, - interaction_clusters=interaction_clusters, - clustering=clustering, - seed=seed, - numba_parallel=numba_parallel, + ).values # (n_clusters, n_genes) + + counts = groups.size().values.astype(np.float64) + inv_counts = 1.0 / np.maximum(counts, 1) + + data_arr = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") + + interactions = np.array(interactions, dtype=np.int32) + interaction_clusters = np.array(interaction_clusters, dtype=np.int32) + rec = interactions[:, 0] + lig = interactions[:, 1] + c1 = interaction_clusters[:, 0] + c2 = interaction_clusters[:, 1] + + obs_score = mean_obs[c1, :][:, rec].T + mean_obs[c2, :][:, lig].T + nonzero = (mean_obs[c1, :][:, rec].T > 0) & (mean_obs[c2, :][:, lig].T > 0) + valid = nonzero & mask[c1, :][:, rec].T & mask[c2, :][:, lig].T + res_means = np.where(nonzero, obs_score / 2.0, 0.0) + + n_inter = len(rec) + n_cpairs = len(c1) + pval_counts = np.zeros((n_inter, n_cpairs), dtype=np.int64) + + n_threads = numba.get_num_threads() + if n_threads <= 0: + n_threads = 1 + n_threads = min(n_threads, n_perms) + + ss = np.random.SeedSequence(seed) + child_seeds = ss.spawn(n_threads) + chunk_seeds = np.array( + [cs.generate_state(1, dtype=np.uint32)[0] for cs in child_seeds], + dtype=np.int64, ) + base_chunk, remainder = divmod(n_perms, n_threads) + chunk_sizes = np.full(n_threads, base_chunk, dtype=np.int64) + chunk_sizes[:remainder] += 1 -def _analysis_helper( - perms: NDArrayA, - data: NDArrayA, - mean: NDArrayA, - mask: NDArrayA, - interactions: NDArrayA, - interaction_clusters: NDArrayA, - clustering: NDArrayA, - seed: int | None = None, - numba_parallel: bool | None = None, - queue: SigQueue | None = None, -) -> TempResult: - """ - Run the results of mean, percent and shuffled analysis. - - Parameters - ---------- - perms - Permutation indices. Only used to set the ``seed``. - data - Array of shape `(n_cells, n_genes)`. - mean - Array of shape `(n_genes, n_clusters)` representing mean expression per cluster. - mask - Array of shape `(n_genes, n_clusters)` containing `True` if the a gene within a cluster is - expressed at least in ``threshold`` percentage of cells. - interactions - Array of shape `(n_interactions, 2)`. - interaction_clusters - Array of shape `(n_interaction_clusters, 2)`. - clustering - Array of shape `(n_cells,)` containing the original clustering. - seed - Random seed for :class:`numpy.random.RandomState`. - numba_parallel - Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - queue - Signalling queue to update progress bar. - - Returns - ------- - Tuple of the following format: - - - `'means'` - array of shape `(n_interactions, n_interaction_clusters)` containing the true test - statistic. It is `None` if ``min(perms)!=0`` so that only 1 worker calculates it. - - `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing `np.sum(T0 > T)` - where `T0` is the test statistic under null hypothesis and `T` is the true test statistic. - """ - rs = np.random.RandomState(None if seed is None else perms[0] + seed) - - clustering = clustering.copy() - n_cls = mean.shape[1] - return_means = np.min(perms) == 0 - - # ideally, these would be both sparse array, but there is no numba impl. (sparse.COO is read-only and very limited) - # keep it f64, because we're setting NaN - res = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) - numba_parallel = ( - (np.prod(res.shape) >= 2**20 or clustering.shape[0] >= 2**15) if numba_parallel is None else numba_parallel # type: ignore[assignment] + _permutation_test( + data_arr, + clustering, + inv_counts, + mean_obs, + mask, + interactions, + interaction_clusters, + valid, + n_perms, + chunk_seeds, + chunk_sizes, + pval_counts, ) - fn_key = f"_test_{n_cls}_{int(return_means)}_{bool(numba_parallel)}" - if fn_key not in globals(): - exec( - compile(_create_template(n_cls, return_means=return_means, parallel=numba_parallel), "", "exec"), # type: ignore[arg-type] - globals(), - ) - _test = globals()[fn_key] + pvalues = pval_counts.astype(np.float64) / n_perms + pvalues[~valid] = np.nan - if return_means: - res_means: NDArrayA | None = np.zeros((len(interactions), len(interaction_clusters)), dtype=np.float64) - test = partial(_test, res_means=res_means) - else: - res_means = None - test = _test - - for _ in perms: - rs.shuffle(clustering) - error = test(interactions, interaction_clusters, data, clustering, mean, mask, res=res) - if error: - raise ValueError("In the execution of the numba function, an unhandled case was encountered. ") - # This is mainly to avoid a numba warning - # Otherwise, the numba function wouldn't be - # executed in parallel - # See: https://github.com/scverse/squidpy/issues/994 - if queue is not None: - queue.put(Signal.UPDATE) - - if queue is not None: - queue.put(Signal.FINISH) - - return TempResult(means=res_means, pvalues=res) + return TempResult(means=res_means, pvalues=pvalues) diff --git a/tests/_images/Ligrec_dendrogram_clusters.png b/tests/_images/Ligrec_dendrogram_clusters.png index 1dfd65bcb..4b7794b40 100644 Binary files a/tests/_images/Ligrec_dendrogram_clusters.png and b/tests/_images/Ligrec_dendrogram_clusters.png differ diff --git a/tests/_images/Ligrec_pvalue_threshold.png b/tests/_images/Ligrec_pvalue_threshold.png index 60d94275e..9e1826bc8 100644 Binary files a/tests/_images/Ligrec_pvalue_threshold.png and b/tests/_images/Ligrec_pvalue_threshold.png differ diff --git a/tests/_images/Ligrec_remove_nonsig_interactions.png b/tests/_images/Ligrec_remove_nonsig_interactions.png index 638a28612..179b10a3f 100644 Binary files a/tests/_images/Ligrec_remove_nonsig_interactions.png and b/tests/_images/Ligrec_remove_nonsig_interactions.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 83d405d8d..ee27a6f8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -256,13 +256,6 @@ def complexes(adata: AnnData) -> Sequence[tuple[str, str]]: ] -@pytest.fixture(scope="session") -def ligrec_no_numba() -> Mapping[str, pd.DataFrame]: - with open("tests/_data/ligrec_no_numba.pickle", "rb") as fin: - data = pickle.load(fin) - return {"means": data[0], "pvalues": data[1], "metadata": data[2]} - - @pytest.fixture(scope="session") def ligrec_result() -> Mapping[str, pd.DataFrame]: adata = _adata.copy() @@ -272,8 +265,6 @@ def ligrec_result() -> Mapping[str, pd.DataFrame]: "leiden", interactions=interactions, n_perms=25, - n_jobs=1, - show_progress_bar=False, copy=True, seed=0, ) diff --git a/tests/graph/test_ligrec.py b/tests/graph/test_ligrec.py index 099ecb1b6..2bc293cc1 100644 --- a/tests/graph/test_ligrec.py +++ b/tests/graph/test_ligrec.py @@ -1,11 +1,11 @@ from __future__ import annotations import sys -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from itertools import product -from time import time from typing import TYPE_CHECKING +import numba import numpy as np import pandas as pd import pytest @@ -167,8 +167,6 @@ def test_fdr_axis_works(self, adata: AnnData, interactions: Interactions_t): n_perms=5, corr_axis="clusters", seed=42, - n_jobs=1, - show_progress_bar=False, copy=True, ) ri = ligrec( @@ -177,8 +175,6 @@ def test_fdr_axis_works(self, adata: AnnData, interactions: Interactions_t): interactions=interactions, n_perms=5, corr_axis="interactions", - n_jobs=1, - show_progress_bar=False, seed=42, copy=True, ) @@ -191,7 +187,7 @@ def test_fdr_axis_works(self, adata: AnnData, interactions: Interactions_t): def test_inplace_default_key(self, adata: AnnData, interactions: Interactions_t): key = Key.uns.ligrec(_CK) assert key not in adata.uns - res = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=False, show_progress_bar=False) + res = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=False) assert res is None assert isinstance(adata.uns[key], dict) @@ -203,9 +199,7 @@ def test_inplace_default_key(self, adata: AnnData, interactions: Interactions_t) def test_inplace_key_added(self, adata: AnnData, interactions: Interactions_t): assert "foobar" not in adata.uns - res = ligrec( - adata, _CK, interactions=interactions, n_perms=5, copy=False, key_added="foobar", show_progress_bar=False - ) + res = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=False, key_added="foobar") assert res is None assert isinstance(adata.uns["foobar"], dict) @@ -217,9 +211,7 @@ def test_inplace_key_added(self, adata: AnnData, interactions: Interactions_t): def test_return_no_write(self, adata: AnnData, interactions: Interactions_t): assert "foobar" not in adata.uns - r = ligrec( - adata, _CK, interactions=interactions, n_perms=5, copy=True, key_added="foobar", show_progress_bar=False - ) + r = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=True, key_added="foobar") assert "foobar" not in adata.uns assert len(r) == 3 @@ -235,7 +227,6 @@ def test_pvals_in_correct_range(self, adata: AnnData, interactions: Interactions interactions=interactions, n_perms=5, copy=True, - show_progress_bar=False, corr_method=fdr_method, threshold=0, ) @@ -247,7 +238,7 @@ def test_pvals_in_correct_range(self, adata: AnnData, interactions: Interactions assert np.nanmin(r["pvalues"].values) >= 0, np.nanmin(r["pvalues"].values) def test_result_correct_index(self, adata: AnnData, interactions: Interactions_t): - r = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=True, show_progress_bar=False) + r = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=True) np.testing.assert_array_equal(r["means"].index, r["pvalues"].index) np.testing.assert_array_equal(r["pvalues"].index, r["metadata"].index) @@ -261,7 +252,7 @@ def test_result_is_sparse(self, adata: AnnData, interactions: Interactions_t): if TYPE_CHECKING: assert isinstance(interactions, pd.DataFrame) interactions["metadata"] = "foo" - r = ligrec(adata, _CK, interactions=interactions, n_perms=5, seed=2, copy=True, show_progress_bar=False) + r = ligrec(adata, _CK, interactions=interactions, n_perms=5, seed=2, copy=True) assert r["means"].sparse.density <= 0.15 assert r["pvalues"].sparse.density <= 0.95 @@ -272,17 +263,14 @@ def test_result_is_sparse(self, adata: AnnData, interactions: Interactions_t): np.testing.assert_array_equal(r["metadata"].columns, ["metadata"]) np.testing.assert_array_equal(r["metadata"]["metadata"], interactions["metadata"]) - @pytest.mark.parametrize("n_jobs", [1, 2]) - def test_reproducibility_cores(self, adata: AnnData, interactions: Interactions_t, n_jobs: int): + def test_reproducibility(self, adata: AnnData, interactions: Interactions_t): r1 = ligrec( adata, _CK, interactions=interactions, n_perms=25, copy=True, - show_progress_bar=False, seed=42, - n_jobs=n_jobs, ) r2 = ligrec( adata, @@ -290,9 +278,7 @@ def test_reproducibility_cores(self, adata: AnnData, interactions: Interactions_ interactions=interactions, n_perms=25, copy=True, - show_progress_bar=False, seed=42, - n_jobs=n_jobs, ) r3 = ligrec( adata, @@ -300,9 +286,7 @@ def test_reproducibility_cores(self, adata: AnnData, interactions: Interactions_ interactions=interactions, n_perms=25, copy=True, - show_progress_bar=False, seed=43, - n_jobs=n_jobs, ) assert r1 is not r2 @@ -313,39 +297,6 @@ def test_reproducibility_cores(self, adata: AnnData, interactions: Interactions_ assert not np.allclose(r3["pvalues"], r1["pvalues"]) assert not np.allclose(r3["pvalues"], r2["pvalues"]) - def test_reproducibility_numba_parallel_off(self, adata: AnnData, interactions: Interactions_t): - t1 = time() - r1 = ligrec( - adata, - _CK, - interactions=interactions, - n_perms=25, - copy=True, - show_progress_bar=False, - seed=42, - numba_parallel=False, - ) - t1 = time() - t1 - - t2 = time() - r2 = ligrec( - adata, - _CK, - interactions=interactions, - n_perms=25, - copy=True, - show_progress_bar=False, - seed=42, - numba_parallel=True, - ) - t2 = time() - t2 - - assert r1 is not r2 - # for such a small data, overhead from parallelization is too high - assert t1 <= t2, (t1, t2) - np.testing.assert_allclose(r1["means"], r2["means"]) - np.testing.assert_allclose(r1["pvalues"], r2["pvalues"]) - def test_paul15_correct_means(self, paul15: AnnData, paul15_means: pd.DataFrame): res = ligrec( paul15, @@ -353,31 +304,29 @@ def test_paul15_correct_means(self, paul15: AnnData, paul15_means: pd.DataFrame) interactions=list(paul15_means.index.to_list()), corr_method=None, copy=True, - show_progress_bar=False, threshold=0.01, seed=0, n_perms=1, - n_jobs=1, ) np.testing.assert_array_equal(res["means"].index, paul15_means.index) np.testing.assert_array_equal(res["means"].columns, paul15_means.columns) np.testing.assert_allclose(res["means"].values, paul15_means.values) - def test_reproducibility_numba_off( - self, adata: AnnData, interactions: Interactions_t, ligrec_no_numba: Mapping[str, pd.DataFrame] - ): - r = ligrec( - adata, _CK, interactions=interactions, n_perms=5, copy=True, show_progress_bar=False, seed=42, n_jobs=1 - ) - np.testing.assert_array_equal(r["means"].index, ligrec_no_numba["means"].index) - np.testing.assert_array_equal(r["means"].columns, ligrec_no_numba["means"].columns) - np.testing.assert_array_equal(r["pvalues"].index, ligrec_no_numba["pvalues"].index) - np.testing.assert_array_equal(r["pvalues"].columns, ligrec_no_numba["pvalues"].columns) + def test_reproducibility_single_thread(self, adata: AnnData, interactions: Interactions_t): - np.testing.assert_allclose(r["means"], ligrec_no_numba["means"]) - np.testing.assert_allclose(r["pvalues"], ligrec_no_numba["pvalues"]) - np.testing.assert_array_equal(np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_no_numba["pvalues"]))) + old_threads = numba.get_num_threads() + try: + numba.set_num_threads(1) + r1 = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=True, seed=42) + finally: + numba.set_num_threads(old_threads) + + r2 = ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=True, seed=42) + + np.testing.assert_allclose(r1["means"], r2["means"]) + np.testing.assert_allclose(r1["pvalues"], r2["pvalues"]) + np.testing.assert_array_equal(np.where(np.isnan(r1["pvalues"])), np.where(np.isnan(r2["pvalues"]))) def test_logging(self, adata: AnnData, interactions: Interactions_t, capsys): s.logfile = sys.stderr @@ -389,10 +338,8 @@ def test_logging(self, adata: AnnData, interactions: Interactions_t, capsys): interactions=interactions, n_perms=5, copy=False, - show_progress_bar=False, complex_policy="all", key_added="ligrec_test", - n_jobs=2, ) err = capsys.readouterr().err @@ -402,7 +349,7 @@ def test_logging(self, adata: AnnData, interactions: Interactions_t, capsys): assert "DEBUG: Creating all gene combinations within complexes" in err assert "DEBUG: Removing interactions with no genes in the data" in err assert "DEBUG: Removing genes not in any interaction" in err - assert "Running `5` permutations on `25` interactions and `25` cluster combinations using `2` core(s)" in err + assert "Running `5` permutations on `25` interactions and `25` cluster combinations" in err assert "Adding `adata.uns['ligrec_test']`" in err def test_non_uniqueness(self, adata: AnnData, interactions: Interactions_t): @@ -418,9 +365,7 @@ def test_non_uniqueness(self, adata: AnnData, interactions: Interactions_t): interactions=interactions, n_perms=1, copy=True, - show_progress_bar=False, seed=42, - numba_parallel=False, ) assert len(res["pvalues"]) == len(expected) @@ -428,7 +373,7 @@ def test_non_uniqueness(self, adata: AnnData, interactions: Interactions_t): @pytest.mark.xfail(reason="AnnData cannot handle writing MultiIndex") def test_writeable(self, adata: AnnData, interactions: Interactions_t, tmpdir): - ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=False, show_progress_bar=False, key_added="foo") + ligrec(adata, _CK, interactions=interactions, n_perms=5, copy=False, key_added="foo") res = adata.uns["foo"] sc.write(tmpdir / "ligrec.h5ad", adata) @@ -448,7 +393,6 @@ def test_gene_symbols(self, adata: AnnData, use_raw: bool): n_perms=5, use_raw=use_raw, copy=True, - show_progress_bar=False, gene_symbols="gene_ids", )