diff --git a/package/PartSeg/_roi_mask/main_window.py b/package/PartSeg/_roi_mask/main_window.py index c3b7d4bad..f1b6d5be7 100644 --- a/package/PartSeg/_roi_mask/main_window.py +++ b/package/PartSeg/_roi_mask/main_window.py @@ -1,6 +1,8 @@ import os +from collections.abc import Sequence from contextlib import suppress from functools import partial +from typing import Union import numpy as np from qtpy.QtCore import QByteArray, Qt, Signal, Slot @@ -16,8 +18,10 @@ QMessageBox, QProgressBar, QPushButton, + QScrollArea, QSizePolicy, QSpinBox, + QSplitter, QTabWidget, QTextEdit, QVBoxLayout, @@ -393,7 +397,7 @@ def leaveEvent(self, _event): self.mouse_leave.emit(self.number) -class ChosenComponents(QWidget): +class ChosenComponents(QScrollArea): """ :type check_box: dict[int, ComponentCheckBox] """ @@ -404,6 +408,7 @@ class ChosenComponents(QWidget): def __init__(self): super().__init__() + self.setWidget(QWidget(self)) self.check_box = {} self.check_all_btn = QPushButton("Select all") self.check_all_btn.clicked.connect(self.check_all) @@ -416,19 +421,33 @@ def __init__(self): self.check_layout = FlowLayout() main_layout.addLayout(btn_layout) main_layout.addLayout(self.check_layout) - self.setLayout(main_layout) + self.widget().setLayout(main_layout) + self.setWidgetResizable(True) + self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) def other_component_choose(self, num): check = self.check_box[num] check.setChecked(not check.isChecked()) + self.ensureWidgetVisible(check) def check_all(self): - for el in self.check_box.values(): - el.setChecked(True) + prev = self.blockSignals(True) + try: + for el in self.check_box.values(): + el.setChecked(True) + finally: + self.blockSignals(prev) + self.check_change_signal.emit() def un_check_all(self): - for el in self.check_box.values(): - el.setChecked(False) + prev = self.blockSignals(True) + try: + for el in self.check_box.values(): + el.setChecked(False) + finally: + self.blockSignals(prev) + self.check_change_signal.emit() def remove_components(self): self.check_layout.clear() @@ -439,10 +458,12 @@ def remove_components(self): el.mouse_enter.disconnect() self.check_box.clear() - def new_choose(self, num, chosen_components): - self.set_chose(range(1, num + 1), chosen_components) + def new_choose(self, num: int, chosen_components: Sequence[int]) -> None: + self.set_components(range(1, num + 1), chosen_components) - def set_chose(self, components_index, chosen_components): + def set_components(self, components_index, chosen_components: Union[Sequence[int], None] = None): + if chosen_components is None: + chosen_components = [] chosen_components = set(chosen_components) self.blockSignals(True) self.remove_components() @@ -459,11 +480,22 @@ def set_chose(self, components_index, chosen_components): self.update() self.check_change_signal.emit() + def set_chosen(self, chosen_components: Sequence[int]): + prev = self.blockSignals(True) + chosen_components = set(chosen_components) + try: + for num, check in self.check_box.items(): + check.setChecked(num in chosen_components) + finally: + self.blockSignals(prev) + self.check_change_signal.emit() + def check_change(self): self.check_change_signal.emit() def change_state(self, num, val): self.check_box[num].setChecked(val) + self.ensureWidgetVisible(self.check_box[num]) def get_state(self, num: int) -> bool: # TODO Check what situation create report of id ID: af9b57f074264169b4353aa1e61d8bc2 @@ -527,6 +559,7 @@ def __init__(self, settings: StackSettings, image_view: StackImageView): # noqa self.choose_components.check_change_signal.connect(image_view.refresh_selected) self.choose_components.mouse_leave.connect(image_view.component_unmark) self.choose_components.mouse_enter.connect(image_view.component_mark) + self.chosen_list = [] self.progress_bar2 = QProgressBar() self.progress_bar2.setHidden(True) @@ -566,8 +599,10 @@ def __init__(self, settings: StackSettings, image_view: StackImageView): # noqa main_layout.addWidget(self.progress_bar2) main_layout.addWidget(self.progress_bar) main_layout.addWidget(self.progress_info_lab) - main_layout.addWidget(self.algorithm_choose_widget, 1) - main_layout.addWidget(self.choose_components) + split = QSplitter(Qt.Orientation.Vertical) + split.addWidget(self.algorithm_choose_widget) + split.addWidget(self.choose_components) + main_layout.addWidget(split, 1) down_layout = QHBoxLayout() down_layout.addWidget(self.keep_chosen_components_chk) down_layout.addWidget(self.show_parameters) @@ -659,7 +694,7 @@ def segmentation(self, val): def _image_changed(self): self.settings.roi = None - self.choose_components.set_chose([], []) + self.choose_components.set_components([], []) def _execute_in_background_init(self): if self.batch_process.isRunning(): diff --git a/package/PartSeg/_roi_mask/stack_settings.py b/package/PartSeg/_roi_mask/stack_settings.py index 3ca23664f..d5acc9de1 100644 --- a/package/PartSeg/_roi_mask/stack_settings.py +++ b/package/PartSeg/_roi_mask/stack_settings.py @@ -170,17 +170,15 @@ def set_project_info(self, data: typing.Union[MaskProjectTuple, PointsInfo]): data.selected_components, self.keep_chosen_components, ) - self.chosen_components_widget.set_chose( - sorted(state2.roi_extraction_parameters.keys()), state2.selected_components - ) self.roi = state2.roi_info + self.chosen_components_widget.set_chosen(state2.selected_components) + self.components_parameters_dict = state2.roi_extraction_parameters else: self.set_history(data.history) - self.chosen_components_widget.set_chose( - sorted(data.roi_extraction_parameters.keys()), data.selected_components - ) self.roi = data.roi_info + self.chosen_components_widget.set_chosen(data.selected_components) + self.components_parameters_dict = data.roi_extraction_parameters @staticmethod @@ -304,18 +302,22 @@ def _set_roi_info( raise ValueError("ROI do not fit to image") from e if save_chosen: state2 = self.transform_state(state, new_roi_info, segmentation_parameters, list_of_components, save_chosen) - self.chosen_components_widget.set_chose( - sorted(state2.roi_extraction_parameters.keys()), state2.selected_components - ) self.roi = state2.roi_info + self.chosen_components_widget.set_chosen(state2.selected_components) self.components_parameters_dict = state2.roi_extraction_parameters else: - selected_parameters = {i: segmentation_parameters[i] for i in new_roi_info.bound_info} - - self.chosen_components_widget.set_chose(sorted(selected_parameters.keys()), list_of_components) self.roi = new_roi_info + self.chosen_components_widget.set_chosen(list_of_components) self.components_parameters_dict = segmentation_parameters + def post_roi_set(self): + if self.chosen_components_widget is not None: + prev = self.chosen_components_widget.blockSignals(True) + try: + self.chosen_components_widget.set_components(self.roi_info.bound_info.keys(), []) + finally: + self.chosen_components_widget.blockSignals(prev) + def get_mask( segmentation: typing.Optional[np.ndarray], mask: typing.Optional[np.ndarray], selected: list[int] diff --git a/package/PartSeg/common_backend/base_settings.py b/package/PartSeg/common_backend/base_settings.py index c19959643..324ca5b0d 100644 --- a/package/PartSeg/common_backend/base_settings.py +++ b/package/PartSeg/common_backend/base_settings.py @@ -145,6 +145,7 @@ def roi(self, val: Union[np.ndarray, ROIInfo]): if val is None: self._roi_info = ROIInfo(val) self._additional_layers = {} + self.post_roi_set() self.roi_clean.emit() return try: @@ -155,8 +156,12 @@ def roi(self, val: Union[np.ndarray, ROIInfo]): except ValueError as e: raise ValueError(ROI_NOT_FIT) from e self._additional_layers = {} + self.post_roi_set() self.roi_changed.emit(self._roi_info) + def post_roi_set(self) -> None: + """called after roi is set, for subclasses to override""" + @property def sizes(self): return self._roi_info.sizes @@ -530,8 +535,7 @@ def set_segmentation_result(self, result: ROIExtractionResult): raise ValueError(ROI_NOT_FIT) from e if result.points is not None: self.points = result.points - self._roi_info = roi_info - self.roi_changed.emit(self._roi_info) + self.roi = roi_info def _load_files_call(self, files_list: list[str]): self.request_load_files.emit(files_list) diff --git a/package/PartSeg/common_gui/napari_image_view.py b/package/PartSeg/common_gui/napari_image_view.py index 95843f16f..2272f16fb 100644 --- a/package/PartSeg/common_gui/napari_image_view.py +++ b/package/PartSeg/common_gui/napari_image_view.py @@ -584,7 +584,7 @@ def _remove_worker(self, sender=None): else: logging.debug("[_remove_worker] %s", sender) - def _add_layer_util(self, index, layer, filters): + def _add_layer_util(self, index: int, layer: _NapariImage, filters: list[tuple[NoiseFilterType, float]]) -> None: if layer not in self.viewer.layers: self.viewer.add_layer(layer)