Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 139 additions & 1 deletion gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(
downsample_factor,
flatfield=None,
darkfield=None,
registration_z=None,
registration_t=0,
registration_channel=0,
):
super().__init__()
self.tiff_path = tiff_path
Expand All @@ -139,6 +142,9 @@ def __init__(
self.downsample_factor = downsample_factor
self.flatfield = flatfield
self.darkfield = darkfield
self.registration_z = registration_z
self.registration_t = registration_t
self.registration_channel = registration_channel

def run(self):
try:
Expand All @@ -153,6 +159,9 @@ def run(self):
downsample_factors=(self.downsample_factor, self.downsample_factor),
flatfield=self.flatfield,
darkfield=self.darkfield,
registration_z=self.registration_z,
registration_t=self.registration_t,
channel_to_use=self.registration_channel,
)

positions = np.array(tf_full._tile_positions)
Expand Down Expand Up @@ -198,7 +207,11 @@ def run(self):

# Create a new TileFusion for the subset
tf = TileFusion(
self.tiff_path, downsample_factors=(self.downsample_factor, self.downsample_factor)
self.tiff_path,
downsample_factors=(self.downsample_factor, self.downsample_factor),
registration_z=self.registration_z,
registration_t=self.registration_t,
channel_to_use=self.registration_channel,
)
tf._tile_positions = selected_positions
tf.n_tiles = len(selected_indices)
Expand Down Expand Up @@ -328,6 +341,9 @@ def __init__(
fusion_mode="blended",
flatfield=None,
darkfield=None,
registration_z=None,
registration_t=0,
registration_channel=0,
):
super().__init__()
self.tiff_path = tiff_path
Expand All @@ -337,6 +353,9 @@ def __init__(
self.fusion_mode = fusion_mode
self.flatfield = flatfield
self.darkfield = darkfield
self.registration_z = registration_z
self.registration_t = registration_t
self.registration_channel = registration_channel
self.output_path = None

def run(self):
Expand Down Expand Up @@ -379,6 +398,9 @@ def run(self):
downsample_factors=(self.downsample_factor, self.downsample_factor),
flatfield=self.flatfield,
darkfield=self.darkfield,
registration_z=self.registration_z,
registration_t=self.registration_t,
channel_to_use=self.registration_channel,
)
load_time = time.time() - step_start
self.progress.emit(f"Loaded {tf.n_tiles} tiles ({tf.Y}x{tf.X} each) [{load_time:.1f}s]")
Expand Down Expand Up @@ -714,6 +736,12 @@ def __init__(self):
self.darkfield = None # Shape (C, Y, X) or None
self.flatfield_worker = None

# Dataset dimension state (for registration z/t selection)
self.dataset_n_z = 1
self.dataset_n_t = 1
self.dataset_n_channels = 1
self.dataset_channel_names = []

self.setup_ui()

def setup_ui(self):
Expand Down Expand Up @@ -887,6 +915,36 @@ def setup_ui(self):
downsample_layout.addStretch()
settings_layout.addWidget(self.downsample_widget)

# Registration z/t selection (shown when registration enabled AND multi-z/t dataset)
self.reg_zt_widget = QWidget()
self.reg_zt_widget.setVisible(False)
reg_zt_layout = QHBoxLayout(self.reg_zt_widget)
reg_zt_layout.setContentsMargins(20, 0, 0, 0)
self.reg_z_label = QLabel("Z-level:")
reg_zt_layout.addWidget(self.reg_z_label)
self.reg_z_spin = QSpinBox()
self.reg_z_spin.setRange(0, 0)
self.reg_z_spin.setValue(0)
self.reg_z_spin.setToolTip("Z-level to use for registration")
self.reg_z_spin.setFixedWidth(60)
reg_zt_layout.addWidget(self.reg_z_spin)
self.reg_t_label = QLabel("Timepoint:")
reg_zt_layout.addWidget(self.reg_t_label)
self.reg_t_spin = QSpinBox()
self.reg_t_spin.setRange(0, 0)
self.reg_t_spin.setValue(0)
self.reg_t_spin.setToolTip("Timepoint to use for registration")
self.reg_t_spin.setFixedWidth(60)
reg_zt_layout.addWidget(self.reg_t_spin)
self.reg_channel_label = QLabel("Channel:")
reg_zt_layout.addWidget(self.reg_channel_label)
self.reg_channel_combo = QComboBox()
self.reg_channel_combo.setToolTip("Channel to use for registration")
self.reg_channel_combo.setMinimumWidth(120)
reg_zt_layout.addWidget(self.reg_channel_combo)
reg_zt_layout.addStretch()
settings_layout.addWidget(self.reg_zt_widget)

