Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cellcommunicationpf2/rank_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
PARAFAC2 Rank Selection via Cell Holdout Cross-Validation.

Main API:
run_rank_selection(adata, ranks, condition_key) -> pd.DataFrame

Usage:
from cellcommunicationpf2.rank_selection import run_rank_selection

df_results = run_rank_selection(
adata=your_anndata,
ranks=list(range(2, 15)),
condition_key="sample",
n_folds=5
)

# Results are in a DataFrame with columns: rank, ot_score_mean, ot_score_std, r2x_mean
print(df_results)
"""

from .rank_selection import run_rank_selection

__all__ = ["run_rank_selection"]
151 changes: 151 additions & 0 deletions cellcommunicationpf2/rank_selection/rank_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""PARAFAC2 Rank Selection via Cell Holdout Cross-Validation."""

import time

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from ott.geometry import pointcloud
from ott.tools import sinkhorn_divergence
from scipy.sparse import issparse
from sklearn.model_selection import StratifiedKFold
from tensorly.parafac2_tensor import parafac2_to_slices

from parafac2.parafac2 import parafac2_nd

jax.config.update("jax_enable_x64", True)


def create_stratified_splits(adata, condition_key, n_folds=5, random_state=42):
"""Create K-Fold splits stratified by condition."""
if condition_key not in adata.obs:
raise ValueError(f"Condition key '{condition_key}' not found in adata.obs")

conditions = adata.obs[condition_key].astype("category").cat.codes.values
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_state)

splits = []
for i, (train_idx, test_idx) in enumerate(skf.split(np.zeros(len(conditions)), conditions)):
splits.append((i, train_idx, test_idx))
return splits


def reconstruct_from_pf2(pf2_output, condition_idxs):
"""Reconstruct expression matrix from PARAFAC2 output."""
weights, factors, projections = pf2_output
slices = parafac2_to_slices((weights, factors, projections))

n_genes = factors[2].shape[0]
n_cells = len(condition_idxs)
X_recon = np.zeros((n_cells, n_genes), dtype=np.float64)

for cond_idx in np.unique(condition_idxs):
mask = condition_idxs == cond_idx
X_recon[mask, :] = slices[cond_idx]

return X_recon


def _compute_sinkhorn_divergence(X_recon, X_real, epsilon):
"""Compute Sinkhorn divergence between two point clouds."""
scale = np.mean(np.linalg.norm(X_real, axis=1))
if scale == 0:
scale = 1.0

X_real_norm = jnp.array(X_real / scale, dtype=jnp.float64)
X_recon_norm = jnp.array(X_recon / scale, dtype=jnp.float64)

div = sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud, x=X_recon_norm, y=X_real_norm, epsilon=epsilon
)

return float(div[0]) if isinstance(div, tuple) else float(div.divergence)


def compute_ot_score(X_recon, X_real, epsilon=0.1, condition_idxs=None):
"""Compute Sinkhorn divergence between reconstructed and real data.

If condition_idxs provided, computes per-condition OT and returns the mean.
"""
if issparse(X_real):
X_real = X_real.toarray()

X_real = np.asarray(X_real, dtype=np.float64)
X_recon = np.asarray(X_recon, dtype=np.float64)

if condition_idxs is None:
return _compute_sinkhorn_divergence(X_recon, X_real, epsilon)

condition_idxs = np.asarray(condition_idxs)
slice_scores = []

for cond_idx in np.unique(condition_idxs):
mask = condition_idxs == cond_idx
X_real_slice = X_real[mask]
X_recon_slice = X_recon[mask]

if len(X_real_slice) == 0:
continue

score = _compute_sinkhorn_divergence(X_recon_slice, X_real_slice, epsilon)
slice_scores.append(score)

return np.mean(slice_scores) if slice_scores else 0.0


def run_rank_selection(
adata, ranks, condition_key,
n_folds=5, n_iter_max=100, tol=1e-6, ot_epsilon=0.1, random_state=1
):
"""Run cross-validation pipeline for rank selection.

Returns DataFrame with columns: rank, ot_score_mean, ot_score_std, r2x_mean
"""
from ..import_data import add_cond_idxs

print(f"Starting Rank Selection ({n_folds}-fold CV)...")
print(f"Testing ranks: {ranks}")

splits = create_stratified_splits(adata, condition_key, n_folds, random_state)
results = []

for r in ranks:
print(f"\n--- Rank {r} ---")
fold_scores = []
fold_r2x = []
start_time = time.time()

for fold_idx, train_idx, test_idx in splits:
train_adata = add_cond_idxs(adata[train_idx].copy(), condition_key)
test_adata = adata[test_idx]

try:
pf2_output, r2x = parafac2_nd(
train_adata, rank=r, n_iter_max=n_iter_max,
tol=tol, random_state=random_state
)
except Exception as e:
print(f"Fit failed for Rank {r}: {e}")
continue

