diff --git a/cellcommunicationpf2/rank_selection/__init__.py b/cellcommunicationpf2/rank_selection/__init__.py new file mode 100644 index 0000000..e6ee2c4 --- /dev/null +++ b/cellcommunicationpf2/rank_selection/__init__.py @@ -0,0 +1,23 @@ +""" +PARAFAC2 Rank Selection via Cell Holdout Cross-Validation. + +Main API: + run_rank_selection(adata, ranks, condition_key) -> pd.DataFrame + +Usage: + from cellcommunicationpf2.rank_selection import run_rank_selection + + df_results = run_rank_selection( + adata=your_anndata, + ranks=list(range(2, 15)), + condition_key="sample", + n_folds=5 + ) + + # Results are in a DataFrame with columns: rank, ot_score_mean, ot_score_std, r2x_mean + print(df_results) +""" + +from .rank_selection import run_rank_selection + +__all__ = ["run_rank_selection"] diff --git a/cellcommunicationpf2/rank_selection/rank_selection.py b/cellcommunicationpf2/rank_selection/rank_selection.py new file mode 100644 index 0000000..9877644 --- /dev/null +++ b/cellcommunicationpf2/rank_selection/rank_selection.py @@ -0,0 +1,151 @@ +"""PARAFAC2 Rank Selection via Cell Holdout Cross-Validation.""" + +import time + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +from ott.geometry import pointcloud +from ott.tools import sinkhorn_divergence +from scipy.sparse import issparse +from sklearn.model_selection import StratifiedKFold +from tensorly.parafac2_tensor import parafac2_to_slices + +from parafac2.parafac2 import parafac2_nd + +jax.config.update("jax_enable_x64", True) + + +def create_stratified_splits(adata, condition_key, n_folds=5, random_state=42): + """Create K-Fold splits stratified by condition.""" + if condition_key not in adata.obs: + raise ValueError(f"Condition key '{condition_key}' not found in adata.obs") + + conditions = adata.obs[condition_key].astype("category").cat.codes.values + skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_state) + + splits = [] + for i, (train_idx, test_idx) in enumerate(skf.split(np.zeros(len(conditions)), conditions)): + splits.append((i, train_idx, test_idx)) + return splits + + +def reconstruct_from_pf2(pf2_output, condition_idxs): + """Reconstruct expression matrix from PARAFAC2 output.""" + weights, factors, projections = pf2_output + slices = parafac2_to_slices((weights, factors, projections)) + + n_genes = factors[2].shape[0] + n_cells = len(condition_idxs) + X_recon = np.zeros((n_cells, n_genes), dtype=np.float64) + + for cond_idx in np.unique(condition_idxs): + mask = condition_idxs == cond_idx + X_recon[mask, :] = slices[cond_idx] + + return X_recon + + +def _compute_sinkhorn_divergence(X_recon, X_real, epsilon): + """Compute Sinkhorn divergence between two point clouds.""" + scale = np.mean(np.linalg.norm(X_real, axis=1)) + if scale == 0: + scale = 1.0 + + X_real_norm = jnp.array(X_real / scale, dtype=jnp.float64) + X_recon_norm = jnp.array(X_recon / scale, dtype=jnp.float64) + + div = sinkhorn_divergence.sinkhorn_divergence( + pointcloud.PointCloud, x=X_recon_norm, y=X_real_norm, epsilon=epsilon + ) + + return float(div[0]) if isinstance(div, tuple) else float(div.divergence) + + +def compute_ot_score(X_recon, X_real, epsilon=0.1, condition_idxs=None): + """Compute Sinkhorn divergence between reconstructed and real data. + + If condition_idxs provided, computes per-condition OT and returns the mean. + """ + if issparse(X_real): + X_real = X_real.toarray() + + X_real = np.asarray(X_real, dtype=np.float64) + X_recon = np.asarray(X_recon, dtype=np.float64) + + if condition_idxs is None: + return _compute_sinkhorn_divergence(X_recon, X_real, epsilon) + + condition_idxs = np.asarray(condition_idxs) + slice_scores = [] + + for cond_idx in np.unique(condition_idxs): + mask = condition_idxs == cond_idx + X_real_slice = X_real[mask] + X_recon_slice = X_recon[mask] + + if len(X_real_slice) == 0: + continue + + score = _compute_sinkhorn_divergence(X_recon_slice, X_real_slice, epsilon) + slice_scores.append(score) + + return np.mean(slice_scores) if slice_scores else 0.0 + + +def run_rank_selection( + adata, ranks, condition_key, + n_folds=5, n_iter_max=100, tol=1e-6, ot_epsilon=0.1, random_state=1 +): + """Run cross-validation pipeline for rank selection. + + Returns DataFrame with columns: rank, ot_score_mean, ot_score_std, r2x_mean + """ + from ..import_data import add_cond_idxs + + print(f"Starting Rank Selection ({n_folds}-fold CV)...") + print(f"Testing ranks: {ranks}") + + splits = create_stratified_splits(adata, condition_key, n_folds, random_state) + results = [] + + for r in ranks: + print(f"\n--- Rank {r} ---") + fold_scores = [] + fold_r2x = [] + start_time = time.time() + + for fold_idx, train_idx, test_idx in splits: + train_adata = add_cond_idxs(adata[train_idx].copy(), condition_key) + test_adata = adata[test_idx] + + try: + pf2_output, r2x = parafac2_nd( + train_adata, rank=r, n_iter_max=n_iter_max, + tol=tol, random_state=random_state + ) + except Exception as e: + print(f"Fit failed for Rank {r}: {e}") + continue + + train_cond_idxs = train_adata.obs["condition_unique_idxs"].values + X_train_recon = reconstruct_from_pf2(pf2_output, train_cond_idxs) + + X_test_real = test_adata.X + score = compute_ot_score(X_train_recon, X_test_real, ot_epsilon) + + fold_scores.append(score) + fold_r2x.append(r2x) + print(f" Fold {fold_idx+1}: OT = {score:.4f}") + + if fold_scores: + results.append({ + "rank": r, + "ot_score_mean": np.mean(fold_scores), + "ot_score_std": np.std(fold_scores), + "r2x_mean": np.mean(fold_r2x) + }) + print(f"Rank {r} done in {time.time() - start_time:.1f}s | Mean OT: {np.mean(fold_scores):.4f}") + + return pd.DataFrame(results) diff --git a/cellcommunicationpf2/rank_selection/tests/generate_validation_plots.py b/cellcommunicationpf2/rank_selection/tests/generate_validation_plots.py new file mode 100644 index 0000000..e491355 --- /dev/null +++ b/cellcommunicationpf2/rank_selection/tests/generate_validation_plots.py @@ -0,0 +1,198 @@ +"""Generate rank selection validation plots using synthetic data with known rank.""" + +import sys +from pathlib import Path + +import anndata as ad +import matplotlib.pyplot as plt +import numpy as np +from scipy.sparse import csr_array +from tensorly.parafac2_tensor import parafac2_to_slices +from tensorly.random import random_parafac2 + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from cellcommunicationpf2.import_data import add_cond_idxs, import_balf_covid +from cellcommunicationpf2.rank_selection import run_rank_selection +from parafac2.parafac2 import parafac2_nd + +USE_RANDOM_PARAFAC2 = False +TARGET_CELLS_PER_COND = 2500 +TARGET_N_CONDITIONS = 8 +TARGET_N_GENES = 200 + +print("Loading real scRNA-seq data...") +adata_full = import_balf_covid() +print(f"Full data: {adata_full.shape}") + + +def create_known_rank_data(adata_full, true_rank, cells_per_cond=100, n_genes=200, + random_state=42): + """Factorize real data at known rank, then reconstruct to get data with known effective rank.""" + print(f"\n=== Creating known-rank-{true_rank} data (factorize-reconstruct) ===") + + np.random.seed(random_state) + + # Select conditions and sample cells from each + all_conditions = adata_full.obs['sample'].unique() + n_conditions = len(all_conditions) + + all_sampled_indices = [] + for condition in all_conditions: + cell_indices_for_condition = np.where(adata_full.obs['sample'] == condition)[0] + n_cells_to_sample = min(cells_per_cond, len(cell_indices_for_condition)) + sampled_cell_indices = np.random.choice( + cell_indices_for_condition, size=n_cells_to_sample, replace=False + ) + all_sampled_indices.extend(sampled_cell_indices) + + adata_subsampled = adata_full[all_sampled_indices].copy() + + # Keep only the top expressed genes + gene_mean_expression = np.asarray(adata_subsampled.X.mean(axis=0)).flatten() + top_gene_indices = np.argsort(gene_mean_expression)[-n_genes:] + adata_subsampled = adata_subsampled[:, top_gene_indices].copy() + + # Add condition index column required by parafac2_nd + adata_subsampled = add_cond_idxs(adata_subsampled, "sample") + print(f"Subsampled: {adata_subsampled.shape}") + + # Factorize at the true rank + print(f"Factorizing at rank {true_rank}...") + pf2_output, r2x = parafac2_nd( + adata_subsampled, rank=true_rank, n_iter_max=100, tol=1e-6, random_state=random_state + ) + print(f"R2X at rank {true_rank}: {r2x:.4f}") + + # Reconstruct data from the factorization (this gives us exact-rank data) + reconstructed_slices = parafac2_to_slices(pf2_output) + condition_idxs = adata_subsampled.obs["condition_unique_idxs"].values + + reconstructed_X_per_condition = [] + for cond_idx in range(n_conditions): + cells_in_condition = (condition_idxs == cond_idx).sum() + reconstructed_slice = reconstructed_slices[cond_idx][:cells_in_condition, :] + reconstructed_X_per_condition.append(reconstructed_slice) + + reconstructed_X = np.vstack(reconstructed_X_per_condition).astype(np.float32) + + # Create output AnnData with reconstructed matrix + adata_reconstructed = adata_subsampled.copy() + adata_reconstructed.X = csr_array(reconstructed_X) + + return adata_reconstructed, r2x + + +def create_random_parafac2_data(true_rank, n_conditions=8, cells_per_cond=100, n_genes=200, + random_state=42): + """Generate synthetic PARAFAC2 data with random factors.""" + print(f"\n=== Creating known-rank-{true_rank} data (random PARAFAC2) ===") + + shapes = [(cells_per_cond, n_genes) for _ in range(n_conditions)] + pf2_tensor = random_parafac2(shapes, rank=true_rank, full=False, random_state=random_state) + slices = parafac2_to_slices(pf2_tensor) + + print(f"Generated {n_conditions} conditions with {cells_per_cond} cells each, {n_genes} genes") + + X_list = [] + sample_labels = [] + condition_idxs = [] + + for i, slice_data in enumerate(slices): + X_list.append(np.asarray(slice_data)) + sample_labels.extend([f"condition_{i}"] * slice_data.shape[0]) + condition_idxs.extend([i] * slice_data.shape[0]) + + X = np.vstack(X_list).astype(np.float32) + + np.random.seed(random_state) + noise_scale = 0.01 * np.std(X) + X += np.random.randn(*X.shape).astype(np.float32) * noise_scale + + adata = ad.AnnData(X=csr_array(X)) + adata.obs['sample'] = sample_labels + adata.obs['condition_unique_idxs'] = condition_idxs + adata.var_names = [f"gene_{i}" for i in range(n_genes)] + + print(f"Synthetic data shape: {adata.shape}") + + return adata, 1.0 + + +def get_test_ranks(true_rank, interval=5, overshoot=20): + """Generate list of ranks to test, including the true rank.""" + ranks = list(range(1, true_rank + overshoot + 1, interval)) + if true_rank not in ranks: + ranks.append(true_rank) + return sorted(ranks) + +true_ranks = [20, 30, 40, 50] +results_all = {} +method_name = "Random PARAFAC2" if USE_RANDOM_PARAFAC2 else "Factorize-Reconstruct" + +for true_rank in true_ranks: + if USE_RANDOM_PARAFAC2: + adata, r2x = create_random_parafac2_data( + true_rank, + n_conditions=TARGET_N_CONDITIONS, + cells_per_cond=TARGET_CELLS_PER_COND, + n_genes=TARGET_N_GENES + ) + else: + adata, r2x = create_known_rank_data( + adata_full, true_rank, + cells_per_cond=TARGET_CELLS_PER_COND, + n_genes=TARGET_N_GENES + ) + + ranks_to_test = get_test_ranks(true_rank) + print(f"\nRunning rank selection on known-rank-{true_rank} data...") + print(f"Data shape: {adata.shape}, Testing ranks: {ranks_to_test}") + + results = run_rank_selection( + adata, ranks=ranks_to_test, condition_key="sample", + n_folds=5, n_iter_max=100, tol=1e-6, ot_epsilon=0.05, random_state=42 + ) + + results_all[true_rank] = {"results": results, "r2x_fit": r2x} + print(f"True rank: {true_rank} validation complete") + +fig, axes = plt.subplots(2, 2, figsize=(12, 10)) +axes = axes.flatten() + +for ax, true_rank in zip(axes, true_ranks): + results = results_all[true_rank]['results'] + ranks = results["rank"].values + scores = results["ot_score_mean"].values + stds = results["ot_score_std"].values + + ax.errorbar(ranks, scores, yerr=stds, marker='o', markersize=7, + linewidth=2, capsize=4, color='steelblue', label='OT Score') + ax.axvline(true_rank, color='green', linestyle='-', linewidth=3, + label=f'True Rank ({true_rank})', alpha=0.8) + + ax.set_xlabel('Rank', fontsize=11, fontweight='bold') + ax.set_ylabel('OT Score', fontsize=11, fontweight='bold') + ax.set_title(f'True Rank = {true_rank}', fontsize=12, fontweight='bold') + ax.set_xticks(ranks) + ax.tick_params(axis='x', labelsize=8, rotation=45) + ax.legend(loc='upper right', fontsize=8) + ax.grid(True, alpha=0.3) + +fig.suptitle(f'Rank Selection Validation ({method_name})', fontsize=14, fontweight='bold') +plt.tight_layout() + +output_path = f"/home/nthomas/cellcommunication-Pf2/cellcommunicationpf2/rank_selection/output/validation_{method_name.lower().replace('-', '_').replace(' ', '_')}.png" +plt.savefig(output_path, dpi=150, bbox_inches='tight') +print(f"\n✓ Plot saved: {output_path}") + +print("\n" + "="*50) +print("SUMMARY") +print("="*50) +for true_rank in true_ranks: + data = results_all[true_rank] + min_idx = data['results']["ot_score_mean"].idxmin() + best_rank = data['results'].loc[min_idx, "rank"] + print(f"True={true_rank}, Best OT Rank={best_rank}, R2X@fit={data['r2x_fit']:.4f}") + +plt.show() diff --git a/pyproject.toml b/pyproject.toml index 81f15b3..0735a8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,9 @@ dependencies = [ "parafac2 @ git+https://github.com/meyer-lab/parafac2.git@main", "zstandard>=0.23.0", "pyarrow>=15.0.0", + "jax[cuda12]>=0.4.20", + "ott-jax>=0.4.6", + "kneed>=0.8.5", ] readme = "README.md" @@ -68,3 +71,8 @@ select = [ # Unused arguments "ARG", ] + +[tool.rye] +dev-dependencies = [ + "pytest>=9.0.2", +] diff --git a/requirements-dev.lock b/requirements-dev.lock index 4a859a9..65e8cc1 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -4,12 +4,15 @@ # last locked with the following flags: # pre: false # features: [] -# all-features: false +# all-features: true # with-sources: false # generate-hashes: false # universal: false -e file:. +absl-py==2.3.1 + # via chex + # via optax anndata==0.12.3 # via cellcommunicationpf2 # via liana @@ -26,6 +29,8 @@ cfgv==3.4.0 # via pre-commit charset-normalizer==3.4.1 # via requests +chex==0.1.90 + # via optax colorcet==3.1.0 # via datashader contourpy==1.3.1 @@ -44,6 +49,8 @@ docrep==0.3.2 # via liana donfig==0.8.1.post1 # via zarr +equinox==0.13.2 + # via lineax fastrlock==0.8.3 # via cupy-cuda12x filelock==3.18.0 @@ -57,17 +64,45 @@ identify==2.6.12 # via pre-commit idna==3.10 # via requests +iniconfig==2.3.0 + # via pytest +jax==0.9.0 + # via cellcommunicationpf2 + # via chex + # via equinox + # via jaxopt + # via lineax + # via optax + # via ott-jax +jax-cuda12-pjrt==0.9.0 + # via jax-cuda12-plugin +jax-cuda12-plugin==0.9.0 + # via jax +jaxlib==0.9.0 + # via chex + # via jax + # via jaxopt + # via optax +jaxopt==0.8.5 + # via ott-jax +jaxtyping==0.3.5 + # via equinox + # via lineax joblib==1.4.2 # via pynndescent # via scanpy # via scikit-learn kiwisolver==1.4.8 # via matplotlib +kneed==0.8.5 + # via cellcommunicationpf2 legacy-api-wrap==1.4.1 # via anndata # via scanpy liana==1.5.1 # via cellcommunicationpf2 +lineax==0.0.8 + # via ott-jax llvmlite==0.44.0 # via numba # via pynndescent @@ -79,6 +114,9 @@ matplotlib==3.10.1 # via tensorly-viz mizani==0.14.2 # via plotnine +ml-dtypes==0.5.4 + # via jax + # via jaxlib mudata==0.3.2 # via liana multipledispatch==1.0.0 @@ -102,14 +140,22 @@ numcodecs==0.16.3 numpy==2.2.6 # via anndata # via cellcommunicationpf2 + # via chex # via contourpy # via cupy-cuda12x # via datashader # via h5py + # via jax + # via jaxlib + # via jaxopt + # via kneed # via matplotlib # via mizani + # via ml-dtypes # via numba # via numcodecs + # via optax + # via ott-jax # via pacmap # via pandas # via parafac2 @@ -126,10 +172,49 @@ numpy==2.2.6 # via umap-learn # via xarray # via zarr +nvidia-cublas-cu12==12.9.1.4 + # via jax-cuda12-plugin + # via nvidia-cudnn-cu12 + # via nvidia-cusolver-cu12 +nvidia-cuda-cccl-cu12==12.9.27 + # via nvidia-nvshmem-cu12 +nvidia-cuda-cupti-cu12==12.9.79 + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.9.86 + # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 + # via jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.9.79 + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.18.0.77 + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.4.1.4 + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.5.82 + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.10.65 + # via jax-cuda12-plugin + # via nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.29.2 + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.9.86 + # via jax-cuda12-plugin + # via nvidia-cufft-cu12 + # via nvidia-cusolver-cu12 + # via nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.5.19 + # via jax-cuda12-plugin +opt-einsum==3.4.0 + # via jax +optax==0.2.6 + # via ott-jax +ott-jax==0.6.0 + # via cellcommunicationpf2 packaging==24.2 # via anndata # via datashader # via matplotlib + # via pytest # via scanpy # via statsmodels # via xarray @@ -160,12 +245,16 @@ platformdirs==4.3.8 # via virtualenv plotnine==0.15.0 # via liana +pluggy==1.6.0 + # via pytest pre-commit==4.2.0 # via liana pyarrow==23.0.0 # via cellcommunicationpf2 pyct==0.5.0 # via datashader +pygments==2.19.2 + # via pytest pynndescent==0.5.13 # via scanpy # via umap-learn @@ -173,6 +262,7 @@ pyparsing==3.2.2 # via matplotlib pyscipopt==5.5.0 # via liana +pytest==9.0.2 python-dateutil==2.9.0.post0 # via matplotlib # via pandas @@ -196,6 +286,10 @@ scipy==1.16.2 # via anndata # via cellcommunicationpf2 # via datashader + # via jax + # via jaxlib + # via jaxopt + # via kneed # via mizani # via parafac2 # via plotnine @@ -211,6 +305,8 @@ seaborn==0.13.2 # via scanpy session-info2==0.1.2 # via scanpy +setuptools==80.10.1 + # via chex six==1.17.0 # via docrep # via python-dateutil @@ -226,6 +322,7 @@ tensorly-viz==0.1.7 threadpoolctl==3.6.0 # via scikit-learn toolz==1.0.0 + # via chex # via datashader tqdm==4.67.1 # via liana @@ -233,6 +330,9 @@ tqdm==4.67.1 # via scanpy # via umap-learn typing-extensions==4.12.2 + # via chex + # via equinox + # via lineax # via numcodecs # via scanpy # via zarr @@ -244,6 +344,9 @@ urllib3==2.3.0 # via requests virtualenv==20.31.2 # via pre-commit +wadler-lindig==0.1.7 + # via equinox + # via jaxtyping xarray==2025.3.0 # via datashader # via tensorly-viz diff --git a/requirements.lock b/requirements.lock index cf39f25..2b6668f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -4,12 +4,15 @@ # last locked with the following flags: # pre: false # features: [] -# all-features: false +# all-features: true # with-sources: false # generate-hashes: false # universal: false -e file:. +absl-py==2.3.1 + # via chex + # via optax anndata==0.12.3 # via cellcommunicationpf2 # via liana @@ -26,6 +29,8 @@ cfgv==3.4.0 # via pre-commit charset-normalizer==3.4.1 # via requests +chex==0.1.90 + # via optax colorcet==3.1.0 # via datashader contourpy==1.3.1 @@ -44,6 +49,8 @@ docrep==0.3.2 # via liana donfig==0.8.1.post1 # via zarr +equinox==0.13.2 + # via lineax fastrlock==0.8.3 # via cupy-cuda12x filelock==3.18.0 @@ -57,17 +64,43 @@ identify==2.6.12 # via pre-commit idna==3.10 # via requests +jax==0.9.0 + # via cellcommunicationpf2 + # via chex + # via equinox + # via jaxopt + # via lineax + # via optax + # via ott-jax +jax-cuda12-pjrt==0.9.0 + # via jax-cuda12-plugin +jax-cuda12-plugin==0.9.0 + # via jax +jaxlib==0.9.0 + # via chex + # via jax + # via jaxopt + # via optax +jaxopt==0.8.5 + # via ott-jax +jaxtyping==0.3.5 + # via equinox + # via lineax joblib==1.4.2 # via pynndescent # via scanpy # via scikit-learn kiwisolver==1.4.8 # via matplotlib +kneed==0.8.5 + # via cellcommunicationpf2 legacy-api-wrap==1.4.1 # via anndata # via scanpy liana==1.5.1 # via cellcommunicationpf2 +lineax==0.0.8 + # via ott-jax llvmlite==0.44.0 # via numba # via pynndescent @@ -79,6 +112,9 @@ matplotlib==3.10.1 # via tensorly-viz mizani==0.14.2 # via plotnine +ml-dtypes==0.5.4 + # via jax + # via jaxlib mudata==0.3.2 # via liana multipledispatch==1.0.0 @@ -102,14 +138,22 @@ numcodecs==0.16.3 numpy==2.2.6 # via anndata # via cellcommunicationpf2 + # via chex # via contourpy # via cupy-cuda12x # via datashader # via h5py + # via jax + # via jaxlib + # via jaxopt + # via kneed # via matplotlib # via mizani + # via ml-dtypes # via numba # via numcodecs + # via optax + # via ott-jax # via pacmap # via pandas # via parafac2 @@ -126,6 +170,44 @@ numpy==2.2.6 # via umap-learn # via xarray # via zarr +nvidia-cublas-cu12==12.9.1.4 + # via jax-cuda12-plugin + # via nvidia-cudnn-cu12 + # via nvidia-cusolver-cu12 +nvidia-cuda-cccl-cu12==12.9.27 + # via nvidia-nvshmem-cu12 +nvidia-cuda-cupti-cu12==12.9.79 + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.9.86 + # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 + # via jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.9.79 + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.18.0.77 + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.4.1.4 + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.5.82 + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.10.65 + # via jax-cuda12-plugin + # via nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.29.2 + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.9.86 + # via jax-cuda12-plugin + # via nvidia-cufft-cu12 + # via nvidia-cusolver-cu12 + # via nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.5.19 + # via jax-cuda12-plugin +opt-einsum==3.4.0 + # via jax +optax==0.2.6 + # via ott-jax +ott-jax==0.6.0 + # via cellcommunicationpf2 packaging==24.2 # via anndata # via datashader @@ -196,6 +278,10 @@ scipy==1.16.2 # via anndata # via cellcommunicationpf2 # via datashader + # via jax + # via jaxlib + # via jaxopt + # via kneed # via mizani # via parafac2 # via plotnine @@ -211,6 +297,8 @@ seaborn==0.13.2 # via scanpy session-info2==0.1.2 # via scanpy +setuptools==80.10.1 + # via chex six==1.17.0 # via docrep # via python-dateutil @@ -226,6 +314,7 @@ tensorly-viz==0.1.7 threadpoolctl==3.6.0 # via scikit-learn toolz==1.0.0 + # via chex # via datashader tqdm==4.67.1 # via liana @@ -233,6 +322,9 @@ tqdm==4.67.1 # via scanpy # via umap-learn typing-extensions==4.14.0 + # via chex + # via equinox + # via lineax # via numcodecs # via scanpy # via zarr @@ -244,6 +336,9 @@ urllib3==2.3.0 # via requests virtualenv==20.31.2 # via pre-commit +wadler-lindig==0.1.7 + # via equinox + # via jaxtyping xarray==2025.3.0 # via datashader # via tensorly-viz