Authors: Preshanth Jagannathan (pjaganna@nrao.edu), Srikrishna Sekhar (ssekhar@nrao.edu), Derod Deal (dealderod@gmail.com)
SAM-RFI is a Python package that applies Meta's Segment Anything Model 2 (SAM2) for Radio Frequency Interference (RFI) detection and flagging in radio astronomy data. The system processes CASA measurement sets and generates precise segmentation masks for contaminated visibilities. In order to do that it leverages the rfi_toolbox package which presents general purpose RFI simulation and measurement set handling
SAM-RFI leverages the state-of-the-art SAM2 vision transformer for RFI segmentation in radio astronomy visibility data. The package provides a complete pipeline from data generation to trained models capable of detecting and flagging RFI with superior accuracy compared to traditional statistical methods.
Key Features:
- SAM2-based segmentation using Hiera transformer architecture
- Physically realistic synthetic data generation with exact ground truth
- Complete training pipeline with validation tracking
- Iterative flagging for progressive RFI cleaning
- GPU-accelerated training and inference
- Command-line interface for all operations
- Modular Python API for custom workflows
- Python 3.10, 3.11, or 3.12
- CUDA-capable GPU (recommended for training)
- CASA tools (optional, for measurement set operations)
# Clone repository
git clone https://github.com/preshanth/SAM-RFI.git
cd SAM-RFI
# Create conda environment
conda create -n samrfi python=3.12 -y
conda activate samrfi
# Install core dependencies (CPU-only, no GPU/CASA)
pip install pandas>=2.2.0 numpy>=1.26.0 --only-binary :all:
pip install -e .SAM-RFI supports modular installation based on your needs:
# GPU support (training and inference)
pip install -e .[gpu]
# CASA tools (measurement set operations)
pip install -e .[casa]
# GPU + CASA (complete functionality)
pip install -e .[gpu,casa]
# Development (all dependencies + testing tools)
pip install -e .[dev]
# Install pre-commit hooks (for development)
pre-commit installInstallation extras:
- Core (default): Data preprocessing, synthetic data generation, evaluation metrics
[gpu]: PyTorch, transformers, SAM2 models (required for training/inference)[casa]: CASA tools for measurement set I/O[viz]: Interactive visualization tools (HoloViews, Bokeh, Datashader)[dev]: All dependencies plus testing and linting tools[ci]: Minimal dependencies for continuous integration
# Check CLI availability
samrfi --help
# Test core imports (no GPU/CASA required)
python -c "from samrfi.data import Preprocessor; from samrfi.data_generation import SyntheticDataGenerator; print('Core installation successful')"
# Test GPU functionality (requires [gpu])
python -c "from samrfi.training import SAM2Trainer; from samrfi.inference import RFIPredictor; print('GPU installation successful')"
# Test CASA functionality (requires [casa])
python -c "from samrfi.data.ms_loader import MSLoader; print('CASA installation successful')"Generate physically realistic training data with exact ground truth masks:
samrfi generate-data \
--source synthetic \
--config configs/synthetic_train_4k.yaml \
--output ./datasets/train_4kConfiguration (configs/synthetic_train_4k.yaml):
synthetic:
num_samples: 4000
num_channels: 1024
num_times: 1024
num_baselines: 2
num_pols: 4
# Physical scales (milli-Jansky and Jansky)
noise_mjy: 1.0 # 1 mJy Gaussian noise
rfi_power_min: 1000.0 # 1000 Jy RFI minimum
rfi_power_max: 10000.0 # 10000 Jy RFI maximum
# RFI types per sample
rfi_type_counts:
narrowband_persistent: 2
broadband_persistent: 1
frequency_sweep: 1
narrowband_bursty: 2
broadband_bursty: 1
# Bandpass effects
enable_bandpass_rolloff: true
bandpass_polynomial_order: 8
polarization_correlation: 0.8
processing:
patch_size: 1024
stretch: null # No stretch for synthetic (preserves physical scales)
enable_augmentation: true # 4-way rotation augmentation
normalize_before_stretch: false
normalize_after_stretch: falseThis generates batched datasets saved to ./datasets/train_4k/exact_masks/ with perfect ground truth masks.
SAM2 models automatically download from HuggingFace on first use. Models are cached at ~/.cache/huggingface/hub/.
samrfi train \
--config configs/gpu_v100_training.yaml \
--dataset ./datasets/train_4k/exact_masks \
--validation-dataset ./datasets/val_1k/exact_masksTraining configuration (configs/gpu_v100_training.yaml):
model:
sam_checkpoint: large # Options: tiny, small, base_plus, large
device: cuda
training:
num_epochs: 10
batch_size: 12
learning_rate: 1.0e-5
weight_decay: 0.0
save_best_only: true
output:
output_dir: ./samrfi_data
save_plots: trueAvailable SAM2 models:
tiny(40 MB) - Fastest, lower accuracysmall(180 MB) - Balanced performancebase_plus(330 MB) - Good accuracylarge(850 MB) - Best accuracy, recommended for production
GPU memory requirements:
- 11 GB VRAM:
tiny, batch_size=2 - 32 GB VRAM:
base_plus, batch_size=12 - 40+ GB VRAM:
large, batch_size=8-12
Single-pass prediction:
samrfi predict \
--model ./samrfi_data/sam2_rfi_best.pth \
--input observation.ms \
--patch-size 1024Iterative prediction (recommended for deep cleaning):
samrfi predict \
--model ./samrfi_data/sam2_rfi_best.pth \
--input observation.ms \
--iterations 3 \
--patch-size 1024Iterative flagging progressively finds fainter RFI by masking already-flagged regions in each iteration. Typically converges in 2-3 passes.
Prediction options:
--iterations N- Number of flagging passes (default: 1)--num-antennas N- Limit number of antennas loaded--patch-size SIZE- Must match training patch size--stretch {SQRT,LOG10,null}- Must match training configuration--threshold FLOAT- Probability threshold (default: adaptive/mean)--no-save- Preview only, do not write flags to MS
# Generate synthetic training data
samrfi generate-data \
--source synthetic \
--config configs/synthetic_train_4k.yaml \
--output ./datasets/train_4k
# Generate data from measurement set
samrfi generate-data \
--source ms \
--config configs/ms_data.yaml \
--output ./datasets/vla_pband# Train with validation dataset
samrfi train \
--config configs/gpu_v100_training.yaml \
--dataset ./datasets/train_4k/exact_masks \
--validation-dataset ./datasets/val_1k/exact_masks
# Resume training from checkpoint
samrfi train \
--config configs/gpu_v100_training.yaml \
--dataset ./datasets/train_4k/exact_masks \
--resume ./samrfi_data/sam2_rfi_best.pth# Single-pass prediction
samrfi predict \
--model ./samrfi_data/sam2_rfi_best.pth \
--input observation.ms
# Iterative prediction (3 passes)
samrfi predict \
--model ./samrfi_data/sam2_rfi_best.pth \
--input observation.ms \
--iterations 3# Create default configuration
samrfi create-config \
--type {training|data|validation} \
--output config.yaml
# Validate configuration
samrfi validate-config --config config.yamlfrom samrfi.data import Preprocessor, TorchDataset
from samrfi.data_generation import SyntheticDataGenerator
from samrfi.evaluation import compute_iou, compute_ffi
# Generate synthetic data
generator = SyntheticDataGenerator(config_path='configs/synthetic_train_4k.yaml')
dataset = generator.generate(num_samples=1000, output_dir='./datasets/synthetic')
# Preprocess data
import numpy as np
data = np.random.randn(2, 4, 1024, 1024) + 1j * np.random.randn(2, 4, 1024, 1024)
preprocessor = Preprocessor(data)
dataset = preprocessor.create_dataset(patch_size=1024, stretch=None)
# Evaluate predictions
iou = compute_iou(predicted_mask, ground_truth_mask)
ffi = compute_ffi(data, flags=predicted_mask)from samrfi.data.ms_loader import MSLoader
# Load measurement set
loader = MSLoader('observation.ms')
loader.load(num_antennas=5, mode='DATA')
# Access data
data = loader.data # Complex visibilities: (baselines, pols, channels, times)
magnitude = loader.magnitude # Magnitude
flags = loader.load_flags() # Existing flags
# Save new flags
loader.save_flags(predicted_flags)from samrfi.training import SAM2Trainer
from samrfi.data import TorchDataset
# Load batched dataset
dataset = TorchDataset.from_directory('./datasets/train_4k/exact_masks')
# Create trainer
trainer = SAM2Trainer(dataset, device='cuda')
# Train model
trainer.train(
num_epochs=10,
batch_size=12,
sam_checkpoint='large',
learning_rate=1e-5,
output_dir='./samrfi_data',
save_best_only=True
)from samrfi.inference import RFIPredictor
# Load predictor with trained model
predictor = RFIPredictor(
model_path='./samrfi_data/sam2_rfi_best.pth',
device='cuda'
)
# Single-pass prediction
flags = predictor.predict_ms(
ms_path='observation.ms',
patch_size=1024,
save_flags=True
)
# Iterative prediction (3 passes)
flags = predictor.predict_iterative(
ms_path='observation.ms',
num_iterations=3,
patch_size=1024,
save_flags=True
)
print(f"Flagged {flags.sum() / flags.size * 100:.2f}% of data")from samrfi.inference import RFIPredictor
import numpy as np
# Load model
predictor = RFIPredictor(model_path='model.pth', device='cuda')
# Predict on arbitrary-sized array
data = np.load('baseline_data.npy') # Any shape, e.g., (2048, 511)
flags, probabilities = predictor.predict_array(
data,
threshold=None, # Adaptive (uses mean of probabilities)
return_probabilities=True
)
# Save probabilities for custom thresholding
np.save('probabilities.npy', probabilities)
# Apply custom threshold
custom_flags = probabilities > 0.1 # More aggressive flagging[Measurement Set or Synthetic Generator]
↓
MSLoader.load()
├─ Complex visibilities (baselines, pols, channels, times)
└─ Combine spectral windows
↓
Preprocessor.create_dataset()
├─ 4-way rotation augmentation (optional)
├─ Patchify into patch_size × patch_size
├─ Extract 3-channel features:
│ • Channel 1: Spatial gradient (edge detection)
│ • Channel 2: Log amplitude (intensity, [-3, 4])
│ • Channel 3: Phase ([-π, π] → [0, 1])
├─ Apply optional stretch (SQRT/LOG10 for real, None for synthetic)
└─ ImageNet normalization
↓
BatchedDataset (streaming)
├─ Batch files: batch_*.pt + metadata.json
├─ On-demand loading in DataLoader workers
└─ OS filesystem cache for efficiency
BatchedDataset
↓
SAMDataset wrapper
├─ Extract bounding boxes from ground truth
├─ Add random perturbation (±20 pixels)
└─ Format: {pixel_values, input_boxes, ground_truth_mask}
↓
SAM2Trainer.train()
├─ Load SAM2Model from HuggingFace
├─ Freeze vision + prompt encoders
├─ Train mask decoder only (~10% of parameters)
├─ Loss: DiceCELoss (Dice + Cross-Entropy)
└─ Save: sam2_rfi_best.pth
[Trained Model] + [Measurement Set]
↓
RFIPredictor.predict_ms() or predict_iterative()
↓
MSLoader.load() → Preprocessor → Patches
↓
SAM2Model.forward()
├─ Vision encoder: Extract features
├─ Prompt encoder: Encode bounding boxes
└─ Mask decoder: Predict segmentation
↓
Reconstruction
├─ Sigmoid(logits) > threshold
├─ Reverse rotations
├─ Combine patches → full waterfall
└─ Boolean flags: (baselines, pols, channels, times)
↓
MSLoader.save_flags() → Write to MS FLAG column
Iterative flagging progressively discovers deeper RFI by masking known contamination in each pass, revealing fainter interference hidden beneath brighter sources.
Iteration 1: Raw data → Model → Flags_1 (finds bright RFI)
Iteration 2: Data with Flags_1 masked → Model → Flags_2 (finds hidden RFI)
Iteration 3: Data with Flags_1|2 masked → Model → Flags_3 (final cleanup)
Final: Flags_cumulative = Flags_1 | Flags_2 | Flags_3
- Single pass (N=1): Fast, suitable for mild contamination (5-10% flagging)
- 2-3 iterations: Recommended for deep cleaning (15-30% flagging)
- >3 iterations: Diminishing returns, increased risk of over-flagging
# Compare single vs iterative
samrfi predict --model model.pth --input obs.ms
# Output: Flagged 12.5% of data
samrfi predict --model model.pth --input obs.ms --iterations 3
# Output: Iteration 1: 12.5%, Iteration 2: 4.2%, Iteration 3: 1.1%
# Total: Flagged 17.8% of dataThe synthetic data generator produces physically realistic RFI signatures:
- Narrowband Persistent - Continuous narrowband signals (GPS, satellites)
- Broadband Persistent - Continuous wideband interference (power lines, harmonics)
- Narrowband Bursty - Intermittent narrowband pulses (radar, transmitters)
- Broadband Bursty - Transient wideband events (lightning, arcing)
- Frequency Sweeps - Linear and quadratic chirps (scanning radar)
- Noise: 1 mJy (milli-Jansky) Gaussian, matches typical system noise
- RFI Power: 1000-10000 Jy (Jansky), 10^6-10^7 dynamic range
- Bandpass: 8th-order polynomial edge rolloff
- Polarization: Correlated RFI across XX/YY feeds (0.8 correlation)
Synthetic data provides exact ground truth masks, enabling supervised training with 100% accurate labels. This is not possible with real observations, where RFI locations are only estimates from statistical flaggers.
Preservation of Physical Scales:
processing:
normalize_before_stretch: false # Critical for synthetic data
normalize_after_stretch: false
stretch: null # Preserves 10^6-10^7 dynamic rangeSAM2 models are automatically downloaded from HuggingFace Hub on first use and cached locally:
from samrfi.training import SAM2Trainer
# Model downloads automatically if not cached
trainer = SAM2Trainer(dataset, device='cuda')
trainer.train(num_epochs=10, sam_checkpoint='large')Cache location: ~/.cache/huggingface/hub/
export HF_HOME=/path/to/custom/cache
samrfi train --config config.yaml --dataset ./datasets/trainfrom samrfi.utils.model_cache import ModelCache
cache = ModelCache()
cache.download_model('large', show_progress=True)Or via command line:
python -c "from samrfi.utils.model_cache import ModelCache; ModelCache().download_model('large')"SAM-RFI supports seamless integration with HuggingFace Hub for sharing and downloading trained models and datasets.
Download and use a published model:
# Automatically downloads model from HuggingFace Hub
samrfi predict --model polarimetic/sam-rfi/large --input observation.msPublish your trained model:
# Upload model to HuggingFace Hub
samrfi publish --type model \
--input ./samrfi_data/sam2_rfi_best.pth \
--repo-id polarimetic/sam-rfiPublish a dataset:
# Upload training dataset
samrfi publish --type dataset \
--input ./datasets/train_4k/exact_masks \
--repo-id polarimetic/sam-rfi-dataset- Automatic Model Downloads: Models are downloaded and cached on first use
- Smart Path Detection: CLI accepts both local paths and HuggingFace repo IDs
- Private Repositories: Support for private models with token authentication
- Model Cards: Auto-generated documentation with training metrics
- Latest Versioning: Simple "latest" approach per model size
from samrfi.inference import RFIPredictor
# Initialize with HuggingFace model (auto-downloads if needed)
predictor = RFIPredictor(
model_path="polarimetic/sam-rfi/large",
device="cuda"
)
# Use normally
flags = predictor.predict_ms("observation.ms")Models are cached at ~/.cache/huggingface/hub/ after first download. Set custom location:
export HF_HOME=/path/to/custom/cacheFor private repositories, set your HuggingFace token:
export HF_TOKEN=hf_xxxxx
samrfi publish --type model --input model.pth --repo-id user/private-repo --privateGet your token from: https://huggingface.co/settings/tokens
For detailed documentation including troubleshooting, batch publishing, and advanced usage, see HuggingFace Integration Guide.
Training can be resumed from any checkpoint to continue where you left off:
# Initial training
samrfi train --config config.yaml --dataset ./datasets/train --epochs 10
# Resume and extend to 20 epochs
samrfi train --config config.yaml --dataset ./datasets/train --epochs 20 \
--resume ./samrfi_data/sam2_rfi_best.pthRestored state:
- Model weights
- Optimizer state (momentum, learning rates)
- Training/validation loss history
- Epoch counter
Checkpoints:
sam2_rfi_best.pth- Best validation loss (updated during training)model_sam2-large_YYYYMMDD_HHMMSS.pth- Final checkpoint with full state
Fast iteration (debugging):
model:
sam_checkpoint: tiny
training:
num_epochs: 3
batch_size: 8
learning_rate: 1.0e-4Production quality:
model:
sam_checkpoint: large
training:
num_epochs: 20
batch_size: 4
learning_rate: 1.0e-5
weight_decay: 0.0Expected behavior: Loss decreases from approximately 1.0 to below 0.3 within 10 epochs.
Troubleshooting stalled training (loss >0.8 after 5 epochs):
- Adjust learning rate (try 5e-6 or 2e-5)
- Verify data quality (visualize sample patches)
- Modify batch size (try 2 or 8)
- Check for NaN values in input data
from samrfi.evaluation import (
compute_iou, # Intersection over Union
compute_precision, # True Positive Rate
compute_recall, # Sensitivity
compute_f1, # Harmonic mean of precision/recall
compute_dice, # Dice coefficient
evaluate_segmentation # All metrics
)
# Evaluate predictions
metrics = evaluate_segmentation(predicted_mask, ground_truth_mask)
# Returns: {'iou': 0.85, 'precision': 0.90, 'recall': 0.82, 'f1': 0.86}from samrfi.evaluation import (
compute_statistics, # Before/after statistics
compute_ffi, # Flagging Fidelity Index
print_statistics_comparison # Formatted output
)
# Compute Flagging Fidelity Index
ffi_metrics = compute_ffi(data, flags=predicted_mask)
# Returns: {'ffi': 0.65, 'mad_reduction': 0.45, 'std_reduction': 0.52}
# Print comparison
print_statistics_comparison(data, predicted_mask)Flagging Fidelity Index (FFI): Measures flagging quality by balancing noise reduction against over-flagging penalty. Higher values indicate better flagging performance.
# Run all tests
pytest tests/ -v
# Unit tests only
pytest tests/unit -v
# Integration tests
pytest tests/integration -v
# With coverage
pytest tests/ --cov=samrfi --cov-report=html
# Skip slow tests
pytest -m "not slow"Pre-commit hooks are configured to run automatically on git commit:
# Install hooks
pre-commit install
# Run manually
pre-commit run --all-filesChecks performed:
- Black (code formatting, line length 100)
- Ruff (linting and auto-fixes)
- isort (import sorting)
- Trailing whitespace, EOF, YAML/JSON/TOML validation
- Large file detection (>5MB)
Manual formatting:
# Format code
black src/ tests/ --line-length 100
# Lint code
ruff check src/ tests/ --fix
# Sort imports
isort src/ tests/ --profile black --line-length 100# Type check (optional)
mypy src/ --ignore-missing-importsSAM-RFI/
├── src/samrfi/
│ ├── cli.py # Command-line interface
│ ├── config/ # Configuration management
│ │ ├── config_loader.py # YAML configuration loading
│ │ └── validators.py # Configuration validation
│ ├── data/ # Data loading and preprocessing
│ │ ├── ms_loader.py # CASA measurement set I/O
│ │ ├── preprocessor.py # Waterfall to patches pipeline
│ │ ├── sam_dataset.py # PyTorch dataset wrapper
│ │ ├── torch_dataset.py # Batched streaming datasets
│ │ └── gpu_transforms.py # Kornia-based GPU transforms
│ ├── data_generation/ # Dataset generators
│ │ ├── synthetic_generator.py # Physics-based RFI simulation
│ │ └── ms_generator.py # MS to dataset converter
│ ├── training/
│ │ └── sam2_trainer.py # SAM2 training loop
│ ├── inference/
│ │ └── predictor.py # RFI prediction (single/iterative)
│ ├── evaluation/ # Metrics and validation
│ │ ├── metrics.py # Segmentation metrics
│ │ └── statistics.py # Flagging quality statistics
│ └── utils/ # Utilities
│ ├── logger.py # Logging configuration
│ ├── model_cache.py # HuggingFace model downloads
│ └── errors.py # Custom exceptions
│
├── tests/ # Test suite
│ ├── unit/ # Unit tests
│ ├── integration/ # Integration tests
│ └── conftest.py # Shared fixtures
│
├── configs/ # Example configurations
│ ├── gpu_*.yaml # GPU-specific training configs
│ ├── synthetic_*.yaml # Synthetic data configs
│ └── validation.yaml # Validation config
│
├── .github/workflows/ # CI/CD
│ └── ci.yml # GitHub Actions workflow
│
├── pyproject.toml # Package definition
├── .pre-commit-config.yaml # Pre-commit hooks
└── README.md # This file
A paper describing SAM-RFI is in preparation. In the meantime, if you use this software in your research, please cite the repository:
@software{samrfi2025,
title = {SAM-RFI: Radio Frequency Interference Detection with SAM2},
author = {Deal, Derod and Jagannathan, Preshanth},
year = {2025},
url = {https://github.com/preshanth/SAM-RFI}
}Please check back for the updated citation once the paper is published.
MIT License - see LICENSE for details.
- Meta AI - SAM2 architecture and pre-trained models
- HuggingFace - Transformers library and model hosting
- NRAO - Radio astronomy expertise and computational resources
- NAC - National Astronomy Consortium support and funding
- Issues: https://github.com/preshanth/SAM-RFI/issues
- Documentation: https://sam-rfi.readthedocs.io (coming soon)
- Contact: pjaganna@nrao.edu
