diff --git a/example_config.toml b/example_config.toml index 8c0e0b1..38d6e4b 100644 --- a/example_config.toml +++ b/example_config.toml @@ -44,24 +44,54 @@ leiden_resolution = "0.2,0.5,1.0" # Generate and save UMAP plots save_plots = true -[annotate] +[quantitate] # Path to input clustered .zarr file input = "clustered.zarr" # Modify the input file in place instead of creating a new file inplace = false # Path to output .zarr file (required unless inplace = true) -output = "annotated.zarr" +output = "scored.zarr" # Path to CSV file with marker genes (columns: cell_type, gene) +# At least one of markers or preset_resources is required markers = "markers.csv" -# Cluster column key to use for annotation (e.g., "leiden_res0p5") -# If null, will use all leiden_res* columns found -cluster_key = null -# Pre-calculate MLM enrichment scores for pathway/TF resources -calculate_ulm = true -# Minimum sensitivity for PanglaoDB markers in MLM (default: 0.5) -panglao_min_sensitivity = 0.5 -# Minimum number of marker genes per cell type for MLM annotation +# Key suffix for custom marker scores (stored as obsm['score_mlm_']) +score_key = "custom" +# Scoring method: "mlm" (default) or "ulm" +method = "mlm" +# Minimum number of targets per source for decoupler tmin = 2 +# Comma-separated built-in resources to score: panglao, hallmark, collectri, dorothea, progeny +# Leave empty or remove to skip preset scoring +preset_resources = null +# Minimum sensitivity for PanglaoDB markers (used when panglao is in preset_resources) +panglao_min_sensitivity = 0.5 +# Only use canonical PanglaoDB markers +panglao_canonical_only = true +# Optional cell filter: "column==value" (e.g. "cell_type==Fibroblast") +# If null, all cells are scored +filter_obs = null +# Generate and save enrichment heatmap plots +save_plots = false + +[assign] +# Path to input scored .zarr file (produced by quantitate) +input = "scored.zarr" +# Modify the input file in place instead of creating a new file +inplace = false +# Path to output .zarr file (required unless inplace = true) +output = "annotated.zarr" +# Full obsm key of the score matrix to use (must match quantitate output) +# e.g. "score_mlm_custom" or "score_mlm_PanglaoDB" +score_key = "score_mlm_custom" +# Cluster column key to assign (e.g. "leiden_res0p5") +# If null, all leiden_res* columns are used +cluster_key = null +# obs column name for cell type labels; defaults to "cell_type_res{resolution}" +annotation_key = null +# Assignment strategy: "top_positive" (default) +strategy = "top_positive" +# Run differential expression per cluster key +run_de = true # Generate and save annotation plots save_plots = true diff --git a/makefile b/makefile index f841847..f24e2f6 100644 --- a/makefile +++ b/makefile @@ -19,7 +19,7 @@ help: @echo " make format - Format code with black" @echo " make clean - Remove build artifacts and caches" @echo " make clean-all - Remove build artifacts, caches, and venv" - @echo " make run ROOT=/path - Run full pipeline using config.toml in ROOT directory" + @echo " make run ROOT=/path - Run full pipeline (6 steps) using config.toml in ROOT directory" # Create virtual environment venv: @@ -140,11 +140,15 @@ run: xenium_process cluster --config "config.toml" || exit 1 @cd "$(ROOT)" && \ echo "" && \ - echo "Step 4: Annotate cell types" && \ - xenium_process annotate --config "config.toml" || exit 1 + echo "Step 4: Quantitate enrichment scores" && \ + xenium_process quantitate --config "config.toml" || exit 1 @cd "$(ROOT)" && \ echo "" && \ - echo "Step 5: Differential expression analysis" && \ + echo "Step 5: Assign cell type labels" && \ + xenium_process assign --config "config.toml" || exit 1 + @cd "$(ROOT)" && \ + echo "" && \ + echo "Step 6: Differential expression analysis" && \ xenium_process differential --config "config.toml" || exit 1 @echo "" @echo "==========================================" diff --git a/projects/PDAC_HIV/config.toml b/projects/PDAC_HIV/config.toml index ab9b29d..49d8006 100644 --- a/projects/PDAC_HIV/config.toml +++ b/projects/PDAC_HIV/config.toml @@ -33,23 +33,47 @@ leiden_resolution = "0.2,0.5,1.0" save_plots = true resume = true -[annotate] +[quantitate] # Path to input clustered .zarr file input = "data.zarr" # Modify the input file in place instead of creating a new file inplace = true # Path to CSV file with marker genes (columns: cell_type, gene) markers = "markers.csv" - -# Pre-calculate MLM enrichment scores for pathway/TF resources -calculate_ulm = true -# Minimum sensitivity for PanglaoDB markers in MLM (default: 0.5) -panglao_min_sensitivity = 0.5 -# Minimum number of marker genes per cell type for MLM annotation +# Key suffix for custom marker scores +score_key = "custom" +# Scoring method +method = "mlm" +# Minimum number of targets per source for decoupler tmin = 2 +# Also score against built-in decoupler resources +preset_resources = "panglao" +# Minimum sensitivity for PanglaoDB markers +panglao_min_sensitivity = 0.5 +# Only use canonical PanglaoDB markers +panglao_canonical_only = true +# No cell filter — score all cells +filter_obs = null +# Generate and save enrichment heatmap plots +save_plots = true + +[assign] +# Path to input scored .zarr file +input = "data.zarr" +# Modify the input file in place +inplace = true +# obsm key of the scores to assign from (produced by quantitate) +score_key = "score_mlm_custom" +# Use all leiden_res* columns +cluster_key = null +# Default annotation key naming +annotation_key = null +# Assignment strategy +strategy = "top_positive" +# Run differential expression +run_de = true # Generate and save annotation plots save_plots = true -resume = true #[differential] # Path to input .zarr file with annotations diff --git a/tests/functional/test_config_integration.py b/tests/functional/test_config_integration.py index 48e5087..bb04e22 100644 --- a/tests/functional/test_config_integration.py +++ b/tests/functional/test_config_integration.py @@ -33,12 +33,27 @@ def test_config_file(tmp_path): leiden_resolution = "0.3,0.6" save_plots = true -[annotate] +[quantitate] input = "placeholder.zarr" -output = "config_annotated.zarr" -calculate_ulm = true -panglao_min_sensitivity = 0.6 +output = "config_scored.zarr" +markers = "markers.csv" +score_key = "custom" +method = "mlm" tmin = 3 +preset_resources = null +panglao_min_sensitivity = 0.6 +panglao_canonical_only = true +filter_obs = null +save_plots = false + +[assign] +input = "config_scored.zarr" +output = "config_annotated.zarr" +score_key = "score_mlm_custom" +cluster_key = null +annotation_key = null +strategy = "top_positive" +run_de = true save_plots = false [differential] diff --git a/tests/functional/test_full_pipeline.py b/tests/functional/test_full_pipeline.py index 11f529b..820e7ca 100644 --- a/tests/functional/test_full_pipeline.py +++ b/tests/functional/test_full_pipeline.py @@ -59,18 +59,31 @@ def test_full_pipeline_end_to_end(test_samples_csv, test_markers_csv, tmp_zarr_c assert result.returncode == 0, f"Cluster failed: {result.stderr}" - # Step 4: Annotate (inplace) + # Step 4: Quantitate – score cells against the marker gene list (inplace) result = subprocess.run([ sys.executable, '-m', 'xenium_process.cli', - 'annotate', + 'quantitate', '--input', str(concat_output), '--inplace', - '--markers', str(test_markers_csv) + '--markers', str(test_markers_csv), + '--tmin', '1', ], capture_output=True, text=True) - - assert result.returncode == 0, f"Annotate failed: {result.stderr}" - - # Step 5: Differential analysis + + assert result.returncode == 0, f"Quantitate failed: {result.stderr}" + + # Step 5: Assign – label clusters from the scored obsm matrix (inplace) + result = subprocess.run([ + sys.executable, '-m', 'xenium_process.cli', + 'assign', + '--input', str(concat_output), + '--inplace', + '--score-key', 'score_mlm_custom', + '--cluster-key', 'leiden_res0p5', + ], capture_output=True, text=True) + + assert result.returncode == 0, f"Assign failed: {result.stderr}" + + # Step 6: Differential analysis diff_output_dir = tmp_zarr_cleanup / "differential" result = subprocess.run([ sys.executable, '-m', 'xenium_process.cli', diff --git a/tests/unit/test_annotation.py b/tests/unit/test_annotation.py index 9a3b1af..aa9e90c 100644 --- a/tests/unit/test_annotation.py +++ b/tests/unit/test_annotation.py @@ -3,7 +3,10 @@ """ import pytest +import logging +import numpy as np import pandas as pd +import anndata as ad from xenium_process.core import annotation @@ -83,3 +86,194 @@ def test_run_differential_expression_resume(mock_adata_with_clusters): rank_key = f"rank_genes_{cluster_key}" assert rank_key in adata.uns + +# --------------------------------------------------------------------------- +# filter_cells_by_obs +# --------------------------------------------------------------------------- + +def test_filter_cells_by_obs_basic(mock_adata): + """filter_cells_by_obs returns correct mask and subset for a valid expression.""" + adata = mock_adata + # mock_adata has a 'status' column with values 'HIV' and 'NEG' + mask, adata_sub = annotation.filter_cells_by_obs(adata, "status==HIV") + + expected_count = (adata.obs["status"] == "HIV").sum() + assert mask.sum() == expected_count + assert adata_sub.n_obs == expected_count + + +def test_filter_cells_by_obs_invalid_column(mock_adata): + """filter_cells_by_obs raises KeyError for a non-existent column.""" + with pytest.raises(KeyError, match="nonexistent_col"): + annotation.filter_cells_by_obs(mock_adata, "nonexistent_col==foo") + + +def test_filter_cells_by_obs_no_matches_warns(mock_adata, caplog): + """filter_cells_by_obs logs a warning when 0 cells match.""" + with caplog.at_level(logging.WARNING, logger="root"): + mask, adata_sub = annotation.filter_cells_by_obs(mock_adata, "status==NONEXISTENT") + + assert mask.sum() == 0 + assert adata_sub.n_obs == 0 + assert any("0 cells" in msg for msg in caplog.messages) + + +def test_filter_cells_by_obs_invalid_expr_format(mock_adata): + """filter_cells_by_obs raises ValueError when '==' is missing.""" + with pytest.raises(ValueError, match="=="): + annotation.filter_cells_by_obs(mock_adata, "status") + + +# --------------------------------------------------------------------------- +# run_enrichment_scoring +# --------------------------------------------------------------------------- + +def _make_net_df(mock_adata): + """Return a minimal network DataFrame using genes that exist in mock_adata.""" + genes = list(mock_adata.var_names[:4]) + return pd.DataFrame( + { + "source": ["TypeA", "TypeA", "TypeB", "TypeB"], + "target": genes, + "weight": [1, 1, 1, 1], + } + ) + + +def test_run_enrichment_scoring_mlm_stores_obsm(mock_adata_with_clusters): + """After MLM scoring, obsm contains the expected key with shape (n_obs, n_sources).""" + adata = mock_adata_with_clusters + import scanpy as sc + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + + net_df = _make_net_df(adata) + adata = annotation.run_enrichment_scoring(adata, net_df, score_key="test", method="mlm", tmin=1) + + key = "score_mlm_test" + assert key in adata.obsm + scores = adata.obsm[key] + assert scores.shape[0] == adata.n_obs + assert scores.shape[1] == 2 # TypeA and TypeB + + +def test_run_enrichment_scoring_ulm_stores_obsm(mock_adata_with_clusters): + """After ULM scoring, obsm contains key prefixed with 'score_ulm_'.""" + adata = mock_adata_with_clusters + import scanpy as sc + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + + net_df = _make_net_df(adata) + adata = annotation.run_enrichment_scoring(adata, net_df, score_key="test", method="ulm", tmin=1) + + key = "score_ulm_test" + assert key in adata.obsm + assert adata.obsm[key].shape[0] == adata.n_obs + + +def test_run_enrichment_scoring_with_mask_fills_nan(mock_adata_with_clusters): + """When a mask is provided, excluded rows contain NaN in the score matrix.""" + adata = mock_adata_with_clusters + import scanpy as sc + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + + net_df = _make_net_df(adata) + + # Only first 50 cells + mask = np.zeros(adata.n_obs, dtype=bool) + mask[:50] = True + + adata = annotation.run_enrichment_scoring( + adata, net_df, score_key="subset", method="mlm", tmin=1, mask=mask + ) + + key = "score_mlm_subset" + assert key in adata.obsm + scores = adata.obsm[key] + # Convert to numpy for easier NaN checking + if hasattr(scores, "values"): + scores_np = scores.values + else: + scores_np = np.asarray(scores) + assert scores_np.shape[0] == adata.n_obs + # Excluded cells (indices 50+) should be NaN + assert np.all(np.isnan(scores_np[50:])) + # Included cells (indices 0–49) should NOT all be NaN + assert not np.all(np.isnan(scores_np[:50])) + + +# --------------------------------------------------------------------------- +# assign_clusters +# --------------------------------------------------------------------------- + +def _adata_with_scores(mock_adata_with_clusters): + """Return adata with a fake score matrix in obsm ready for assign_clusters.""" + adata = mock_adata_with_clusters + import scanpy as sc + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + + net_df = _make_net_df(adata) + adata = annotation.run_enrichment_scoring(adata, net_df, score_key="custom", method="mlm", tmin=1) + return adata + + +def test_assign_clusters_top_positive_adds_obs_column(mock_adata_with_clusters): + """assign_clusters with top_positive strategy adds annotation column as categorical.""" + adata = _adata_with_scores(mock_adata_with_clusters) + + adata = annotation.assign_clusters( + adata, + score_key="score_mlm_custom", + cluster_key="leiden_res0p5", + annotation_key="cell_type_test", + strategy="top_positive", + ) + + assert "cell_type_test" in adata.obs.columns + assert str(adata.obs["cell_type_test"].dtype) == "category" + assert adata.obs["cell_type_test"].notna().all() + + +def test_assign_clusters_unknown_for_no_positive_stat(mock_adata_with_clusters): + """Clusters with no positive enrichment score are labelled 'Unknown'.""" + adata = mock_adata_with_clusters + + # Create all-zero / all-negative score matrix → no positive stats + import pandas as pd + n_obs = adata.n_obs + zero_scores = pd.DataFrame( + np.full((n_obs, 2), -1.0), + index=adata.obs_names, + columns=["TypeA", "TypeB"], + ) + adata.obsm["score_mlm_zeros"] = zero_scores + + adata = annotation.assign_clusters( + adata, + score_key="score_mlm_zeros", + cluster_key="leiden_res0p5", + annotation_key="test_annotation", + strategy="top_positive", + ) + + assert "test_annotation" in adata.obs.columns + # All clusters have no positive stat, so all should be "Unknown" + assert (adata.obs["test_annotation"] == "Unknown").all() + + +def test_assign_clusters_invalid_strategy_raises(mock_adata_with_clusters): + """assign_clusters raises ValueError for an unrecognised strategy.""" + adata = _adata_with_scores(mock_adata_with_clusters) + + with pytest.raises(ValueError, match="nonexistent"): + annotation.assign_clusters( + adata, + score_key="score_mlm_custom", + cluster_key="leiden_res0p5", + annotation_key="test_annotation", + strategy="nonexistent", + ) + diff --git a/tests/unit/test_assign_command.py b/tests/unit/test_assign_command.py new file mode 100644 index 0000000..e36b4af --- /dev/null +++ b/tests/unit/test_assign_command.py @@ -0,0 +1,185 @@ +""" +Unit tests for the assign command. + +Verifies that CLI arguments are correctly wired through to the core +annotation functions. All external I/O is mocked. +""" + +import pytest +from argparse import Namespace +from unittest.mock import patch, MagicMock, call +import pandas as pd + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_args(**overrides): + """Return a Namespace mimicking fully-parsed assign arguments.""" + defaults = dict( + input="test.zarr", + output="output.zarr", + inplace=False, + score_key="score_mlm_custom", + cluster_key=None, + annotation_key=None, + strategy="top_positive", + run_de=True, + save_plots=False, + config=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def _run_assign(args): + """Run assign.main(args) with all common I/O mocked; returns mock_adata.""" + from xenium_process.commands import assign + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.n_obs = 100 + mock_adata.n_vars = 300 + # score_key is present in obsm by default + mock_adata.obsm = {args.score_key: MagicMock()} + mock_adata.obs.columns = ["leiden_res0p5", "leiden_res1p0"] + + with patch("xenium_process.commands.assign.Path") as mock_path_cls, \ + patch("xenium_process.commands.assign.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.assign.save_spatial_data"), \ + patch("xenium_process.commands.assign.set_table"), \ + patch("xenium_process.commands.assign.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.assign.get_output_path") as mock_out, \ + patch("xenium_process.commands.assign.get_table") as mock_get_table, \ + patch("xenium_process.commands.assign.annotation.assign_clusters") as mock_assign_clusters, \ + patch("xenium_process.commands.assign.annotation.run_differential_expression") as mock_de: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_assign_clusters.return_value = mock_adata + mock_de.return_value = mock_adata + + try: + assign.main(args) + except SystemExit: + pass + + return mock_adata, mock_assign_clusters, mock_de + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_assign_calls_assign_clusters(): + """assign_clusters is called with score_key and default strategy='top_positive'.""" + args = _make_args(cluster_key="leiden_res0p5") + _, mock_assign_clusters, _ = _run_assign(args) + + assert mock_assign_clusters.called + call_kwargs = mock_assign_clusters.call_args[1] + assert call_kwargs["score_key"] == "score_mlm_custom" + assert call_kwargs["strategy"] == "top_positive" + + +def test_assign_custom_strategy_passed(): + """--strategy threshold is forwarded to assign_clusters.""" + args = _make_args(cluster_key="leiden_res0p5", strategy="threshold") + _, mock_assign_clusters, _ = _run_assign(args) + + assert mock_assign_clusters.called + assert mock_assign_clusters.call_args[1]["strategy"] == "threshold" + + +def test_assign_auto_discovers_leiden_keys(): + """Without --cluster-key, assign_clusters is called once per leiden_res* column.""" + args = _make_args() # cluster_key=None → auto-discover + _, mock_assign_clusters, _ = _run_assign(args) + + # The mock_adata has two leiden_res* columns + assert mock_assign_clusters.call_count == 2 + called_keys = {c[1]["cluster_key"] for c in mock_assign_clusters.call_args_list} + assert "leiden_res0p5" in called_keys + assert "leiden_res1p0" in called_keys + + +def test_assign_single_cluster_key(): + """--cluster-key leiden_res0p5 → assign_clusters called exactly once.""" + args = _make_args(cluster_key="leiden_res0p5") + _, mock_assign_clusters, _ = _run_assign(args) + + assert mock_assign_clusters.call_count == 1 + assert mock_assign_clusters.call_args[1]["cluster_key"] == "leiden_res0p5" + + +def test_assign_annotation_key_derived_from_cluster_key(): + """leiden_res0p5 → annotation_key='cell_type_res0p5' (no --annotation-key set).""" + args = _make_args(cluster_key="leiden_res0p5") + _, mock_assign_clusters, _ = _run_assign(args) + + assert mock_assign_clusters.called + assert mock_assign_clusters.call_args[1]["annotation_key"] == "cell_type_res0p5" + + +def test_assign_custom_annotation_key(): + """--annotation-key my_labels is forwarded to assign_clusters.""" + args = _make_args(cluster_key="leiden_res0p5", annotation_key="my_labels") + _, mock_assign_clusters, _ = _run_assign(args) + + assert mock_assign_clusters.called + assert mock_assign_clusters.call_args[1]["annotation_key"] == "my_labels" + + +def test_assign_de_runs_by_default(): + """run_differential_expression is called for each cluster key by default.""" + args = _make_args() # run_de=True, auto-discover two cluster keys + _, _, mock_de = _run_assign(args) + + assert mock_de.call_count == 2 + + +def test_assign_de_skipped_when_disabled(): + """--run-de false → run_differential_expression is never called.""" + args = _make_args(run_de=False) + _, _, mock_de = _run_assign(args) + + assert not mock_de.called + + +def test_assign_missing_score_key_exits(): + """When obsm does not contain the score key, main() exits with code 1.""" + from xenium_process.commands import assign + + args = _make_args(score_key="score_mlm_nonexistent") + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.obsm = {} # score key absent + mock_adata.obs.columns = ["leiden_res0p5"] + + with patch("xenium_process.commands.assign.Path") as mock_path_cls, \ + patch("xenium_process.commands.assign.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.assign.save_spatial_data"), \ + patch("xenium_process.commands.assign.set_table"), \ + patch("xenium_process.commands.assign.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.assign.get_output_path") as mock_out, \ + patch("xenium_process.commands.assign.get_table") as mock_get_table: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + + with pytest.raises(SystemExit) as exc_info: + assign.main(args) + + assert exc_info.value.code == 1 diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index ecd19e4..8d8bb9c 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -1,7 +1,9 @@ """ Unit tests for CLI command modules. -These tests verify that command-line arguments are properly passed to core functions. +These tests verify that command-line arguments are properly passed to core +functions. The annotate command has been replaced by quantitate + assign; +see test_quantitate_command.py and test_assign_command.py for those tests. """ import pytest @@ -10,151 +12,169 @@ from pathlib import Path -def test_annotate_command_passes_tmin_default(): - """Test that annotate command uses default tmin value of 2.""" - from xenium_process.commands import annotate - - # Create mock args with tmin default +# --------------------------------------------------------------------------- +# quantitate – quick smoke tests for tmin and score_key wiring +# --------------------------------------------------------------------------- + +def test_quantitate_passes_tmin_default(): + """quantitate uses default tmin=2.""" + from xenium_process.commands import quantitate + args = Namespace( input="test.zarr", output="output.zarr", inplace=False, markers="markers.csv", - cluster_key=None, - calculate_ulm=False, + score_key="custom", + method="mlm", + tmin=2, + preset_resources=None, panglao_min_sensitivity=0.5, - tmin=2, # Default value + panglao_canonical_only=True, + filter_obs=None, save_plots=False, - config=None # No config file + config=None, ) - - # Mock all the functions that would be called - with patch('xenium_process.commands.annotate.load_existing_spatial_data') as mock_load, \ - patch('xenium_process.commands.annotate.get_table') as mock_get_table, \ - patch('xenium_process.commands.annotate.annotation.load_marker_genes') as mock_load_markers, \ - patch('xenium_process.commands.annotate.annotation.annotate_with_markers') as mock_annotate, \ - patch('xenium_process.commands.annotate.annotation.run_differential_expression'), \ - patch('xenium_process.commands.annotate.save_spatial_data'), \ - patch('xenium_process.commands.annotate.set_table'), \ - patch('xenium_process.commands.annotate.prepare_spatial_data_for_save'), \ - patch('xenium_process.commands.annotate.Path'): - - # Setup mocks + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + mock_sdata = MagicMock() mock_adata = MagicMock() - mock_adata.obs.columns = ['leiden_res0p5'] + mock_adata.var_names = [] + mock_adata.obsm = {} mock_load.return_value = mock_sdata mock_get_table.return_value = mock_adata - mock_load_markers.return_value = {'T cells': ['CD3D']} - - # Make Path(markers).exists() return True - with patch('xenium_process.commands.annotate.Path') as mock_path_class: - mock_path_obj = MagicMock() - mock_path_obj.exists.return_value = True - mock_path_class.return_value = mock_path_obj - - # Run the command - try: - annotate.main(args) - except SystemExit: - pass # Expected when paths don't exist - - # Verify annotate_with_markers was called with tmin=2 - assert mock_annotate.called, "annotate_with_markers should have been called" - call_kwargs = mock_annotate.call_args[1] - assert 'tmin' in call_kwargs, "tmin parameter should be passed" - assert call_kwargs['tmin'] == 2, "tmin should be 2 (default)" - - -def test_annotate_command_passes_custom_tmin(): - """Test that annotate command respects custom tmin value.""" - from xenium_process.commands import annotate - - # Create mock args with custom tmin + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = MagicMock() + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called + assert mock_score.call_args[1]["tmin"] == 2 + + +def test_quantitate_passes_custom_tmin(): + """quantitate respects a custom tmin value.""" + from xenium_process.commands import quantitate + args = Namespace( input="test.zarr", output="output.zarr", inplace=False, markers="markers.csv", - cluster_key=None, - calculate_ulm=False, + score_key="custom", + method="mlm", + tmin=1, + preset_resources=None, panglao_min_sensitivity=0.5, - tmin=1, # Custom value for small marker sets + panglao_canonical_only=True, + filter_obs=None, save_plots=False, - config=None # No config file + config=None, ) - - # Mock all the functions - with patch('xenium_process.commands.annotate.load_existing_spatial_data') as mock_load, \ - patch('xenium_process.commands.annotate.get_table') as mock_get_table, \ - patch('xenium_process.commands.annotate.annotation.load_marker_genes') as mock_load_markers, \ - patch('xenium_process.commands.annotate.annotation.annotate_with_markers') as mock_annotate, \ - patch('xenium_process.commands.annotate.annotation.run_differential_expression'), \ - patch('xenium_process.commands.annotate.save_spatial_data'), \ - patch('xenium_process.commands.annotate.set_table'), \ - patch('xenium_process.commands.annotate.prepare_spatial_data_for_save'): - - # Setup mocks + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + mock_sdata = MagicMock() mock_adata = MagicMock() - mock_adata.obs.columns = ['leiden_res0p5'] + mock_adata.var_names = [] + mock_adata.obsm = {} mock_load.return_value = mock_sdata mock_get_table.return_value = mock_adata - mock_load_markers.return_value = {'T cells': ['CD3D']} - - # Make Path(markers).exists() return True - with patch('xenium_process.commands.annotate.Path') as mock_path_class: - mock_path_obj = MagicMock() - mock_path_obj.exists.return_value = True - mock_path_class.return_value = mock_path_obj - - # Run the command - try: - annotate.main(args) - except SystemExit: - pass - - # Verify annotate_with_markers was called with tmin=1 - assert mock_annotate.called - call_kwargs = mock_annotate.call_args[1] - assert call_kwargs['tmin'] == 1, "tmin should be 1 (custom value)" - - -def test_annotate_command_without_markers_no_tmin_error(): - """Test that annotate command works without markers (no tmin needed).""" - from xenium_process.commands import annotate - - # Create mock args without markers + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = MagicMock() + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called + assert mock_score.call_args[1]["tmin"] == 1 + + +# --------------------------------------------------------------------------- +# assign – smoke test for default DE behaviour +# --------------------------------------------------------------------------- + +def test_assign_runs_de_by_default(): + """assign runs differential expression by default.""" + from xenium_process.commands import assign + args = Namespace( input="test.zarr", output="output.zarr", inplace=False, - markers=None, # No markers + score_key="score_mlm_custom", cluster_key="leiden_res0p5", - calculate_ulm=False, - panglao_min_sensitivity=0.5, + annotation_key=None, + strategy="top_positive", + run_de=True, save_plots=False, - config=None # No config file + config=None, ) - - # Mock all the functions - with patch('xenium_process.commands.annotate.load_existing_spatial_data') as mock_load, \ - patch('xenium_process.commands.annotate.get_table') as mock_get_table, \ - patch('xenium_process.commands.annotate.annotation.run_differential_expression'), \ - patch('xenium_process.commands.annotate.save_spatial_data'), \ - patch('xenium_process.commands.annotate.set_table'), \ - patch('xenium_process.commands.annotate.prepare_spatial_data_for_save'), \ - patch('xenium_process.commands.annotate.Path'): - - mock_sdata = MagicMock() - mock_adata = MagicMock() - mock_adata.obs.columns = ['leiden_res0p5'] + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.obsm = {"score_mlm_custom": MagicMock()} + mock_adata.obs.columns = ["leiden_res0p5"] + + with patch("xenium_process.commands.assign.Path") as mock_path_cls, \ + patch("xenium_process.commands.assign.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.assign.save_spatial_data"), \ + patch("xenium_process.commands.assign.set_table"), \ + patch("xenium_process.commands.assign.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.assign.get_output_path") as mock_out, \ + patch("xenium_process.commands.assign.get_table") as mock_get_table, \ + patch("xenium_process.commands.assign.annotation.assign_clusters") as mock_assign_clusters, \ + patch("xenium_process.commands.assign.annotation.run_differential_expression") as mock_de: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + mock_load.return_value = mock_sdata mock_get_table.return_value = mock_adata - - # Should not raise an error + mock_assign_clusters.return_value = mock_adata + mock_de.return_value = mock_adata + try: - annotate.main(args) + assign.main(args) except SystemExit: - pass # Expected when paths don't exist + pass + assert mock_de.called, "run_differential_expression should have been called" diff --git a/tests/unit/test_quantitate_command.py b/tests/unit/test_quantitate_command.py new file mode 100644 index 0000000..6050d4f --- /dev/null +++ b/tests/unit/test_quantitate_command.py @@ -0,0 +1,516 @@ +""" +Unit tests for the quantitate command. + +Verifies that CLI arguments are correctly wired through to the core +annotation functions. All external I/O is mocked. +""" + +import pytest +from argparse import Namespace +from unittest.mock import patch, MagicMock, call +import pandas as pd + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def _make_args(**overrides): + """Return a Namespace mimicking fully-parsed quantitate arguments.""" + defaults = dict( + input="test.zarr", + output="output.zarr", + inplace=False, + markers=None, + score_key="custom", + method="mlm", + tmin=2, + preset_resources=None, + panglao_min_sensitivity=0.5, + panglao_canonical_only=True, + filter_obs=None, + save_plots=False, + config=None, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +_MOCK_NET_DF = pd.DataFrame( + {"source": ["T cells", "T cells"], "target": ["CD3D", "CD3E"], "weight": [1, 1]} +) + +_COMMON_PATCHES = [ + "xenium_process.commands.quantitate.load_existing_spatial_data", + "xenium_process.commands.quantitate.save_spatial_data", + "xenium_process.commands.quantitate.set_table", + "xenium_process.commands.quantitate.prepare_spatial_data_for_save", + "xenium_process.commands.quantitate.get_output_path", +] + + +def _run_main_with_patches(args, extra_patches=None): + """ + Run quantitate.main(args) with common I/O mocked. + + Returns a dict mapping patch target → Mock for assertions. + """ + from xenium_process.commands import quantitate + + all_patches = _COMMON_PATCHES + (extra_patches or []) + + mocks = {} + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls: + # Make Path(anything).exists() return True + mock_path_instance = MagicMock() + mock_path_instance.exists.return_value = True + mock_path_instance.__truediv__ = lambda self, other: self # plots_dir / "..." + mock_path_cls.return_value = mock_path_instance + + with patch.multiple("xenium_process.commands.quantitate", **{p.split(".")[-1]: MagicMock() for p in _COMMON_PATCHES}): + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.n_obs = 100 + mock_adata.n_vars = 300 + mock_adata.var_names = [f"gene_{i}" for i in range(300)] + mock_adata.obsm = {} + + with patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out: + mock_load.return_value = mock_sdata + mock_out.return_value = mock_path_instance + + from xenium_process.commands.quantitate import get_table as _gt + with patch("xenium_process.commands.quantitate.get_table") as mock_get_table: + mock_get_table.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + return mock_adata, mock_get_table + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_quantitate_custom_markers_calls_run_scoring(): + """Given --markers, run_enrichment_scoring is called with the marker DataFrame.""" + from xenium_process.commands import quantitate + + args = _make_args(markers="markers.csv") + + mock_markers = {"T cells": ["CD3D", "CD3E"]} + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.n_obs = 100 + mock_adata.n_vars = 300 + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_markers.return_value = mock_markers + mock_to_df.return_value = _MOCK_NET_DF + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called, "run_enrichment_scoring should have been called" + call_kwargs = mock_score.call_args[1] + assert call_kwargs["score_key"] == "custom" + + +def test_quantitate_custom_score_key_passed_through(): + """--score-key fibroblasts is forwarded to run_enrichment_scoring.""" + from xenium_process.commands import quantitate + + args = _make_args(markers="markers.csv", score_key="fibroblasts") + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = _MOCK_NET_DF + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called + assert mock_score.call_args[1]["score_key"] == "fibroblasts" + + +def test_quantitate_method_mlm_default(): + """Without --method, run_enrichment_scoring is called with method='mlm'.""" + from xenium_process.commands import quantitate + + args = _make_args(markers="markers.csv") # default method="mlm" + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = _MOCK_NET_DF + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called + assert mock_score.call_args[1]["method"] == "mlm" + + +def test_quantitate_method_ulm(): + """--method ulm is forwarded to run_enrichment_scoring.""" + from xenium_process.commands import quantitate + + args = _make_args(markers="markers.csv", method="ulm") + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = _MOCK_NET_DF + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called + assert mock_score.call_args[1]["method"] == "ulm" + + +def test_quantitate_tmin_default(): + """Default tmin=2 is passed to run_enrichment_scoring.""" + from xenium_process.commands import quantitate + + args = _make_args(markers="markers.csv") + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = _MOCK_NET_DF + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called + assert mock_score.call_args[1]["tmin"] == 2 + + +def test_quantitate_tmin_custom(): + """--tmin 1 is forwarded to run_enrichment_scoring.""" + from xenium_process.commands import quantitate + + args = _make_args(markers="markers.csv", tmin=1) + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = _MOCK_NET_DF + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_score.called + assert mock_score.call_args[1]["tmin"] == 1 + + +def test_quantitate_filter_obs_parsed_and_passed(): + """--filter-obs causes filter_cells_by_obs to be called before scoring.""" + from xenium_process.commands import quantitate + + args = _make_args(markers="markers.csv", filter_obs="cell_type==Fibroblast") + + mock_mask = MagicMock() + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.filter_cells_by_obs") as mock_filter, \ + patch("xenium_process.commands.quantitate.annotation.load_marker_genes") as mock_load_markers, \ + patch("xenium_process.commands.quantitate.annotation.markers_dict_to_dataframe") as mock_to_df, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_filter.return_value = (mock_mask, MagicMock()) + mock_load_markers.return_value = {"T cells": ["CD3D"]} + mock_to_df.return_value = _MOCK_NET_DF + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + # filter_cells_by_obs must be called with the expression + mock_filter.assert_called_once_with(mock_adata, "cell_type==Fibroblast") + + # The mask must be forwarded to run_enrichment_scoring + assert mock_score.called + assert mock_score.call_args[1]["mask"] is mock_mask + + +def test_quantitate_preset_resources_calls_load_preset(): + """--preset-resources panglao,hallmark: load_preset_resource called twice; run_enrichment_scoring called for each.""" + from xenium_process.commands import quantitate + + args = _make_args(preset_resources="panglao,hallmark") # no custom markers + + mock_net = MagicMock() + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_preset_resource") as mock_load_preset, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_preset.return_value = mock_net + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + # load_preset_resource called once per resource + assert mock_load_preset.call_count == 2 + preset_names_called = [c[0][0] for c in mock_load_preset.call_args_list] + assert "panglao" in preset_names_called + assert "hallmark" in preset_names_called + + # run_enrichment_scoring called once per resource + assert mock_score.call_count == 2 + + +def test_quantitate_panglao_sensitivity_passed(): + """--panglao-min-sensitivity 0.7 is forwarded to load_preset_resource.""" + from xenium_process.commands import quantitate + + args = _make_args(preset_resources="panglao", panglao_min_sensitivity=0.7) + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table, \ + patch("xenium_process.commands.quantitate.annotation.load_preset_resource") as mock_load_preset, \ + patch("xenium_process.commands.quantitate.annotation.run_enrichment_scoring") as mock_score: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_adata.var_names = [] + mock_adata.obsm = {} + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + mock_load_preset.return_value = MagicMock() + mock_score.return_value = mock_adata + + try: + quantitate.main(args) + except SystemExit: + pass + + assert mock_load_preset.called + call_kwargs = mock_load_preset.call_args[1] + assert call_kwargs["panglao_min_sensitivity"] == 0.7 + + +def test_quantitate_no_markers_no_preset_exits(): + """Neither --markers nor --preset-resources → sys.exit(1).""" + from xenium_process.commands import quantitate + + args = _make_args() # both None by default + + with patch("xenium_process.commands.quantitate.Path") as mock_path_cls, \ + patch("xenium_process.commands.quantitate.load_existing_spatial_data") as mock_load, \ + patch("xenium_process.commands.quantitate.save_spatial_data"), \ + patch("xenium_process.commands.quantitate.set_table"), \ + patch("xenium_process.commands.quantitate.prepare_spatial_data_for_save"), \ + patch("xenium_process.commands.quantitate.get_output_path") as mock_out, \ + patch("xenium_process.commands.quantitate.get_table") as mock_get_table: + + mock_path_obj = MagicMock() + mock_path_obj.exists.return_value = True + mock_path_cls.return_value = mock_path_obj + mock_out.return_value = mock_path_obj + + mock_sdata = MagicMock() + mock_adata = MagicMock() + mock_load.return_value = mock_sdata + mock_get_table.return_value = mock_adata + + with pytest.raises(SystemExit) as exc_info: + quantitate.main(args) + + assert exc_info.value.code == 1 diff --git a/xenium_process/cli.py b/xenium_process/cli.py index a324b88..c04cc62 100644 --- a/xenium_process/cli.py +++ b/xenium_process/cli.py @@ -13,7 +13,7 @@ # Suppress warnings warnings.filterwarnings('ignore', category=FutureWarning) -from xenium_process.commands import concat, normalize, cluster, annotate, differential +from xenium_process.commands import concat, normalize, cluster, quantitate, assign, differential from xenium_process.utils.helpers import setup_logging @@ -39,19 +39,31 @@ def create_parser() -> argparse.ArgumentParser: # Cluster with multiple resolutions xenium_process cluster --input merged.zarr --inplace --leiden-resolution 0.2,0.5,1.0 - # Annotate with markers - xenium_process annotate --input merged.zarr --inplace --markers markers.csv - + # Score enrichment with a custom marker list (all cells) + xenium_process quantitate --input clustered.zarr --inplace --markers markers.csv + + # Score only fibroblasts against a custom list, plus built-in PanglaoDB + xenium_process quantitate --input clustered.zarr --inplace \\ + --markers markers.csv --filter-obs "cell_type==Fibroblast" \\ + --preset-resources panglao + + # Assign cell type labels to clusters from the computed scores + xenium_process assign --input clustered.zarr --inplace \\ + --score-key score_mlm_custom + # Differential analysis between groups - xenium_process differential --input merged.zarr --output-dir results/ \\ + xenium_process differential --input annotated.zarr --output-dir results/ \\ --groupby status --compare-groups HIV,NEG # Full pipeline (separate files) xenium_process concat --input samples.csv --output step1_concat.zarr xenium_process normalize --input step1_concat.zarr --output step2_normalized.zarr xenium_process cluster --input step2_normalized.zarr --output step3_clustered.zarr - xenium_process annotate --input step3_clustered.zarr --output step4_annotated.zarr - xenium_process differential --input step4_annotated.zarr --output-dir results/ + xenium_process quantitate --input step3_clustered.zarr --output step4_scored.zarr \\ + --markers markers.csv + xenium_process assign --input step4_scored.zarr --output step5_annotated.zarr \\ + --score-key score_mlm_custom + xenium_process differential --input step5_annotated.zarr --output-dir results/ """ ) @@ -89,15 +101,32 @@ def create_parser() -> argparse.ArgumentParser: cluster.add_arguments(cluster_parser) cluster_parser.set_defaults(func=cluster.main) - # Add annotate subcommand - annotate_parser = subparsers.add_parser( - 'annotate', - help='Annotate cell types', - description='Perform cell type annotation using marker genes and/or MLM scoring' + # Add quantitate subcommand + quantitate_parser = subparsers.add_parser( + 'quantitate', + help='Run enrichment scoring (MLM/ULM) on a gene list or built-in resources', + description=( + 'Run MLM or ULM enrichment scoring using a custom marker gene list, ' + 'decoupler built-in resources (panglao, hallmark, collectri, dorothea, progeny), ' + 'or both. Supports optional cell filtering via --filter-obs.' + ), ) - annotate.add_arguments(annotate_parser) - annotate_parser.set_defaults(func=annotate.main) - + quantitate.add_arguments(quantitate_parser) + quantitate_parser.set_defaults(func=quantitate.main) + + # Add assign subcommand + assign_parser = subparsers.add_parser( + 'assign', + help='Assign cell type labels to clusters from enrichment scores', + description=( + 'Read an enrichment score matrix from obsm (produced by quantitate) ' + 'and assign a cell type label to each cluster using a configurable strategy. ' + 'Optionally runs per-cluster differential expression.' + ), + ) + assign.add_arguments(assign_parser) + assign_parser.set_defaults(func=assign.main) + # Add differential subcommand differential_parser = subparsers.add_parser( 'differential', diff --git a/xenium_process/commands/annotate.py b/xenium_process/commands/annotate.py deleted file mode 100644 index b5d9afe..0000000 --- a/xenium_process/commands/annotate.py +++ /dev/null @@ -1,226 +0,0 @@ -#!/usr/bin/env python3 -""" -Annotate command: Perform cell type annotation using marker genes and/or MLM scoring. -""" - -import argparse -import logging -import sys -from pathlib import Path - -from xenium_process.core.data_io import load_existing_spatial_data, save_spatial_data -from xenium_process.core import annotation -from xenium_process.core import plotting -from xenium_process.utils.helpers import ( - get_table, set_table, get_output_path, - prepare_spatial_data_for_save, parse_resolutions -) -from xenium_process.utils.config import load_config, merge_config_with_args - - -def add_arguments(parser: argparse.ArgumentParser) -> None: - """ - Add arguments for the annotate command. - - Args: - parser: ArgumentParser to add arguments to - """ - parser.add_argument( - '--input', - required=False, - help='Path to input clustered .zarr file' - ) - parser.add_argument( - '--output', - help='Path to output .zarr file (required unless --inplace is used)' - ) - parser.add_argument( - '--inplace', - action='store_true', - help='Modify the input file in place instead of creating a new file' - ) - parser.add_argument( - '--markers', - help='Path to CSV file with marker genes (columns: cell_type, gene)' - ) - parser.add_argument( - '--cluster-key', - default=None, - help='Cluster column key to use for annotation (e.g., leiden_res0p5). If not specified, will use all leiden_res* columns found' - ) - parser.add_argument( - '--calculate-ulm', - action='store_true', - help='Pre-calculate MLM enrichment scores for pathway/TF resources' - ) - parser.add_argument( - '--panglao-min-sensitivity', - type=float, - default=0.5, - help='Minimum sensitivity for PanglaoDB markers in MLM (default: 0.5)' - ) - parser.add_argument( - '--tmin', - type=int, - default=2, - help='Minimum number of marker genes per cell type for MLM annotation (default: 2)' - ) - parser.add_argument( - '--save-plots', - action='store_true', - help='Generate and save annotation plots' - ) - parser.add_argument( - '--config', - help='Path to TOML configuration file (optional)' - ) - - -def main(args: argparse.Namespace) -> None: - """ - Execute the annotate command. - - Args: - args: Parsed command-line arguments - """ - # Load and merge config if provided - if args.config: - try: - config_dict = load_config(args.config) - # Create a temporary parser to get defaults - temp_parser = argparse.ArgumentParser() - add_arguments(temp_parser) - args = merge_config_with_args('annotate', config_dict, args, temp_parser) - except Exception as e: - logging.error(f"Error loading config file: {e}") - sys.exit(1) - - # Validate required arguments (after config merge) - if not args.input: - logging.error("--input is required (provide via CLI or config file)") - sys.exit(1) - - logging.info("="*60) - logging.info("Xenium Process: Cell Type Annotation") - logging.info("="*60) - - # Validate inputs - input_path = Path(args.input) - if not input_path.exists(): - logging.error(f"Input file not found: {input_path}") - sys.exit(1) - - try: - output_path = get_output_path(args.input, args.output, args.inplace) - except ValueError as e: - logging.error(str(e)) - sys.exit(1) - - try: - # Load spatial data - sdata = load_existing_spatial_data(input_path) - adata = get_table(sdata) - - if adata is None: - raise ValueError("No expression table found in spatial data") - - logging.info(f"Starting annotation: {adata.n_obs} cells × {adata.n_vars} genes") - - # Calculate MLM enrichment scores if requested (before annotation) - if args.calculate_ulm: - adata = annotation.calculate_mlm_scores( - adata, - use_panglao=True, - panglao_min_sensitivity=args.panglao_min_sensitivity - ) - - # Determine which cluster keys to annotate - if args.cluster_key: - cluster_keys = [args.cluster_key] - else: - # Find all leiden_res* columns - cluster_keys = [col for col in adata.obs.columns if col.startswith('leiden_res')] - if not cluster_keys: - logging.warning("No leiden clustering columns found. Run cluster command first.") - - # Cell type annotation if markers provided - markers = None - if args.markers: - markers_path = Path(args.markers) - if not markers_path.exists(): - logging.error(f"Markers file not found: {markers_path}") - sys.exit(1) - - markers = annotation.load_marker_genes(str(markers_path)) - - # Annotate each clustering resolution - for cluster_key in cluster_keys: - # Extract resolution from cluster key for annotation key naming - res_str = cluster_key.replace("leiden_res", "") - annotation_key = f"cell_type_res{res_str}" - - adata = annotation.annotate_with_markers( - adata, - markers, - cluster_key=cluster_key, - annotation_key=annotation_key, - tmin=args.tmin - ) - - # Run differential expression for each cluster key if not already done - for cluster_key in cluster_keys: - adata = annotation.run_differential_expression(adata, cluster_key) - - # Update the SpatialData table with processed AnnData - prepare_spatial_data_for_save(adata) - set_table(sdata, adata) - - # Save results - if args.inplace: - logging.info(f"Saving results in place: {output_path}") - else: - output_path.parent.mkdir(parents=True, exist_ok=True) - logging.info(f"Saving results to: {output_path}") - - save_spatial_data(sdata, output_path, overwrite=args.inplace) - - # Generate plots if requested - if args.save_plots: - plots_dir = output_path.parent / "plots" - plots_dir.mkdir(exist_ok=True) - - for cluster_key in cluster_keys: - # Extract resolution - res_str = cluster_key.replace("leiden_res", "") - try: - resolution = float(res_str.replace("p", ".")) - except ValueError: - resolution = None - - annotation_key = f"cell_type_res{res_str}" - - # UMAP with annotation - if annotation_key in adata.obs.columns: - plotting.save_umap_plots( - adata, plots_dir, cluster_key, annotation_key, resolution - ) - - # Marker dotplots - if markers: - plotting.save_marker_dotplot(adata, plots_dir, markers, cluster_key, resolution) - - # Differential expression plots - plotting.save_de_plots(adata, plots_dir, cluster_key, resolution) - - # Enrichment heatmap - if annotation_key in adata.obs.columns: - plotting.create_enrichment_heatmap(adata, plots_dir, cluster_key, resolution) - - logging.info("="*60) - logging.info(f"Annotation complete: {output_path}") - logging.info("="*60) - - except Exception as e: - logging.error(f"Annotation failed: {e}", exc_info=True) - sys.exit(1) - diff --git a/xenium_process/commands/assign.py b/xenium_process/commands/assign.py new file mode 100644 index 0000000..45e03d8 --- /dev/null +++ b/xenium_process/commands/assign.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +Assign command: Assign cell type labels to clusters from enrichment scores. + +Reads a score matrix stored in obsm (produced by the quantitate command), +applies a configurable assignment strategy to label each cluster, and +optionally runs per-cluster differential expression. +""" + +import argparse +import logging +import sys +from pathlib import Path + +from xenium_process.core.data_io import load_existing_spatial_data, save_spatial_data +from xenium_process.core import annotation +from xenium_process.core import plotting +from xenium_process.utils.helpers import ( + get_table, set_table, get_output_path, + prepare_spatial_data_for_save, +) +from xenium_process.utils.config import load_config, merge_config_with_args + + +def add_arguments(parser: argparse.ArgumentParser) -> None: + """Add arguments for the assign command.""" + parser.add_argument( + "--input", + required=False, + help="Path to input scored .zarr file (produced by quantitate)", + ) + parser.add_argument( + "--output", + help="Path to output .zarr file (required unless --inplace is used)", + ) + parser.add_argument( + "--inplace", + action="store_true", + help="Modify the input file in place instead of creating a new file", + ) + parser.add_argument( + "--score-key", + required=False, + default=None, + help=( + "Full obsm key name holding the enrichment scores to use for assignment " + "(e.g. 'score_mlm_custom', 'score_mlm_PanglaoDB'). " + "Must match the key produced by the quantitate command." + ), + ) + parser.add_argument( + "--cluster-key", + default=None, + help=( + "Cluster column key in obs to annotate (e.g. leiden_res0p5). " + "If not specified, all leiden_res* columns are used." + ), + ) + parser.add_argument( + "--annotation-key", + default=None, + help=( + "obs column name to write cell type labels into. " + "Defaults to 'cell_type_res{resolution}' derived from --cluster-key." + ), + ) + parser.add_argument( + "--strategy", + default="top_positive", + choices=list(annotation.STRATEGY_REGISTRY), + help=( + "Assignment strategy. Default: 'top_positive' (highest positive " + "enrichment stat per cluster; 'Unknown' if none)." + ), + ) + parser.add_argument( + "--run-de", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Run differential expression (rank_genes_groups) per cluster key. " + "Use --no-run-de to skip. Default: True" + ), + ) + parser.add_argument( + "--save-plots", + action="store_true", + help="Generate and save UMAP, dotplot, DE, and enrichment heatmap plots", + ) + parser.add_argument( + "--config", + help="Path to TOML configuration file (optional)", + ) + + +def main(args: argparse.Namespace) -> None: + """Execute the assign command.""" + # Load and merge config if provided + if args.config: + try: + config_dict = load_config(args.config) + temp_parser = argparse.ArgumentParser() + add_arguments(temp_parser) + args = merge_config_with_args("assign", config_dict, args, temp_parser) + except Exception as exc: + logging.error(f"Error loading config file: {exc}") + sys.exit(1) + + # Validate required arguments + if not args.input: + logging.error("--input is required (provide via CLI or config file)") + sys.exit(1) + + if not args.score_key: + logging.error("--score-key is required (provide via CLI or config file)") + sys.exit(1) + + logging.info("=" * 60) + logging.info("Xenium Process: Cluster Label Assignment (assign)") + logging.info("=" * 60) + + input_path = Path(args.input) + if not input_path.exists(): + logging.error(f"Input file not found: {input_path}") + sys.exit(1) + + try: + output_path = get_output_path(args.input, args.output, args.inplace) + except ValueError as exc: + logging.error(str(exc)) + sys.exit(1) + + try: + sdata = load_existing_spatial_data(input_path) + adata = get_table(sdata) + + if adata is None: + raise ValueError("No expression table found in spatial data") + + logging.info(f"Loaded dataset: {adata.n_obs} cells × {adata.n_vars} genes") + + # Validate score key presence + if args.score_key not in adata.obsm: + logging.error( + f"Score key '{args.score_key}' not found in adata.obsm. " + f"Available keys: {list(adata.obsm.keys())}. " + "Run the quantitate command first." + ) + sys.exit(1) + + # Determine cluster keys + if args.cluster_key: + cluster_keys = [args.cluster_key] + else: + cluster_keys = [col for col in adata.obs.columns if col.startswith("leiden_res")] + if not cluster_keys: + logging.warning( + "No leiden_res* columns found. Run the cluster command first." + ) + + # ------------------------------------------------------------------ # + # Assign cell type labels per clustering resolution + # ------------------------------------------------------------------ # + for cluster_key in cluster_keys: + res_str = cluster_key.replace("leiden_res", "") + + # Determine annotation key + if args.annotation_key: + annotation_key = args.annotation_key + else: + annotation_key = f"cell_type_res{res_str}" + + adata = annotation.assign_clusters( + adata, + score_key=args.score_key, + cluster_key=cluster_key, + annotation_key=annotation_key, + strategy=args.strategy, + ) + + # ------------------------------------------------------------------ # + # Differential expression + # ------------------------------------------------------------------ # + if args.run_de: + for cluster_key in cluster_keys: + adata = annotation.run_differential_expression(adata, cluster_key) + + # ------------------------------------------------------------------ # + # Save + # ------------------------------------------------------------------ # + prepare_spatial_data_for_save(adata) + set_table(sdata, adata) + + if not args.inplace: + output_path.parent.mkdir(parents=True, exist_ok=True) + + save_spatial_data(sdata, output_path, overwrite=args.inplace) + logging.info(f"Saved results to: {output_path}") + + # ------------------------------------------------------------------ # + # Optional plots + # ------------------------------------------------------------------ # + if args.save_plots: + plots_dir = output_path.parent / "plots" + plots_dir.mkdir(exist_ok=True) + + for cluster_key in cluster_keys: + res_str = cluster_key.replace("leiden_res", "") + try: + resolution = float(res_str.replace("p", ".")) + except ValueError: + resolution = None + + if args.annotation_key: + annotation_key = args.annotation_key + else: + annotation_key = f"cell_type_res{res_str}" + + if annotation_key in adata.obs.columns: + plotting.save_umap_plots( + adata, plots_dir, cluster_key, annotation_key, resolution + ) + plotting.create_enrichment_heatmap( + adata, plots_dir, cluster_key, resolution + ) + + if args.run_de: + plotting.save_de_plots(adata, plots_dir, cluster_key, resolution) + + logging.info("=" * 60) + logging.info("Assign complete.") + logging.info("=" * 60) + + except Exception as exc: + logging.error(f"Assign failed: {exc}", exc_info=True) + sys.exit(1) diff --git a/xenium_process/commands/quantitate.py b/xenium_process/commands/quantitate.py new file mode 100644 index 0000000..c12c109 --- /dev/null +++ b/xenium_process/commands/quantitate.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +""" +Quantitate command: Run MLM/ULM enrichment scoring across the dataset. + +Supports a custom marker gene list (CSV), decoupler built-in resources +(panglao, hallmark, collectri, dorothea, progeny), or both simultaneously. +An optional cell filter restricts scoring to a subset of cells; scores are +written back into the full dataset with NaN for excluded cells. +""" + +import argparse +import logging +import sys +from pathlib import Path + +from xenium_process.core.data_io import load_existing_spatial_data, save_spatial_data +from xenium_process.core import annotation +from xenium_process.utils.helpers import ( + get_table, set_table, get_output_path, + prepare_spatial_data_for_save, +) +from xenium_process.utils.config import load_config, merge_config_with_args + +VALID_METHODS = ("mlm", "ulm") +VALID_PRESETS = annotation.PRESET_RESOURCE_NAMES # ('panglao', 'hallmark', ...) + + +def add_arguments(parser: argparse.ArgumentParser) -> None: + """Add arguments for the quantitate command.""" + parser.add_argument( + "--input", + required=False, + help="Path to input clustered .zarr file", + ) + parser.add_argument( + "--output", + help="Path to output .zarr file (required unless --inplace is used)", + ) + parser.add_argument( + "--inplace", + action="store_true", + help="Modify the input file in place instead of creating a new file", + ) + parser.add_argument( + "--markers", + default=None, + help=( + "Path to CSV file with marker genes (columns: cell_type, gene). " + "At least one of --markers or --preset-resources is required." + ), + ) + parser.add_argument( + "--score-key", + default="custom", + help=( + "Key suffix for custom marker scores stored in obsm. " + "Result is stored at obsm['score__']. " + "Default: 'custom'" + ), + ) + parser.add_argument( + "--method", + default="mlm", + choices=list(VALID_METHODS), + help="Decoupler scoring method: 'mlm' (default) or 'ulm'", + ) + parser.add_argument( + "--tmin", + type=int, + default=2, + help="Minimum number of targets per source for decoupler (default: 2)", + ) + parser.add_argument( + "--preset-resources", + default=None, + help=( + "Comma-separated list of built-in decoupler resources to score against. " + f"Valid names: {', '.join(VALID_PRESETS)}. " + "Each resource is stored at obsm['score__']." + ), + ) + parser.add_argument( + "--panglao-min-sensitivity", + type=float, + default=0.5, + help="Minimum sensitivity for PanglaoDB markers (default: 0.5)", + ) + parser.add_argument( + "--panglao-canonical-only", + action="store_true", + default=True, + help="Only use canonical PanglaoDB markers (default: True)", + ) + parser.add_argument( + "--filter-obs", + default=None, + help=( + "Filter expression 'column==value' to subset cells before scoring " + "(e.g. 'cell_type==Fibroblast'). Scores for excluded cells are NaN." + ), + ) + parser.add_argument( + "--save-plots", + action="store_true", + help="Generate and save enrichment heatmap plots", + ) + parser.add_argument( + "--config", + help="Path to TOML configuration file (optional)", + ) + + +def main(args: argparse.Namespace) -> None: + """Execute the quantitate command.""" + # Load and merge config if provided + if args.config: + try: + config_dict = load_config(args.config) + temp_parser = argparse.ArgumentParser() + add_arguments(temp_parser) + args = merge_config_with_args("quantitate", config_dict, args, temp_parser) + except Exception as exc: + logging.error(f"Error loading config file: {exc}") + sys.exit(1) + + # Validate required arguments + if not args.input: + logging.error("--input is required (provide via CLI or config file)") + sys.exit(1) + + if not args.markers and not args.preset_resources: + logging.error( + "At least one of --markers or --preset-resources must be specified." + ) + sys.exit(1) + + logging.info("=" * 60) + logging.info("Xenium Process: Enrichment Scoring (quantitate)") + logging.info("=" * 60) + + input_path = Path(args.input) + if not input_path.exists(): + logging.error(f"Input file not found: {input_path}") + sys.exit(1) + + try: + output_path = get_output_path(args.input, args.output, args.inplace) + except ValueError as exc: + logging.error(str(exc)) + sys.exit(1) + + try: + sdata = load_existing_spatial_data(input_path) + adata = get_table(sdata) + + if adata is None: + raise ValueError("No expression table found in spatial data") + + logging.info(f"Loaded dataset: {adata.n_obs} cells × {adata.n_vars} genes") + + # ------------------------------------------------------------------ # + # Optional cell filter + # ------------------------------------------------------------------ # + mask = None + if args.filter_obs: + try: + mask, _ = annotation.filter_cells_by_obs(adata, args.filter_obs) + except (ValueError, KeyError) as exc: + logging.error(f"Invalid --filter-obs expression: {exc}") + sys.exit(1) + + # ------------------------------------------------------------------ # + # Custom marker gene scoring + # ------------------------------------------------------------------ # + if args.markers: + markers_path = Path(args.markers) + if not markers_path.exists(): + logging.error(f"Markers file not found: {markers_path}") + sys.exit(1) + + markers = annotation.load_marker_genes(str(markers_path)) + net_df = annotation.markers_dict_to_dataframe(markers) + + all_marker_genes = set(net_df["target"]) + missing = all_marker_genes - set(adata.var_names) + if missing: + logging.info(f"Note: {len(missing)} marker genes not found in dataset") + + adata = annotation.run_enrichment_scoring( + adata, + net_df=net_df, + score_key=args.score_key, + method=args.method, + tmin=args.tmin, + mask=mask, + ) + + # ------------------------------------------------------------------ # + # Preset resource scoring + # ------------------------------------------------------------------ # + if args.preset_resources: + preset_names = [n.strip() for n in args.preset_resources.split(",") if n.strip()] + for name in preset_names: + try: + net_df = annotation.load_preset_resource( + name, + panglao_min_sensitivity=args.panglao_min_sensitivity, + panglao_canonical_only=args.panglao_canonical_only, + ) + except ValueError as exc: + logging.error(str(exc)) + sys.exit(1) + + adata = annotation.run_enrichment_scoring( + adata, + net_df=net_df, + score_key=name, + method=args.method, + tmin=args.tmin, + mask=mask, + ) + + # ------------------------------------------------------------------ # + # Save + # ------------------------------------------------------------------ # + prepare_spatial_data_for_save(adata) + set_table(sdata, adata) + + if not args.inplace: + output_path.parent.mkdir(parents=True, exist_ok=True) + + save_spatial_data(sdata, output_path, overwrite=args.inplace) + logging.info(f"Saved results to: {output_path}") + + # ------------------------------------------------------------------ # + # Optional plots + # ------------------------------------------------------------------ # + if args.save_plots: + from xenium_process.core import plotting + plots_dir = output_path.parent / "plots" + plots_dir.mkdir(exist_ok=True) + + scored_keys = [] + if args.markers: + scored_keys.append(f"score_{args.method}_{args.score_key}") + if args.preset_resources: + for name in [n.strip() for n in args.preset_resources.split(",") if n.strip()]: + scored_keys.append(f"score_{args.method}_{name}") + + for key in scored_keys: + if key in adata.obsm: + try: + plotting.create_enrichment_heatmap( + adata, plots_dir, key, resolution=None + ) + except Exception as exc: + logging.warning(f" Could not save enrichment heatmap for {key}: {exc}") + + logging.info("=" * 60) + logging.info("Quantitate complete.") + logging.info("=" * 60) + + except Exception as exc: + logging.error(f"Quantitate failed: {exc}", exc_info=True) + sys.exit(1) diff --git a/xenium_process/core/annotation.py b/xenium_process/core/annotation.py index 805be49..5b80c43 100644 --- a/xenium_process/core/annotation.py +++ b/xenium_process/core/annotation.py @@ -3,86 +3,433 @@ Cell type annotation and enrichment scoring functions. This module handles marker-based cell type annotation using decoupler, -differential expression analysis, and MLM score calculation for -multiple pathway/TF resources. +differential expression analysis, and MLM/ULM score calculation for +custom gene lists and multiple built-in pathway/TF resources. """ import logging -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import anndata as ad import decoupler as dc +import numpy as np import pandas as pd import scanpy as sc +# --------------------------------------------------------------------------- +# Strategy registry for cluster assignment +# --------------------------------------------------------------------------- + +def _strategy_top_positive( + acts: ad.AnnData, + adata: ad.AnnData, + cluster_key: str, +) -> Dict[str, str]: + """ + Default assignment strategy: for each cluster pick the cell type with + the highest positive enrichment statistic (rankby_group). Clusters with + no positive stat are labelled "Unknown". + + Args: + acts: AnnData of enrichment scores (cells × cell_types) + adata: Full AnnData object (used to read cluster labels) + cluster_key: obs column with cluster assignments + + Returns: + Dict mapping cluster label → assigned cell type name + """ + enr = dc.tl.rankby_group(acts, groupby=cluster_key) + + annotation_dict = ( + enr[enr["stat"] > 0] + .groupby("group", observed=True) + .head(1) + .set_index("group")["name"] + .to_dict() + ) + + # Clusters with no positive score get "Unknown" + for cluster in adata.obs[cluster_key].unique(): + if cluster not in annotation_dict: + annotation_dict[cluster] = "Unknown" + + return annotation_dict + + +def _strategy_threshold( + acts: ad.AnnData, + adata: ad.AnnData, + cluster_key: str, + threshold: float = 0.0, +) -> Dict[str, str]: + """ + Stub: assign the top cell type per cluster only when its mean score + exceeds *threshold*; otherwise assign "Unknown". + """ + raise NotImplementedError( + "The 'threshold' strategy is not yet implemented. " + "Use 'top_positive' for now." + ) + + +def _strategy_top_n_vote( + acts: ad.AnnData, + adata: ad.AnnData, + cluster_key: str, + n: int = 3, +) -> Dict[str, str]: + """ + Stub: consensus vote among top-N scoring cell types per cluster. + """ + raise NotImplementedError( + "The 'top_n_vote' strategy is not yet implemented. " + "Use 'top_positive' for now." + ) + + +STRATEGY_REGISTRY: Dict = { + "top_positive": _strategy_top_positive, + "threshold": _strategy_threshold, + "top_n_vote": _strategy_top_n_vote, +} + + +# --------------------------------------------------------------------------- +# Cell filtering +# --------------------------------------------------------------------------- + +def filter_cells_by_obs( + adata: ad.AnnData, + expr: str, +) -> Tuple[np.ndarray, ad.AnnData]: + """ + Parse a simple equality expression and return a boolean mask + subset. + + Args: + adata: AnnData object + expr: Expression of the form "column==value", e.g. "cell_type==Fibroblast" + + Returns: + (mask, adata_sub) where mask is a boolean array of length n_obs + and adata_sub is adata[mask]. + + Raises: + ValueError: If expr does not contain exactly one "==" separator. + KeyError: If the referenced column does not exist in adata.obs. + """ + if "==" not in expr: + raise ValueError( + f"filter_cells_by_obs: expression must contain '==', got: {expr!r}" + ) + + col, val = expr.split("==", 1) + col = col.strip() + val = val.strip() + + if col not in adata.obs.columns: + raise KeyError( + f"filter_cells_by_obs: column {col!r} not found in adata.obs. " + f"Available columns: {list(adata.obs.columns)}" + ) + + mask = (adata.obs[col] == val).values + + n_match = mask.sum() + if n_match == 0: + logging.warning( + f"filter_cells_by_obs: filter '{expr}' matched 0 cells. " + "Scores will be all NaN." + ) + else: + logging.info( + f"filter_cells_by_obs: filter '{expr}' matched {n_match} / {adata.n_obs} cells" + ) + + adata_sub = adata[mask].copy() + return mask, adata_sub + + +# --------------------------------------------------------------------------- +# Preset resource loading +# --------------------------------------------------------------------------- + +PRESET_RESOURCE_NAMES = ("panglao", "hallmark", "collectri", "dorothea", "progeny") + + +def load_preset_resource( + name: str, + panglao_min_sensitivity: float = 0.5, + panglao_canonical_only: bool = True, + organism: str = "human", +) -> pd.DataFrame: + """ + Load a named decoupler built-in gene-set resource as a DataFrame. + + Supported names: panglao, hallmark, collectri, dorothea, progeny. + + Args: + name: Resource name (case-insensitive match against PRESET_RESOURCE_NAMES) + panglao_min_sensitivity: Sensitivity threshold for PanglaoDB filtering + panglao_canonical_only: Only use canonical PanglaoDB markers + organism: Organism name (default: "human") + + Returns: + DataFrame with at least columns "source" and "target" in decoupler format. + + Raises: + ValueError: If name is not a recognised preset resource. + """ + name_lower = name.lower() + + if name_lower == "panglao": + return get_panglao_markers( + organism=organism, + min_sensitivity=panglao_min_sensitivity, + canonical_only=panglao_canonical_only, + ) + elif name_lower == "hallmark": + logging.info(f"Loading Hallmark gene sets (organism={organism})") + return dc.op.hallmark(organism=organism) + elif name_lower == "collectri": + logging.info(f"Loading CollectRI TF regulons (organism={organism})") + return dc.op.collectri(organism=organism) + elif name_lower == "dorothea": + logging.info(f"Loading DoRothEA TF regulons (organism={organism})") + return dc.op.dorothea(organism=organism) + elif name_lower == "progeny": + logging.info(f"Loading PROGENy pathway gene sets (organism={organism})") + return dc.op.progeny(organism=organism) + else: + raise ValueError( + f"Unknown preset resource: {name!r}. " + f"Supported names: {PRESET_RESOURCE_NAMES}" + ) + + +# --------------------------------------------------------------------------- +# Core enrichment scoring +# --------------------------------------------------------------------------- + +def run_enrichment_scoring( + adata: ad.AnnData, + net_df: pd.DataFrame, + score_key: str, + method: str = "mlm", + tmin: int = 2, + mask: Optional[np.ndarray] = None, +) -> ad.AnnData: + """ + Run MLM or ULM enrichment scoring and store the result in adata.obsm. + + When *mask* is provided, scoring is performed only on the masked subset + and the results are written back into the full adata (NaN for excluded + cells). + + Args: + adata: Full AnnData object (normalised expression in .X) + net_df: Gene-set network in decoupler format (source, target[, weight]) + score_key: Key suffix; result stored at obsm[f"score_{method}_{score_key}"] + method: "mlm" or "ulm" + tmin: Minimum number of targets per source + mask: Optional boolean array of length n_obs; if provided, scoring + runs on adata[mask] and is merged back with NaN fill. + + Returns: + adata with obsm key added/updated in-place (also returned for chaining). + + Raises: + ValueError: If method is not "mlm" or "ulm". + """ + if method not in ("mlm", "ulm"): + raise ValueError(f"method must be 'mlm' or 'ulm', got {method!r}") + + obsm_key = f"score_{method}_{score_key}" + raw_obsm_key = f"score_{method}" # decoupler always writes to this key + + if mask is not None: + adata_run = adata[mask].copy() + else: + adata_run = adata + + logging.info( + f"Running {method.upper()} scoring on {adata_run.n_obs} cells " + f"with {len(net_df)} network entries (score_key={score_key!r})" + ) + + try: + if method == "mlm": + dc.mt.mlm(data=adata_run, net=net_df, verbose=False, tmin=tmin) + else: + dc.mt.ulm(data=adata_run, net=net_df, verbose=False, tmin=tmin) + except Exception as exc: + logging.warning(f" Enrichment scoring failed: {exc}") + return adata + + scores = adata_run.obsm[raw_obsm_key] + + if mask is not None: + # Build a full-sized DataFrame initialised to NaN, fill in scored rows + if isinstance(scores, pd.DataFrame): + full = pd.DataFrame( + np.nan, + index=adata.obs_names, + columns=scores.columns, + dtype=float, + ) + full.loc[adata_run.obs_names] = scores.values + else: + full = np.full((adata.n_obs, scores.shape[1]), np.nan, dtype=float) + full[mask] = scores + adata.obsm[obsm_key] = full + else: + adata.obsm[obsm_key] = scores + + n_sources = scores.shape[1] if hasattr(scores, "shape") and len(scores.shape) > 1 else 1 + logging.info(f" Stored {n_sources} source scores at obsm['{obsm_key}']") + + return adata + + +# --------------------------------------------------------------------------- +# Cluster assignment +# --------------------------------------------------------------------------- + +def assign_clusters( + adata: ad.AnnData, + score_key: str, + cluster_key: str, + annotation_key: str, + strategy: str = "top_positive", +) -> ad.AnnData: + """ + Assign a cell type label to each cluster based on enrichment scores. + + Args: + adata: AnnData object with scores in obsm[score_key] + score_key: Full obsm key produced by run_enrichment_scoring + (e.g. "score_mlm_custom") + cluster_key: obs column containing cluster assignments + annotation_key: obs column name to write labels into + strategy: Assignment strategy name from STRATEGY_REGISTRY + + Returns: + adata with annotation_key column added to obs. + + Raises: + ValueError: If strategy is not in STRATEGY_REGISTRY or score_key missing. + """ + if strategy not in STRATEGY_REGISTRY: + raise ValueError( + f"Unknown strategy {strategy!r}. " + f"Available strategies: {list(STRATEGY_REGISTRY)}" + ) + + if score_key not in adata.obsm: + raise ValueError( + f"Score key {score_key!r} not found in adata.obsm. " + f"Available keys: {list(adata.obsm.keys())}" + ) + + logging.info( + f"Assigning clusters using strategy={strategy!r}, " + f"cluster_key={cluster_key!r}, score_key={score_key!r}" + ) + + acts = dc.pp.get_obsm(adata, score_key) + + strategy_fn = STRATEGY_REGISTRY[strategy] + annotation_dict = strategy_fn(acts, adata, cluster_key) + + adata.obs[annotation_key] = adata.obs[cluster_key].map(annotation_dict) + adata.obs[annotation_key] = adata.obs[annotation_key].astype("category") + + annotation_counts = adata.obs[annotation_key].value_counts() + logging.info("Cell type assignment summary:") + for cell_type, count in annotation_counts.items(): + logging.info(f" {cell_type}: {count} cells") + + return adata + + +# --------------------------------------------------------------------------- +# Legacy / convenience functions (kept for backward compatibility) +# --------------------------------------------------------------------------- + def load_marker_genes(marker_path: str) -> Dict[str, List[str]]: """ Load marker genes from CSV file. - + Args: marker_path: Path to CSV file with columns: cell_type, gene - + Returns: Dictionary mapping cell type to list of marker genes """ logging.info(f"Loading marker genes from {marker_path}") - + df = pd.read_csv(marker_path) - + if not all(col in df.columns for col in ["cell_type", "gene"]): raise ValueError("Marker CSV must have 'cell_type' and 'gene' columns") - - # Group by cell type + markers = df.groupby("cell_type")["gene"].apply(list).to_dict() - + total_markers = sum(len(genes) for genes in markers.values()) logging.info(f"Loaded {len(markers)} cell types with {total_markers} total marker genes") - + return markers +def markers_dict_to_dataframe(markers: Dict[str, List[str]]) -> pd.DataFrame: + """ + Convert a {cell_type: [gene, ...]} dict to decoupler network format. + + Returns: + DataFrame with columns: source, target, weight (all weights = 1) + """ + rows = [ + {"source": cell_type, "target": gene} + for cell_type, genes in markers.items() + for gene in genes + ] + df = pd.DataFrame(rows) + df["weight"] = 1 + return df + + def get_panglao_markers( organism: str = "human", min_sensitivity: float = 0.5, - canonical_only: bool = True + canonical_only: bool = True, ) -> pd.DataFrame: """ Get PanglaoDB markers with filtering. - - Best practices per: - https://decoupler.readthedocs.io/en/latest/notebooks/scell/rna_sc.html#panglaodb - + Args: organism: Organism name ('human' or 'mouse') min_sensitivity: Minimum sensitivity threshold (0-1) canonical_only: If True, only use canonical markers - + Returns: DataFrame with columns: source (cell_type), target (gene) """ - logging.info(f"Loading PanglaoDB markers (organism={organism}, min_sensitivity={min_sensitivity})") - + logging.info( + f"Loading PanglaoDB markers (organism={organism}, min_sensitivity={min_sensitivity})" + ) + markers = dc.op.resource("PanglaoDB", organism=organism) - - # Apply filters + filters = markers[organism].astype(bool) - if canonical_only: filters &= markers["canonical_marker"].astype(bool) - - filters &= (markers[f"{organism}_sensitivity"].astype(float) > min_sensitivity) - + filters &= markers[f"{organism}_sensitivity"].astype(float) > min_sensitivity + markers = markers[filters] - - # Remove duplicates markers = markers[~markers.duplicated(["cell_type", "genesymbol"])] - - # Rename columns to decoupler format markers = markers.rename(columns={"cell_type": "source", "genesymbol": "target"}) - + logging.info(f" Filtered to {len(markers)} PanglaoDB markers") - return markers[["source", "target"]] @@ -92,15 +439,15 @@ def annotate_with_markers( cluster_key: str = "leiden", annotation_key: str = "cell_type", resume: bool = False, - tmin: int = 2 + tmin: int = 2, ) -> ad.AnnData: """ Annotate clusters with cell types based on marker gene expression using decoupler's multivariate linear model (MLM) approach. - - This method uses enrichment analysis to test if marker gene collections - are enriched in cells, similar to the approach in the Scverse tutorial. - + + This is kept for backward compatibility; internally it delegates to + run_enrichment_scoring and assign_clusters. + Args: adata: AnnData object markers: Dictionary mapping cell type to list of marker genes @@ -108,77 +455,34 @@ def annotate_with_markers( annotation_key: Key name for storing cell type annotations resume: If True, skip if annotation already exists tmin: Minimum number of targets per source (default: 2) - + Returns: AnnData object with cell type annotations added """ if resume and annotation_key in adata.obs.columns: - logging.info(f"Cell type annotation already exists (resuming)") + logging.info("Cell type annotation already exists (resuming)") return adata - - logging.info(f"Annotating cell types using decoupler MLM (cluster_key={cluster_key})") - - # Check which marker genes are present - all_marker_genes = set() - for genes in markers.values(): - all_marker_genes.update(genes) - - missing_genes = all_marker_genes - set(adata.var_names) - if missing_genes: - logging.info(f"Note: {len(missing_genes)} marker genes not found in dataset") - - # Convert markers dictionary to DataFrame format expected by decoupler - # Format: columns 'source' (cell_type) and 'target' (gene) - marker_rows = [] - for cell_type, genes in markers.items(): - for gene in genes: - marker_rows.append({"source": cell_type, "target": gene}) - - marker_df = pd.DataFrame(marker_rows) - - # Add weight column (all weights = 1) - marker_df["weight"] = 1 - - logging.info(f"Running MLM with {len(marker_df)} marker gene entries across {len(markers)} cell types") - - # Run multivariate linear model - # This calculates enrichment scores for each cell type in each cell - dc.mt.mlm(adata, net=marker_df, verbose=False, tmin=tmin) - - # Extract the MLM scores from adata.obsm - # This creates a new AnnData-like object with cells x cell_types - acts = dc.pp.get_obsm(adata, "score_mlm") - - # For each cluster, find the cell type with highest enrichment score - # Use decoupler's rankby_group to get top scoring cell type per cluster - enr = dc.tl.rankby_group(acts, groupby=cluster_key) - - # Get the top cell type (highest stat) for each cluster - # Filter to positive stats only (stat > 0) - annotation_dict = ( - enr[enr["stat"] > 0] - .groupby("group", observed=True) - .head(1) - .set_index("group")["name"] - .to_dict() + + # Log coverage + all_marker_genes = {g for genes in markers.values() for g in genes} + missing = all_marker_genes - set(adata.var_names) + if missing: + logging.info(f"Note: {len(missing)} marker genes not found in dataset") + + net_df = markers_dict_to_dataframe(markers) + + score_key = annotation_key # re-use annotation_key as score_key suffix + adata = run_enrichment_scoring(adata, net_df, score_key=score_key, method="mlm", tmin=tmin) + + obsm_key = f"score_mlm_{score_key}" + adata = assign_clusters( + adata, + score_key=obsm_key, + cluster_key=cluster_key, + annotation_key=annotation_key, + strategy="top_positive", ) - - # Handle clusters that may not have positive enrichment scores - all_clusters = adata.obs[cluster_key].unique() - for cluster in all_clusters: - if cluster not in annotation_dict: - annotation_dict[cluster] = "Unknown" - - # Map cluster annotations to cells - adata.obs[annotation_key] = adata.obs[cluster_key].map(annotation_dict) - adata.obs[annotation_key] = adata.obs[annotation_key].astype("category") - - # Log annotation summary - annotation_counts = adata.obs[annotation_key].value_counts() - logging.info("Cell type annotation summary:") - for cell_type, count in annotation_counts.items(): - logging.info(f" {cell_type}: {count} cells") - + return adata @@ -187,72 +491,51 @@ def calculate_mlm_scores( use_panglao: bool = True, panglao_min_sensitivity: float = 0.5, tmin: int = 5, - resume: bool = False + resume: bool = False, ) -> ad.AnnData: """ Pre-calculate MLM scores for multiple decoupler resources. - + Resources include: - hallmark: Hallmark gene sets - collectri: Transcription factor regulons - dorothea: TF activity inference - progeny: Pathway activity - PanglaoDB: Cell type markers (optional, filtered) - + Scores are stored in adata.obsm[f'score_mlm_{resource}'] - + Args: adata: AnnData object with normalized data use_panglao: If True, include PanglaoDB markers panglao_min_sensitivity: Minimum sensitivity for PanglaoDB markers tmin: Minimum number of targets per source (default: 5) resume: If True, skip resources that already have scores - + Returns: AnnData object with MLM scores added to obsm """ logging.info("Calculating MLM scores for pathway/TF resources") - - # Define resources using decoupler omnipath API (dc.op.*) - resources = [ - ('hallmark', dc.op.hallmark(organism='human')), - ('collectri', dc.op.collectri(organism='human')), - ('dorothea', dc.op.dorothea(organism='human')), - ('progeny', dc.op.progeny(organism='human')) - ] - + + resource_names = ["hallmark", "collectri", "dorothea", "progeny"] if use_panglao: - panglao = get_panglao_markers( - organism="human", - min_sensitivity=panglao_min_sensitivity, - canonical_only=True - ) - resources.append(('PanglaoDB', panglao)) - - # Calculate MLM for each resource - for name, resource in resources: - obsm_key = f'score_mlm_{name}' - + resource_names.append("panglao") + + for name in resource_names: + obsm_key = f"score_mlm_{name}" if resume and obsm_key in adata.obsm: logging.info(f" MLM scores for {name} already calculated (resuming)") continue - - logging.info(f" Calculating MLM for {name}") - + try: - # Run MLM - dc.mt.mlm(data=adata, net=resource, tmin=tmin) - - # Store in named obsm key - adata.obsm[obsm_key] = adata.obsm['score_mlm'].copy() - - # Get shape info - n_sources = adata.obsm[obsm_key].shape[1] if len(adata.obsm[obsm_key].shape) > 1 else 1 - logging.info(f" Calculated scores for {n_sources} sources") - - except Exception as e: - logging.warning(f" Failed to calculate MLM for {name}: {e}") - + net_df = load_preset_resource( + name, + panglao_min_sensitivity=panglao_min_sensitivity, + ) + adata = run_enrichment_scoring(adata, net_df, score_key=name, method="mlm", tmin=tmin) + except Exception as exc: + logging.warning(f" Failed to calculate MLM for {name}: {exc}") + logging.info("MLM score calculation complete") return adata @@ -261,44 +544,42 @@ def run_differential_expression( adata: ad.AnnData, cluster_key: str, method: str = "wilcoxon", - resume: bool = False + resume: bool = False, ) -> ad.AnnData: """ Run differential expression analysis to find marker genes for each cluster. - + Args: adata: AnnData object cluster_key: Key in adata.obs containing cluster assignments method: Statistical test to use (default: wilcoxon) resume: If True, skip if differential expression already computed - + Returns: AnnData object with differential expression results added """ rank_key = f"rank_genes_{cluster_key}" - + if resume and "rank_genes_groups" in adata.uns and adata.uns.get("rank_genes_groups_key") == rank_key: logging.info(f"Differential expression already computed for {cluster_key} (resuming)") return adata - + logging.info(f"Running differential expression analysis for {cluster_key}") - - # Run rank_genes_groups + sc.tl.rank_genes_groups( adata, groupby=cluster_key, method=method, - use_raw=False, # Use normalized data in .X + use_raw=False, key_added=rank_key, - layer=None + layer=None, ) - - # Store which key was used + adata.uns["rank_genes_groups_key"] = rank_key - + n_clusters = adata.obs[cluster_key].nunique() logging.info(f" Differential expression completed for {n_clusters} clusters") - + return adata @@ -306,11 +587,11 @@ def save_differential_expression_results( adata: ad.AnnData, cluster_key: str, output_dir, - n_genes: int = 100 + n_genes: int = 100, ) -> None: """ Save differential expression results to CSV files. - + Args: adata: AnnData object with differential expression results cluster_key: Key in adata.obs containing cluster assignments @@ -318,30 +599,25 @@ def save_differential_expression_results( n_genes: Number of top genes to save per cluster """ rank_key = f"rank_genes_{cluster_key}" - + if rank_key not in adata.uns: logging.warning(f" No differential expression results found for {cluster_key}") return - + logging.info(f" Saving differential expression results for {cluster_key}") - - # Get the differential expression results as a DataFrame + result = sc.get.rank_genes_groups_df(adata, group=None, key=rank_key) - - # Save all results + de_dir = output_dir / "differential_expression" de_dir.mkdir(exist_ok=True) - + res_str = cluster_key.replace("leiden_res", "") - - # Save complete results + all_results_path = de_dir / f"deg_all_clusters_res{res_str}.csv" result.to_csv(all_results_path, index=False) logging.info(f" Saved all DE genes to {all_results_path}") - - # Save top N genes per cluster + top_results_path = de_dir / f"deg_top{n_genes}_per_cluster_res{res_str}.csv" - top_result = result.groupby('group').head(n_genes) + top_result = result.groupby("group").head(n_genes) top_result.to_csv(top_results_path, index=False) logging.info(f" Saved top {n_genes} DE genes per cluster to {top_results_path}") -