diff --git a/gui/app.py b/gui/app.py index aac8551..822abe1 100644 --- a/gui/app.py +++ b/gui/app.py @@ -131,6 +131,9 @@ def __init__( downsample_factor, flatfield=None, darkfield=None, + registration_z=None, + registration_t=0, + registration_channel=0, ): super().__init__() self.tiff_path = tiff_path @@ -139,6 +142,9 @@ def __init__( self.downsample_factor = downsample_factor self.flatfield = flatfield self.darkfield = darkfield + self.registration_z = registration_z + self.registration_t = registration_t + self.registration_channel = registration_channel def run(self): try: @@ -153,6 +159,9 @@ def run(self): downsample_factors=(self.downsample_factor, self.downsample_factor), flatfield=self.flatfield, darkfield=self.darkfield, + registration_z=self.registration_z, + registration_t=self.registration_t, + channel_to_use=self.registration_channel, ) positions = np.array(tf_full._tile_positions) @@ -198,7 +207,11 @@ def run(self): # Create a new TileFusion for the subset tf = TileFusion( - self.tiff_path, downsample_factors=(self.downsample_factor, self.downsample_factor) + self.tiff_path, + downsample_factors=(self.downsample_factor, self.downsample_factor), + registration_z=self.registration_z, + registration_t=self.registration_t, + channel_to_use=self.registration_channel, ) tf._tile_positions = selected_positions tf.n_tiles = len(selected_indices) @@ -328,6 +341,9 @@ def __init__( fusion_mode="blended", flatfield=None, darkfield=None, + registration_z=None, + registration_t=0, + registration_channel=0, ): super().__init__() self.tiff_path = tiff_path @@ -337,6 +353,9 @@ def __init__( self.fusion_mode = fusion_mode self.flatfield = flatfield self.darkfield = darkfield + self.registration_z = registration_z + self.registration_t = registration_t + self.registration_channel = registration_channel self.output_path = None def run(self): @@ -379,6 +398,9 @@ def run(self): downsample_factors=(self.downsample_factor, self.downsample_factor), flatfield=self.flatfield, darkfield=self.darkfield, + registration_z=self.registration_z, + registration_t=self.registration_t, + channel_to_use=self.registration_channel, ) load_time = time.time() - step_start self.progress.emit(f"Loaded {tf.n_tiles} tiles ({tf.Y}x{tf.X} each) [{load_time:.1f}s]") @@ -714,6 +736,12 @@ def __init__(self): self.darkfield = None # Shape (C, Y, X) or None self.flatfield_worker = None + # Dataset dimension state (for registration z/t selection) + self.dataset_n_z = 1 + self.dataset_n_t = 1 + self.dataset_n_channels = 1 + self.dataset_channel_names = [] + self.setup_ui() def setup_ui(self): @@ -887,6 +915,36 @@ def setup_ui(self): downsample_layout.addStretch() settings_layout.addWidget(self.downsample_widget) + # Registration z/t selection (shown when registration enabled AND multi-z/t dataset) + self.reg_zt_widget = QWidget() + self.reg_zt_widget.setVisible(False) + reg_zt_layout = QHBoxLayout(self.reg_zt_widget) + reg_zt_layout.setContentsMargins(20, 0, 0, 0) + self.reg_z_label = QLabel("Z-level:") + reg_zt_layout.addWidget(self.reg_z_label) + self.reg_z_spin = QSpinBox() + self.reg_z_spin.setRange(0, 0) + self.reg_z_spin.setValue(0) + self.reg_z_spin.setToolTip("Z-level to use for registration") + self.reg_z_spin.setFixedWidth(60) + reg_zt_layout.addWidget(self.reg_z_spin) + self.reg_t_label = QLabel("Timepoint:") + reg_zt_layout.addWidget(self.reg_t_label) + self.reg_t_spin = QSpinBox() + self.reg_t_spin.setRange(0, 0) + self.reg_t_spin.setValue(0) + self.reg_t_spin.setToolTip("Timepoint to use for registration") + self.reg_t_spin.setFixedWidth(60) + reg_zt_layout.addWidget(self.reg_t_spin) + self.reg_channel_label = QLabel("Channel:") + reg_zt_layout.addWidget(self.reg_channel_label) + self.reg_channel_combo = QComboBox() + self.reg_channel_combo.setToolTip("Channel to use for registration") + self.reg_channel_combo.setMinimumWidth(120) + reg_zt_layout.addWidget(self.reg_channel_combo) + reg_zt_layout.addStretch() + settings_layout.addWidget(self.reg_zt_widget) + self.blend_checkbox = QCheckBox("Enable blending") self.blend_checkbox.setChecked(False) self.blend_checkbox.toggled.connect(self.on_blend_toggled) @@ -978,6 +1036,31 @@ def on_file_dropped(self, file_path): self.clear_flatfield_button.setEnabled(False) self.save_flatfield_button.setEnabled(False) + # Load dataset dimensions for registration z/t selection + try: + from tilefusion import TileFusion + + tf_temp = TileFusion(file_path) + self.dataset_n_z = tf_temp.n_z + self.dataset_n_t = tf_temp.n_t + self.dataset_n_channels = tf_temp.channels + if "channel_names" in tf_temp._metadata: + self.dataset_channel_names = tf_temp._metadata["channel_names"] + else: + self.dataset_channel_names = [ + f"Channel {i}" for i in range(self.dataset_n_channels) + ] + tf_temp.close() + if self.dataset_n_z > 1 or self.dataset_n_t > 1: + self.log(f"Dataset: {self.dataset_n_z} z-levels, {self.dataset_n_t} timepoints") + self._update_reg_zt_controls() + except Exception: + self.dataset_n_z = 1 + self.dataset_n_t = 1 + self.dataset_n_channels = 1 + self.dataset_channel_names = [] + self._update_reg_zt_controls() + # Auto-load existing flatfield if present, otherwise disable correction # For directories (SQUID folders), also check inside the directory if path.is_dir(): @@ -997,6 +1080,41 @@ def on_file_dropped(self, file_path): def on_registration_toggled(self, checked): self.downsample_widget.setVisible(checked) + self._update_reg_zt_controls() + + def _update_reg_zt_controls(self): + """Update visibility and ranges of registration z/t controls.""" + registration_enabled = self.registration_checkbox.isChecked() + has_multi_z = self.dataset_n_z > 1 + has_multi_t = self.dataset_n_t > 1 + has_multi_channel = self.dataset_n_channels > 1 + + # Show z/t widget only when registration is enabled AND dataset has multi-z or multi-t or multi-channel + show_zt = registration_enabled and (has_multi_z or has_multi_t or has_multi_channel) + self.reg_zt_widget.setVisible(show_zt) + + if show_zt: + # Update z spinbox + self.reg_z_label.setVisible(has_multi_z) + self.reg_z_spin.setVisible(has_multi_z) + if has_multi_z: + self.reg_z_spin.setRange(0, self.dataset_n_z - 1) + self.reg_z_spin.setValue(self.dataset_n_z // 2) # Default to middle + + # Update t spinbox + self.reg_t_label.setVisible(has_multi_t) + self.reg_t_spin.setVisible(has_multi_t) + if has_multi_t: + self.reg_t_spin.setRange(0, self.dataset_n_t - 1) + self.reg_t_spin.setValue(0) # Default to first timepoint + + # Update channel combo + self.reg_channel_label.setVisible(has_multi_channel) + self.reg_channel_combo.setVisible(has_multi_channel) + if has_multi_channel: + self.reg_channel_combo.clear() + self.reg_channel_combo.addItems(self.dataset_channel_names) + self.reg_channel_combo.setCurrentIndex(0) def on_blend_toggled(self, checked): self.blend_value_widget.setVisible(checked) @@ -1228,6 +1346,13 @@ def run_stitching(self): flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None + # Get registration z/t values (None means use default middle z) + registration_z = self.reg_z_spin.value() if self.dataset_n_z > 1 else None + registration_t = self.reg_t_spin.value() if self.dataset_n_t > 1 else 0 + registration_channel = ( + self.reg_channel_combo.currentIndex() if self.dataset_n_channels > 1 else 0 + ) + self.worker = FusionWorker( self.drop_area.file_path, self.registration_checkbox.isChecked(), @@ -1236,6 +1361,9 @@ def run_stitching(self): fusion_mode, flatfield=flatfield, darkfield=darkfield, + registration_z=registration_z, + registration_t=registration_t, + registration_channel=registration_channel, ) self.worker.progress.connect(self.log) self.worker.finished.connect(self.on_fusion_finished) @@ -1294,6 +1422,13 @@ def run_preview(self): flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None + # Get registration z/t values (None means use default middle z) + registration_z = self.reg_z_spin.value() if self.dataset_n_z > 1 else None + registration_t = self.reg_t_spin.value() if self.dataset_n_t > 1 else 0 + registration_channel = ( + self.reg_channel_combo.currentIndex() if self.dataset_n_channels > 1 else 0 + ) + self.preview_worker = PreviewWorker( self.drop_area.file_path, self.preview_cols_spin.value(), @@ -1301,6 +1436,9 @@ def run_preview(self): self.downsample_spin.value(), flatfield=flatfield, darkfield=darkfield, + registration_z=registration_z, + registration_t=registration_t, + registration_channel=registration_channel, ) self.preview_worker.progress.connect(self.log) self.preview_worker.finished.connect(self.on_preview_finished) diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index 1c49d68..d28bc7e 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -110,6 +110,8 @@ def __init__( region: Optional[str] = None, flatfield: Optional[np.ndarray] = None, darkfield: Optional[np.ndarray] = None, + registration_z: Optional[int] = None, + registration_t: int = 0, ): self.tiff_path = Path(tiff_path) if not self.tiff_path.exists(): @@ -194,6 +196,18 @@ def __init__( self._time_folders = self._metadata.get("time_folders", None) self._middle_z = self.n_z // 2 # Use middle z-level for registration + # Registration z/t selection (validate after n_z/n_t are known) + if registration_z is None: + self._registration_z = self._middle_z + else: + if registration_z < 0 or registration_z >= self.n_z: + raise ValueError(f"registration_z={registration_z} out of range [0, {self.n_z})") + self._registration_z = registration_z + + if registration_t < 0 or registration_t >= self.n_t: + raise ValueError(f"registration_t={registration_t} out of range [0, {self.n_t})") + self._registration_t = registration_t + # Configuration self.downsample_factors = tuple(downsample_factors) self.ssim_window = int(ssim_window) @@ -447,10 +461,12 @@ def _update_profiles(self) -> None: # I/O methods (delegate to format-specific loaders) # ------------------------------------------------------------------------- - def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = 0) -> np.ndarray: + def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -> np.ndarray: """Read a single tile from the input data (all channels).""" if z_level is None: - z_level = self._middle_z # Default to middle z for registration + z_level = self._registration_z # Default to registration z-level + if time_idx is None: + time_idx = self._registration_t # Default to registration timepoint if self._is_zarr_format: zarr_ts = self._metadata["tensorstore"] @@ -493,11 +509,13 @@ def _read_tile_region( y_slice: slice, x_slice: slice, z_level: int = None, - time_idx: int = 0, + time_idx: int = None, ) -> np.ndarray: """Read a region of a tile from the input data.""" if z_level is None: - z_level = self._middle_z # Default to middle z for registration + z_level = self._registration_z # Default to registration z-level + if time_idx is None: + time_idx = self._registration_t # Default to registration timepoint if self._is_zarr_format: zarr_ts = self._metadata["tensorstore"]