From 54a34dd40cd4064b16eeccf59b498907761479df Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Tue, 15 Aug 2023 20:10:18 +0200 Subject: [PATCH 1/8] Update proofreading.py implements requested changes add a class that allows storing locations of errors - add some typing and docstring Update proofreading.py pass actionstate to store_error_location with different mode input directly & remove intermedeiary functions --- ffn/utils/proofreading.py | 409 +++++++++++++++++++++++++++++++++----- 1 file changed, 363 insertions(+), 46 deletions(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index 7fb8aba..33c230e 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -24,17 +24,29 @@ import networkx as nx import neuroglancer +from typing import Union, Optional, Iterable, Any + class Base: """Base class for proofreading workflows. - To use, define a subclass overriding the `set_init_state` method to - provide initial Neuroglancer settings. - - The segmentation volume needs to be called `seg`. + To use, define a subclass overriding the `set_init_state` method to provide + initial Neuroglancer settings. The segmentation volume needs to be called + `seg`. """ - def __init__(self, num_to_prefetch=10, locations=None, objects=None): + def __init__(self, + num_to_prefetch: int = 10, + locations: Optional[Iterable[tuple[int, int, int]]] = None, + objects: Optional[ + Union[dict[str, Any], Iterable[int]]] = None): + """Initializes the Base class for proofreading. + + Args: + num_to_prefetch: Number of items to prefetch. + locations: List of xyz coordinates corresponding to object locations. + objects: Object IDs or a dictionary mapping layer names to object IDs. + """ self.viewer = neuroglancer.Viewer() self.num_to_prefetch = num_to_prefetch @@ -55,7 +67,8 @@ def __init__(self, num_to_prefetch=10, locations=None, objects=None): self.set_init_state() - def _set_todo(self, objects): + def _set_todo(self, objects: Union[list[str, Any], Iterable[int]]) -> None: + """Private method to set the todo list.""" for o in objects: if isinstance(o, collections.abc.Mapping): self.todo.append(o) @@ -65,14 +78,28 @@ def _set_todo(self, objects): else: self.todo.append({'seg': [o]}) - def set_init_state(self): + def set_init_state(self) -> None: + """Sets the initial state for Neuroglancer. + Subclasses should override this method. + """ raise NotImplementedError() - def update_msg(self, msg): + def update_msg(self, msg: str) -> None: + """Updates the status message in Neuroglancer viewer.""" with self.viewer.config_state.txn() as s: s.status_messages['status'] = msg - def update_segments(self, segments, loc=None, layer='seg'): + def update_segments(self, + segments: list[int], + loc: Optional[tuple[int, int, int]] = None, + layer: str = 'seg') -> None: + """Updates segments in Neuroglancer viewer. + + Args: + segments: List of segment IDs to update. + loc: 3D coordinates to set the viewer to. + layer: Layer name in Neuroglancer to be updated. + """ s = copy.deepcopy(self.viewer.state) l = s.layers[layer] l.segments = segments @@ -90,53 +117,78 @@ def update_segments(self, segments, loc=None, layer='seg'): self.viewer.set_state(s) - def toggle_equiv(self): + def toggle_equiv(self) -> None: + """Toggle the apply equivalence flag and update the batch.""" self.apply_equivs = not self.apply_equivs self.update_batch() - def batch_dec(self): + def batch_dec(self) -> None: + """Decrease the batch size by half and update the batch.""" self.batch //= 2 self.batch = max(self.batch, 1) self.update_batch() - def batch_inc(self): + def batch_inc(self) -> None: + """Increase the batch size by double and update the batch.""" self.batch *= 2 self.update_batch() - def next_batch(self): + def next_batch(self) -> None: + """Move to the next batch of segments and update the viewer.""" self.index += self.batch self.index = min(self.index, len(self.todo) - 1) self.prefetch() self.update_batch() - def prev_batch(self): + def prev_batch(self) -> None: + """Move to the previous batch of segments and update the viewer.""" self.index -= self.batch self.index = max(0, self.index) self.update_batch() - def list_segments(self, index=None, layer='seg'): + def list_segments(self, + index: Optional[int] = None, + layer: str = 'seg') -> list[int]: + """Get a list of segment IDs for a given index and layer. + + Args: + index: Index of segments to list. + layer: Layer name to list the segments from. + + Returns: + List of segment IDs. + """ if index is None: index = self.index return list( - set( - itertools.chain( - *[x[layer] for x in self.todo[index:index + self.batch]]))) + set( + itertools.chain( + *[x[layer] for x in self.todo[index:index + self.batch]]))) - def custom_msg(self): + def custom_msg(self) -> str: + """Generate a custom message for the current state. + + Returns: + A custom message string. + """ return '' - def update_batch(self, update=True): + def update_batch(self) -> None: + """Update the segments displayed in the viewer based on batch settings.""" if self.batch == 1 and self.locations is not None: loc = self.locations[self.index] else: loc = None for layer in self.managed_layers: - self.update_segments(self.list_segments(layer=layer), loc, layer=layer) + self.update_segments(self.list_segments(layer=layer), loc, + layer=layer) self.update_msg('index:%d/%d batch:%d %s' % - (self.index, len(self.todo), self.batch, self.custom_msg())) + (self.index, len(self.todo), self.batch, + self.custom_msg())) - def prefetch(self): + def prefetch(self) -> None: + """Pre-fetch the segments for smoother navigation in the viewer.""" prefetch_states = [] for i in range(self.num_to_prefetch): idx = self.index + (i + 1) * self.batch @@ -145,7 +197,7 @@ def prefetch(self): prefetch_state = copy.deepcopy(self.viewer.state) for layer in self.managed_layers: prefetch_state.layers[layer].segments = self.list_segments( - idx, layer=layer) + idx, layer=layer) prefetch_state.layout = '3d' if self.locations is not None: prefetch_state.position = self.locations[idx] @@ -154,10 +206,29 @@ def prefetch(self): with self.viewer.config_state.txn() as s: s.prefetch = [ - neuroglancer.PrefetchState(state=prefetch_state, priority=-i) - for i, prefetch_state in enumerate(prefetch_states) + neuroglancer.PrefetchState(state=prefetch_state, priority=-i) + for i, prefetch_state in enumerate(prefetch_states) ] + def get_cursor_position(self, + action_state: neuroglancer.viewer_config_state.ActionState): + """Return coordinates of the cursor position from a neuroglancer action state + + Args: + action_state : Neuroglancer action state + + Returns: + (x, y, z) cursor position + """ + try: + cursor_position = [int(x) for x in + action_state.mouse_voxel_coordinates] + except Exception: + self.update_msg('cursor misplaced') + return + + return cursor_position + class ObjectReview(Base): """Base class for rapid (agglomerated) object review. @@ -166,7 +237,11 @@ class ObjectReview(Base): batches. """ - def __init__(self, objects, bad, num_to_prefetch=10, locations=None): + def __init__(self, + objects: Iterable, + bad: list, + num_to_prefetch: int = 10, + locations: Optional[Iterable[tuple[int, int, int]]] = None): """Constructor. Args: @@ -176,19 +251,27 @@ def __init__(self, objects, bad, num_to_prefetch=10, locations=None): bad: set in which to store objects or groups of objects flagged as bad. num_to_prefetch: number of items from `objects` to prefetch locations: iterable of xyz tuples of length len(objects). If specified, - the cursor will be automaticaly moved to the location corresponding to + the cursor will be automatically moved to the location corresponding to the current object if batch == 1. """ super().__init__( - num_to_prefetch=num_to_prefetch, locations=locations, objects=objects) + num_to_prefetch=num_to_prefetch, locations=locations, + objects=objects) self.bad = bad + self.set_keybindings() + + self.update_batch() + + def set_keybindings(self) -> None: + """Set key bindings for the viewer.""" self.viewer.actions.add('next-batch', lambda s: self.next_batch()) self.viewer.actions.add('prev-batch', lambda s: self.prev_batch()) self.viewer.actions.add('dec-batch', lambda s: self.batch_dec()) self.viewer.actions.add('inc-batch', lambda s: self.batch_inc()) self.viewer.actions.add('mark-bad', lambda s: self.mark_bad()) - self.viewer.actions.add('mark-removed-bad', lambda s: self.mark_removed_bad()) + self.viewer.actions.add('mark-removed-bad', + lambda s: self.mark_removed_bad()) self.viewer.actions.add('toggle-equiv', lambda s: self.toggle_equiv()) with self.viewer.config_state.txn() as s: @@ -200,12 +283,20 @@ def __init__(self, objects, bad, num_to_prefetch=10, locations=None): s.input_event_bindings.viewer['keyt'] = 'toggle-equiv' s.input_event_bindings.viewer['keya'] = 'mark-removed-bad' - self.update_batch() + def custom_msg(self) -> str: + """Construct a custom message for the current state. - def custom_msg(self): + Returns: + A formatted message indicating the number of bad objects. + """ return 'num_bad: %d' % len(self.bad) - def mark_bad(self): + def mark_bad(self) -> None: + """Mark an object or group of objects as bad. + + If the batch size is greater than 1, the user is prompted to decrease + the batch size. + """ if self.batch > 1: self.update_msg('decrease batch to 1 to mark objects bad') return @@ -216,21 +307,243 @@ def mark_bad(self): else: self.bad.add(frozenset(sids)) - self.update_msg('marked bad: %r' % (sids, )) + self.update_msg('marked bad: %r' % (sids,)) self.next_batch() - def mark_removed_bad(self): + def mark_removed_bad(self) -> None: + """From the set of original objects mark those bad that are not displayed. + Update the message with the IDs of the newly marked bad objects. + """ original = set(self.list_segments()) new_bad = original - set(self.viewer.state.layers['seg'].segments) if new_bad: self.bad |= new_bad - self.update_msg('marked bad: %r' % (new_bad, )) + self.update_msg('marked bad: %r' % (new_bad,)) + + +class ObjectReviewStoreLocation(ObjectReview): + """Class to mark and store locations of errors in the segmentation + + To mark a merger, move the cursor to a spot of the false merger and press 'w'. + Then, move the cursor to a spot within the object that should belong to a + separate object and press 'shift + W'. Yellow point annotations indicate the + merger. For split errors, proceed in similar manner but press 'd' and + 'shift + D', which will display blue annotations. + Marked locations can be deleted either by pressing 'ctrl + Z' (to delete the + last marked location) or by hovering the cursor over one of the point + annotations and pressing 'ctrl + v'. + + Attributes: + seg_error_coordinates: A mapping of annotation identifier substrings to + error coordinate pairs. + Example: {'m0': [[x1,y1,z1],[x2,y2,z2]], 's0':[[x1,y1,z1],[x2,y2,z2]], ...} + - Keys starting with 'm' indicate merge errors. + - Keys starting with 's' indicate split errors. + temp_coord_list: Temporary storage for coordinates. + """ + + def __init__(self, + objects: list, + bad: list, + seg_error_coordinates: Optional[ + list[str, list[list[int]]]] = {}, + load_annotations: bool = False) -> None: + """Initialize the ObjectReviewStoreLocation class. + + Args: + objects: A list of objects. + bad: A list of bad objects or markers. + seg_error_coordinates: A dictionary of error coordinates. + load_annotations: A flag to indicate if annotations should be loaded. + """ + super(ObjectReviewStoreLocation, self).__init__(objects, bad) + self.seg_error_coordinates = seg_error_coordinates + if load_annotations and seg_error_coordinates: + for k, v in seg_error_coordinates.items(): + self.annotate_error_locations(v, k) + self.temp_coord_list = [] + + def set_keybindings(self) -> None: + """Set key bindings for the viewer.""" + super().set_keybindings() + self.viewer.actions.add('merge0', + lambda s: self.store_error_location(s, index=0, + mode='merger')) + self.viewer.actions.add('merge1', + lambda s: self.store_error_location(s, index=1, + mode='merger')) + self.viewer.actions.add('split0', + lambda s: self.store_error_location(s, index=0, + mode='split')) + self.viewer.actions.add('split1', + lambda s: self.store_error_location(s, index=1, + mode='split')) + self.viewer.actions.add('delete_from_annotation', + self.delete_location_from_annotation) + self.viewer.actions.add('delete_last_entry', + lambda s: self.delete_last_location()) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.viewer['keyw'] = 'merge0' + s.input_event_bindings.viewer['shift+keyw'] = 'merge1' + s.input_event_bindings.viewer['keyd'] = 'split0' + s.input_event_bindings.viewer['shift+keyd'] = 'split1' + s.input_event_bindings.viewer[ + 'control+keyv'] = 'delete_from_annotation' + s.input_event_bindings.viewer[ + 'control+keyz'] = 'delete_last_entry' + + def get_id(self, mode: str) -> str: + """Generate a unique identifier for an error based on its type. + + Args: + mode: Error type, either 'merge' or 'split'. + + Returns: + A unique identifier string. + """ + id_ = mode[0] + if any(self.seg_error_coordinates): + counter = int( + max([x[1:] for x in self.seg_error_coordinates.keys()])) + 1 + else: + counter = 0 + id_ = id_ + str(counter) + return id_ + + def store_error_location(self, + action_state: neuroglancer.viewer_config_state.ActionState, + mode: str, + index: int = 0) -> None: + """Store error locations. + + Args: + action_state: State of the viewer during the action. + mode: Type of the error ('merger' or 'split'). + index: Indicates if it's the first or second coordinate (0 or 1). + """ + location = self.get_cursor_position(action_state) + if location is None: + return + + if index == 1 and not self.temp_coord_list: + self.update_msg('You have not entered a first coord yet') + return + + if index == 0 and self.temp_coord_list: + self.temp_coord_list = [] + + self.temp_coord_list.append(location) + + if index == 1: + if self.temp_coord_list[0] == self.temp_coord_list[1]: + self.update_msg( + 'You entered the same coordinate twice. Try again!') + self.temp_coord_list = [] + return + + identifier = self.get_id(mode=mode) + self.seg_error_coordinates.update( + {identifier: self.temp_coord_list}) + self.annotate_error_locations(self.temp_coord_list, identifier) + self.temp_coord_list = [] + + def annotate_error_locations(self, + coordinate_lst: list[list[int]], + id_: str) -> None: + """Annotate the error locations in the viewer. + + Args: + coordinate_lst: List of coordinates to be annotated. + id_: Unique identifier for the error. + """ + for i, coord in enumerate(coordinate_lst): + annotation_id = id_ + f'_{i}' + self.mk_point_annotation(coord, annotation_id) + + def mk_point_annotation(self, + coordinate: list[int], + annotation_id: str) -> None: + """Create a point annotation in the viewer. + + Args: + coordinate: 3D coordinate of the annotation point. + annotation_id: Unique identifier for the annotation. + """ + if annotation_id.startswith('m'): + color = '#fae505' + else: + color = '#05f2fa' + annotation = neuroglancer.PointAnnotation(id=annotation_id, + point=coordinate, + props=[color]) + with self.viewer.txn() as s: + annotations = s.layers['annotation'].annotations + annotations.append(annotation) + + def get_annotation_id(self, + action_state: neuroglancer.viewer_config_state.ActionState) -> \ + Optional[str]: + """Retrieve the ID of a selected annotation. + + Args: + action_state: neuroglancer.viewer_config_state.ActionState. + + Returns: + The selected object's ID or None if retrieval fails. + """ + try: + selection_state = action_state.selected_values[ + 'annotation'].to_json() + selected_object = selection_state['annotationId'] + except Exception: + self.update_msg('Could not retrieve annotation id') + return + + return selected_object + + def delete_location_from_annotation(self, + action_state: neuroglancer.viewer_config_state.ActionState) -> None: + """Delete the error location pair associated with the annotation at the cursor position + + Args: + action_state: State of the viewer during the action. + """ + id_ = self.get_annotation_id(action_state) + if id_ is None: + return + + target_key = id_[:2] + del self.seg_error_coordinates[target_key] + + to_remove = [target_key + '_0', target_key + '_1'] + self.delete_annotation(to_remove) + + def delete_annotation(self, to_remove: list[str]) -> None: + """Delete specified annotations from the viewer. + + Args: + to_remove: list of annotation IDs to be removed. + """ + with self.viewer.txn() as s: + annotations = s.layers['annotation'].annotations + annotations = [a for a in annotations if a.id not in to_remove] + s.layers['annotation'].annotations = annotations + + def delete_last_location(self): + """Delete the last error location pair tagged.""" + last_key = list(self.seg_error_coordinates.keys())[-1] + del self.seg_error_coordinates[last_key] + + to_remove = [last_key + '_0', last_key + '_1'] + self.delete_annotation(to_remove) class ObjectClassification(Base): """Base class for object classification.""" - def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None): + def __init__(self, objects, key_to_class, num_to_prefetch=10, + locations=None): """Constructor. Args: @@ -239,7 +552,8 @@ def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None): num_to_prefetch: number of `objects` to prefetch """ super().__init__( - num_to_prefetch=num_to_prefetch, locations=locations, objects=objects) + num_to_prefetch=num_to_prefetch, locations=locations, + objects=objects) self.results = defaultdict(set) # class -> ids @@ -249,11 +563,12 @@ def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None): for key, cls in key_to_class.items(): self.viewer.actions.add( - 'classify-%s' % cls, lambda s, cls=cls: self.classify(cls)) + 'classify-%s' % cls, lambda s, cls=cls: self.classify(cls)) with self.viewer.config_state.txn() as s: for key, cls in key_to_class.items(): - s.input_event_bindings.viewer['key%s' % key] = 'classify-%s' % cls + s.input_event_bindings.viewer[ + 'key%s' % key] = 'classify-%s' % cls # Navigation without classification. s.input_event_bindings.viewer['keyj'] = 'mr-next-batch' @@ -313,7 +628,8 @@ def __init__(self, graph, objects, bad, num_to_prefetch=0): self.viewer.actions.add('accept-split', lambda s: self.accept_split()) self.viewer.actions.add('split-inc', lambda s: self.inc_split()) self.viewer.actions.add('split-dec', lambda s: self.dec_split()) - self.viewer.actions.add('merge-segments', lambda s: self.merge_segments()) + self.viewer.actions.add('merge-segments', + lambda s: self.merge_segments()) self.viewer.actions.add('mark-bad', lambda s: self.mark_bad()) self.viewer.actions.add('next-batch', lambda s: self.next_batch()) self.viewer.actions.add('prev-batch', lambda s: self.prev_batch()) @@ -332,11 +648,12 @@ def __init__(self, graph, objects, bad, num_to_prefetch=0): with self.viewer.txn() as s: s.layers['split'] = neuroglancer.SegmentationLayer( - source=s.layers['seg'].source) + source=s.layers['seg'].source) s.layers['split'].visible = False def merge_segments(self): - sids = [sid for sid in self.viewer.state.layers['seg'].segments if sid > 0] + sids = [sid for sid in self.viewer.state.layers['seg'].segments if + sid > 0] self.graph.add_edges_from(zip(sids, sids[1:])) def update_split(self): @@ -384,7 +701,7 @@ def start_split(self): self.split_objects[1]) self.split_index = 1 self.update_msg( - 'splitting: %s' % ('-'.join(str(x) for x in self.split_path))) + 'splitting: %s' % ('-'.join(str(x) for x in self.split_path))) s = copy.deepcopy(self.viewer.state) s.layers['seg'].visible = False @@ -396,7 +713,7 @@ def add_split(self, s): if len(self.split_objects) < 2: self.split_objects.append(s.selected_values['seg'].value) self.update_msg( - 'split: %s' % (':'.join(str(x) for x in self.split_objects))) + 'split: %s' % (':'.join(str(x) for x in self.split_objects))) if len(self.split_objects) == 2: self.start_split() @@ -412,5 +729,5 @@ def mark_bad(self): else: self.bad.add(frozenset(sids)) - self.update_msg('marked bad: %r' % (sids, )) + self.update_msg('marked bad: %r' % (sids,)) self.next_batch() From 6bc285708d4230536b20832eb8961ac3da15e0f1 Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Mon, 21 Aug 2023 07:43:51 +0200 Subject: [PATCH 2/8] pyink run on proofreading.py --- ffn/utils/proofreading.py | 1345 +++++++++++++++++++------------------ 1 file changed, 674 insertions(+), 671 deletions(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index 33c230e..e9f03de 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -28,706 +28,709 @@ class Base: - """Base class for proofreading workflows. - - To use, define a subclass overriding the `set_init_state` method to provide - initial Neuroglancer settings. The segmentation volume needs to be called - `seg`. - """ - - def __init__(self, - num_to_prefetch: int = 10, - locations: Optional[Iterable[tuple[int, int, int]]] = None, - objects: Optional[ - Union[dict[str, Any], Iterable[int]]] = None): - """Initializes the Base class for proofreading. - - Args: - num_to_prefetch: Number of items to prefetch. - locations: List of xyz coordinates corresponding to object locations. - objects: Object IDs or a dictionary mapping layer names to object IDs. - """ - self.viewer = neuroglancer.Viewer() - self.num_to_prefetch = num_to_prefetch - - self.managed_layers = set(['seg']) - self.todo = [] # items are maps from layer name to lists of segment IDs - if objects is not None: - self._set_todo(objects) - - self.index = 0 - self.batch = 1 - self.apply_equivs = False - - if locations is not None: - self.locations = list(locations) - assert len(self.todo) == len(locations) - else: - self.locations = None - - self.set_init_state() - - def _set_todo(self, objects: Union[list[str, Any], Iterable[int]]) -> None: - """Private method to set the todo list.""" - for o in objects: - if isinstance(o, collections.abc.Mapping): - self.todo.append(o) - self.managed_layers |= set(o.keys()) - elif isinstance(o, collections.abc.Iterable): - self.todo.append({'seg': o}) - else: - self.todo.append({'seg': [o]}) - - def set_init_state(self) -> None: - """Sets the initial state for Neuroglancer. - Subclasses should override this method. - """ - raise NotImplementedError() - - def update_msg(self, msg: str) -> None: - """Updates the status message in Neuroglancer viewer.""" - with self.viewer.config_state.txn() as s: - s.status_messages['status'] = msg - - def update_segments(self, - segments: list[int], - loc: Optional[tuple[int, int, int]] = None, - layer: str = 'seg') -> None: - """Updates segments in Neuroglancer viewer. - - Args: - segments: List of segment IDs to update. - loc: 3D coordinates to set the viewer to. - layer: Layer name in Neuroglancer to be updated. - """ - s = copy.deepcopy(self.viewer.state) - l = s.layers[layer] - l.segments = segments - - if not self.apply_equivs: - l.equivalences.clear() - else: - l.equivalences.clear() - for a in self.todo[self.index:self.index + self.batch]: - a = [aa[layer] for aa in a] - l.equivalences.union(*a) - - if loc is not None: - s.position = loc - - self.viewer.set_state(s) - - def toggle_equiv(self) -> None: - """Toggle the apply equivalence flag and update the batch.""" - self.apply_equivs = not self.apply_equivs - self.update_batch() - - def batch_dec(self) -> None: - """Decrease the batch size by half and update the batch.""" - self.batch //= 2 - self.batch = max(self.batch, 1) - self.update_batch() - - def batch_inc(self) -> None: - """Increase the batch size by double and update the batch.""" - self.batch *= 2 - self.update_batch() - - def next_batch(self) -> None: - """Move to the next batch of segments and update the viewer.""" - self.index += self.batch - self.index = min(self.index, len(self.todo) - 1) - self.prefetch() - self.update_batch() - - def prev_batch(self) -> None: - """Move to the previous batch of segments and update the viewer.""" - self.index -= self.batch - self.index = max(0, self.index) - self.update_batch() - - def list_segments(self, - index: Optional[int] = None, - layer: str = 'seg') -> list[int]: - """Get a list of segment IDs for a given index and layer. - - Args: - index: Index of segments to list. - layer: Layer name to list the segments from. - - Returns: - List of segment IDs. - """ - if index is None: - index = self.index - return list( - set( - itertools.chain( - *[x[layer] for x in self.todo[index:index + self.batch]]))) - - def custom_msg(self) -> str: - """Generate a custom message for the current state. - - Returns: - A custom message string. - """ - return '' - - def update_batch(self) -> None: - """Update the segments displayed in the viewer based on batch settings.""" - if self.batch == 1 and self.locations is not None: - loc = self.locations[self.index] - else: - loc = None - - for layer in self.managed_layers: - self.update_segments(self.list_segments(layer=layer), loc, - layer=layer) - self.update_msg('index:%d/%d batch:%d %s' % - (self.index, len(self.todo), self.batch, - self.custom_msg())) - - def prefetch(self) -> None: - """Pre-fetch the segments for smoother navigation in the viewer.""" - prefetch_states = [] - for i in range(self.num_to_prefetch): - idx = self.index + (i + 1) * self.batch - if idx >= len(self.todo): - break - prefetch_state = copy.deepcopy(self.viewer.state) - for layer in self.managed_layers: - prefetch_state.layers[layer].segments = self.list_segments( - idx, layer=layer) - prefetch_state.layout = '3d' - if self.locations is not None: - prefetch_state.position = self.locations[idx] - - prefetch_states.append(prefetch_state) - - with self.viewer.config_state.txn() as s: - s.prefetch = [ - neuroglancer.PrefetchState(state=prefetch_state, priority=-i) - for i, prefetch_state in enumerate(prefetch_states) - ] - - def get_cursor_position(self, - action_state: neuroglancer.viewer_config_state.ActionState): - """Return coordinates of the cursor position from a neuroglancer action state - - Args: - action_state : Neuroglancer action state - - Returns: - (x, y, z) cursor position + """Base class for proofreading workflows. + + To use, define a subclass overriding the `set_init_state` method to provide + initial Neuroglancer settings. The segmentation volume needs to be called + `seg`. """ - try: - cursor_position = [int(x) for x in - action_state.mouse_voxel_coordinates] - except Exception: - self.update_msg('cursor misplaced') - return - return cursor_position + def __init__( + self, + num_to_prefetch: int = 10, + locations: Optional[Iterable[tuple[int, int, int]]] = None, + objects: Optional[Union[dict[str, Any], Iterable[int]]] = None, + ): + """Initializes the Base class for proofreading. + + Args: + num_to_prefetch: Number of items to prefetch. + locations: List of xyz coordinates corresponding to object locations. + objects: Object IDs or a dictionary mapping layer names to object IDs. + """ + self.viewer = neuroglancer.Viewer() + self.num_to_prefetch = num_to_prefetch + + self.managed_layers = set(["seg"]) + self.todo = [] # items are maps from layer name to lists of segment IDs + if objects is not None: + self._set_todo(objects) + + self.index = 0 + self.batch = 1 + self.apply_equivs = False + + if locations is not None: + self.locations = list(locations) + assert len(self.todo) == len(locations) + else: + self.locations = None + + self.set_init_state() + + def _set_todo(self, objects: Union[list[str, Any], Iterable[int]]) -> None: + """Private method to set the todo list.""" + for o in objects: + if isinstance(o, collections.abc.Mapping): + self.todo.append(o) + self.managed_layers |= set(o.keys()) + elif isinstance(o, collections.abc.Iterable): + self.todo.append({"seg": o}) + else: + self.todo.append({"seg": [o]}) + + def set_init_state(self) -> None: + """Sets the initial state for Neuroglancer. + Subclasses should override this method. + """ + raise NotImplementedError() + + def update_msg(self, msg: str) -> None: + """Updates the status message in Neuroglancer viewer.""" + with self.viewer.config_state.txn() as s: + s.status_messages["status"] = msg + + def update_segments( + self, + segments: list[int], + loc: Optional[tuple[int, int, int]] = None, + layer: str = "seg", + ) -> None: + """Updates segments in Neuroglancer viewer. + + Args: + segments: List of segment IDs to update. + loc: 3D coordinates to set the viewer to. + layer: Layer name in Neuroglancer to be updated. + """ + s = copy.deepcopy(self.viewer.state) + l = s.layers[layer] + l.segments = segments + + if not self.apply_equivs: + l.equivalences.clear() + else: + l.equivalences.clear() + for a in self.todo[self.index : self.index + self.batch]: + a = [aa[layer] for aa in a] + l.equivalences.union(*a) + + if loc is not None: + s.position = loc + + self.viewer.set_state(s) + + def toggle_equiv(self) -> None: + """Toggle the apply equivalence flag and update the batch.""" + self.apply_equivs = not self.apply_equivs + self.update_batch() + + def batch_dec(self) -> None: + """Decrease the batch size by half and update the batch.""" + self.batch //= 2 + self.batch = max(self.batch, 1) + self.update_batch() + + def batch_inc(self) -> None: + """Increase the batch size by double and update the batch.""" + self.batch *= 2 + self.update_batch() + + def next_batch(self) -> None: + """Move to the next batch of segments and update the viewer.""" + self.index += self.batch + self.index = min(self.index, len(self.todo) - 1) + self.prefetch() + self.update_batch() + + def prev_batch(self) -> None: + """Move to the previous batch of segments and update the viewer.""" + self.index -= self.batch + self.index = max(0, self.index) + self.update_batch() + + def list_segments( + self, index: Optional[int] = None, layer: str = "seg" + ) -> list[int]: + """Get a list of segment IDs for a given index and layer. + + Args: + index: Index of segments to list. + layer: Layer name to list the segments from. + + Returns: + List of segment IDs. + """ + if index is None: + index = self.index + return list( + set( + itertools.chain( + *[x[layer] for x in self.todo[index : index + self.batch]] + ) + ) + ) + + def custom_msg(self) -> str: + """Generate a custom message for the current state. + + Returns: + A custom message string. + """ + return "" + + def update_batch(self) -> None: + """Update the segments displayed in the viewer based on batch settings.""" + if self.batch == 1 and self.locations is not None: + loc = self.locations[self.index] + else: + loc = None + + for layer in self.managed_layers: + self.update_segments(self.list_segments(layer=layer), loc, layer=layer) + self.update_msg( + "index:%d/%d batch:%d %s" + % (self.index, len(self.todo), self.batch, self.custom_msg()) + ) + + def prefetch(self) -> None: + """Pre-fetch the segments for smoother navigation in the viewer.""" + prefetch_states = [] + for i in range(self.num_to_prefetch): + idx = self.index + (i + 1) * self.batch + if idx >= len(self.todo): + break + prefetch_state = copy.deepcopy(self.viewer.state) + for layer in self.managed_layers: + prefetch_state.layers[layer].segments = self.list_segments( + idx, layer=layer + ) + prefetch_state.layout = "3d" + if self.locations is not None: + prefetch_state.position = self.locations[idx] + + prefetch_states.append(prefetch_state) + + with self.viewer.config_state.txn() as s: + s.prefetch = [ + neuroglancer.PrefetchState(state=prefetch_state, priority=-i) + for i, prefetch_state in enumerate(prefetch_states) + ] + + def get_cursor_position( + self, action_state: neuroglancer.viewer_config_state.ActionState + ): + """Return coordinates of the cursor position from a neuroglancer action state + + Args: + action_state : Neuroglancer action state + + Returns: + (x, y, z) cursor position + """ + try: + cursor_position = [int(x) for x in action_state.mouse_voxel_coordinates] + except Exception: + self.update_msg("cursor misplaced") + return + + return cursor_position class ObjectReview(Base): - """Base class for rapid (agglomerated) object review. - - To achieve good throughput, smaller objects are usually reviewed in - batches. - """ - - def __init__(self, - objects: Iterable, - bad: list, - num_to_prefetch: int = 10, - locations: Optional[Iterable[tuple[int, int, int]]] = None): - """Constructor. - - Args: - objects: iterable of object IDs or iterables of object IDs. In the latter - case it is assumed that every iterable forms a group of objects to be - agglomerated together. - bad: set in which to store objects or groups of objects flagged as bad. - num_to_prefetch: number of items from `objects` to prefetch - locations: iterable of xyz tuples of length len(objects). If specified, - the cursor will be automatically moved to the location corresponding to - the current object if batch == 1. - """ - super().__init__( - num_to_prefetch=num_to_prefetch, locations=locations, - objects=objects) - self.bad = bad - - self.set_keybindings() - - self.update_batch() - - def set_keybindings(self) -> None: - """Set key bindings for the viewer.""" - self.viewer.actions.add('next-batch', lambda s: self.next_batch()) - self.viewer.actions.add('prev-batch', lambda s: self.prev_batch()) - self.viewer.actions.add('dec-batch', lambda s: self.batch_dec()) - self.viewer.actions.add('inc-batch', lambda s: self.batch_inc()) - self.viewer.actions.add('mark-bad', lambda s: self.mark_bad()) - self.viewer.actions.add('mark-removed-bad', - lambda s: self.mark_removed_bad()) - self.viewer.actions.add('toggle-equiv', lambda s: self.toggle_equiv()) - - with self.viewer.config_state.txn() as s: - s.input_event_bindings.viewer['keyj'] = 'next-batch' - s.input_event_bindings.viewer['keyk'] = 'prev-batch' - s.input_event_bindings.viewer['keym'] = 'dec-batch' - s.input_event_bindings.viewer['keyp'] = 'inc-batch' - s.input_event_bindings.viewer['keyv'] = 'mark-bad' - s.input_event_bindings.viewer['keyt'] = 'toggle-equiv' - s.input_event_bindings.viewer['keya'] = 'mark-removed-bad' - - def custom_msg(self) -> str: - """Construct a custom message for the current state. - - Returns: - A formatted message indicating the number of bad objects. - """ - return 'num_bad: %d' % len(self.bad) - - def mark_bad(self) -> None: - """Mark an object or group of objects as bad. + """Base class for rapid (agglomerated) object review. - If the batch size is greater than 1, the user is prompted to decrease - the batch size. + To achieve good throughput, smaller objects are usually reviewed in + batches. """ - if self.batch > 1: - self.update_msg('decrease batch to 1 to mark objects bad') - return - - sids = self.todo[self.index]['seg'] - if len(sids) == 1: - self.bad.add(list(sids)[0]) - else: - self.bad.add(frozenset(sids)) - - self.update_msg('marked bad: %r' % (sids,)) - self.next_batch() - - def mark_removed_bad(self) -> None: - """From the set of original objects mark those bad that are not displayed. - Update the message with the IDs of the newly marked bad objects. - """ - original = set(self.list_segments()) - new_bad = original - set(self.viewer.state.layers['seg'].segments) - if new_bad: - self.bad |= new_bad - self.update_msg('marked bad: %r' % (new_bad,)) + + def __init__( + self, + objects: Iterable, + bad: list, + num_to_prefetch: int = 10, + locations: Optional[Iterable[tuple[int, int, int]]] = None, + ): + """Constructor. + + Args: + objects: iterable of object IDs or iterables of object IDs. In the latter + case it is assumed that every iterable forms a group of objects to be + agglomerated together. + bad: set in which to store objects or groups of objects flagged as bad. + num_to_prefetch: number of items from `objects` to prefetch + locations: iterable of xyz tuples of length len(objects). If specified, + the cursor will be automatically moved to the location corresponding to + the current object if batch == 1. + """ + super().__init__( + num_to_prefetch=num_to_prefetch, locations=locations, objects=objects + ) + self.bad = bad + + self.set_keybindings() + + self.update_batch() + + def set_keybindings(self) -> None: + """Set key bindings for the viewer.""" + self.viewer.actions.add("next-batch", lambda s: self.next_batch()) + self.viewer.actions.add("prev-batch", lambda s: self.prev_batch()) + self.viewer.actions.add("dec-batch", lambda s: self.batch_dec()) + self.viewer.actions.add("inc-batch", lambda s: self.batch_inc()) + self.viewer.actions.add("mark-bad", lambda s: self.mark_bad()) + self.viewer.actions.add("mark-removed-bad", lambda s: self.mark_removed_bad()) + self.viewer.actions.add("toggle-equiv", lambda s: self.toggle_equiv()) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.viewer["keyj"] = "next-batch" + s.input_event_bindings.viewer["keyk"] = "prev-batch" + s.input_event_bindings.viewer["keym"] = "dec-batch" + s.input_event_bindings.viewer["keyp"] = "inc-batch" + s.input_event_bindings.viewer["keyv"] = "mark-bad" + s.input_event_bindings.viewer["keyt"] = "toggle-equiv" + s.input_event_bindings.viewer["keya"] = "mark-removed-bad" + + def custom_msg(self) -> str: + """Construct a custom message for the current state. + + Returns: + A formatted message indicating the number of bad objects. + """ + return "num_bad: %d" % len(self.bad) + + def mark_bad(self) -> None: + """Mark an object or group of objects as bad. + + If the batch size is greater than 1, the user is prompted to decrease + the batch size. + """ + if self.batch > 1: + self.update_msg("decrease batch to 1 to mark objects bad") + return + + sids = self.todo[self.index]["seg"] + if len(sids) == 1: + self.bad.add(list(sids)[0]) + else: + self.bad.add(frozenset(sids)) + + self.update_msg("marked bad: %r" % (sids,)) + self.next_batch() + + def mark_removed_bad(self) -> None: + """From the set of original objects mark those bad that are not displayed. + Update the message with the IDs of the newly marked bad objects. + """ + original = set(self.list_segments()) + new_bad = original - set(self.viewer.state.layers["seg"].segments) + if new_bad: + self.bad |= new_bad + self.update_msg("marked bad: %r" % (new_bad,)) class ObjectReviewStoreLocation(ObjectReview): - """Class to mark and store locations of errors in the segmentation - - To mark a merger, move the cursor to a spot of the false merger and press 'w'. - Then, move the cursor to a spot within the object that should belong to a - separate object and press 'shift + W'. Yellow point annotations indicate the - merger. For split errors, proceed in similar manner but press 'd' and - 'shift + D', which will display blue annotations. - Marked locations can be deleted either by pressing 'ctrl + Z' (to delete the - last marked location) or by hovering the cursor over one of the point - annotations and pressing 'ctrl + v'. - - Attributes: - seg_error_coordinates: A mapping of annotation identifier substrings to - error coordinate pairs. - Example: {'m0': [[x1,y1,z1],[x2,y2,z2]], 's0':[[x1,y1,z1],[x2,y2,z2]], ...} - - Keys starting with 'm' indicate merge errors. - - Keys starting with 's' indicate split errors. - temp_coord_list: Temporary storage for coordinates. - """ - - def __init__(self, - objects: list, - bad: list, - seg_error_coordinates: Optional[ - list[str, list[list[int]]]] = {}, - load_annotations: bool = False) -> None: - """Initialize the ObjectReviewStoreLocation class. - - Args: - objects: A list of objects. - bad: A list of bad objects or markers. - seg_error_coordinates: A dictionary of error coordinates. - load_annotations: A flag to indicate if annotations should be loaded. + """Class to mark and store locations of errors in the segmentation + + To mark a merger, move the cursor to a spot of the false merger and press 'w'. + Then, move the cursor to a spot within the object that should belong to a + separate object and press 'shift + W'. Yellow point annotations indicate the + merger. For split errors, proceed in similar manner but press 'd' and + 'shift + D', which will display blue annotations. + Marked locations can be deleted either by pressing 'ctrl + Z' (to delete the + last marked location) or by hovering the cursor over one of the point + annotations and pressing 'ctrl + v'. + + Attributes: + seg_error_coordinates: A mapping of annotation identifier substrings to + error coordinate pairs. + Example: {'m0': [[x1,y1,z1],[x2,y2,z2]], 's0':[[x1,y1,z1],[x2,y2,z2]], ...} + - Keys starting with 'm' indicate merge errors. + - Keys starting with 's' indicate split errors. + temp_coord_list: Temporary storage for coordinates. """ - super(ObjectReviewStoreLocation, self).__init__(objects, bad) - self.seg_error_coordinates = seg_error_coordinates - if load_annotations and seg_error_coordinates: - for k, v in seg_error_coordinates.items(): - self.annotate_error_locations(v, k) - self.temp_coord_list = [] - - def set_keybindings(self) -> None: - """Set key bindings for the viewer.""" - super().set_keybindings() - self.viewer.actions.add('merge0', - lambda s: self.store_error_location(s, index=0, - mode='merger')) - self.viewer.actions.add('merge1', - lambda s: self.store_error_location(s, index=1, - mode='merger')) - self.viewer.actions.add('split0', - lambda s: self.store_error_location(s, index=0, - mode='split')) - self.viewer.actions.add('split1', - lambda s: self.store_error_location(s, index=1, - mode='split')) - self.viewer.actions.add('delete_from_annotation', - self.delete_location_from_annotation) - self.viewer.actions.add('delete_last_entry', - lambda s: self.delete_last_location()) - - with self.viewer.config_state.txn() as s: - s.input_event_bindings.viewer['keyw'] = 'merge0' - s.input_event_bindings.viewer['shift+keyw'] = 'merge1' - s.input_event_bindings.viewer['keyd'] = 'split0' - s.input_event_bindings.viewer['shift+keyd'] = 'split1' - s.input_event_bindings.viewer[ - 'control+keyv'] = 'delete_from_annotation' - s.input_event_bindings.viewer[ - 'control+keyz'] = 'delete_last_entry' - - def get_id(self, mode: str) -> str: - """Generate a unique identifier for an error based on its type. - - Args: - mode: Error type, either 'merge' or 'split'. - - Returns: - A unique identifier string. - """ - id_ = mode[0] - if any(self.seg_error_coordinates): - counter = int( - max([x[1:] for x in self.seg_error_coordinates.keys()])) + 1 - else: - counter = 0 - id_ = id_ + str(counter) - return id_ - - def store_error_location(self, - action_state: neuroglancer.viewer_config_state.ActionState, - mode: str, - index: int = 0) -> None: - """Store error locations. - - Args: - action_state: State of the viewer during the action. - mode: Type of the error ('merger' or 'split'). - index: Indicates if it's the first or second coordinate (0 or 1). - """ - location = self.get_cursor_position(action_state) - if location is None: - return - - if index == 1 and not self.temp_coord_list: - self.update_msg('You have not entered a first coord yet') - return - - if index == 0 and self.temp_coord_list: - self.temp_coord_list = [] - self.temp_coord_list.append(location) - - if index == 1: - if self.temp_coord_list[0] == self.temp_coord_list[1]: - self.update_msg( - 'You entered the same coordinate twice. Try again!') + def __init__( + self, + objects: list, + bad: list, + seg_error_coordinates: Optional[list[str, list[list[int]]]] = {}, + load_annotations: bool = False, + ) -> None: + """Initialize the ObjectReviewStoreLocation class. + + Args: + objects: A list of objects. + bad: A list of bad objects or markers. + seg_error_coordinates: A dictionary of error coordinates. + load_annotations: A flag to indicate if annotations should be loaded. + """ + super(ObjectReviewStoreLocation, self).__init__(objects, bad) + self.seg_error_coordinates = seg_error_coordinates + if load_annotations and seg_error_coordinates: + for k, v in seg_error_coordinates.items(): + self.annotate_error_locations(v, k) self.temp_coord_list = [] - return - - identifier = self.get_id(mode=mode) - self.seg_error_coordinates.update( - {identifier: self.temp_coord_list}) - self.annotate_error_locations(self.temp_coord_list, identifier) - self.temp_coord_list = [] - - def annotate_error_locations(self, - coordinate_lst: list[list[int]], - id_: str) -> None: - """Annotate the error locations in the viewer. - - Args: - coordinate_lst: List of coordinates to be annotated. - id_: Unique identifier for the error. - """ - for i, coord in enumerate(coordinate_lst): - annotation_id = id_ + f'_{i}' - self.mk_point_annotation(coord, annotation_id) - - def mk_point_annotation(self, - coordinate: list[int], - annotation_id: str) -> None: - """Create a point annotation in the viewer. - - Args: - coordinate: 3D coordinate of the annotation point. - annotation_id: Unique identifier for the annotation. - """ - if annotation_id.startswith('m'): - color = '#fae505' - else: - color = '#05f2fa' - annotation = neuroglancer.PointAnnotation(id=annotation_id, - point=coordinate, - props=[color]) - with self.viewer.txn() as s: - annotations = s.layers['annotation'].annotations - annotations.append(annotation) - - def get_annotation_id(self, - action_state: neuroglancer.viewer_config_state.ActionState) -> \ - Optional[str]: - """Retrieve the ID of a selected annotation. - - Args: - action_state: neuroglancer.viewer_config_state.ActionState. - - Returns: - The selected object's ID or None if retrieval fails. - """ - try: - selection_state = action_state.selected_values[ - 'annotation'].to_json() - selected_object = selection_state['annotationId'] - except Exception: - self.update_msg('Could not retrieve annotation id') - return - - return selected_object - - def delete_location_from_annotation(self, - action_state: neuroglancer.viewer_config_state.ActionState) -> None: - """Delete the error location pair associated with the annotation at the cursor position - - Args: - action_state: State of the viewer during the action. - """ - id_ = self.get_annotation_id(action_state) - if id_ is None: - return - target_key = id_[:2] - del self.seg_error_coordinates[target_key] - - to_remove = [target_key + '_0', target_key + '_1'] - self.delete_annotation(to_remove) - - def delete_annotation(self, to_remove: list[str]) -> None: - """Delete specified annotations from the viewer. - - Args: - to_remove: list of annotation IDs to be removed. - """ - with self.viewer.txn() as s: - annotations = s.layers['annotation'].annotations - annotations = [a for a in annotations if a.id not in to_remove] - s.layers['annotation'].annotations = annotations - - def delete_last_location(self): - """Delete the last error location pair tagged.""" - last_key = list(self.seg_error_coordinates.keys())[-1] - del self.seg_error_coordinates[last_key] - - to_remove = [last_key + '_0', last_key + '_1'] - self.delete_annotation(to_remove) + def set_keybindings(self) -> None: + """Set key bindings for the viewer.""" + super().set_keybindings() + self.viewer.actions.add( + "merge0", lambda s: self.store_error_location(s, index=0, mode="merger") + ) + self.viewer.actions.add( + "merge1", lambda s: self.store_error_location(s, index=1, mode="merger") + ) + self.viewer.actions.add( + "split0", lambda s: self.store_error_location(s, index=0, mode="split") + ) + self.viewer.actions.add( + "split1", lambda s: self.store_error_location(s, index=1, mode="split") + ) + self.viewer.actions.add( + "delete_from_annotation", self.delete_location_from_annotation + ) + self.viewer.actions.add( + "delete_last_entry", lambda s: self.delete_last_location() + ) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.viewer["keyw"] = "merge0" + s.input_event_bindings.viewer["shift+keyw"] = "merge1" + s.input_event_bindings.viewer["keyd"] = "split0" + s.input_event_bindings.viewer["shift+keyd"] = "split1" + s.input_event_bindings.viewer["control+keyv"] = "delete_from_annotation" + s.input_event_bindings.viewer["control+keyz"] = "delete_last_entry" + + def get_id(self, mode: str) -> str: + """Generate a unique identifier for an error based on its type. + + Args: + mode: Error type, either 'merge' or 'split'. + + Returns: + A unique identifier string. + """ + id_ = mode[0] + if any(self.seg_error_coordinates): + counter = int(max([x[1:] for x in self.seg_error_coordinates.keys()])) + 1 + else: + counter = 0 + id_ = id_ + str(counter) + return id_ + + def store_error_location( + self, + action_state: neuroglancer.viewer_config_state.ActionState, + mode: str, + index: int = 0, + ) -> None: + """Store error locations. + + Args: + action_state: State of the viewer during the action. + mode: Type of the error ('merger' or 'split'). + index: Indicates if it's the first or second coordinate (0 or 1). + """ + location = self.get_cursor_position(action_state) + if location is None: + return + + if index == 1 and not self.temp_coord_list: + self.update_msg("You have not entered a first coord yet") + return + + if index == 0 and self.temp_coord_list: + self.temp_coord_list = [] + + self.temp_coord_list.append(location) + + if index == 1: + if self.temp_coord_list[0] == self.temp_coord_list[1]: + self.update_msg("You entered the same coordinate twice. Try again!") + self.temp_coord_list = [] + return + + identifier = self.get_id(mode=mode) + self.seg_error_coordinates.update({identifier: self.temp_coord_list}) + self.annotate_error_locations(self.temp_coord_list, identifier) + self.temp_coord_list = [] + + def annotate_error_locations( + self, coordinate_lst: list[list[int]], id_: str + ) -> None: + """Annotate the error locations in the viewer. + + Args: + coordinate_lst: List of coordinates to be annotated. + id_: Unique identifier for the error. + """ + for i, coord in enumerate(coordinate_lst): + annotation_id = id_ + f"_{i}" + self.mk_point_annotation(coord, annotation_id) + + def mk_point_annotation(self, coordinate: list[int], annotation_id: str) -> None: + """Create a point annotation in the viewer. + + Args: + coordinate: 3D coordinate of the annotation point. + annotation_id: Unique identifier for the annotation. + """ + if annotation_id.startswith("m"): + color = "#fae505" + else: + color = "#05f2fa" + annotation = neuroglancer.PointAnnotation( + id=annotation_id, point=coordinate, props=[color] + ) + with self.viewer.txn() as s: + annotations = s.layers["annotation"].annotations + annotations.append(annotation) + + def get_annotation_id( + self, action_state: neuroglancer.viewer_config_state.ActionState + ) -> Optional[str]: + """Retrieve the ID of a selected annotation. + + Args: + action_state: neuroglancer.viewer_config_state.ActionState. + + Returns: + The selected object's ID or None if retrieval fails. + """ + try: + selection_state = action_state.selected_values["annotation"].to_json() + selected_object = selection_state["annotationId"] + except Exception: + self.update_msg("Could not retrieve annotation id") + return + + return selected_object + + def delete_location_from_annotation( + self, action_state: neuroglancer.viewer_config_state.ActionState + ) -> None: + """Delete the error location pair associated with the annotation at the cursor position + + Args: + action_state: State of the viewer during the action. + """ + id_ = self.get_annotation_id(action_state) + if id_ is None: + return + + target_key = id_[:2] + del self.seg_error_coordinates[target_key] + + to_remove = [target_key + "_0", target_key + "_1"] + self.delete_annotation(to_remove) + + def delete_annotation(self, to_remove: list[str]) -> None: + """Delete specified annotations from the viewer. + + Args: + to_remove: list of annotation IDs to be removed. + """ + with self.viewer.txn() as s: + annotations = s.layers["annotation"].annotations + annotations = [a for a in annotations if a.id not in to_remove] + s.layers["annotation"].annotations = annotations + + def delete_last_location(self): + """Delete the last error location pair tagged.""" + last_key = list(self.seg_error_coordinates.keys())[-1] + del self.seg_error_coordinates[last_key] + + to_remove = [last_key + "_0", last_key + "_1"] + self.delete_annotation(to_remove) class ObjectClassification(Base): - """Base class for object classification.""" + """Base class for object classification.""" - def __init__(self, objects, key_to_class, num_to_prefetch=10, - locations=None): - """Constructor. + def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None): + """Constructor. - Args: - objects: iterable of object IDs - key_to_class: dict mapping keys to class labels - num_to_prefetch: number of `objects` to prefetch - """ - super().__init__( - num_to_prefetch=num_to_prefetch, locations=locations, - objects=objects) + Args: + objects: iterable of object IDs + key_to_class: dict mapping keys to class labels + num_to_prefetch: number of `objects` to prefetch + """ + super().__init__( + num_to_prefetch=num_to_prefetch, locations=locations, objects=objects + ) - self.results = defaultdict(set) # class -> ids + self.results = defaultdict(set) # class -> ids - self.viewer.actions.add('mr-next-batch', lambda s: self.next_batch()) - self.viewer.actions.add('mr-prev-batch', lambda s: self.prev_batch()) - self.viewer.actions.add('unclassify', lambda s: self.classify(None)) + self.viewer.actions.add("mr-next-batch", lambda s: self.next_batch()) + self.viewer.actions.add("mr-prev-batch", lambda s: self.prev_batch()) + self.viewer.actions.add("unclassify", lambda s: self.classify(None)) - for key, cls in key_to_class.items(): - self.viewer.actions.add( - 'classify-%s' % cls, lambda s, cls=cls: self.classify(cls)) + for key, cls in key_to_class.items(): + self.viewer.actions.add( + "classify-%s" % cls, lambda s, cls=cls: self.classify(cls) + ) - with self.viewer.config_state.txn() as s: - for key, cls in key_to_class.items(): - s.input_event_bindings.viewer[ - 'key%s' % key] = 'classify-%s' % cls + with self.viewer.config_state.txn() as s: + for key, cls in key_to_class.items(): + s.input_event_bindings.viewer["key%s" % key] = "classify-%s" % cls - # Navigation without classification. - s.input_event_bindings.viewer['keyj'] = 'mr-next-batch' - s.input_event_bindings.viewer['keyk'] = 'mr-prev-batch' - s.input_event_bindings.viewer['keyv'] = 'unclassify' + # Navigation without classification. + s.input_event_bindings.viewer["keyj"] = "mr-next-batch" + s.input_event_bindings.viewer["keyk"] = "mr-prev-batch" + s.input_event_bindings.viewer["keyv"] = "unclassify" - self.update_batch() + self.update_batch() - def custom_msg(self): - return ' '.join('%s:%d' % (k, len(v)) for k, v in self.results.items()) + def custom_msg(self): + return " ".join("%s:%d" % (k, len(v)) for k, v in self.results.items()) - def classify(self, cls): - sid = list(self.todo[self.index]['seg'])[0] - for v in self.results.values(): - v -= set([sid]) + def classify(self, cls): + sid = list(self.todo[self.index]["seg"])[0] + for v in self.results.values(): + v -= set([sid]) - if cls is not None: - self.results[cls].add(sid) + if cls is not None: + self.results[cls].add(sid) - self.next_batch() + self.next_batch() class GraphUpdater(Base): - """Base class for agglomeration graph modification. - - Usage: - * splitting - 1) select merged objects (start with a supervoxel, then press 'c') - 2) shift-click on two supervoxels that should be separated; a new layer - will be displayed showing the supervoxels along the shortest path - between selected objects - 3) use '[' and ']' to restrict the path so that the displayed supervoxels - are not wrongly merged - 4) press 's' to remove the edge next to the last shown one from the - agglomeration graph - - * merging - 1) select segments to be merged - 2) press 'm' - - Press 'c' to add any supervoxels connected to the ones currently displayed - (according to the current state of the agglomeraton graph). - """ - - def __init__(self, graph, objects, bad, num_to_prefetch=0): - super().__init__(objects=objects, num_to_prefetch=num_to_prefetch) - self.graph = graph - self.split_objects = [] - self.split_path = [] - self.split_index = 1 - self.sem = threading.Semaphore() - - self.bad = bad - self.viewer.actions.add('add-ccs', lambda s: self.add_ccs()) - self.viewer.actions.add('clear-splits', lambda s: self.clear_splits()) - self.viewer.actions.add('add-split', self.add_split) - self.viewer.actions.add('accept-split', lambda s: self.accept_split()) - self.viewer.actions.add('split-inc', lambda s: self.inc_split()) - self.viewer.actions.add('split-dec', lambda s: self.dec_split()) - self.viewer.actions.add('merge-segments', - lambda s: self.merge_segments()) - self.viewer.actions.add('mark-bad', lambda s: self.mark_bad()) - self.viewer.actions.add('next-batch', lambda s: self.next_batch()) - self.viewer.actions.add('prev-batch', lambda s: self.prev_batch()) - - with self.viewer.config_state.txn() as s: - s.input_event_bindings.viewer['keyj'] = 'next-batch' - s.input_event_bindings.viewer['keyk'] = 'prev-batch' - s.input_event_bindings.viewer['keyc'] = 'add-ccs' - s.input_event_bindings.viewer['keya'] = 'clear-splits' - s.input_event_bindings.viewer['keym'] = 'merge-segments' - s.input_event_bindings.viewer['shift+bracketleft'] = 'split-dec' - s.input_event_bindings.viewer['shift+bracketright'] = 'split-inc' - s.input_event_bindings.viewer['keys'] = 'accept-split' - s.input_event_bindings.data_view['shift+mousedown0'] = 'add-split' - s.input_event_bindings.viewer['keyv'] = 'mark-bad' - - with self.viewer.txn() as s: - s.layers['split'] = neuroglancer.SegmentationLayer( - source=s.layers['seg'].source) - s.layers['split'].visible = False - - def merge_segments(self): - sids = [sid for sid in self.viewer.state.layers['seg'].segments if - sid > 0] - self.graph.add_edges_from(zip(sids, sids[1:])) - - def update_split(self): - s = copy.deepcopy(self.viewer.state) - s.layers['split'].segments = list(self.split_path)[:self.split_index] - self.viewer.set_state(s) - - def inc_split(self): - self.split_index = min(len(self.split_path), self.split_index + 1) - self.update_split() - - def dec_split(self): - self.split_index = max(1, self.split_index - 1) - self.update_split() - - def add_ccs(self): - if self.sem.acquire(blocking=False): - curr = set(self.viewer.state.layers['seg'].segments) - for sid in self.viewer.state.layers['seg'].segments: - if sid in self.graph: - curr |= set(nx.node_connected_component(self.graph, sid)) - - self.update_segments(curr) - self.sem.release() - - def accept_split(self): - edge = self.split_path[self.split_index - 1:self.split_index + 1] - if len(edge) < 2: - return - - self.graph.remove_edge(edge[0], edge[1]) - self.clear_splits() - - def clear_splits(self): - self.split_objects = [] - self.update_msg('splits cleared') - - s = copy.deepcopy(self.viewer.state) - s.layers['split'].visible = False - s.layers['seg'].visible = True - self.viewer.set_state(s) - - def start_split(self): - self.split_path = nx.shortest_path(self.graph, self.split_objects[0], - self.split_objects[1]) - self.split_index = 1 - self.update_msg( - 'splitting: %s' % ('-'.join(str(x) for x in self.split_path))) - - s = copy.deepcopy(self.viewer.state) - s.layers['seg'].visible = False - s.layers['split'].visible = True - self.viewer.set_state(s) - self.update_split() - - def add_split(self, s): - if len(self.split_objects) < 2: - self.split_objects.append(s.selected_values['seg'].value) - self.update_msg( - 'split: %s' % (':'.join(str(x) for x in self.split_objects))) - - if len(self.split_objects) == 2: - self.start_split() - - def mark_bad(self): - if self.batch > 1: - self.update_msg('decrease batch to 1 to mark objects bad') - return - - sids = self.todo[self.index]['seg'] - if len(sids) == 1: - self.bad.add(list(sids)[0]) - else: - self.bad.add(frozenset(sids)) - - self.update_msg('marked bad: %r' % (sids,)) - self.next_batch() + """Base class for agglomeration graph modification. + + Usage: + * splitting + 1) select merged objects (start with a supervoxel, then press 'c') + 2) shift-click on two supervoxels that should be separated; a new layer + will be displayed showing the supervoxels along the shortest path + between selected objects + 3) use '[' and ']' to restrict the path so that the displayed supervoxels + are not wrongly merged + 4) press 's' to remove the edge next to the last shown one from the + agglomeration graph + + * merging + 1) select segments to be merged + 2) press 'm' + + Press 'c' to add any supervoxels connected to the ones currently displayed + (according to the current state of the agglomeraton graph). + """ + + def __init__(self, graph, objects, bad, num_to_prefetch=0): + super().__init__(objects=objects, num_to_prefetch=num_to_prefetch) + self.graph = graph + self.split_objects = [] + self.split_path = [] + self.split_index = 1 + self.sem = threading.Semaphore() + + self.bad = bad + self.viewer.actions.add("add-ccs", lambda s: self.add_ccs()) + self.viewer.actions.add("clear-splits", lambda s: self.clear_splits()) + self.viewer.actions.add("add-split", self.add_split) + self.viewer.actions.add("accept-split", lambda s: self.accept_split()) + self.viewer.actions.add("split-inc", lambda s: self.inc_split()) + self.viewer.actions.add("split-dec", lambda s: self.dec_split()) + self.viewer.actions.add("merge-segments", lambda s: self.merge_segments()) + self.viewer.actions.add("mark-bad", lambda s: self.mark_bad()) + self.viewer.actions.add("next-batch", lambda s: self.next_batch()) + self.viewer.actions.add("prev-batch", lambda s: self.prev_batch()) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.viewer["keyj"] = "next-batch" + s.input_event_bindings.viewer["keyk"] = "prev-batch" + s.input_event_bindings.viewer["keyc"] = "add-ccs" + s.input_event_bindings.viewer["keya"] = "clear-splits" + s.input_event_bindings.viewer["keym"] = "merge-segments" + s.input_event_bindings.viewer["shift+bracketleft"] = "split-dec" + s.input_event_bindings.viewer["shift+bracketright"] = "split-inc" + s.input_event_bindings.viewer["keys"] = "accept-split" + s.input_event_bindings.data_view["shift+mousedown0"] = "add-split" + s.input_event_bindings.viewer["keyv"] = "mark-bad" + + with self.viewer.txn() as s: + s.layers["split"] = neuroglancer.SegmentationLayer( + source=s.layers["seg"].source + ) + s.layers["split"].visible = False + + def merge_segments(self): + sids = [sid for sid in self.viewer.state.layers["seg"].segments if sid > 0] + self.graph.add_edges_from(zip(sids, sids[1:])) + + def update_split(self): + s = copy.deepcopy(self.viewer.state) + s.layers["split"].segments = list(self.split_path)[: self.split_index] + self.viewer.set_state(s) + + def inc_split(self): + self.split_index = min(len(self.split_path), self.split_index + 1) + self.update_split() + + def dec_split(self): + self.split_index = max(1, self.split_index - 1) + self.update_split() + + def add_ccs(self): + if self.sem.acquire(blocking=False): + curr = set(self.viewer.state.layers["seg"].segments) + for sid in self.viewer.state.layers["seg"].segments: + if sid in self.graph: + curr |= set(nx.node_connected_component(self.graph, sid)) + + self.update_segments(curr) + self.sem.release() + + def accept_split(self): + edge = self.split_path[self.split_index - 1 : self.split_index + 1] + if len(edge) < 2: + return + + self.graph.remove_edge(edge[0], edge[1]) + self.clear_splits() + + def clear_splits(self): + self.split_objects = [] + self.update_msg("splits cleared") + + s = copy.deepcopy(self.viewer.state) + s.layers["split"].visible = False + s.layers["seg"].visible = True + self.viewer.set_state(s) + + def start_split(self): + self.split_path = nx.shortest_path( + self.graph, self.split_objects[0], self.split_objects[1] + ) + self.split_index = 1 + self.update_msg("splitting: %s" % ("-".join(str(x) for x in self.split_path))) + + s = copy.deepcopy(self.viewer.state) + s.layers["seg"].visible = False + s.layers["split"].visible = True + self.viewer.set_state(s) + self.update_split() + + def add_split(self, s): + if len(self.split_objects) < 2: + self.split_objects.append(s.selected_values["seg"].value) + self.update_msg("split: %s" % (":".join(str(x) for x in self.split_objects))) + + if len(self.split_objects) == 2: + self.start_split() + + def mark_bad(self): + if self.batch > 1: + self.update_msg("decrease batch to 1 to mark objects bad") + return + + sids = self.todo[self.index]["seg"] + if len(sids) == 1: + self.bad.add(list(sids)[0]) + else: + self.bad.add(frozenset(sids)) + + self.update_msg("marked bad: %r" % (sids,)) + self.next_batch() From e5d48b9ff963489af772bf91db9812c57446c934 Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Tue, 22 Aug 2023 09:10:22 +0200 Subject: [PATCH 3/8] Update proofreading.py reformat with pyink from toml --- ffn/utils/proofreading.py | 1346 +++++++++++++++++++------------------ 1 file changed, 675 insertions(+), 671 deletions(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index e9f03de..dc7e283 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -28,709 +28,713 @@ class Base: - """Base class for proofreading workflows. - - To use, define a subclass overriding the `set_init_state` method to provide - initial Neuroglancer settings. The segmentation volume needs to be called - `seg`. + """Base class for proofreading workflows. + + To use, define a subclass overriding the `set_init_state` method to provide + initial Neuroglancer settings. The segmentation volume needs to be called + `seg`. + """ + + def __init__( + self, + num_to_prefetch: int = 10, + locations: Optional[Iterable[tuple[int, int, int]]] = None, + objects: Optional[Union[dict[str, Any], Iterable[int]]] = None, + ): + """Initializes the Base class for proofreading. + + Args: + num_to_prefetch: Number of items to prefetch. + locations: List of xyz coordinates corresponding to object locations. + objects: Object IDs or a dictionary mapping layer names to object IDs. """ - - def __init__( - self, - num_to_prefetch: int = 10, - locations: Optional[Iterable[tuple[int, int, int]]] = None, - objects: Optional[Union[dict[str, Any], Iterable[int]]] = None, - ): - """Initializes the Base class for proofreading. - - Args: - num_to_prefetch: Number of items to prefetch. - locations: List of xyz coordinates corresponding to object locations. - objects: Object IDs or a dictionary mapping layer names to object IDs. - """ - self.viewer = neuroglancer.Viewer() - self.num_to_prefetch = num_to_prefetch - - self.managed_layers = set(["seg"]) - self.todo = [] # items are maps from layer name to lists of segment IDs - if objects is not None: - self._set_todo(objects) - - self.index = 0 - self.batch = 1 - self.apply_equivs = False - - if locations is not None: - self.locations = list(locations) - assert len(self.todo) == len(locations) - else: - self.locations = None - - self.set_init_state() - - def _set_todo(self, objects: Union[list[str, Any], Iterable[int]]) -> None: - """Private method to set the todo list.""" - for o in objects: - if isinstance(o, collections.abc.Mapping): - self.todo.append(o) - self.managed_layers |= set(o.keys()) - elif isinstance(o, collections.abc.Iterable): - self.todo.append({"seg": o}) - else: - self.todo.append({"seg": [o]}) - - def set_init_state(self) -> None: - """Sets the initial state for Neuroglancer. - Subclasses should override this method. - """ - raise NotImplementedError() - - def update_msg(self, msg: str) -> None: - """Updates the status message in Neuroglancer viewer.""" - with self.viewer.config_state.txn() as s: - s.status_messages["status"] = msg - - def update_segments( - self, - segments: list[int], - loc: Optional[tuple[int, int, int]] = None, - layer: str = "seg", - ) -> None: - """Updates segments in Neuroglancer viewer. - - Args: - segments: List of segment IDs to update. - loc: 3D coordinates to set the viewer to. - layer: Layer name in Neuroglancer to be updated. - """ - s = copy.deepcopy(self.viewer.state) - l = s.layers[layer] - l.segments = segments - - if not self.apply_equivs: - l.equivalences.clear() - else: - l.equivalences.clear() - for a in self.todo[self.index : self.index + self.batch]: - a = [aa[layer] for aa in a] - l.equivalences.union(*a) - - if loc is not None: - s.position = loc - - self.viewer.set_state(s) - - def toggle_equiv(self) -> None: - """Toggle the apply equivalence flag and update the batch.""" - self.apply_equivs = not self.apply_equivs - self.update_batch() - - def batch_dec(self) -> None: - """Decrease the batch size by half and update the batch.""" - self.batch //= 2 - self.batch = max(self.batch, 1) - self.update_batch() - - def batch_inc(self) -> None: - """Increase the batch size by double and update the batch.""" - self.batch *= 2 - self.update_batch() - - def next_batch(self) -> None: - """Move to the next batch of segments and update the viewer.""" - self.index += self.batch - self.index = min(self.index, len(self.todo) - 1) - self.prefetch() - self.update_batch() - - def prev_batch(self) -> None: - """Move to the previous batch of segments and update the viewer.""" - self.index -= self.batch - self.index = max(0, self.index) - self.update_batch() - - def list_segments( - self, index: Optional[int] = None, layer: str = "seg" - ) -> list[int]: - """Get a list of segment IDs for a given index and layer. - - Args: - index: Index of segments to list. - layer: Layer name to list the segments from. - - Returns: - List of segment IDs. - """ - if index is None: - index = self.index - return list( - set( - itertools.chain( - *[x[layer] for x in self.todo[index : index + self.batch]] - ) + self.viewer = neuroglancer.Viewer() + self.num_to_prefetch = num_to_prefetch + + self.managed_layers = set(["seg"]) + self.todo = [] # items are maps from layer name to lists of segment IDs + if objects is not None: + self._set_todo(objects) + + self.index = 0 + self.batch = 1 + self.apply_equivs = False + + if locations is not None: + self.locations = list(locations) + assert len(self.todo) == len(locations) + else: + self.locations = None + + self.set_init_state() + + def _set_todo(self, objects: Union[list[str, Any], Iterable[int]]) -> None: + """Private method to set the todo list.""" + for o in objects: + if isinstance(o, collections.abc.Mapping): + self.todo.append(o) + self.managed_layers |= set(o.keys()) + elif isinstance(o, collections.abc.Iterable): + self.todo.append({"seg": o}) + else: + self.todo.append({"seg": [o]}) + + def set_init_state(self) -> None: + """Sets the initial state for Neuroglancer. + Subclasses should override this method. + """ + raise NotImplementedError() + + def update_msg(self, msg: str) -> None: + """Updates the status message in Neuroglancer viewer.""" + with self.viewer.config_state.txn() as s: + s.status_messages["status"] = msg + + def update_segments( + self, + segments: list[int], + loc: Optional[tuple[int, int, int]] = None, + layer: str = "seg", + ) -> None: + """Updates segments in Neuroglancer viewer. + + Args: + segments: List of segment IDs to update. + loc: 3D coordinates to set the viewer to. + layer: Layer name in Neuroglancer to be updated. + """ + s = copy.deepcopy(self.viewer.state) + l = s.layers[layer] + l.segments = segments + + if not self.apply_equivs: + l.equivalences.clear() + else: + l.equivalences.clear() + for a in self.todo[self.index : self.index + self.batch]: + a = [aa[layer] for aa in a] + l.equivalences.union(*a) + + if loc is not None: + s.position = loc + + self.viewer.set_state(s) + + def toggle_equiv(self) -> None: + """Toggle the apply equivalence flag and update the batch.""" + self.apply_equivs = not self.apply_equivs + self.update_batch() + + def batch_dec(self) -> None: + """Decrease the batch size by half and update the batch.""" + self.batch //= 2 + self.batch = max(self.batch, 1) + self.update_batch() + + def batch_inc(self) -> None: + """Increase the batch size by double and update the batch.""" + self.batch *= 2 + self.update_batch() + + def next_batch(self) -> None: + """Move to the next batch of segments and update the viewer.""" + self.index += self.batch + self.index = min(self.index, len(self.todo) - 1) + self.prefetch() + self.update_batch() + + def prev_batch(self) -> None: + """Move to the previous batch of segments and update the viewer.""" + self.index -= self.batch + self.index = max(0, self.index) + self.update_batch() + + def list_segments( + self, index: Optional[int] = None, layer: str = "seg" + ) -> list[int]: + """Get a list of segment IDs for a given index and layer. + + Args: + index: Index of segments to list. + layer: Layer name to list the segments from. + + Returns: + List of segment IDs. + """ + if index is None: + index = self.index + return list( + set( + itertools.chain( + *[x[layer] for x in self.todo[index : index + self.batch]] ) ) + ) + + def custom_msg(self) -> str: + """Generate a custom message for the current state. - def custom_msg(self) -> str: - """Generate a custom message for the current state. - - Returns: - A custom message string. - """ - return "" - - def update_batch(self) -> None: - """Update the segments displayed in the viewer based on batch settings.""" - if self.batch == 1 and self.locations is not None: - loc = self.locations[self.index] - else: - loc = None - - for layer in self.managed_layers: - self.update_segments(self.list_segments(layer=layer), loc, layer=layer) - self.update_msg( - "index:%d/%d batch:%d %s" - % (self.index, len(self.todo), self.batch, self.custom_msg()) + Returns: + A custom message string. + """ + return "" + + def update_batch(self) -> None: + """Update the segments displayed in the viewer based on batch settings.""" + if self.batch == 1 and self.locations is not None: + loc = self.locations[self.index] + else: + loc = None + + for layer in self.managed_layers: + self.update_segments(self.list_segments(layer=layer), loc, layer=layer) + self.update_msg( + "index:%d/%d batch:%d %s" + % (self.index, len(self.todo), self.batch, self.custom_msg()) + ) + + def prefetch(self) -> None: + """Pre-fetch the segments for smoother navigation in the viewer.""" + prefetch_states = [] + for i in range(self.num_to_prefetch): + idx = self.index + (i + 1) * self.batch + if idx >= len(self.todo): + break + prefetch_state = copy.deepcopy(self.viewer.state) + for layer in self.managed_layers: + prefetch_state.layers[layer].segments = self.list_segments( + idx, layer=layer ) + prefetch_state.layout = "3d" + if self.locations is not None: + prefetch_state.position = self.locations[idx] - def prefetch(self) -> None: - """Pre-fetch the segments for smoother navigation in the viewer.""" - prefetch_states = [] - for i in range(self.num_to_prefetch): - idx = self.index + (i + 1) * self.batch - if idx >= len(self.todo): - break - prefetch_state = copy.deepcopy(self.viewer.state) - for layer in self.managed_layers: - prefetch_state.layers[layer].segments = self.list_segments( - idx, layer=layer - ) - prefetch_state.layout = "3d" - if self.locations is not None: - prefetch_state.position = self.locations[idx] - - prefetch_states.append(prefetch_state) - - with self.viewer.config_state.txn() as s: - s.prefetch = [ - neuroglancer.PrefetchState(state=prefetch_state, priority=-i) - for i, prefetch_state in enumerate(prefetch_states) - ] - - def get_cursor_position( - self, action_state: neuroglancer.viewer_config_state.ActionState - ): - """Return coordinates of the cursor position from a neuroglancer action state - - Args: - action_state : Neuroglancer action state - - Returns: - (x, y, z) cursor position - """ - try: - cursor_position = [int(x) for x in action_state.mouse_voxel_coordinates] - except Exception: - self.update_msg("cursor misplaced") - return - - return cursor_position + prefetch_states.append(prefetch_state) + with self.viewer.config_state.txn() as s: + s.prefetch = [ + neuroglancer.PrefetchState(state=prefetch_state, priority=-i) + for i, prefetch_state in enumerate(prefetch_states) + ] -class ObjectReview(Base): - """Base class for rapid (agglomerated) object review. + def get_cursor_position( + self, action_state: neuroglancer.viewer_config_state.ActionState + ): + """Return coordinates of the cursor position from a neuroglancer action state + + Args: + action_state : Neuroglancer action state - To achieve good throughput, smaller objects are usually reviewed in - batches. + Returns: + (x, y, z) cursor position """ + try: + cursor_position = [int(x) for x in action_state.mouse_voxel_coordinates] + except Exception: + self.update_msg("cursor misplaced") + return - def __init__( - self, - objects: Iterable, - bad: list, - num_to_prefetch: int = 10, - locations: Optional[Iterable[tuple[int, int, int]]] = None, - ): - """Constructor. - - Args: - objects: iterable of object IDs or iterables of object IDs. In the latter - case it is assumed that every iterable forms a group of objects to be - agglomerated together. - bad: set in which to store objects or groups of objects flagged as bad. - num_to_prefetch: number of items from `objects` to prefetch - locations: iterable of xyz tuples of length len(objects). If specified, - the cursor will be automatically moved to the location corresponding to - the current object if batch == 1. - """ - super().__init__( - num_to_prefetch=num_to_prefetch, locations=locations, objects=objects - ) - self.bad = bad - - self.set_keybindings() - - self.update_batch() - - def set_keybindings(self) -> None: - """Set key bindings for the viewer.""" - self.viewer.actions.add("next-batch", lambda s: self.next_batch()) - self.viewer.actions.add("prev-batch", lambda s: self.prev_batch()) - self.viewer.actions.add("dec-batch", lambda s: self.batch_dec()) - self.viewer.actions.add("inc-batch", lambda s: self.batch_inc()) - self.viewer.actions.add("mark-bad", lambda s: self.mark_bad()) - self.viewer.actions.add("mark-removed-bad", lambda s: self.mark_removed_bad()) - self.viewer.actions.add("toggle-equiv", lambda s: self.toggle_equiv()) - - with self.viewer.config_state.txn() as s: - s.input_event_bindings.viewer["keyj"] = "next-batch" - s.input_event_bindings.viewer["keyk"] = "prev-batch" - s.input_event_bindings.viewer["keym"] = "dec-batch" - s.input_event_bindings.viewer["keyp"] = "inc-batch" - s.input_event_bindings.viewer["keyv"] = "mark-bad" - s.input_event_bindings.viewer["keyt"] = "toggle-equiv" - s.input_event_bindings.viewer["keya"] = "mark-removed-bad" - - def custom_msg(self) -> str: - """Construct a custom message for the current state. - - Returns: - A formatted message indicating the number of bad objects. - """ - return "num_bad: %d" % len(self.bad) - - def mark_bad(self) -> None: - """Mark an object or group of objects as bad. - - If the batch size is greater than 1, the user is prompted to decrease - the batch size. - """ - if self.batch > 1: - self.update_msg("decrease batch to 1 to mark objects bad") - return - - sids = self.todo[self.index]["seg"] - if len(sids) == 1: - self.bad.add(list(sids)[0]) - else: - self.bad.add(frozenset(sids)) - - self.update_msg("marked bad: %r" % (sids,)) - self.next_batch() - - def mark_removed_bad(self) -> None: - """From the set of original objects mark those bad that are not displayed. - Update the message with the IDs of the newly marked bad objects. - """ - original = set(self.list_segments()) - new_bad = original - set(self.viewer.state.layers["seg"].segments) - if new_bad: - self.bad |= new_bad - self.update_msg("marked bad: %r" % (new_bad,)) + return cursor_position + + +class ObjectReview(Base): + """Base class for rapid (agglomerated) object review. + + To achieve good throughput, smaller objects are usually reviewed in + batches. + """ + + def __init__( + self, + objects: Iterable, + bad: list, + num_to_prefetch: int = 10, + locations: Optional[Iterable[tuple[int, int, int]]] = None, + ): + """Constructor. + + Args: + objects: iterable of object IDs or iterables of object IDs. In the latter + case it is assumed that every iterable forms a group of objects to be + agglomerated together. + bad: set in which to store objects or groups of objects flagged as bad. + num_to_prefetch: number of items from `objects` to prefetch + locations: iterable of xyz tuples of length len(objects). If specified, + the cursor will be automatically moved to the location corresponding to + the current object if batch == 1. + """ + super().__init__( + num_to_prefetch=num_to_prefetch, locations=locations, objects=objects + ) + self.bad = bad + + self.set_keybindings() + + self.update_batch() + + def set_keybindings(self) -> None: + """Set key bindings for the viewer.""" + self.viewer.actions.add("next-batch", lambda s: self.next_batch()) + self.viewer.actions.add("prev-batch", lambda s: self.prev_batch()) + self.viewer.actions.add("dec-batch", lambda s: self.batch_dec()) + self.viewer.actions.add("inc-batch", lambda s: self.batch_inc()) + self.viewer.actions.add("mark-bad", lambda s: self.mark_bad()) + self.viewer.actions.add( + "mark-removed-bad", lambda s: self.mark_removed_bad() + ) + self.viewer.actions.add("toggle-equiv", lambda s: self.toggle_equiv()) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.viewer["keyj"] = "next-batch" + s.input_event_bindings.viewer["keyk"] = "prev-batch" + s.input_event_bindings.viewer["keym"] = "dec-batch" + s.input_event_bindings.viewer["keyp"] = "inc-batch" + s.input_event_bindings.viewer["keyv"] = "mark-bad" + s.input_event_bindings.viewer["keyt"] = "toggle-equiv" + s.input_event_bindings.viewer["keya"] = "mark-removed-bad" + + def custom_msg(self) -> str: + """Construct a custom message for the current state. + + Returns: + A formatted message indicating the number of bad objects. + """ + return "num_bad: %d" % len(self.bad) + + def mark_bad(self) -> None: + """Mark an object or group of objects as bad. + + If the batch size is greater than 1, the user is prompted to decrease + the batch size. + """ + if self.batch > 1: + self.update_msg("decrease batch to 1 to mark objects bad") + return + + sids = self.todo[self.index]["seg"] + if len(sids) == 1: + self.bad.add(list(sids)[0]) + else: + self.bad.add(frozenset(sids)) + + self.update_msg("marked bad: %r" % (sids,)) + self.next_batch() + + def mark_removed_bad(self) -> None: + """From the set of original objects mark those bad that are not displayed. + Update the message with the IDs of the newly marked bad objects. + """ + original = set(self.list_segments()) + new_bad = original - set(self.viewer.state.layers["seg"].segments) + if new_bad: + self.bad |= new_bad + self.update_msg("marked bad: %r" % (new_bad,)) class ObjectReviewStoreLocation(ObjectReview): - """Class to mark and store locations of errors in the segmentation - - To mark a merger, move the cursor to a spot of the false merger and press 'w'. - Then, move the cursor to a spot within the object that should belong to a - separate object and press 'shift + W'. Yellow point annotations indicate the - merger. For split errors, proceed in similar manner but press 'd' and - 'shift + D', which will display blue annotations. - Marked locations can be deleted either by pressing 'ctrl + Z' (to delete the - last marked location) or by hovering the cursor over one of the point - annotations and pressing 'ctrl + v'. - - Attributes: - seg_error_coordinates: A mapping of annotation identifier substrings to - error coordinate pairs. - Example: {'m0': [[x1,y1,z1],[x2,y2,z2]], 's0':[[x1,y1,z1],[x2,y2,z2]], ...} - - Keys starting with 'm' indicate merge errors. - - Keys starting with 's' indicate split errors. - temp_coord_list: Temporary storage for coordinates. + """Class to mark and store locations of errors in the segmentation + + To mark a merger, move the cursor to a spot of the false merger and press 'w'. + Then, move the cursor to a spot within the object that should belong to a + separate object and press 'shift + W'. Yellow point annotations indicate the + merger. For split errors, proceed in similar manner but press 'd' and + 'shift + D', which will display blue annotations. + Marked locations can be deleted either by pressing 'ctrl + Z' (to delete the + last marked location) or by hovering the cursor over one of the point + annotations and pressing 'ctrl + v'. + + Attributes: + seg_error_coordinates: A mapping of annotation identifier substrings to + error coordinate pairs. + Example: {'m0': [[x1,y1,z1],[x2,y2,z2]], 's0':[[x1,y1,z1],[x2,y2,z2]], ...} + - Keys starting with 'm' indicate merge errors. + - Keys starting with 's' indicate split errors. + temp_coord_list: Temporary storage for coordinates. + """ + + def __init__( + self, + objects: list, + bad: list, + seg_error_coordinates: Optional[list[str, list[list[int]]]] = {}, + load_annotations: bool = False, + ) -> None: + """Initialize the ObjectReviewStoreLocation class. + + Args: + objects: A list of objects. + bad: A list of bad objects or markers. + seg_error_coordinates: A dictionary of error coordinates. + load_annotations: A flag to indicate if annotations should be loaded. + """ + super(ObjectReviewStoreLocation, self).__init__(objects, bad) + self.seg_error_coordinates = seg_error_coordinates + if load_annotations and seg_error_coordinates: + for k, v in seg_error_coordinates.items(): + self.annotate_error_locations(v, k) + self.temp_coord_list = [] + + def set_keybindings(self) -> None: + """Set key bindings for the viewer.""" + super().set_keybindings() + self.viewer.actions.add( + "merge0", lambda s: self.store_error_location(s, idx=0, mode="merger") + ) + self.viewer.actions.add( + "merge1", lambda s: self.store_error_location(s, idx=1, mode="merger") + ) + self.viewer.actions.add( + "split0", lambda s: self.store_error_location(s, idx=0, mode="split") + ) + self.viewer.actions.add( + "split1", lambda s: self.store_error_location(s, idx=1, mode="split") + ) + self.viewer.actions.add( + "delete_from_annotation", self.delete_location_from_annotation + ) + self.viewer.actions.add( + "delete_last_entry", lambda s: self.delete_last_location() + ) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.viewer["keyw"] = "merge0" + s.input_event_bindings.viewer["shift+keyw"] = "merge1" + s.input_event_bindings.viewer["keyd"] = "split0" + s.input_event_bindings.viewer["shift+keyd"] = "split1" + s.input_event_bindings.viewer["control+keyv"] = "delete_from_annotation" + s.input_event_bindings.viewer["control+keyz"] = "delete_last_entry" + + def get_id(self, mode: str) -> str: + """Generate a unique identifier for an error based on its type. + + Args: + mode: Error type, either 'merge' or 'split'. + + Returns: + A unique identifier string. """ + id_ = mode[0] + if any(self.seg_error_coordinates): + counter = int(max([x[1:] for x in self.seg_error_coordinates.keys()])) + 1 + else: + counter = 0 + id_ = id_ + str(counter) + return id_ + + def store_error_location( + self, + action_state: neuroglancer.viewer_config_state.ActionState, + mode: str, + idx: int = 0, + ) -> None: + """Store error locations. + + Args: + action_state: State of the viewer during the action. + mode: Type of the error ('merger' or 'split'). + idx: Indicates if it's the first or second coordinate (0 or 1). + """ + location = self.get_cursor_position(action_state) + if location is None: + return + + if idx == 1 and not self.temp_coord_list: + self.update_msg("You have not entered a first coord yet") + return + + if idx == 0 and self.temp_coord_list: + self.temp_coord_list = [] - def __init__( - self, - objects: list, - bad: list, - seg_error_coordinates: Optional[list[str, list[list[int]]]] = {}, - load_annotations: bool = False, - ) -> None: - """Initialize the ObjectReviewStoreLocation class. - - Args: - objects: A list of objects. - bad: A list of bad objects or markers. - seg_error_coordinates: A dictionary of error coordinates. - load_annotations: A flag to indicate if annotations should be loaded. - """ - super(ObjectReviewStoreLocation, self).__init__(objects, bad) - self.seg_error_coordinates = seg_error_coordinates - if load_annotations and seg_error_coordinates: - for k, v in seg_error_coordinates.items(): - self.annotate_error_locations(v, k) + self.temp_coord_list.append(location) + + if idx == 1: + if self.temp_coord_list[0] == self.temp_coord_list[1]: + self.update_msg("You entered the same coordinate twice. Try again!") self.temp_coord_list = [] + return - def set_keybindings(self) -> None: - """Set key bindings for the viewer.""" - super().set_keybindings() - self.viewer.actions.add( - "merge0", lambda s: self.store_error_location(s, index=0, mode="merger") - ) - self.viewer.actions.add( - "merge1", lambda s: self.store_error_location(s, index=1, mode="merger") - ) - self.viewer.actions.add( - "split0", lambda s: self.store_error_location(s, index=0, mode="split") - ) - self.viewer.actions.add( - "split1", lambda s: self.store_error_location(s, index=1, mode="split") - ) - self.viewer.actions.add( - "delete_from_annotation", self.delete_location_from_annotation - ) - self.viewer.actions.add( - "delete_last_entry", lambda s: self.delete_last_location() - ) + identifier = self.get_id(mode=mode) + self.seg_error_coordinates.update({identifier: self.temp_coord_list}) + self.annotate_error_locations(self.temp_coord_list, identifier) + self.temp_coord_list = [] - with self.viewer.config_state.txn() as s: - s.input_event_bindings.viewer["keyw"] = "merge0" - s.input_event_bindings.viewer["shift+keyw"] = "merge1" - s.input_event_bindings.viewer["keyd"] = "split0" - s.input_event_bindings.viewer["shift+keyd"] = "split1" - s.input_event_bindings.viewer["control+keyv"] = "delete_from_annotation" - s.input_event_bindings.viewer["control+keyz"] = "delete_last_entry" - - def get_id(self, mode: str) -> str: - """Generate a unique identifier for an error based on its type. - - Args: - mode: Error type, either 'merge' or 'split'. - - Returns: - A unique identifier string. - """ - id_ = mode[0] - if any(self.seg_error_coordinates): - counter = int(max([x[1:] for x in self.seg_error_coordinates.keys()])) + 1 - else: - counter = 0 - id_ = id_ + str(counter) - return id_ - - def store_error_location( - self, - action_state: neuroglancer.viewer_config_state.ActionState, - mode: str, - index: int = 0, - ) -> None: - """Store error locations. - - Args: - action_state: State of the viewer during the action. - mode: Type of the error ('merger' or 'split'). - index: Indicates if it's the first or second coordinate (0 or 1). - """ - location = self.get_cursor_position(action_state) - if location is None: - return - - if index == 1 and not self.temp_coord_list: - self.update_msg("You have not entered a first coord yet") - return - - if index == 0 and self.temp_coord_list: - self.temp_coord_list = [] - - self.temp_coord_list.append(location) - - if index == 1: - if self.temp_coord_list[0] == self.temp_coord_list[1]: - self.update_msg("You entered the same coordinate twice. Try again!") - self.temp_coord_list = [] - return - - identifier = self.get_id(mode=mode) - self.seg_error_coordinates.update({identifier: self.temp_coord_list}) - self.annotate_error_locations(self.temp_coord_list, identifier) - self.temp_coord_list = [] - - def annotate_error_locations( - self, coordinate_lst: list[list[int]], id_: str - ) -> None: - """Annotate the error locations in the viewer. - - Args: - coordinate_lst: List of coordinates to be annotated. - id_: Unique identifier for the error. - """ - for i, coord in enumerate(coordinate_lst): - annotation_id = id_ + f"_{i}" - self.mk_point_annotation(coord, annotation_id) - - def mk_point_annotation(self, coordinate: list[int], annotation_id: str) -> None: - """Create a point annotation in the viewer. - - Args: - coordinate: 3D coordinate of the annotation point. - annotation_id: Unique identifier for the annotation. - """ - if annotation_id.startswith("m"): - color = "#fae505" - else: - color = "#05f2fa" - annotation = neuroglancer.PointAnnotation( - id=annotation_id, point=coordinate, props=[color] - ) - with self.viewer.txn() as s: - annotations = s.layers["annotation"].annotations - annotations.append(annotation) - - def get_annotation_id( - self, action_state: neuroglancer.viewer_config_state.ActionState - ) -> Optional[str]: - """Retrieve the ID of a selected annotation. - - Args: - action_state: neuroglancer.viewer_config_state.ActionState. - - Returns: - The selected object's ID or None if retrieval fails. - """ - try: - selection_state = action_state.selected_values["annotation"].to_json() - selected_object = selection_state["annotationId"] - except Exception: - self.update_msg("Could not retrieve annotation id") - return - - return selected_object - - def delete_location_from_annotation( - self, action_state: neuroglancer.viewer_config_state.ActionState - ) -> None: - """Delete the error location pair associated with the annotation at the cursor position - - Args: - action_state: State of the viewer during the action. - """ - id_ = self.get_annotation_id(action_state) - if id_ is None: - return - - target_key = id_[:2] - del self.seg_error_coordinates[target_key] - - to_remove = [target_key + "_0", target_key + "_1"] - self.delete_annotation(to_remove) - - def delete_annotation(self, to_remove: list[str]) -> None: - """Delete specified annotations from the viewer. - - Args: - to_remove: list of annotation IDs to be removed. - """ - with self.viewer.txn() as s: - annotations = s.layers["annotation"].annotations - annotations = [a for a in annotations if a.id not in to_remove] - s.layers["annotation"].annotations = annotations - - def delete_last_location(self): - """Delete the last error location pair tagged.""" - last_key = list(self.seg_error_coordinates.keys())[-1] - del self.seg_error_coordinates[last_key] - - to_remove = [last_key + "_0", last_key + "_1"] - self.delete_annotation(to_remove) + def annotate_error_locations( + self, coordinate_lst: list[list[int]], id_: str + ) -> None: + """Annotate the error locations in the viewer. + + Args: + coordinate_lst: List of coordinates to be annotated. + id_: Unique identifier for the error. + """ + for i, coord in enumerate(coordinate_lst): + annotation_id = id_ + f"_{i}" + self.mk_point_annotation(coord, annotation_id) + + def mk_point_annotation( + self, coordinate: list[int], annotation_id: str + ) -> None: + """Create a point annotation in the viewer. + + Args: + coordinate: 3D coordinate of the annotation point. + annotation_id: Unique identifier for the annotation. + """ + if annotation_id.startswith("m"): + color = "#fae505" + else: + color = "#05f2fa" + annotation = neuroglancer.PointAnnotation( + id=annotation_id, point=coordinate, props=[color] + ) + with self.viewer.txn() as s: + annotations = s.layers["annotation"].annotations + annotations.append(annotation) + + def get_annotation_id( + self, action_state: neuroglancer.viewer_config_state.ActionState + ) -> Optional[str]: + """Retrieve the ID of a selected annotation. + + Args: + action_state: neuroglancer.viewer_config_state.ActionState. + + Returns: + The selected object's ID or None if retrieval fails. + """ + try: + selection_state = action_state.selected_values["annotation"].to_json() + selected_object = selection_state["annotationId"] + except Exception: + self.update_msg("Could not retrieve annotation id") + return + + return selected_object + + def delete_location_from_annotation( + self, action_state: neuroglancer.viewer_config_state.ActionState + ) -> None: + """Delete the error location pair associated with the annotation at the cursor position + + Args: + action_state: State of the viewer during the action. + """ + id_ = self.get_annotation_id(action_state) + if id_ is None: + return + + target_key = id_[:2] + del self.seg_error_coordinates[target_key] + + to_remove = [target_key + "_0", target_key + "_1"] + self.delete_annotation(to_remove) + + def delete_annotation(self, to_remove: list[str]) -> None: + """Delete specified annotations from the viewer. + + Args: + to_remove: list of annotation IDs to be removed. + """ + with self.viewer.txn() as s: + annotations = s.layers["annotation"].annotations + annotations = [a for a in annotations if a.id not in to_remove] + s.layers["annotation"].annotations = annotations + + def delete_last_location(self): + """Delete the last error location pair tagged.""" + last_key = list(self.seg_error_coordinates.keys())[-1] + del self.seg_error_coordinates[last_key] + + to_remove = [last_key + "_0", last_key + "_1"] + self.delete_annotation(to_remove) class ObjectClassification(Base): - """Base class for object classification.""" - - def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None): - """Constructor. - - Args: - objects: iterable of object IDs - key_to_class: dict mapping keys to class labels - num_to_prefetch: number of `objects` to prefetch - """ - super().__init__( - num_to_prefetch=num_to_prefetch, locations=locations, objects=objects - ) + """Base class for object classification.""" - self.results = defaultdict(set) # class -> ids + def __init__(self, objects, key_to_class, num_to_prefetch=10, locations=None): + """Constructor. - self.viewer.actions.add("mr-next-batch", lambda s: self.next_batch()) - self.viewer.actions.add("mr-prev-batch", lambda s: self.prev_batch()) - self.viewer.actions.add("unclassify", lambda s: self.classify(None)) + Args: + objects: iterable of object IDs + key_to_class: dict mapping keys to class labels + num_to_prefetch: number of `objects` to prefetch + """ + super().__init__( + num_to_prefetch=num_to_prefetch, locations=locations, objects=objects + ) - for key, cls in key_to_class.items(): - self.viewer.actions.add( - "classify-%s" % cls, lambda s, cls=cls: self.classify(cls) - ) + self.results = defaultdict(set) # class -> ids - with self.viewer.config_state.txn() as s: - for key, cls in key_to_class.items(): - s.input_event_bindings.viewer["key%s" % key] = "classify-%s" % cls + self.viewer.actions.add("mr-next-batch", lambda s: self.next_batch()) + self.viewer.actions.add("mr-prev-batch", lambda s: self.prev_batch()) + self.viewer.actions.add("unclassify", lambda s: self.classify(None)) - # Navigation without classification. - s.input_event_bindings.viewer["keyj"] = "mr-next-batch" - s.input_event_bindings.viewer["keyk"] = "mr-prev-batch" - s.input_event_bindings.viewer["keyv"] = "unclassify" + for key, cls in key_to_class.items(): + self.viewer.actions.add( + "classify-%s" % cls, lambda s, cls=cls: self.classify(cls) + ) - self.update_batch() + with self.viewer.config_state.txn() as s: + for key, cls in key_to_class.items(): + s.input_event_bindings.viewer["key%s" % key] = "classify-%s" % cls - def custom_msg(self): - return " ".join("%s:%d" % (k, len(v)) for k, v in self.results.items()) + # Navigation without classification. + s.input_event_bindings.viewer["keyj"] = "mr-next-batch" + s.input_event_bindings.viewer["keyk"] = "mr-prev-batch" + s.input_event_bindings.viewer["keyv"] = "unclassify" - def classify(self, cls): - sid = list(self.todo[self.index]["seg"])[0] - for v in self.results.values(): - v -= set([sid]) + self.update_batch() - if cls is not None: - self.results[cls].add(sid) + def custom_msg(self): + return " ".join("%s:%d" % (k, len(v)) for k, v in self.results.items()) - self.next_batch() + def classify(self, cls): + sid = list(self.todo[self.index]["seg"])[0] + for v in self.results.values(): + v -= set([sid]) + if cls is not None: + self.results[cls].add(sid) -class GraphUpdater(Base): - """Base class for agglomeration graph modification. - - Usage: - * splitting - 1) select merged objects (start with a supervoxel, then press 'c') - 2) shift-click on two supervoxels that should be separated; a new layer - will be displayed showing the supervoxels along the shortest path - between selected objects - 3) use '[' and ']' to restrict the path so that the displayed supervoxels - are not wrongly merged - 4) press 's' to remove the edge next to the last shown one from the - agglomeration graph - - * merging - 1) select segments to be merged - 2) press 'm' - - Press 'c' to add any supervoxels connected to the ones currently displayed - (according to the current state of the agglomeraton graph). - """ + self.next_batch() - def __init__(self, graph, objects, bad, num_to_prefetch=0): - super().__init__(objects=objects, num_to_prefetch=num_to_prefetch) - self.graph = graph - self.split_objects = [] - self.split_path = [] - self.split_index = 1 - self.sem = threading.Semaphore() - - self.bad = bad - self.viewer.actions.add("add-ccs", lambda s: self.add_ccs()) - self.viewer.actions.add("clear-splits", lambda s: self.clear_splits()) - self.viewer.actions.add("add-split", self.add_split) - self.viewer.actions.add("accept-split", lambda s: self.accept_split()) - self.viewer.actions.add("split-inc", lambda s: self.inc_split()) - self.viewer.actions.add("split-dec", lambda s: self.dec_split()) - self.viewer.actions.add("merge-segments", lambda s: self.merge_segments()) - self.viewer.actions.add("mark-bad", lambda s: self.mark_bad()) - self.viewer.actions.add("next-batch", lambda s: self.next_batch()) - self.viewer.actions.add("prev-batch", lambda s: self.prev_batch()) - - with self.viewer.config_state.txn() as s: - s.input_event_bindings.viewer["keyj"] = "next-batch" - s.input_event_bindings.viewer["keyk"] = "prev-batch" - s.input_event_bindings.viewer["keyc"] = "add-ccs" - s.input_event_bindings.viewer["keya"] = "clear-splits" - s.input_event_bindings.viewer["keym"] = "merge-segments" - s.input_event_bindings.viewer["shift+bracketleft"] = "split-dec" - s.input_event_bindings.viewer["shift+bracketright"] = "split-inc" - s.input_event_bindings.viewer["keys"] = "accept-split" - s.input_event_bindings.data_view["shift+mousedown0"] = "add-split" - s.input_event_bindings.viewer["keyv"] = "mark-bad" - - with self.viewer.txn() as s: - s.layers["split"] = neuroglancer.SegmentationLayer( - source=s.layers["seg"].source - ) - s.layers["split"].visible = False - - def merge_segments(self): - sids = [sid for sid in self.viewer.state.layers["seg"].segments if sid > 0] - self.graph.add_edges_from(zip(sids, sids[1:])) - - def update_split(self): - s = copy.deepcopy(self.viewer.state) - s.layers["split"].segments = list(self.split_path)[: self.split_index] - self.viewer.set_state(s) - - def inc_split(self): - self.split_index = min(len(self.split_path), self.split_index + 1) - self.update_split() - - def dec_split(self): - self.split_index = max(1, self.split_index - 1) - self.update_split() - - def add_ccs(self): - if self.sem.acquire(blocking=False): - curr = set(self.viewer.state.layers["seg"].segments) - for sid in self.viewer.state.layers["seg"].segments: - if sid in self.graph: - curr |= set(nx.node_connected_component(self.graph, sid)) - - self.update_segments(curr) - self.sem.release() - - def accept_split(self): - edge = self.split_path[self.split_index - 1 : self.split_index + 1] - if len(edge) < 2: - return - - self.graph.remove_edge(edge[0], edge[1]) - self.clear_splits() - - def clear_splits(self): - self.split_objects = [] - self.update_msg("splits cleared") - - s = copy.deepcopy(self.viewer.state) - s.layers["split"].visible = False - s.layers["seg"].visible = True - self.viewer.set_state(s) - - def start_split(self): - self.split_path = nx.shortest_path( - self.graph, self.split_objects[0], self.split_objects[1] - ) - self.split_index = 1 - self.update_msg("splitting: %s" % ("-".join(str(x) for x in self.split_path))) - - s = copy.deepcopy(self.viewer.state) - s.layers["seg"].visible = False - s.layers["split"].visible = True - self.viewer.set_state(s) - self.update_split() - - def add_split(self, s): - if len(self.split_objects) < 2: - self.split_objects.append(s.selected_values["seg"].value) - self.update_msg("split: %s" % (":".join(str(x) for x in self.split_objects))) - - if len(self.split_objects) == 2: - self.start_split() - - def mark_bad(self): - if self.batch > 1: - self.update_msg("decrease batch to 1 to mark objects bad") - return - - sids = self.todo[self.index]["seg"] - if len(sids) == 1: - self.bad.add(list(sids)[0]) - else: - self.bad.add(frozenset(sids)) - - self.update_msg("marked bad: %r" % (sids,)) - self.next_batch() + +class GraphUpdater(Base): + """Base class for agglomeration graph modification. + + Usage: + * splitting + 1) select merged objects (start with a supervoxel, then press 'c') + 2) shift-click on two supervoxels that should be separated; a new layer + will be displayed showing the supervoxels along the shortest path + between selected objects + 3) use '[' and ']' to restrict the path so that the displayed supervoxels + are not wrongly merged + 4) press 's' to remove the edge next to the last shown one from the + agglomeration graph + + * merging + 1) select segments to be merged + 2) press 'm' + + Press 'c' to add any supervoxels connected to the ones currently displayed + (according to the current state of the agglomeraton graph). + """ + + def __init__(self, graph, objects, bad, num_to_prefetch=0): + super().__init__(objects=objects, num_to_prefetch=num_to_prefetch) + self.graph = graph + self.split_objects = [] + self.split_path = [] + self.split_index = 1 + self.sem = threading.Semaphore() + + self.bad = bad + self.viewer.actions.add("add-ccs", lambda s: self.add_ccs()) + self.viewer.actions.add("clear-splits", lambda s: self.clear_splits()) + self.viewer.actions.add("add-split", self.add_split) + self.viewer.actions.add("accept-split", lambda s: self.accept_split()) + self.viewer.actions.add("split-inc", lambda s: self.inc_split()) + self.viewer.actions.add("split-dec", lambda s: self.dec_split()) + self.viewer.actions.add("merge-segments", lambda s: self.merge_segments()) + self.viewer.actions.add("mark-bad", lambda s: self.mark_bad()) + self.viewer.actions.add("next-batch", lambda s: self.next_batch()) + self.viewer.actions.add("prev-batch", lambda s: self.prev_batch()) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.viewer["keyj"] = "next-batch" + s.input_event_bindings.viewer["keyk"] = "prev-batch" + s.input_event_bindings.viewer["keyc"] = "add-ccs" + s.input_event_bindings.viewer["keya"] = "clear-splits" + s.input_event_bindings.viewer["keym"] = "merge-segments" + s.input_event_bindings.viewer["shift+bracketleft"] = "split-dec" + s.input_event_bindings.viewer["shift+bracketright"] = "split-inc" + s.input_event_bindings.viewer["keys"] = "accept-split" + s.input_event_bindings.data_view["shift+mousedown0"] = "add-split" + s.input_event_bindings.viewer["keyv"] = "mark-bad" + + with self.viewer.txn() as s: + s.layers["split"] = neuroglancer.SegmentationLayer( + source=s.layers["seg"].source + ) + s.layers["split"].visible = False + + def merge_segments(self): + sids = [sid for sid in self.viewer.state.layers["seg"].segments if sid > 0] + self.graph.add_edges_from(zip(sids, sids[1:])) + + def update_split(self): + s = copy.deepcopy(self.viewer.state) + s.layers["split"].segments = list(self.split_path)[: self.split_index] + self.viewer.set_state(s) + + def inc_split(self): + self.split_index = min(len(self.split_path), self.split_index + 1) + self.update_split() + + def dec_split(self): + self.split_index = max(1, self.split_index - 1) + self.update_split() + + def add_ccs(self): + if self.sem.acquire(blocking=False): + curr = set(self.viewer.state.layers["seg"].segments) + for sid in self.viewer.state.layers["seg"].segments: + if sid in self.graph: + curr |= set(nx.node_connected_component(self.graph, sid)) + + self.update_segments(curr) + self.sem.release() + + def accept_split(self): + edge = self.split_path[self.split_index - 1 : self.split_index + 1] + if len(edge) < 2: + return + + self.graph.remove_edge(edge[0], edge[1]) + self.clear_splits() + + def clear_splits(self): + self.split_objects = [] + self.update_msg("splits cleared") + + s = copy.deepcopy(self.viewer.state) + s.layers["split"].visible = False + s.layers["seg"].visible = True + self.viewer.set_state(s) + + def start_split(self): + self.split_path = nx.shortest_path( + self.graph, self.split_objects[0], self.split_objects[1] + ) + self.split_index = 1 + self.update_msg("splitting: %s" % "-".join(str(x) for x in self.split_path)) + + s = copy.deepcopy(self.viewer.state) + s.layers["seg"].visible = False + s.layers["split"].visible = True + self.viewer.set_state(s) + self.update_split() + + def add_split(self, s): + if len(self.split_objects) < 2: + self.split_objects.append(s.selected_values["seg"].value) + self.update_msg("split: %s" % ":".join(str(x) for x in self.split_objects)) + + if len(self.split_objects) == 2: + self.start_split() + + def mark_bad(self): + if self.batch > 1: + self.update_msg("decrease batch to 1 to mark objects bad") + return + + sids = self.todo[self.index]["seg"] + if len(sids) == 1: + self.bad.add(list(sids)[0]) + else: + self.bad.add(frozenset(sids)) + + self.update_msg("marked bad: %r" % (sids,)) + self.next_batch() From 813ce6daf2c2b69a2423d36e55ad87af304d6b18 Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Tue, 22 Aug 2023 09:27:44 +0200 Subject: [PATCH 4/8] Update proofreading.py correct remaining indent of 4 in docstring --- ffn/utils/proofreading.py | 80 +++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index dc7e283..c69d6ce 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -44,9 +44,9 @@ def __init__( """Initializes the Base class for proofreading. Args: - num_to_prefetch: Number of items to prefetch. - locations: List of xyz coordinates corresponding to object locations. - objects: Object IDs or a dictionary mapping layer names to object IDs. + num_to_prefetch: Number of items to prefetch. + locations: List of xyz coordinates corresponding to object locations. + objects: Object IDs or a dictionary mapping layer names to object IDs. """ self.viewer = neuroglancer.Viewer() self.num_to_prefetch = num_to_prefetch @@ -99,9 +99,9 @@ def update_segments( """Updates segments in Neuroglancer viewer. Args: - segments: List of segment IDs to update. - loc: 3D coordinates to set the viewer to. - layer: Layer name in Neuroglancer to be updated. + segments: List of segment IDs to update. + loc: 3D coordinates to set the viewer to. + layer: Layer name in Neuroglancer to be updated. """ s = copy.deepcopy(self.viewer.state) l = s.layers[layer] @@ -155,11 +155,11 @@ def list_segments( """Get a list of segment IDs for a given index and layer. Args: - index: Index of segments to list. - layer: Layer name to list the segments from. + index: Index of segments to list. + layer: Layer name to list the segments from. Returns: - List of segment IDs. + List of segment IDs. """ if index is None: index = self.index @@ -175,7 +175,7 @@ def custom_msg(self) -> str: """Generate a custom message for the current state. Returns: - A custom message string. + A custom message string. """ return "" @@ -223,10 +223,10 @@ def get_cursor_position( """Return coordinates of the cursor position from a neuroglancer action state Args: - action_state : Neuroglancer action state + action_state : Neuroglancer action state Returns: - (x, y, z) cursor position + (x, y, z) cursor position """ try: cursor_position = [int(x) for x in action_state.mouse_voxel_coordinates] @@ -297,7 +297,7 @@ def custom_msg(self) -> str: """Construct a custom message for the current state. Returns: - A formatted message indicating the number of bad objects. + A formatted message indicating the number of bad objects. """ return "num_bad: %d" % len(self.bad) @@ -334,22 +334,22 @@ def mark_removed_bad(self) -> None: class ObjectReviewStoreLocation(ObjectReview): """Class to mark and store locations of errors in the segmentation - To mark a merger, move the cursor to a spot of the false merger and press 'w'. - Then, move the cursor to a spot within the object that should belong to a - separate object and press 'shift + W'. Yellow point annotations indicate the - merger. For split errors, proceed in similar manner but press 'd' and + To mark a merger, move the cursor to a spot of the false merger and press + 'w'. Then, move the cursor to a spot within the object that should belong to + a separate object and press 'shift + W'. Yellow point annotations indicate + the merger. For split errors, proceed in similar manner but press 'd' and 'shift + D', which will display blue annotations. Marked locations can be deleted either by pressing 'ctrl + Z' (to delete the last marked location) or by hovering the cursor over one of the point annotations and pressing 'ctrl + v'. Attributes: - seg_error_coordinates: A mapping of annotation identifier substrings to - error coordinate pairs. - Example: {'m0': [[x1,y1,z1],[x2,y2,z2]], 's0':[[x1,y1,z1],[x2,y2,z2]], ...} - - Keys starting with 'm' indicate merge errors. - - Keys starting with 's' indicate split errors. - temp_coord_list: Temporary storage for coordinates. + seg_error_coordinates: A mapping of annotation identifier substrings to + error coordinate pairs. + Example: {'m0': [[x1,y1,z1],[x2,y2,z2]], 's0':[[x1,y1,z1],[x2,y2,z2]], ...} + - Keys starting with 'm' indicate merge errors. + - Keys starting with 's' indicate split errors. + temp_coord_list: Temporary storage for coordinates. """ def __init__( @@ -362,10 +362,10 @@ def __init__( """Initialize the ObjectReviewStoreLocation class. Args: - objects: A list of objects. - bad: A list of bad objects or markers. - seg_error_coordinates: A dictionary of error coordinates. - load_annotations: A flag to indicate if annotations should be loaded. + objects: A list of objects. + bad: A list of bad objects or markers. + seg_error_coordinates: A dictionary of error coordinates. + load_annotations: A flag to indicate if annotations should be loaded. """ super(ObjectReviewStoreLocation, self).__init__(objects, bad) self.seg_error_coordinates = seg_error_coordinates @@ -408,10 +408,10 @@ def get_id(self, mode: str) -> str: """Generate a unique identifier for an error based on its type. Args: - mode: Error type, either 'merge' or 'split'. + mode: Error type, either 'merge' or 'split'. Returns: - A unique identifier string. + A unique identifier string. """ id_ = mode[0] if any(self.seg_error_coordinates): @@ -430,9 +430,9 @@ def store_error_location( """Store error locations. Args: - action_state: State of the viewer during the action. - mode: Type of the error ('merger' or 'split'). - idx: Indicates if it's the first or second coordinate (0 or 1). + action_state: State of the viewer during the action. + mode: Type of the error ('merger' or 'split'). + idx: Indicates if it's the first or second coordinate (0 or 1). """ location = self.get_cursor_position(action_state) if location is None: @@ -464,8 +464,8 @@ def annotate_error_locations( """Annotate the error locations in the viewer. Args: - coordinate_lst: List of coordinates to be annotated. - id_: Unique identifier for the error. + coordinate_lst: List of coordinates to be annotated. + id_: Unique identifier for the error. """ for i, coord in enumerate(coordinate_lst): annotation_id = id_ + f"_{i}" @@ -477,8 +477,8 @@ def mk_point_annotation( """Create a point annotation in the viewer. Args: - coordinate: 3D coordinate of the annotation point. - annotation_id: Unique identifier for the annotation. + coordinate: 3D coordinate of the annotation point. + annotation_id: Unique identifier for the annotation. """ if annotation_id.startswith("m"): color = "#fae505" @@ -497,10 +497,10 @@ def get_annotation_id( """Retrieve the ID of a selected annotation. Args: - action_state: neuroglancer.viewer_config_state.ActionState. + action_state: neuroglancer.viewer_config_state.ActionState. Returns: - The selected object's ID or None if retrieval fails. + The selected object's ID or None if retrieval fails. """ try: selection_state = action_state.selected_values["annotation"].to_json() @@ -517,7 +517,7 @@ def delete_location_from_annotation( """Delete the error location pair associated with the annotation at the cursor position Args: - action_state: State of the viewer during the action. + action_state: State of the viewer during the action. """ id_ = self.get_annotation_id(action_state) if id_ is None: @@ -533,7 +533,7 @@ def delete_annotation(self, to_remove: list[str]) -> None: """Delete specified annotations from the viewer. Args: - to_remove: list of annotation IDs to be removed. + to_remove: list of annotation IDs to be removed. """ with self.viewer.txn() as s: annotations = s.layers["annotation"].annotations From b09792fe60ee814510afd7213b9d2ed4237e983f Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Tue, 22 Aug 2023 21:32:09 +0200 Subject: [PATCH 5/8] Update proofreading.py - include requested updates - fix bug for annotation selection --- ffn/utils/proofreading.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index c69d6ce..a021721 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -367,7 +367,7 @@ def __init__( seg_error_coordinates: A dictionary of error coordinates. load_annotations: A flag to indicate if annotations should be loaded. """ - super(ObjectReviewStoreLocation, self).__init__(objects, bad) + super().__init__(objects, bad) self.seg_error_coordinates = seg_error_coordinates if load_annotations and seg_error_coordinates: for k, v in seg_error_coordinates.items(): @@ -459,19 +459,19 @@ def store_error_location( self.temp_coord_list = [] def annotate_error_locations( - self, coordinate_lst: list[list[int]], id_: str + self, coordinates: list[list[int]], error_id: str ) -> None: """Annotate the error locations in the viewer. Args: - coordinate_lst: List of coordinates to be annotated. - id_: Unique identifier for the error. + coordinates: List of coordinates to be annotated. + error_id: Unique identifier for the error. """ - for i, coord in enumerate(coordinate_lst): - annotation_id = id_ + f"_{i}" - self.mk_point_annotation(coord, annotation_id) + for i, coord in enumerate(coordinates): + annotation_id = f"{error_id}_{i}" + self.make_point_annotation(coord, annotation_id) - def mk_point_annotation( + def make_point_annotation( self, coordinate: list[int], annotation_id: str ) -> None: """Create a point annotation in the viewer. @@ -488,8 +488,7 @@ def mk_point_annotation( id=annotation_id, point=coordinate, props=[color] ) with self.viewer.txn() as s: - annotations = s.layers["annotation"].annotations - annotations.append(annotation) + s.layers["annotation"].annotations.append(annotation) def get_annotation_id( self, action_state: neuroglancer.viewer_config_state.ActionState @@ -519,17 +518,17 @@ def delete_location_from_annotation( Args: action_state: State of the viewer during the action. """ - id_ = self.get_annotation_id(action_state) - if id_ is None: + ann_id = self.get_annotation_id(action_state) + if ann_id is None: return - target_key = id_[:2] + target_key, _ = ann_id.split("_") del self.seg_error_coordinates[target_key] - to_remove = [target_key + "_0", target_key + "_1"] + to_remove = frozenset([target_key + "_0", target_key + "_1"]) self.delete_annotation(to_remove) - def delete_annotation(self, to_remove: list[str]) -> None: + def delete_annotation(self, to_remove: frozenset[str]) -> None: """Delete specified annotations from the viewer. Args: @@ -540,9 +539,9 @@ def delete_annotation(self, to_remove: list[str]) -> None: annotations = [a for a in annotations if a.id not in to_remove] s.layers["annotation"].annotations = annotations - def delete_last_location(self): + def delete_last_location(self) -> None: """Delete the last error location pair tagged.""" - last_key = list(self.seg_error_coordinates.keys())[-1] + last_key = next(reversed(self.seg_error_coordinates)) del self.seg_error_coordinates[last_key] to_remove = [last_key + "_0", last_key + "_1"] From d3cfa9677d8e246df4d9c2751b0afc908060fa19 Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Wed, 23 Aug 2023 18:00:31 +0200 Subject: [PATCH 6/8] Update proofreading.py transform list of annotations id pair to frozenset --- ffn/utils/proofreading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index a021721..8c5d864 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -544,7 +544,7 @@ def delete_last_location(self) -> None: last_key = next(reversed(self.seg_error_coordinates)) del self.seg_error_coordinates[last_key] - to_remove = [last_key + "_0", last_key + "_1"] + to_remove = frozenset([last_key + "_0", last_key + "_1"]) self.delete_annotation(to_remove) From 5f7c55e2278ee6cb6b068b16618e80d4be7012c4 Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Tue, 12 Sep 2023 07:23:32 +0200 Subject: [PATCH 7/8] Update proofreading.py correct typing --- ffn/utils/proofreading.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index 8c5d864..4eb8b8b 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -24,7 +24,7 @@ import networkx as nx import neuroglancer -from typing import Union, Optional, Iterable, Any +from typing import Union, Optional, Sequence class Base: @@ -38,8 +38,8 @@ class Base: def __init__( self, num_to_prefetch: int = 10, - locations: Optional[Iterable[tuple[int, int, int]]] = None, - objects: Optional[Union[dict[str, Any], Iterable[int]]] = None, + locations: Optional[list[Sequence[int]]] = None, + objects: Optional[Union[dict[str, int], Sequence[int]]] = None, ): """Initializes the Base class for proofreading. @@ -68,7 +68,7 @@ def __init__( self.set_init_state() - def _set_todo(self, objects: Union[list[str, Any], Iterable[int]]) -> None: + def _set_todo(self, objects: Union[dict[str, int], Sequence[int]]) -> None: """Private method to set the todo list.""" for o in objects: if isinstance(o, collections.abc.Mapping): @@ -93,7 +93,7 @@ def update_msg(self, msg: str) -> None: def update_segments( self, segments: list[int], - loc: Optional[tuple[int, int, int]] = None, + loc: Optional[Sequence[int]] = None, layer: str = "seg", ) -> None: """Updates segments in Neuroglancer viewer. @@ -246,10 +246,10 @@ class ObjectReview(Base): def __init__( self, - objects: Iterable, - bad: list, + objects: Union[dict[str, int], Sequence[int]], + bad: set, num_to_prefetch: int = 10, - locations: Optional[Iterable[tuple[int, int, int]]] = None, + locations: Optional[list[Sequence[int]]] = None, ): """Constructor. @@ -355,8 +355,8 @@ class ObjectReviewStoreLocation(ObjectReview): def __init__( self, objects: list, - bad: list, - seg_error_coordinates: Optional[list[str, list[list[int]]]] = {}, + bad: set, + seg_error_coordinates: Optional[dict[str, list]] = {}, load_annotations: bool = False, ) -> None: """Initialize the ObjectReviewStoreLocation class. From 45f20906800ff1e9469c97ec257468f55090169c Mon Sep 17 00:00:00 2001 From: moenigin <15244500+moenigin@users.noreply.github.com> Date: Wed, 13 Sep 2023 20:55:25 +0200 Subject: [PATCH 8/8] several fixes - fix/silence type hint errors (I hope) - fix error annotation count - fix error type assignment by only the second location --- ffn/utils/proofreading.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/ffn/utils/proofreading.py b/ffn/utils/proofreading.py index 4eb8b8b..e5e3ef2 100644 --- a/ffn/utils/proofreading.py +++ b/ffn/utils/proofreading.py @@ -24,7 +24,7 @@ import networkx as nx import neuroglancer -from typing import Union, Optional, Sequence +from typing import Union, Optional, Sequence, cast class Base: @@ -92,7 +92,7 @@ def update_msg(self, msg: str) -> None: def update_segments( self, - segments: list[int], + segments: Union[set[int], list[int]], loc: Optional[Sequence[int]] = None, layer: str = "seg", ) -> None: @@ -112,7 +112,7 @@ def update_segments( else: l.equivalences.clear() for a in self.todo[self.index : self.index + self.batch]: - a = [aa[layer] for aa in a] + a = [cast(dict, aa)[layer] for aa in a] l.equivalences.union(*a) if loc is not None: @@ -151,7 +151,7 @@ def prev_batch(self) -> None: def list_segments( self, index: Optional[int] = None, layer: str = "seg" - ) -> list[int]: + ) -> list: """Get a list of segment IDs for a given index and layer. Args: @@ -356,7 +356,7 @@ def __init__( self, objects: list, bad: set, - seg_error_coordinates: Optional[dict[str, list]] = {}, + seg_error_coordinates: dict[str, list] = {}, load_annotations: bool = False, ) -> None: """Initialize the ObjectReviewStoreLocation class. @@ -373,6 +373,7 @@ def __init__( for k, v in seg_error_coordinates.items(): self.annotate_error_locations(v, k) self.temp_coord_list = [] + self.cur_error_type = None def set_keybindings(self) -> None: """Set key bindings for the viewer.""" @@ -415,7 +416,7 @@ def get_id(self, mode: str) -> str: """ id_ = mode[0] if any(self.seg_error_coordinates): - counter = int(max([x[1:] for x in self.seg_error_coordinates.keys()])) + 1 + counter = max([int(x[1:]) for x in self.seg_error_coordinates.keys()]) + 1 else: counter = 0 id_ = id_ + str(counter) @@ -442,8 +443,14 @@ def store_error_location( self.update_msg("You have not entered a first coord yet") return - if idx == 0 and self.temp_coord_list: - self.temp_coord_list = [] + if idx == 1 and self.cur_error_type != mode: + self.update_msg("error type of first and second location do not match") + return + + if idx == 0: + self.cur_error_type = mode + if self.temp_coord_list: + self.temp_coord_list = [] self.temp_coord_list.append(location) @@ -457,6 +464,7 @@ def store_error_location( self.seg_error_coordinates.update({identifier: self.temp_coord_list}) self.annotate_error_locations(self.temp_coord_list, identifier) self.temp_coord_list = [] + self.cur_error_type = None def annotate_error_locations( self, coordinates: list[list[int]], error_id: str @@ -541,7 +549,7 @@ def delete_annotation(self, to_remove: frozenset[str]) -> None: def delete_last_location(self) -> None: """Delete the last error location pair tagged.""" - last_key = next(reversed(self.seg_error_coordinates)) + last_key = next(reversed(self.seg_error_coordinates.keys())) del self.seg_error_coordinates[last_key] to_remove = frozenset([last_key + "_0", last_key + "_1"])