self.blend_checkbox = QCheckBox("Enable blending")
self.blend_checkbox.setChecked(False)
self.blend_checkbox.toggled.connect(self.on_blend_toggled)
Expand Down Expand Up @@ -978,6 +1036,31 @@ def on_file_dropped(self, file_path):
self.clear_flatfield_button.setEnabled(False)
self.save_flatfield_button.setEnabled(False)

# Load dataset dimensions for registration z/t selection
try:
from tilefusion import TileFusion

tf_temp = TileFusion(file_path)
self.dataset_n_z = tf_temp.n_z
self.dataset_n_t = tf_temp.n_t
self.dataset_n_channels = tf_temp.channels
if "channel_names" in tf_temp._metadata:
self.dataset_channel_names = tf_temp._metadata["channel_names"]
else:
self.dataset_channel_names = [
f"Channel {i}" for i in range(self.dataset_n_channels)
]
tf_temp.close()
if self.dataset_n_z > 1 or self.dataset_n_t > 1:
self.log(f"Dataset: {self.dataset_n_z} z-levels, {self.dataset_n_t} timepoints")
self._update_reg_zt_controls()
except Exception:
self.dataset_n_z = 1
self.dataset_n_t = 1
self.dataset_n_channels = 1
self.dataset_channel_names = []
self._update_reg_zt_controls()

# Auto-load existing flatfield if present, otherwise disable correction
# For directories (SQUID folders), also check inside the directory
if path.is_dir():
Expand All @@ -997,6 +1080,41 @@ def on_file_dropped(self, file_path):

def on_registration_toggled(self, checked):
self.downsample_widget.setVisible(checked)
self._update_reg_zt_controls()

def _update_reg_zt_controls(self):
"""Update visibility and ranges of registration z/t controls."""
registration_enabled = self.registration_checkbox.isChecked()
has_multi_z = self.dataset_n_z > 1
has_multi_t = self.dataset_n_t > 1
has_multi_channel = self.dataset_n_channels > 1

# Show z/t widget only when registration is enabled AND dataset has multi-z or multi-t or multi-channel
show_zt = registration_enabled and (has_multi_z or has_multi_t or has_multi_channel)
self.reg_zt_widget.setVisible(show_zt)