train_cond_idxs = train_adata.obs["condition_unique_idxs"].values
X_train_recon = reconstruct_from_pf2(pf2_output, train_cond_idxs)

X_test_real = test_adata.X
score = compute_ot_score(X_train_recon, X_test_real, ot_epsilon)
Comment thread
Nathaniel-github marked this conversation as resolved.

fold_scores.append(score)
fold_r2x.append(r2x)
print(f" Fold {fold_idx+1}: OT = {score:.4f}")

if fold_scores:
results.append({
"rank": r,
"ot_score_mean": np.mean(fold_scores),
"ot_score_std": np.std(fold_scores),
"r2x_mean": np.mean(fold_r2x)
})
print(f"Rank {r} done in {time.time() - start_time:.1f}s | Mean OT: {np.mean(fold_scores):.4f}")

return pd.DataFrame(results)
198 changes: 198 additions & 0 deletions cellcommunicationpf2/rank_selection/tests/generate_validation_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Generate rank selection validation plots using synthetic data with known rank."""

import sys
from pathlib import Path

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csr_array
from tensorly.parafac2_tensor import parafac2_to_slices
from tensorly.random import random_parafac2

sys.path.insert(0, str(Path(__file__).resolve().parents[3]))

from cellcommunicationpf2.import_data import add_cond_idxs, import_balf_covid
from cellcommunicationpf2.rank_selection import run_rank_selection
from parafac2.parafac2 import parafac2_nd

USE_RANDOM_PARAFAC2 = False
TARGET_CELLS_PER_COND = 2500
TARGET_N_CONDITIONS = 8
TARGET_N_GENES = 200

print("Loading real scRNA-seq data...")
adata_full = import_balf_covid()
print(f"Full data: {adata_full.shape}")


def create_known_rank_data(adata_full, true_rank, cells_per_cond=100, n_genes=200,
random_state=42):
"""Factorize real data at known rank, then reconstruct to get data with known effective rank."""
print(f"\n=== Creating known-rank-{true_rank} data (factorize-reconstruct) ===")

np.random.seed(random_state)

# Select conditions and sample cells from each
all_conditions = adata_full.obs['sample'].unique()
n_conditions = len(all_conditions)

all_sampled_indices = []
for condition in all_conditions:
cell_indices_for_condition = np.where(adata_full.obs['sample'] == condition)[0]
n_cells_to_sample = min(cells_per_cond, len(cell_indices_for_condition))
sampled_cell_indices = np.random.choice(
cell_indices_for_condition, size=n_cells_to_sample, replace=False
)
all_sampled_indices.extend(sampled_cell_indices)

adata_subsampled = adata_full[all_sampled_indices].copy()

# Keep only the top expressed genes
gene_mean_expression = np.asarray(adata_subsampled.X.mean(axis=0)).flatten()
top_gene_indices = np.argsort(gene_mean_expression)[-n_genes:]
adata_subsampled = adata_subsampled[:, top_gene_indices].copy()

# Add condition index column required by parafac2_nd
adata_subsampled = add_cond_idxs(adata_subsampled, "sample")
print(f"Subsampled: {adata_subsampled.shape}")

# Factorize at the true rank
print(f"Factorizing at rank {true_rank}...")
pf2_output, r2x = parafac2_nd(
adata_subsampled, rank=true_rank, n_iter_max=100, tol=1e-6, random_state=random_state
)
print(f"R2X at rank {true_rank}: {r2x:.4f}")

# Reconstruct data from the factorization (this gives us exact-rank data)
reconstructed_slices = parafac2_to_slices(pf2_output)
condition_idxs = adata_subsampled.obs["condition_unique_idxs"].values

reconstructed_X_per_condition = []
for cond_idx in range(n_conditions):
cells_in_condition = (condition_idxs == cond_idx).sum()
reconstructed_slice = reconstructed_slices[cond_idx][:cells_in_condition, :]
reconstructed_X_per_condition.append(reconstructed_slice)

reconstructed_X = np.vstack(reconstructed_X_per_condition).astype(np.float32)

# Create output AnnData with reconstructed matrix
adata_reconstructed = adata_subsampled.copy()
adata_reconstructed.X = csr_array(reconstructed_X)

return adata_reconstructed, r2x


def create_random_parafac2_data(true_rank, n_conditions=8, cells_per_cond=100, n_genes=200,
random_state=42):
"""Generate synthetic PARAFAC2 data with random factors."""
print(f"\n=== Creating known-rank-{true_rank} data (random PARAFAC2) ===")

shapes = [(cells_per_cond, n_genes) for _ in range(n_conditions)]
pf2_tensor = random_parafac2(shapes, rank=true_rank, full=False, random_state=random_state)
slices = parafac2_to_slices(pf2_tensor)

print(f"Generated {n_conditions} conditions with {cells_per_cond} cells each, {n_genes} genes")

X_list = []
sample_labels = []
condition_idxs = []

for i, slice_data in enumerate(slices):
X_list.append(np.asarray(slice_data))
sample_labels.extend([f"condition_{i}"] * slice_data.shape[0])
condition_idxs.extend([i] * slice_data.shape[0])

X = np.vstack(X_list).astype(np.float32)

np.random.seed(random_state)
noise_scale = 0.01 * np.std(X)
X += np.random.randn(*X.shape).astype(np.float32) * noise_scale

adata = ad.AnnData(X=csr_array(X))
adata.obs['sample'] = sample_labels
adata.obs['condition_unique_idxs'] = condition_idxs
adata.var_names = [f"gene_{i}" for i in range(n_genes)]

print(f"Synthetic data shape: {adata.shape}")

return adata, 1.0


def get_test_ranks(true_rank, interval=5, overshoot=20):
"""Generate list of ranks to test, including the true rank."""
ranks = list(range(1, true_rank + overshoot + 1, interval))
if true_rank not in ranks:
ranks.append(true_rank)
return sorted(ranks)

true_ranks = [20, 30, 40, 50]
results_all = {}
method_name = "Random PARAFAC2" if USE_RANDOM_PARAFAC2 else "Factorize-Reconstruct"

for true_rank in true_ranks:
if USE_RANDOM_PARAFAC2:
adata, r2x = create_random_parafac2_data(
true_rank,
n_conditions=TARGET_N_CONDITIONS,
cells_per_cond=TARGET_CELLS_PER_COND,
n_genes=TARGET_N_GENES
)
else:
adata, r2x = create_known_rank_data(
adata_full, true_rank,
cells_per_cond=TARGET_CELLS_PER_COND,
n_genes=TARGET_N_GENES
)

ranks_to_test = get_test_ranks(true_rank)
print(f"\nRunning rank selection on known-rank-{true_rank} data...")
print(f"Data shape: {adata.shape}, Testing ranks: {ranks_to_test}")

results = run_rank_selection(
adata, ranks=ranks_to_test, condition_key="sample",
n_folds=5, n_iter_max=100, tol=1e-6, ot_epsilon=0.05, random_state=42
)

results_all[true_rank] = {"results": results, "r2x_fit": r2x}
print(f"True rank: {true_rank} validation complete")

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for ax, true_rank in zip(axes, true_ranks):
results = results_all[true_rank]['results']
ranks = results["rank"].values
scores = results["ot_score_mean"].values
stds = results["ot_score_std"].values

ax.errorbar(ranks, scores, yerr=stds, marker='o', markersize=7,
linewidth=2, capsize=4, color='steelblue', label='OT Score')
ax.axvline(true_rank, color='green', linestyle='-', linewidth=3,
label=f'True Rank ({true_rank})', alpha=0.8)

ax.set_xlabel('Rank', fontsize=11, fontweight='bold')
ax.set_ylabel('OT Score', fontsize=11, fontweight='bold')
ax.set_title(f'True Rank = {true_rank}', fontsize=12, fontweight='bold')
ax.set_xticks(ranks)
ax.tick_params(axis='x', labelsize=8, rotation=45)
ax.legend(loc='upper right', fontsize=8)
ax.grid(True, alpha=0.3)

fig.suptitle(f'Rank Selection Validation ({method_name})', fontsize=14, fontweight='bold')
plt.tight_layout()

output_path = f"/home/nthomas/cellcommunication-Pf2/cellcommunicationpf2/rank_selection/output/validation_{method_name.lower().replace('-', '_').replace(' ', '_')}.png"
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"\n✓ Plot saved: {output_path}")

print("\n" + "="*50)
print("SUMMARY")
print("="*50)
for true_rank in true_ranks:
data = results_all[true_rank]
min_idx = data['results']["ot_score_mean"].idxmin()
best_rank = data['results'].loc[min_idx, "rank"]
print(f"True={true_rank}, Best OT Rank={best_rank}, R2X@fit={data['r2x_fit']:.4f}")

plt.show()
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ dependencies = [
"parafac2 @ git+https://github.com/meyer-lab/parafac2.git@main",
"zstandard>=0.23.0",
"pyarrow>=15.0.0",
"jax[cuda12]>=0.4.20",
"ott-jax>=0.4.6",
"kneed>=0.8.5",
]

readme = "README.md"
Expand Down Expand Up @@ -68,3 +71,8 @@ select = [
# Unused arguments
"ARG",
]

[tool.rye]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a separate branch, we should move this to uv.

dev-dependencies = [
"pytest>=9.0.2",
]
Loading
Loading