Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pandas>=2.2.0 numpy>=1.26.0 --only-binary :all:
pip install --no-cache-dir --upgrade "rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git"
pip install -e .[ci]

- name: Run unit tests
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

**Authors:** Preshanth Jagannathan (pjaganna@nrao.edu), Srikrishna Sekhar (ssekhar@nrao.edu), Derod Deal (dealderod@gmail.com)

SAM-RFI is a Python package that applies Meta's Segment Anything Model 2 (SAM2) for Radio Frequency Interference (RFI) detection and flagging in radio astronomy data. The system processes CASA measurement sets and generates precise segmentation masks for contaminated visibilities.
SAM-RFI is a Python package that applies Meta's Segment Anything Model 2 (SAM2) for Radio Frequency Interference (RFI) detection and flagging in radio astronomy data. The system processes CASA measurement sets and generates precise segmentation masks for contaminated visibilities. In order to do that it leverages the `rfi_toolbox` package which presents general purpose RFI simulation and measurement set handling

## Overview

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ requires-python = ">=3.10"

# Core dependencies (CPU-only, minimal)
dependencies = [
# rfi_toolbox: Shared RFI utilities (io, preprocessing, evaluation, datasets)
"rfi_toolbox @ git+https://github.com/preshanth/rfi_toolbox.git",
"numpy>=1.26.0",
"scipy>=1.10.0",
"pandas>=2.2.0",
Expand All @@ -33,7 +35,6 @@ dependencies = [
"tqdm>=4.65.0",
"matplotlib>=3.7.0",
"datasets>=2.10.0",
"patchify>=0.2.3",
"scikit-image>=0.20.0",
]

Expand Down
33 changes: 14 additions & 19 deletions src/samrfi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@

Usage:
------
>>> # Core data operations (no GPU/CASA required)
>>> from samrfi.data import Preprocessor, TorchDataset
>>> from samrfi.data_generation import SyntheticDataGenerator
>>> # Core data operations
>>> from rfi_toolbox.preprocessing import Preprocessor
>>> from rfi_toolbox.datasets import TorchDataset
>>> from rfi_toolbox.data_generation import SyntheticDataGenerator
>>>
>>> # Optional: CASA-dependent operations
>>> from samrfi.data.ms_loader import MSLoader # Requires pip install samrfi[casa]
>>> from rfi_toolbox.io import MSLoader # Requires pip install samrfi[casa]
>>>
>>> # Optional: GPU/transformers-dependent operations
>>> from samrfi.training import SAM2Trainer # Requires pip install samrfi[gpu]
Expand All @@ -60,25 +61,19 @@
__version__ = "2.0.0"
__author__ = "Derod Deal, Preshanth Jagannathan"

# Config module - always available
# Data module
# Config module
from .config import ConfigLoader
from .data import (
BatchedDataset,
BatchWriter,
HFDatasetWrapper,
Preprocessor,
SAMDataset,
TorchDataset,
)

# Data generation module
from .data_generation import SyntheticDataGenerator
# SAM2-specific data modules
from .data import BatchedDataset, HFDatasetWrapper, SAMDataset

# Note: MSLoader and MSDataGenerator require CASA and are not imported by default
# Use: from samrfi.data.ms_loader import MSLoader
# Use: from samrfi.data_generation.ms_generator import MSDataGenerator
# Note: Shared utilities from rfi_toolbox - import directly when needed:
# from rfi_toolbox.io import MSLoader
# from rfi_toolbox.preprocessing import Preprocessor
# from rfi_toolbox.datasets import BatchWriter, TorchDataset
# from rfi_toolbox.data_generation import SyntheticDataGenerator
# Note: MSDataGenerator requires CASA and is not imported by default
# Use: from samrfi.data_generation import MSDataGenerator

# Note: ModelCache, RFIPredictor, and SAM2Trainer require transformers and are not imported by default
# Use: from samrfi.utils.model_cache import ModelCache
Expand Down
19 changes: 9 additions & 10 deletions src/samrfi/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
"""
Data module - MS loading, preprocessing, and dataset creation
Data module - SAM2-specific dataset wrappers and utilities

NOTE: Core data utilities (MSLoader, Preprocessor, TorchDataset, BatchWriter)
are in rfi_toolbox. Import directly:
from rfi_toolbox.io import MSLoader
from rfi_toolbox.preprocessing import Preprocessor
from rfi_toolbox.datasets import BatchWriter, TorchDataset
"""

# SAM2-specific modules
from .adaptive_patcher import AdaptivePatcher, check_ms_compatibility
from .gpu_dataset import GPUBatchTransformDataset, GPUTransformDataset
from .gpu_transforms import GPUTransforms, create_gpu_transforms
from .hf_dataset_wrapper import HFDatasetWrapper
from .preprocessor import GPUPreprocessor, Preprocessor
from .ram_dataset import RAMCachedDataset
from .sam_dataset import BatchedDataset, SAMDataset
from .torch_dataset import BatchWriter, TorchDataset

__all__ = [
"Preprocessor",
"GPUPreprocessor",
# SAM2-specific
"SAMDataset",
"BatchedDataset",
"TorchDataset",
"BatchWriter",
"HFDatasetWrapper",
"AdaptivePatcher",
"check_ms_compatibility",
Expand All @@ -27,6 +29,3 @@
"GPUBatchTransformDataset",
"RAMCachedDataset",
]

# Note: MSLoader requires CASA and is not imported by default
# Use: from samrfi.data.ms_loader import MSLoader
24 changes: 23 additions & 1 deletion src/samrfi/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,36 @@

import numpy as np
import torch
from patchify import patchify
from scipy import stats

from samrfi.utils import logger

from .torch_dataset import TorchDataset


def patchify(array, patch_shape, step):
"""
Extract patches from 2D array using torch.unfold (replaces patchify library).

Args:
array: 2D numpy array (H, W)
patch_shape: Tuple (patch_h, patch_w)
step: Step size for patch extraction

Returns:
4D array (n_patches_h, n_patches_w, patch_h, patch_w)
"""
patch_h, patch_w = patch_shape
tensor = torch.from_numpy(array)

# Use unfold to extract patches: (H, W) -> (n_h, n_w, patch_h, patch_w)
patches = tensor.unfold(0, patch_h, step).unfold(1, patch_w, step)

# Rearrange to match patchify output format
patches = patches.contiguous().numpy()
return patches


# Standalone functions for multiprocessing (must be picklable)
def _patchify_single_waterfall(waterfall, patch_size):
"""
Expand Down
12 changes: 7 additions & 5 deletions src/samrfi/data_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Data generation modules for SAM-RFI"""
"""Data generation modules for SAM-RFI

from .synthetic_generator import SyntheticDataGenerator
NOTE: SyntheticDataGenerator is in rfi_toolbox. Import directly:
from rfi_toolbox.data_generation import SyntheticDataGenerator
"""

__all__ = ["SyntheticDataGenerator"]
# SAM2-specific data generation
from .ms_generator import MSDataGenerator

# Note: MSDataGenerator requires CASA and is not imported by default
# Use: from samrfi.data_generation.ms_generator import MSDataGenerator
__all__ = ["MSDataGenerator"]
4 changes: 2 additions & 2 deletions src/samrfi/data_generation/ms_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import json
from pathlib import Path

from samrfi.data import Preprocessor
from samrfi.data.ms_loader import MSLoader
from rfi_toolbox.io import MSLoader
from rfi_toolbox.preprocessing import Preprocessor


class MSDataGenerator:
Expand Down
3 changes: 1 addition & 2 deletions src/samrfi/data_generation/synthetic_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import numpy as np
import torch
from rfi_toolbox.preprocessing import Preprocessor
from tqdm import tqdm

from samrfi.data import Preprocessor

# Global generator instance for multiprocessing workers
_global_generator = None
_global_proc_config = None
Expand Down
45 changes: 11 additions & 34 deletions src/samrfi/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,16 @@
"""
Evaluation metrics and validation tools for RFI segmentation
"""

from .metrics import (
compute_dice,
compute_f1,
compute_iou,
compute_precision,
compute_recall,
evaluate_segmentation,
)
from .statistics import (
compute_calcquality,
compute_ffi,
compute_statistics,
print_statistics_comparison,
)

__all__ = [
"compute_iou",
"compute_precision",
"compute_recall",
"compute_f1",
"compute_dice",
"evaluate_segmentation",
"compute_statistics",
"compute_ffi",
"compute_calcquality",
"print_statistics_comparison",
]
NOTE: Metrics are in rfi_toolbox. Import directly:
from rfi_toolbox.evaluation import (
compute_iou, compute_f1, compute_dice, compute_ffi,
compute_precision, compute_recall, evaluate_segmentation,
compute_statistics, print_statistics_comparison
)
from rfi_toolbox.io import inject_synthetic_data
"""

# Optional CASA dependency for MS injection
try:
from .ms_injection import inject_synthetic_data
# SAM2-specific evaluation tools would go here
# Currently all metrics are in rfi_toolbox

__all__.append("inject_synthetic_data")
except ImportError:
pass # CASA not available
__all__ = []
8 changes: 3 additions & 5 deletions src/samrfi/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import numpy as np
import torch
from rfi_toolbox.io import MSLoader
from rfi_toolbox.preprocessing import Preprocessor
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Sam2Model, Sam2Processor

from samrfi.data import AdaptivePatcher, Preprocessor, SAMDataset
from samrfi.data import AdaptivePatcher, SAMDataset
from samrfi.utils import logger
from samrfi.utils.errors import CheckpointMismatchError

Expand Down Expand Up @@ -554,8 +556,6 @@ def predict_ms(
Returns:
Predicted flags array (baselines, pols, channels, times)
"""
from samrfi.data.ms_loader import MSLoader

logger.info(f"\n{'='*60}")
logger.info("RFI Prediction - Single Pass")
logger.info(f"{'='*60}")
Expand Down Expand Up @@ -677,8 +677,6 @@ def predict_iterative(
Returns:
Cumulative flags from all iterations
"""
from samrfi.data.ms_loader import MSLoader

logger.info(f"\n{'='*60}")
logger.info(f"RFI Prediction - Iterative ({num_iterations} passes)")
logger.info(f"{'='*60}")
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def mock_torch_dataset():
if not TORCH_AVAILABLE:
pytest.skip("torch not available")

from samrfi.data import TorchDataset
from rfi_toolbox.datasets import TorchDataset

# Create mock images (10 samples, 256×256×3)
images = torch.randn(10, 256, 256, 3, dtype=torch.float32)
Expand Down
3 changes: 2 additions & 1 deletion tests/extended/test_full_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_tiny_model_inference_on_synthetic_data(self, tmp_path, synthetic_data_w
This is a smoke test to ensure the full pipeline works end-to-end.
Does not train the model (too slow), just tests inference with pretrained weights.
"""
from samrfi.evaluation import evaluate_segmentation
from rfi_toolbox.evaluation import evaluate_segmentation

from samrfi.inference import RFIPredictor

# Get synthetic data
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestPreprocessingPipeline:

def test_complex_data_to_dataset(self, synthetic_data_with_rfi):
"""Test complete pipeline from complex data to TorchDataset."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

data = synthetic_data_with_rfi["data"]
flags = synthetic_data_with_rfi["flags"]
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_complex_data_to_dataset(self, synthetic_data_with_rfi):

def test_preprocessing_metadata_consistency(self, synthetic_waterfall_small):
"""Test that metadata remains consistent through preprocessing."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

data = synthetic_waterfall_small[np.newaxis, ...]

Expand Down Expand Up @@ -84,7 +84,7 @@ class TestInferencePipelineIntegration:

def test_metadata_flows_to_reconstruction(self, synthetic_data_with_rfi):
"""Test that metadata flows from preprocessing to reconstruction."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

data = synthetic_data_with_rfi["data"]

Expand Down Expand Up @@ -127,7 +127,7 @@ class TestPipelineRobustness:

def test_pipeline_handles_single_baseline(self, synthetic_waterfall_small):
"""Test pipeline works with single baseline."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

# Single baseline
data = synthetic_waterfall_small[np.newaxis, ...]
Expand All @@ -143,7 +143,7 @@ def test_pipeline_handles_single_baseline(self, synthetic_waterfall_small):

def test_pipeline_handles_multiple_baselines(self, synthetic_waterfall_small):
"""Test pipeline works with multiple baselines."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

# 3 baselines
data = np.stack([synthetic_waterfall_small] * 3)
Expand All @@ -159,7 +159,7 @@ def test_pipeline_handles_multiple_baselines(self, synthetic_waterfall_small):

def test_pipeline_preserves_data_integrity(self, synthetic_data_with_rfi):
"""Test that preprocessing doesn't corrupt data."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

data = synthetic_data_with_rfi["data"]
original_data = data.copy()
Expand All @@ -182,7 +182,7 @@ class TestAugmentationConsistency:

def test_augmentation_matches_num_rotations(self, synthetic_waterfall_small):
"""Test that enabling augmentation creates expected number of patches."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

data = synthetic_waterfall_small[np.newaxis, ...]

Expand All @@ -206,7 +206,7 @@ def test_augmentation_matches_num_rotations(self, synthetic_waterfall_small):

def test_inference_mode_disables_blank_removal(self, synthetic_waterfall_small):
"""Test that inference mode preserves all patches (no blank removal)."""
from samrfi.data import Preprocessor
from rfi_toolbox.preprocessing import Preprocessor

data = synthetic_waterfall_small[np.newaxis, ...]

Expand Down
Loading