if show_zt:
# Update z spinbox
self.reg_z_label.setVisible(has_multi_z)
self.reg_z_spin.setVisible(has_multi_z)
if has_multi_z:
self.reg_z_spin.setRange(0, self.dataset_n_z - 1)
self.reg_z_spin.setValue(self.dataset_n_z // 2) # Default to middle

# Update t spinbox
self.reg_t_label.setVisible(has_multi_t)
self.reg_t_spin.setVisible(has_multi_t)
if has_multi_t:
self.reg_t_spin.setRange(0, self.dataset_n_t - 1)
self.reg_t_spin.setValue(0) # Default to first timepoint

# Update channel combo
self.reg_channel_label.setVisible(has_multi_channel)
self.reg_channel_combo.setVisible(has_multi_channel)
if has_multi_channel:
self.reg_channel_combo.clear()
self.reg_channel_combo.addItems(self.dataset_channel_names)
self.reg_channel_combo.setCurrentIndex(0)

def on_blend_toggled(self, checked):
self.blend_value_widget.setVisible(checked)
Expand Down Expand Up @@ -1228,6 +1346,13 @@ def run_stitching(self):
flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None
darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None

# Get registration z/t values (None means use default middle z)
registration_z = self.reg_z_spin.value() if self.dataset_n_z > 1 else None
registration_t = self.reg_t_spin.value() if self.dataset_n_t > 1 else 0
registration_channel = (
self.reg_channel_combo.currentIndex() if self.dataset_n_channels > 1 else 0
)

self.worker = FusionWorker(
self.drop_area.file_path,
self.registration_checkbox.isChecked(),
Expand All @@ -1236,6 +1361,9 @@ def run_stitching(self):
fusion_mode,
flatfield=flatfield,
darkfield=darkfield,
registration_z=registration_z,
registration_t=registration_t,
registration_channel=registration_channel,
)
self.worker.progress.connect(self.log)
self.worker.finished.connect(self.on_fusion_finished)
Expand Down Expand Up @@ -1294,13 +1422,23 @@ def run_preview(self):
flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None
darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None

# Get registration z/t values (None means use default middle z)
registration_z = self.reg_z_spin.value() if self.dataset_n_z > 1 else None
registration_t = self.reg_t_spin.value() if self.dataset_n_t > 1 else 0
registration_channel = (
self.reg_channel_combo.currentIndex() if self.dataset_n_channels > 1 else 0
)

self.preview_worker = PreviewWorker(
self.drop_area.file_path,
self.preview_cols_spin.value(),
self.preview_rows_spin.value(),
self.downsample_spin.value(),
flatfield=flatfield,
darkfield=darkfield,
registration_z=registration_z,
registration_t=registration_t,
registration_channel=registration_channel,
)
self.preview_worker.progress.connect(self.log)
self.preview_worker.finished.connect(self.on_preview_finished)
Expand Down
26 changes: 22 additions & 4 deletions src/tilefusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(
region: Optional[str] = None,
flatfield: Optional[np.ndarray] = None,
darkfield: Optional[np.ndarray] = None,
registration_z: Optional[int] = None,
registration_t: int = 0,
):
self.tiff_path = Path(tiff_path)
if not self.tiff_path.exists():
Expand Down Expand Up @@ -194,6 +196,18 @@ def __init__(
self._time_folders = self._metadata.get("time_folders", None)
self._middle_z = self.n_z // 2 # Use middle z-level for registration

# Registration z/t selection (validate after n_z/n_t are known)
if registration_z is None:
self._registration_z = self._middle_z
else:
if registration_z < 0 or registration_z >= self.n_z:
raise ValueError(f"registration_z={registration_z} out of range [0, {self.n_z})")
self._registration_z = registration_z

if registration_t < 0 or registration_t >= self.n_t:
raise ValueError(f"registration_t={registration_t} out of range [0, {self.n_t})")
self._registration_t = registration_t

# Configuration
self.downsample_factors = tuple(downsample_factors)
self.ssim_window = int(ssim_window)
Expand Down Expand Up @@ -447,10 +461,12 @@ def _update_profiles(self) -> None:
# I/O methods (delegate to format-specific loaders)
# -------------------------------------------------------------------------

def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = 0) -> np.ndarray:
def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -> np.ndarray:
"""Read a single tile from the input data (all channels)."""
if z_level is None:
z_level = self._middle_z # Default to middle z for registration
z_level = self._registration_z # Default to registration z-level
if time_idx is None:
time_idx = self._registration_t # Default to registration timepoint

if self._is_zarr_format:
zarr_ts = self._metadata["tensorstore"]
Expand Down Expand Up @@ -493,11 +509,13 @@ def _read_tile_region(
y_slice: slice,
x_slice: slice,
z_level: int = None,
time_idx: int = 0,
time_idx: int = None,
) -> np.ndarray:
"""Read a region of a tile from the input data."""
if z_level is None:
z_level = self._middle_z # Default to middle z for registration
z_level = self._registration_z # Default to registration z-level
if time_idx is None:
time_idx = self._registration_t # Default to registration timepoint

if self._is_zarr_format:
zarr_ts = self._metadata["tensorstore"]
Expand Down