diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 65e8c50..53c97c4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,6 +33,7 @@ jobs: run: | python -m pip install --upgrade pip pip install pandas>=2.2.0 numpy>=1.26.0 --only-binary :all: + pip install --no-cache-dir --upgrade "rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git" pip install -e .[ci] - name: Run unit tests diff --git a/README.md b/README.md index 712a868..68b92c0 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ **Authors:** Preshanth Jagannathan (pjaganna@nrao.edu), Srikrishna Sekhar (ssekhar@nrao.edu), Derod Deal (dealderod@gmail.com) -SAM-RFI is a Python package that applies Meta's Segment Anything Model 2 (SAM2) for Radio Frequency Interference (RFI) detection and flagging in radio astronomy data. The system processes CASA measurement sets and generates precise segmentation masks for contaminated visibilities. +SAM-RFI is a Python package that applies Meta's Segment Anything Model 2 (SAM2) for Radio Frequency Interference (RFI) detection and flagging in radio astronomy data. The system processes CASA measurement sets and generates precise segmentation masks for contaminated visibilities. In order to do that it leverages the `rfi_toolbox` package which presents general purpose RFI simulation and measurement set handling ## Overview diff --git a/pyproject.toml b/pyproject.toml index 13ecc52..6446a6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ requires-python = ">=3.10" # Core dependencies (CPU-only, minimal) dependencies = [ + # rfi_toolbox: Shared RFI utilities (io, preprocessing, evaluation, datasets) + "rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git", "numpy>=1.26.0", "scipy>=1.10.0", "pandas>=2.2.0", @@ -33,7 +35,6 @@ dependencies = [ "tqdm>=4.65.0", "matplotlib>=3.7.0", "datasets>=2.10.0", - "patchify>=0.2.3", "scikit-image>=0.20.0", ] diff --git a/src/samrfi/__init__.py b/src/samrfi/__init__.py index 1d40abd..77db0f3 100644 --- a/src/samrfi/__init__.py +++ b/src/samrfi/__init__.py @@ -35,12 +35,13 @@ Usage: ------ ->>> # Core data operations (no GPU/CASA required) ->>> from samrfi.data import Preprocessor, TorchDataset ->>> from samrfi.data_generation import SyntheticDataGenerator +>>> # Core data operations +>>> from rfi_toolbox.preprocessing import Preprocessor +>>> from rfi_toolbox.datasets import TorchDataset +>>> from rfi_toolbox.data_generation import SyntheticDataGenerator >>> >>> # Optional: CASA-dependent operations ->>> from samrfi.data.ms_loader import MSLoader # Requires pip install samrfi[casa] +>>> from rfi_toolbox.io import MSLoader # Requires pip install samrfi[casa] >>> >>> # Optional: GPU/transformers-dependent operations >>> from samrfi.training import SAM2Trainer # Requires pip install samrfi[gpu] @@ -60,25 +61,19 @@ __version__ = "2.0.0" __author__ = "Derod Deal, Preshanth Jagannathan" -# Config module - always available -# Data module # Config module from .config import ConfigLoader -from .data import ( - BatchedDataset, - BatchWriter, - HFDatasetWrapper, - Preprocessor, - SAMDataset, - TorchDataset, -) -# Data generation module -from .data_generation import SyntheticDataGenerator +# SAM2-specific data modules +from .data import BatchedDataset, HFDatasetWrapper, SAMDataset -# Note: MSLoader and MSDataGenerator require CASA and are not imported by default -# Use: from samrfi.data.ms_loader import MSLoader -# Use: from samrfi.data_generation.ms_generator import MSDataGenerator +# Note: Shared utilities from rfi_toolbox - import directly when needed: +# from rfi_toolbox.io import MSLoader +# from rfi_toolbox.preprocessing import Preprocessor +# from rfi_toolbox.datasets import BatchWriter, TorchDataset +# from rfi_toolbox.data_generation import SyntheticDataGenerator +# Note: MSDataGenerator requires CASA and is not imported by default +# Use: from samrfi.data_generation import MSDataGenerator # Note: ModelCache, RFIPredictor, and SAM2Trainer require transformers and are not imported by default # Use: from samrfi.utils.model_cache import ModelCache diff --git a/src/samrfi/data/__init__.py b/src/samrfi/data/__init__.py index a56245b..0542997 100644 --- a/src/samrfi/data/__init__.py +++ b/src/samrfi/data/__init__.py @@ -1,23 +1,25 @@ """ -Data module - MS loading, preprocessing, and dataset creation +Data module - SAM2-specific dataset wrappers and utilities + +NOTE: Core data utilities (MSLoader, Preprocessor, TorchDataset, BatchWriter) +are in rfi_toolbox. Import directly: + from rfi_toolbox.io import MSLoader + from rfi_toolbox.preprocessing import Preprocessor + from rfi_toolbox.datasets import BatchWriter, TorchDataset """ +# SAM2-specific modules from .adaptive_patcher import AdaptivePatcher, check_ms_compatibility from .gpu_dataset import GPUBatchTransformDataset, GPUTransformDataset from .gpu_transforms import GPUTransforms, create_gpu_transforms from .hf_dataset_wrapper import HFDatasetWrapper -from .preprocessor import GPUPreprocessor, Preprocessor from .ram_dataset import RAMCachedDataset from .sam_dataset import BatchedDataset, SAMDataset -from .torch_dataset import BatchWriter, TorchDataset __all__ = [ - "Preprocessor", - "GPUPreprocessor", + # SAM2-specific "SAMDataset", "BatchedDataset", - "TorchDataset", - "BatchWriter", "HFDatasetWrapper", "AdaptivePatcher", "check_ms_compatibility", @@ -27,6 +29,3 @@ "GPUBatchTransformDataset", "RAMCachedDataset", ] - -# Note: MSLoader requires CASA and is not imported by default -# Use: from samrfi.data.ms_loader import MSLoader diff --git a/src/samrfi/data/preprocessor.py b/src/samrfi/data/preprocessor.py index 5fe43ea..f06fa37 100644 --- a/src/samrfi/data/preprocessor.py +++ b/src/samrfi/data/preprocessor.py @@ -9,7 +9,6 @@ import numpy as np import torch -from patchify import patchify from scipy import stats from samrfi.utils import logger @@ -17,6 +16,29 @@ from .torch_dataset import TorchDataset +def patchify(array, patch_shape, step): + """ + Extract patches from 2D array using torch.unfold (replaces patchify library). + + Args: + array: 2D numpy array (H, W) + patch_shape: Tuple (patch_h, patch_w) + step: Step size for patch extraction + + Returns: + 4D array (n_patches_h, n_patches_w, patch_h, patch_w) + """ + patch_h, patch_w = patch_shape + tensor = torch.from_numpy(array) + + # Use unfold to extract patches: (H, W) -> (n_h, n_w, patch_h, patch_w) + patches = tensor.unfold(0, patch_h, step).unfold(1, patch_w, step) + + # Rearrange to match patchify output format + patches = patches.contiguous().numpy() + return patches + + # Standalone functions for multiprocessing (must be picklable) def _patchify_single_waterfall(waterfall, patch_size): """ diff --git a/src/samrfi/data_generation/__init__.py b/src/samrfi/data_generation/__init__.py index 1dc7ed8..6e3a60a 100644 --- a/src/samrfi/data_generation/__init__.py +++ b/src/samrfi/data_generation/__init__.py @@ -1,8 +1,10 @@ -"""Data generation modules for SAM-RFI""" +"""Data generation modules for SAM-RFI -from .synthetic_generator import SyntheticDataGenerator +NOTE: SyntheticDataGenerator is in rfi_toolbox. Import directly: + from rfi_toolbox.data_generation import SyntheticDataGenerator +""" -__all__ = ["SyntheticDataGenerator"] +# SAM2-specific data generation +from .ms_generator import MSDataGenerator -# Note: MSDataGenerator requires CASA and is not imported by default -# Use: from samrfi.data_generation.ms_generator import MSDataGenerator +__all__ = ["MSDataGenerator"] diff --git a/src/samrfi/data_generation/ms_generator.py b/src/samrfi/data_generation/ms_generator.py index 1f0ab79..5ca40a2 100644 --- a/src/samrfi/data_generation/ms_generator.py +++ b/src/samrfi/data_generation/ms_generator.py @@ -5,8 +5,8 @@ import json from pathlib import Path -from samrfi.data import Preprocessor -from samrfi.data.ms_loader import MSLoader +from rfi_toolbox.io import MSLoader +from rfi_toolbox.preprocessing import Preprocessor class MSDataGenerator: diff --git a/src/samrfi/data_generation/synthetic_generator.py b/src/samrfi/data_generation/synthetic_generator.py index 45fc942..63eaaca 100644 --- a/src/samrfi/data_generation/synthetic_generator.py +++ b/src/samrfi/data_generation/synthetic_generator.py @@ -7,10 +7,9 @@ import numpy as np import torch +from rfi_toolbox.preprocessing import Preprocessor from tqdm import tqdm -from samrfi.data import Preprocessor - # Global generator instance for multiprocessing workers _global_generator = None _global_proc_config = None diff --git a/src/samrfi/evaluation/__init__.py b/src/samrfi/evaluation/__init__.py index f603061..f0fb7ff 100644 --- a/src/samrfi/evaluation/__init__.py +++ b/src/samrfi/evaluation/__init__.py @@ -1,39 +1,16 @@ """ Evaluation metrics and validation tools for RFI segmentation -""" - -from .metrics import ( - compute_dice, - compute_f1, - compute_iou, - compute_precision, - compute_recall, - evaluate_segmentation, -) -from .statistics import ( - compute_calcquality, - compute_ffi, - compute_statistics, - print_statistics_comparison, -) -__all__ = [ - "compute_iou", - "compute_precision", - "compute_recall", - "compute_f1", - "compute_dice", - "evaluate_segmentation", - "compute_statistics", - "compute_ffi", - "compute_calcquality", - "print_statistics_comparison", -] +NOTE: Metrics are in rfi_toolbox. Import directly: + from rfi_toolbox.evaluation import ( + compute_iou, compute_f1, compute_dice, compute_ffi, + compute_precision, compute_recall, evaluate_segmentation, + compute_statistics, print_statistics_comparison + ) + from rfi_toolbox.io import inject_synthetic_data +""" -# Optional CASA dependency for MS injection -try: - from .ms_injection import inject_synthetic_data +# SAM2-specific evaluation tools would go here +# Currently all metrics are in rfi_toolbox - __all__.append("inject_synthetic_data") -except ImportError: - pass # CASA not available +__all__ = [] diff --git a/src/samrfi/inference/predictor.py b/src/samrfi/inference/predictor.py index 884ba91..5663d95 100644 --- a/src/samrfi/inference/predictor.py +++ b/src/samrfi/inference/predictor.py @@ -8,11 +8,13 @@ import numpy as np import torch +from rfi_toolbox.io import MSLoader +from rfi_toolbox.preprocessing import Preprocessor from torch.utils.data import DataLoader from tqdm import tqdm from transformers import Sam2Model, Sam2Processor -from samrfi.data import AdaptivePatcher, Preprocessor, SAMDataset +from samrfi.data import AdaptivePatcher, SAMDataset from samrfi.utils import logger from samrfi.utils.errors import CheckpointMismatchError @@ -554,8 +556,6 @@ def predict_ms( Returns: Predicted flags array (baselines, pols, channels, times) """ - from samrfi.data.ms_loader import MSLoader - logger.info(f"\n{'='*60}") logger.info("RFI Prediction - Single Pass") logger.info(f"{'='*60}") @@ -677,8 +677,6 @@ def predict_iterative( Returns: Cumulative flags from all iterations """ - from samrfi.data.ms_loader import MSLoader - logger.info(f"\n{'='*60}") logger.info(f"RFI Prediction - Iterative ({num_iterations} passes)") logger.info(f"{'='*60}") diff --git a/tests/conftest.py b/tests/conftest.py index 158da8f..8571b00 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -233,7 +233,7 @@ def mock_torch_dataset(): if not TORCH_AVAILABLE: pytest.skip("torch not available") - from samrfi.data import TorchDataset + from rfi_toolbox.datasets import TorchDataset # Create mock images (10 samples, 256×256×3) images = torch.randn(10, 256, 256, 3, dtype=torch.float32) diff --git a/tests/extended/test_full_validation.py b/tests/extended/test_full_validation.py index a837e0d..e1bf62f 100644 --- a/tests/extended/test_full_validation.py +++ b/tests/extended/test_full_validation.py @@ -25,7 +25,8 @@ def test_tiny_model_inference_on_synthetic_data(self, tmp_path, synthetic_data_w This is a smoke test to ensure the full pipeline works end-to-end. Does not train the model (too slow), just tests inference with pretrained weights. """ - from samrfi.evaluation import evaluate_segmentation + from rfi_toolbox.evaluation import evaluate_segmentation + from samrfi.inference import RFIPredictor # Get synthetic data diff --git a/tests/integration/test_data_pipeline.py b/tests/integration/test_data_pipeline.py index 32df433..99bf9be 100644 --- a/tests/integration/test_data_pipeline.py +++ b/tests/integration/test_data_pipeline.py @@ -13,7 +13,7 @@ class TestPreprocessingPipeline: def test_complex_data_to_dataset(self, synthetic_data_with_rfi): """Test complete pipeline from complex data to TorchDataset.""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor data = synthetic_data_with_rfi["data"] flags = synthetic_data_with_rfi["flags"] @@ -46,7 +46,7 @@ def test_complex_data_to_dataset(self, synthetic_data_with_rfi): def test_preprocessing_metadata_consistency(self, synthetic_waterfall_small): """Test that metadata remains consistent through preprocessing.""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor data = synthetic_waterfall_small[np.newaxis, ...] @@ -84,7 +84,7 @@ class TestInferencePipelineIntegration: def test_metadata_flows_to_reconstruction(self, synthetic_data_with_rfi): """Test that metadata flows from preprocessing to reconstruction.""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor data = synthetic_data_with_rfi["data"] @@ -127,7 +127,7 @@ class TestPipelineRobustness: def test_pipeline_handles_single_baseline(self, synthetic_waterfall_small): """Test pipeline works with single baseline.""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor # Single baseline data = synthetic_waterfall_small[np.newaxis, ...] @@ -143,7 +143,7 @@ def test_pipeline_handles_single_baseline(self, synthetic_waterfall_small): def test_pipeline_handles_multiple_baselines(self, synthetic_waterfall_small): """Test pipeline works with multiple baselines.""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor # 3 baselines data = np.stack([synthetic_waterfall_small] * 3) @@ -159,7 +159,7 @@ def test_pipeline_handles_multiple_baselines(self, synthetic_waterfall_small): def test_pipeline_preserves_data_integrity(self, synthetic_data_with_rfi): """Test that preprocessing doesn't corrupt data.""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor data = synthetic_data_with_rfi["data"] original_data = data.copy() @@ -182,7 +182,7 @@ class TestAugmentationConsistency: def test_augmentation_matches_num_rotations(self, synthetic_waterfall_small): """Test that enabling augmentation creates expected number of patches.""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor data = synthetic_waterfall_small[np.newaxis, ...] @@ -206,7 +206,7 @@ def test_augmentation_matches_num_rotations(self, synthetic_waterfall_small): def test_inference_mode_disables_blank_removal(self, synthetic_waterfall_small): """Test that inference mode preserves all patches (no blank removal).""" - from samrfi.data import Preprocessor + from rfi_toolbox.preprocessing import Preprocessor data = synthetic_waterfall_small[np.newaxis, ...] diff --git a/tests/unit/test_aaa_imports.py b/tests/unit/test_aaa_imports.py new file mode 100644 index 0000000..3a0333b --- /dev/null +++ b/tests/unit/test_aaa_imports.py @@ -0,0 +1,188 @@ +""" +Smoke tests for imports - runs first to catch import issues early. + +These tests verify that core imports work correctly without triggering +circular dependencies or missing optional dependencies. + +Naming: test_aaa_* ensures this runs first alphabetically in pytest. +""" + +import pytest + + +class TestCoreImports: + """Test that core samrfi imports work without rfi_toolbox.""" + + def test_samrfi_base_import(self): + """samrfi package should import without triggering rfi_toolbox initialization.""" + try: + import samrfi + + assert samrfi.__version__ is not None + except ImportError as e: + pytest.fail(f"Failed to import samrfi base package: {e}") + + def test_samrfi_config_import(self): + """samrfi.config should import without external dependencies.""" + try: + from samrfi.config import ConfigLoader + + assert ConfigLoader is not None + except ImportError as e: + pytest.fail(f"Failed to import samrfi.config: {e}") + + def test_samrfi_utils_import(self): + """samrfi.utils.errors should import without external dependencies.""" + try: + from samrfi.utils.errors import ( + CheckpointMismatchError, + ConfigValidationError, + DataShapeError, + ) + + assert CheckpointMismatchError is not None + assert ConfigValidationError is not None + assert DataShapeError is not None + except ImportError as e: + pytest.fail(f"Failed to import samrfi.utils.errors: {e}") + + def test_samrfi_data_import(self): + """samrfi.data (SAM2-specific modules) should import without rfi_toolbox.""" + try: + from samrfi.data import BatchedDataset, HFDatasetWrapper, SAMDataset + + assert BatchedDataset is not None + assert HFDatasetWrapper is not None + assert SAMDataset is not None + except ImportError as e: + pytest.fail(f"Failed to import samrfi.data: {e}") + + +class TestRFIToolboxImports: + """Test that rfi_toolbox imports work correctly.""" + + def test_rfi_toolbox_preprocessing_import(self): + """rfi_toolbox.preprocessing should import cleanly.""" + try: + from rfi_toolbox.preprocessing import GPUPreprocessor, Preprocessor + + assert Preprocessor is not None + assert GPUPreprocessor is not None + except ImportError as e: + pytest.fail( + f"Failed to import rfi_toolbox.preprocessing: {e}\n" + "This suggests a circular import or missing dependency in rfi_toolbox." + ) + + def test_rfi_toolbox_datasets_import(self): + """rfi_toolbox.datasets should import without sklearn.""" + try: + from rfi_toolbox.datasets import BatchWriter, RFIMaskDataset, TorchDataset + + assert BatchWriter is not None + assert TorchDataset is not None + assert RFIMaskDataset is not None + except ImportError as e: + pytest.fail( + f"Failed to import rfi_toolbox.datasets: {e}\n" + "This suggests sklearn is required (should be optional) or circular import." + ) + + def test_rfi_toolbox_data_generation_import(self): + """rfi_toolbox.data_generation should import cleanly.""" + try: + from rfi_toolbox.data_generation import RawPatchDataset, SyntheticDataGenerator + + assert SyntheticDataGenerator is not None + assert RawPatchDataset is not None + except ImportError as e: + pytest.fail( + f"Failed to import rfi_toolbox.data_generation: {e}\n" + "This suggests a circular import in rfi_toolbox." + ) + + def test_rfi_toolbox_evaluation_import(self): + """rfi_toolbox.evaluation should import without torch.""" + try: + from rfi_toolbox.evaluation import ( # noqa: F401 + compute_dice, + compute_f1, + compute_ffi, + compute_iou, + compute_precision, + compute_recall, + evaluate_segmentation, + ) + + assert compute_iou is not None + assert compute_ffi is not None + assert evaluate_segmentation is not None + except ImportError as e: + pytest.fail(f"Failed to import rfi_toolbox.evaluation: {e}") + + +class TestOptionalImports: + """Test optional imports that require specific dependencies.""" + + @pytest.mark.requires_casa + def test_rfi_toolbox_io_import(self): + """rfi_toolbox.io requires CASA - should fail gracefully if missing.""" + try: + from rfi_toolbox.io import MSLoader + + # If CASA is not available, MSLoader should be None + if MSLoader is None: + pytest.skip("CASA not available - MSLoader is None (expected)") + except ImportError as e: + # This is expected if CASA is not installed + assert "CASA" in str(e) or "casatools" in str(e) + + +class TestImportIsolation: + """Test that imports don't have unintended side effects.""" + + def test_samrfi_import_does_not_trigger_rfi_toolbox(self): + """Importing samrfi should not initialize rfi_toolbox package.""" + import sys + + # Remove rfi_toolbox from sys.modules if present + rfi_toolbox_modules = [key for key in sys.modules if key.startswith("rfi_toolbox")] + for mod in rfi_toolbox_modules: + del sys.modules[mod] + + # Import samrfi + import samrfi # noqa: F401 + + # Check that rfi_toolbox.__init__ was not loaded + assert ( + "rfi_toolbox" not in sys.modules + ), "Importing samrfi should not trigger rfi_toolbox initialization" + + def test_import_order_independence(self): + """Importing in different orders should not cause failures.""" + import sys + + # Clear all samrfi and rfi_toolbox modules + for key in list(sys.modules.keys()): + if key.startswith(("samrfi", "rfi_toolbox")): + del sys.modules[key] + + # Test order 1: samrfi first + try: + from rfi_toolbox.preprocessing import Preprocessor # noqa: F401 + + import samrfi # noqa: F401 + except ImportError as e: + pytest.fail(f"Failed with samrfi first: {e}") + + # Clear and test order 2: rfi_toolbox first + for key in list(sys.modules.keys()): + if key.startswith(("samrfi", "rfi_toolbox")): + del sys.modules[key] + + try: + from rfi_toolbox.preprocessing import Preprocessor # noqa: F401 + + import samrfi # noqa: F401 + except ImportError as e: + pytest.fail(f"Failed with rfi_toolbox first: {e}") diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index bb5488e..ee15962 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -6,8 +6,7 @@ import numpy as np import torch - -from samrfi.evaluation import ( +from rfi_toolbox.evaluation import ( compute_calcquality, compute_dice, compute_f1, diff --git a/tests/unit/test_preprocessor.py b/tests/unit/test_preprocessor.py index e5af34d..64bed56 100644 --- a/tests/unit/test_preprocessor.py +++ b/tests/unit/test_preprocessor.py @@ -6,8 +6,7 @@ import numpy as np import torch - -from samrfi.data import Preprocessor +from rfi_toolbox.preprocessing import Preprocessor class TestPatchification: