From 61a0e4229ca974327f7a29c53a1e2b47a72cf487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Goll?= Date: Tue, 17 Jun 2025 12:06:54 +0200 Subject: [PATCH 1/3] ENH: Update segmentation logic to support Reformat orientation --- .../ContextLogic.py | 15 +- .../SegmentationLogic.py | 146 ++++++++---------- 2 files changed, 68 insertions(+), 93 deletions(-) diff --git a/MultiverSeg/SegmentEditorMultiverSegLib/ContextLogic.py b/MultiverSeg/SegmentEditorMultiverSegLib/ContextLogic.py index 9dad762..f312d21 100644 --- a/MultiverSeg/SegmentEditorMultiverSegLib/ContextLogic.py +++ b/MultiverSeg/SegmentEditorMultiverSegLib/ContextLogic.py @@ -90,17 +90,12 @@ def saveNewExample(self, volume: vtkMRMLVolumeNode, view, segmentID, segmentatio imageArray = slicer.util.arrayFromVolume(volume).copy() maskArray = slicer.util.arrayFromSegmentBinaryLabelmap(segmentationNode, segmentID, volume) - IJKToRAS = np.zeros((3, 3)) - volume.GetIJKToRASDirections(IJKToRAS) - KJIToRAS = IJKToRAS.copy() - KJIToRAS[:, 0] = IJKToRAS[:, 2] - KJIToRAS[:, 2] = IJKToRAS[:, 0] + sliceNodeID = f"vtkMRMLSliceNode{view}" + sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) + axis = segLogic.computeSliceAxis(volume, sliceNode) - imageArray = segLogic.reorderAxisToRAS(imageArray, KJIToRAS) - maskArray = segLogic.reorderAxisToRAS(maskArray, KJIToRAS) - - imageTensor = torch.from_numpy(segLogic.extractSlice(imageArray, k, view)) - maskTensor = torch.from_numpy(segLogic.extractSlice(maskArray, k, view)) + imageTensor = torch.from_numpy(segLogic.extractSlice(imageArray, k, axis)) + maskTensor = torch.from_numpy(segLogic.extractSlice(maskArray, k, axis)) imageTensor = segLogic.preprocessSlice(imageTensor[None]) maskTensor = segLogic.preprocessSlice(maskTensor[None]) diff --git a/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py b/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py index ced719b..55a12f3 100644 --- a/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py +++ b/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py @@ -4,6 +4,7 @@ import numpy as np import slicer +import vtkAddon from MRMLCorePython import vtkMRMLSegmentationNode, vtkMRMLScalarVolumeNode, vtkMRMLSliceNode from numpy.ma.core import maximum @@ -134,9 +135,11 @@ def predict(self): KJIToRAS[:, 0] = IJKToRAS[:, 2] KJIToRAS[:, 2] = IJKToRAS[:, 0] - resultSegment = self.reorderAxisToRAS(resultSegment, KJIToRAS) - resultSegment = self.updateSlice(resultSegment, y, k) - resultSegment = self.invertAxisReordering(resultSegment, KJIToRAS) + sliceNodeID = f"vtkMRMLSliceNode{self.workingView}" + sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) + axis = self.computeSliceAxis(volumeNode, sliceNode) + + resultSegment = self.updateSlice(resultSegment, y, k, axis) slicer.util.updateSegmentBinaryLabelmapFromArray(resultSegment, segNode, segmentId, volumeNode) @@ -161,12 +164,9 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S posSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.posSegment) negSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.negSegment) - # Create the convertion matrix needed to handle slice selection correctly - IJKToRAS = np.zeros((3, 3)) - volumeNode.GetIJKToRASDirections(IJKToRAS) - KJIToRAS = IJKToRAS.copy() - KJIToRAS[:, 0] = IJKToRAS[:, 2] - KJIToRAS[:, 2] = IJKToRAS[:, 0] + sliceNodeID = f"vtkMRMLSliceNode{self.workingView}" + sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) + axis = self.computeSliceAxis(volumeNode, sliceNode) # Getting the different arrays # Array from slicer.util are K-J-I indexed @@ -175,17 +175,11 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S posArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, posSegId, volumeNode) negArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, negSegId, volumeNode) - # Reorder axis to be R-A-S indexed - imageArray = self.reorderAxisToRAS(imageArray, KJIToRAS) - resultSegment = self.reorderAxisToRAS(resultSegment, KJIToRAS) - posArray = self.reorderAxisToRAS(posArray, KJIToRAS) - negArray = self.reorderAxisToRAS(negArray, KJIToRAS) - # Extract the slice corresponding to the current view - imageSlice = self.extractSlice(imageArray, sliceNumber) - prevPredSlice = self.extractSlice(resultSegment, sliceNumber) - posSlice = self.extractSlice(posArray, sliceNumber) - negSlice = self.extractSlice(negArray, sliceNumber) + imageSlice = self.extractSlice(imageArray, sliceNumber, axis) + prevPredSlice = self.extractSlice(resultSegment, sliceNumber, axis) + posSlice = self.extractSlice(posArray, sliceNumber, axis) + negSlice = self.extractSlice(negArray, sliceNumber, axis) # Convertion to tensors imageTensor = torch.from_numpy(imageSlice) @@ -203,7 +197,6 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S scribbles = torch.cat((posTensor, negTensor), dim=0) - # print("Starting prediction") y = self.model.predict(imageTensor[None], scribbles=scribbles[None], mask_input=prevPredTensor[None], @@ -235,12 +228,7 @@ def predict3d(self): posSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.posSegment) negSegId = segNode.GetSegmentation().GetSegmentIdBySegment(self.negSegment) - # Create the convertion matrix needed to handle slice selection correctly - IJKToRAS = np.zeros((3, 3)) - volumeNode.GetIJKToRASDirections(IJKToRAS) - KJIToRAS = IJKToRAS.copy() - KJIToRAS[:, 0] = IJKToRAS[:, 2] - KJIToRAS[:, 2] = IJKToRAS[:, 0] + axis = self.computeSliceAxis(volumeNode, sliceNode) # Getting the different arrays # Array from slicer.util are K-J-I indexed @@ -249,13 +237,7 @@ def predict3d(self): posArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, posSegId, volumeNode) negArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, negSegId, volumeNode) - # Reorder axis to be R-A-S indexed - imageArray = self.reorderAxisToRAS(imageArray, KJIToRAS) - resultSegment = self.reorderAxisToRAS(resultSegment, KJIToRAS) - posArray = self.reorderAxisToRAS(posArray, KJIToRAS) - negArray = self.reorderAxisToRAS(negArray, KJIToRAS) - - imageSlice = self.extractSlice(imageArray, 0) + imageSlice = self.extractSlice(imageArray, 0, axis) originalDim = imageSlice.shape import torch @@ -267,10 +249,10 @@ def predict3d(self): negTensor = torch.from_numpy(negArray) # Pre process - imageTensor = self.preprocessVolume(imageTensor[None])[0] - posTensor = self.preprocessVolume(posTensor[None], isSegmentation=True)[0] - negTensor = self.preprocessVolume(negTensor[None], isSegmentation=True)[0] - prevPredTensor = self.preprocessVolume(prevPredTensor[None], isSegmentation=True)[0] + imageTensor = self.preprocessVolume(imageTensor[None], axis)[0] + posTensor = self.preprocessVolume(posTensor[None], axis, isSegmentation=True)[0] + negTensor = self.preprocessVolume(negTensor[None], axis, isSegmentation=True)[0] + prevPredTensor = self.preprocessVolume(prevPredTensor[None], axis, isSegmentation=True)[0] progressDialog = slicer.util.createProgressDialog(value=startSlice - 1, minimum=startSlice - 1, @@ -289,10 +271,10 @@ def predict3d(self): sliceLogic.SetSliceOffset(sliceOffset) # Extract the slice corresponding to the current view - imageSlice = self.extractSlice(imageTensor, sliceNumber)[None] - prevPredSlice = self.extractSlice(prevPredTensor, sliceNumber)[None] - posSlice = self.extractSlice(posTensor, sliceNumber)[None] - negSlice = self.extractSlice(negTensor, sliceNumber)[None] + imageSlice = self.extractSlice(imageTensor, sliceNumber, axis)[None] + prevPredSlice = self.extractSlice(prevPredTensor, sliceNumber, axis)[None] + posSlice = self.extractSlice(posTensor, sliceNumber, axis)[None] + negSlice = self.extractSlice(negTensor, sliceNumber, axis)[None] scribbles = torch.cat((posSlice, negSlice), dim=0) @@ -305,7 +287,7 @@ def predict3d(self): y = torchviz.functional.resize(y[0], originalDim)[0] y = self.thresholdPrediction(y) - resultSegment = self.updateSlice(resultSegment, y, sliceNumber) + resultSegment = self.updateSlice(resultSegment, y, sliceNumber, axis) progressDialog.setValue(sliceNumber) if progressDialog.wasCanceled: @@ -313,7 +295,6 @@ def predict3d(self): break slicer.app.processEvents() - resultSegment = self.invertAxisReordering(resultSegment, KJIToRAS) slicer.util.updateSegmentBinaryLabelmapFromArray(resultSegment, segNode, segmentId, volumeNode) def getCurrentSliceIndex(self, sliceColor): @@ -325,44 +306,52 @@ def getCurrentSliceIndex(self, sliceColor): sliceOffset = sliceLogic.GetSliceOffset() return sliceLogic.GetSliceIndexFromOffset(sliceOffset) - 1 # slice is 1-indexed - def reorderAxisToRAS(self, array: np.ndarray, directionMatrix: np.ndarray): - perm_order = np.argmax(np.abs(directionMatrix), axis=0) - return np.transpose(array, axes=perm_order) + def computeSliceAxis(self, volumeNode: vtkMRMLScalarVolumeNode, sliceNode: vtkMRMLSliceNode): + """ + Given the volume node and the slice node, find the axis of the volume which correspond to the stepping direction in the selected view. + :return: + :raise: + """ + # Get the slice normal vector in RAS + sliceToRAS = sliceNode.GetSliceToRAS() + sliceNormal = np.zeros(4) + vtkAddon.vtkAddonMathUtilities.GetOrientationMatrixColumn(sliceToRAS, 2, sliceNormal) + + # Get the KIJ to RAS matrix + IJKToRAS = np.zeros((3, 3)) + volumeNode.GetIJKToRASDirections(IJKToRAS) + KJIToRAS = IJKToRAS.copy() + KJIToRAS[:, 0] = IJKToRAS[:, 2] + KJIToRAS[:, 2] = IJKToRAS[:, 0] + + res = KJIToRAS.T @ sliceNormal[:3] + res = np.abs(res) + + if np.allclose(res, [1, 0, 0], atol=0.01): + return 0 + if np.allclose(res, [0, 1, 0], atol=0.01): + return 1 + if np.allclose(res, [0, 0, 1], atol=0.01): + return 2 + raise ValueError(f"View {self.workingView} is not axis aligned with the volume geometry") def invertAxisReordering(self, permutedArray: np.ndarray, directionMatrix: np.ndarray): perm_order = np.argmax(np.abs(directionMatrix), axis=0) inverse_order = np.argsort(perm_order) # Compute the inverse permutation return np.transpose(permutedArray, axes=inverse_order) - def extractSlice(self, array: np.ndarray, sliceNumber: int, sliceColor=None): - sliceNodeID = f"vtkMRMLSliceNode{self.workingView if sliceColor is None else sliceColor}" - sliceNode: vtkMRMLSliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) - - orientation = sliceNode.GetOrientation() - if orientation == "Axial": - orientationAx = 2 - elif orientation == "Sagittal": - orientationAx = 0 - elif orientation == "Coronal": - orientationAx = 1 - else: - raise ValueError(f"Orientation {orientation} is not supported") - - return np.take(array, sliceNumber, axis=orientationAx) - - def updateSlice(self, array: np.ndarray, newSlice: np.ndarray, sliceNumber: int): - sliceNodeID = f"vtkMRMLSliceNode{self.workingView}" - sliceNode: vtkMRMLSliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) + def extractSlice(self, array: np.ndarray, sliceNumber: int, axis: int): + return np.take(array, sliceNumber, axis=axis) - orientation = sliceNode.GetOrientation() - if orientation == "Axial": - array[:, :, sliceNumber] = newSlice - elif orientation == "Sagittal": + def updateSlice(self, array: np.ndarray, newSlice: np.ndarray, sliceNumber: int, axis: int): + if axis == 0: array[sliceNumber] = newSlice - elif orientation == "Coronal": + elif axis == 1: array[:, sliceNumber] = newSlice + elif axis == 2: + array[:, :, sliceNumber] = newSlice else: - raise ValueError(f"Orientation {orientation} is not supported") + slicer.util.errorDisplay(f"Error during segmentation update, axis {axis} was given") return array def preprocessSlice(self, slice: "torch.Tensor", isSegmentation=False): @@ -384,7 +373,7 @@ def preprocessSlice(self, slice: "torch.Tensor", isSegmentation=False): return result # 1*W*H - def preprocessVolume(self, volume: "torch.Tensor", isSegmentation=False): + def preprocessVolume(self, volume: "torch.Tensor", axis: int, isSegmentation=False): # volume indexed RAS of shape 1*X*Y*Z import torch if isSegmentation: @@ -392,19 +381,10 @@ def preprocessVolume(self, volume: "torch.Tensor", isSegmentation=False): else: targetDtype = torch.float16 - sliceNodeID = f"vtkMRMLSliceNode{self.workingView}" - sliceNode: vtkMRMLSliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) - orientation = sliceNode.GetOrientation() originalSize = volume.shape - if orientation == "Axial": - targetSize = [128, 128, originalSize[3]] - elif orientation == "Sagittal": - targetSize = [originalSize[1], 128, 128] - elif orientation == "Coronal": - targetSize = [128, originalSize[2], 128] - else: - raise ValueError(f"Orientation {orientation} is not supported") + targetSize = [128, 128, 128] + targetSize[axis] = originalSize[axis + 1] # Resizing result = torch.nn.functional.interpolate(volume[None].to(torch.float), targetSize, mode='trilinear').to( From ee86f85a7c6af1867e84ba4cb208771dddedfafd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Goll?= Date: Tue, 17 Jun 2025 12:07:18 +0200 Subject: [PATCH 2/3] TEST: Update test for axis information --- .../Python/SegmentationLogicTestCase.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/MultiverSeg/Testing/Python/SegmentationLogicTestCase.py b/MultiverSeg/Testing/Python/SegmentationLogicTestCase.py index 79576db..764d279 100644 --- a/MultiverSeg/Testing/Python/SegmentationLogicTestCase.py +++ b/MultiverSeg/Testing/Python/SegmentationLogicTestCase.py @@ -73,7 +73,7 @@ def test_volumePreprocess(self): logic.workingView = "Red" dummyVolume = torch.rand([1, 512, 512, 512]) - res = logic.preprocessVolume(dummyVolume) + res = logic.preprocessVolume(dummyVolume, 2) self.assertSequenceEqual(res.shape, [1, 128, 128, 512]) self.assertIs(res.dtype, torch.float16) self.assertAlmostEquals(torch.max(res).item(), 1) @@ -81,7 +81,7 @@ def test_volumePreprocess(self): logic.workingView = "Green" dummyVolume = torch.rand([1, 1000, 43, 100]) - res = logic.preprocessVolume(dummyVolume) + res = logic.preprocessVolume(dummyVolume, 1) self.assertSequenceEqual(res.shape, [1, 128, 43, 128]) self.assertIs(res.dtype, torch.float16) self.assertAlmostEquals(torch.max(res).item(), 1) @@ -89,7 +89,7 @@ def test_volumePreprocess(self): logic.workingView = "Yellow" dummyVolume = torch.randint(2, [1, 1000, 43, 100]) - res = logic.preprocessVolume(dummyVolume, isSegmentation=True) + res = logic.preprocessVolume(dummyVolume, 0, isSegmentation=True) self.assertSequenceEqual(res.shape, [1, 1000, 128, 128]) self.assertIs(res.dtype, torch.bool) @@ -100,19 +100,19 @@ def test_updateSlice(self): logic.workingView = "Red" updatedSlice = torch.zeros([55, 66]) - result = logic.updateSlice(baseVolume, updatedSlice, 25) + result = logic.updateSlice(baseVolume, updatedSlice, 25, 2) self.assertSequenceEqual(result.shape, [55, 66, 77]) self.assertEqual(torch.max(result[:, :, 25]).item(), 0) logic.workingView = "Green" updatedSlice = torch.zeros([55, 77]) - result = logic.updateSlice(baseVolume, updatedSlice, 20) + result = logic.updateSlice(baseVolume, updatedSlice, 20, 1) self.assertSequenceEqual(result.shape, [55, 66, 77]) self.assertEqual(torch.max(result[:, 20]).item(), 0) logic.workingView = "Yellow" updatedSlice = torch.zeros([66, 77]) - result = logic.updateSlice(baseVolume, updatedSlice, 30) + result = logic.updateSlice(baseVolume, updatedSlice, 30, 0) self.assertSequenceEqual(result.shape, [55, 66, 77]) self.assertEqual(torch.max(result[30]).item(), 0) @@ -122,15 +122,15 @@ def test_extractSlice(self): baseVolume = torch.rand([55, 66, 77]) logic.workingView = "Red" - result = logic.extractSlice(baseVolume, 25) + result = logic.extractSlice(baseVolume, 25,2) self.assertSequenceEqual(result.shape, [55, 66]) logic.workingView = "Green" - result = logic.extractSlice(baseVolume, 20) + result = logic.extractSlice(baseVolume, 20,1) self.assertSequenceEqual(result.shape, [55, 77]) logic.workingView = "Yellow" - result = logic.extractSlice(baseVolume, 30) + result = logic.extractSlice(baseVolume, 30,0) self.assertSequenceEqual(result.shape, [66, 77]) def test_rawPredict(self): From 640b34e3a850d53c35a3d6836de1f57a9cb65375 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Goll?= Date: Tue, 17 Jun 2025 12:21:48 +0200 Subject: [PATCH 3/3] DOC: Add doc to segmentation logic class --- .../SegmentationLogic.py | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py b/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py index 55a12f3..ab467e3 100644 --- a/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py +++ b/MultiverSeg/SegmentEditorMultiverSegLib/SegmentationLogic.py @@ -38,6 +38,9 @@ def __init__(self, scriptedEffect): self.sliceOffsetRange = (0., 0.) def initSegments(self): + """ + Initialize the segments by creating the positive and negative segments. + """ # Get the current segment self.segmentationNode: vtkMRMLSegmentationNode = self.scriptedEffect.parameterSetNode().GetSegmentationNode() segmentation: vtkSegmentation = self.segmentationNode.GetSegmentation() @@ -59,6 +62,10 @@ def initSegments(self): segmentation.AddSegment(self.negSegment) def initModel(self): + """ + Verify the dependencies and initialize the model. + :return: True if the initialization was successful, False otherwise. + """ from .InstallLogic import InstallLogic, DependenciesLogic progress = slicer.util.createProgressDialog(maximum=10, labelText="Verifying dependencies") @@ -97,6 +104,9 @@ def initModel(self): return True def reset(self): + """ + Remove the pos and neg segments and reset the internal state of the logic. + """ if self.segmentationNode is None: return @@ -113,6 +123,9 @@ def setOffsetRange(self, min: float, max: float): self.sliceOffsetRange = (min, max) def predict(self): + """ + Launch a 2D prediction for the current slice and view. + """ # Get the slice number import torchvision.transforms.v2 as torchviz @@ -144,11 +157,17 @@ def predict(self): slicer.util.updateSegmentBinaryLabelmapFromArray(resultSegment, segNode, segmentId, volumeNode) def thresholdPrediction(self, prediction: "torch.Tensor", threshold=0.5): + """ + Apply a threshold to the prediction. + """ prediction[prediction < threshold] = 0 prediction[prediction >= threshold] = 1 return prediction def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.Size"]: + """ + Make a prediction for a 2D slice without post-processing + """ # return the raw prediction and the original dimension of the slice (for resizing) import torch # Load the context @@ -206,6 +225,9 @@ def rawPredictForSlice(self, sliceNumber: int) -> tuple["torch.Tensor", "torch.S return y, originalDim def predict3d(self): + """ + Make a 3D prediction + """ sliceNodeID = f"vtkMRMLSliceNode{self.workingView}" sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) @@ -298,6 +320,9 @@ def predict3d(self): slicer.util.updateSegmentBinaryLabelmapFromArray(resultSegment, segNode, segmentId, volumeNode) def getCurrentSliceIndex(self, sliceColor): + """ + Get the index of the current slice for the view sliceColor based on the offset value. + """ sliceNodeID = f"vtkMRMLSliceNode{sliceColor}" sliceNode = slicer.mrmlScene.GetNodeByID(sliceNodeID) @@ -309,8 +334,6 @@ def getCurrentSliceIndex(self, sliceColor): def computeSliceAxis(self, volumeNode: vtkMRMLScalarVolumeNode, sliceNode: vtkMRMLSliceNode): """ Given the volume node and the slice node, find the axis of the volume which correspond to the stepping direction in the selected view. - :return: - :raise: """ # Get the slice normal vector in RAS sliceToRAS = sliceNode.GetSliceToRAS() @@ -335,15 +358,14 @@ def computeSliceAxis(self, volumeNode: vtkMRMLScalarVolumeNode, sliceNode: vtkMR return 2 raise ValueError(f"View {self.workingView} is not axis aligned with the volume geometry") - def invertAxisReordering(self, permutedArray: np.ndarray, directionMatrix: np.ndarray): - perm_order = np.argmax(np.abs(directionMatrix), axis=0) - inverse_order = np.argsort(perm_order) # Compute the inverse permutation - return np.transpose(permutedArray, axes=inverse_order) - def extractSlice(self, array: np.ndarray, sliceNumber: int, axis: int): + """Extract the slice sliceNumber from the array given an axis""" return np.take(array, sliceNumber, axis=axis) def updateSlice(self, array: np.ndarray, newSlice: np.ndarray, sliceNumber: int, axis: int): + """ + Replace the slice in array by the newSlice. sliceNumber and axis are for positional information. + """ if axis == 0: array[sliceNumber] = newSlice elif axis == 1: @@ -355,6 +377,9 @@ def updateSlice(self, array: np.ndarray, newSlice: np.ndarray, sliceNumber: int, return array def preprocessSlice(self, slice: "torch.Tensor", isSegmentation=False): + """ + Preprocess a 2d slice for the model. If isSegmentation, the resulting Tensor in of type bool + """ # Slice of dim of shape 1*W*H import torch import torchvision.transforms.v2 as torchviz @@ -374,6 +399,9 @@ def preprocessSlice(self, slice: "torch.Tensor", isSegmentation=False): return result # 1*W*H def preprocessVolume(self, volume: "torch.Tensor", axis: int, isSegmentation=False): + """ + Apply the preprocessing pipeline on a full volume given an axis. The direction of the axis is not rescaled to allow stepping through each slice. + """ # volume indexed RAS of shape 1*X*Y*Z import torch if isSegmentation: