55** Python library for 2D cell/nuclei instance segmentation models written with [ PyTorch] ( https://pytorch.org/ ) .**
66
77[ ![ Generic badge] ( https://img.shields.io/badge/License-MIT-<COLOR>.svg?style=for-the-badge )] ( https://github.com/okunator/cellseg_models.pytorch/blob/master/LICENSE )
8- [ ![ PyTorch - Version] ( https://img.shields.io/badge/PYTORCH-1.8+-red?style=for-the-badge&logo=pytorch )] ( https://pytorch.org/ )
8+ [ ![ PyTorch - Version] ( https://img.shields.io/badge/PYTORCH-1.8.1 +-red?style=for-the-badge&logo=pytorch )] ( https://pytorch.org/ )
99[ ![ Python - Version] ( https://img.shields.io/badge/PYTHON-3.7+-red?style=for-the-badge&logo=python&logoColor=white )] ( https://www.python.org/ )
1010<br >
1111[ ![ Github Test] ( https://img.shields.io/github/workflow/status/okunator/cellseg_models.pytorch/Tests?label=Tests&logo=github&style=for-the-badge )] ( https://github.com/okunator/cellseg_models.pytorch/actions/workflows/tests.yml )
@@ -51,10 +51,9 @@ pip install cellseg-models-pytorch[all]
5151- Pre-trained backbones/encoders from the [ timm] ( https://github.com/rwightman/pytorch-image-models ) library.
5252- All the architectures can be augmented to output semantic segmentation outputs along with instance semgentation outputs (panoptic segmentation).
5353- A lot of flexibility to modify the components of the model architectures.
54- - Optimized inference methods .
54+ - Multi-GPU inference.
5555- Popular training losses and benchmarking metrics.
5656- Simple model training with [ pytorch-lightning] ( https://www.pytorchlightning.ai/ ) .
57- - Popular optimizers for training (provided by [ pytorch-optimizer] ( https://github.com/jettify/pytorch-optimizer ) ).
5857
5958## Models
6059
@@ -85,10 +84,10 @@ pip install cellseg-models-pytorch[all]
8584import cellseg_models_pytorch as csmp
8685import torch
8786
88- model = csmp.models.cellpose_base(type_classes = 5 ) # num of cell types in training data=5.
87+ model = csmp.models.cellpose_base(type_classes = 5 )
8988x = torch.rand([1 , 3 , 256 , 256 ])
9089
91- # NOTE : these outputs still need post-processing to obtain instance segmentation masks .
90+ # NOTE : the outputs still need post-processing.
9291y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256]}
9392```
9493
@@ -98,10 +97,10 @@ y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256]}
9897import cellseg_models_pytorch as csmp
9998import torch
10099
101- model = csmp.models.cellpose_plus(type_classes = 5 , sem_classes = 3 ) # num cell types and tissue types
100+ model = csmp.models.cellpose_plus(type_classes = 5 , sem_classes = 3 )
102101x = torch.rand([1 , 3 , 256 , 256 ])
103102
104- # NOTE : these outputs still need post-processing to obtain instance and semantic segmentation masks .
103+ # NOTE : the outputs still need post-processing.
105104y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [1, 3, 256, 256]}
106105```
107106
@@ -110,27 +109,37 @@ y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [
110109``` python
111110import cellseg_models_pytorch as csmp
112111
112+ # two decoder branches.
113+ decoders = (" cellpose" , " sem" )
114+
115+ # three segmentation heads from the decoders.
116+ heads = {
117+ " cellpose" : {" cellpose" : 2 , " type" : 5 },
118+ " sem" : {" sem" : 3 }
119+ }
120+
113121model = csmp.CellPoseUnet(
114- decoders = ( " cellpose " , " sem " ), # cellpose and semantic decoders
115- heads = { " cellpose " : { " cellpose " : 2 , " type " : 5 }, " sem " : { " sem " : 3 }}, # three output heads
116- depth = 5 , # encoder depth
117- out_channels = (256 , 128 , 64 , 32 , 16 ), # number of out channels at each decoder stage
118- layer_depths = (4 , 4 , 4 , 4 , 4 ), # number of conv blocks at each decoder layer
119- style_channels = 256 , # Number of style vector channels
120- enc_name = " resnet50" , # timm encoder
121- enc_pretrain = True , # imagenet pretrained encoder
122- long_skip = " unetpp" , # use unet++ long skips. ("unet", "unetpp", "unet3p")
123- merge_policy = " sum" , # ("cat", "sum")
124- short_skip = " residual" , # residual short skips. ("basic", "residual", "dense")
125- normalization = " bcn" , # batch-channel-normalization. ("bcn", "bn", "gn", "ln", "in")
126- activation = " gelu" , # gelu activation instead of relu. Several options for this .
127- convolution = " wsconv" , # weight standardized conv. ("wsconv", "conv", "scaled_wsconv")
128- attention = " se" , # squeeze-and-excitation attention. ("se", "gc", "scse", "eca")
129- pre_activate = False , # normalize and activation after convolution.
122+ decoders = decoders, # cellpose and semantic decoders
123+ heads = heads, # three output heads
124+ depth = 5 , # encoder depth
125+ out_channels = (256 , 128 , 64 , 32 , 16 ), # num out channels at each decoder stage
126+ layer_depths = (4 , 4 , 4 , 4 , 4 ), # num of conv blocks at each decoder layer
127+ style_channels = 256 , # num of style vector channels
128+ enc_name = " resnet50" , # timm encoder
129+ enc_pretrain = True , # imagenet pretrained encoder
130+ long_skip = " unetpp" , # unet++ long skips ("unet", "unetpp", "unet3p")
131+ merge_policy = " sum" , # concatenate long skips ("cat", "sum")
132+ short_skip = " residual" , # residual short skips ("basic", "residual", "dense")
133+ normalization = " bcn" , # batch-channel-normalization.
134+ activation = " gelu" , # gelu activation .
135+ convolution = " wsconv" , # weight standardized conv.
136+ attention = " se" , # squeeze-and-excitation attention.
137+ pre_activate = False , # normalize and activation after convolution.
130138)
131139
132140x = torch.rand([1 , 3 , 256 , 256 ])
133- # NOTE : these outputs still need post-processing to obtain instance and semantic segmentation masks.
141+
142+ # NOTE : the outputs still need post-processing.
134143y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [1, 3, 256, 256]}
135144```
136145
@@ -142,13 +151,20 @@ import cellseg_models_pytorch as csmp
142151model = csmp.models.hovernet_base(type_classes = 5 )
143152# returns {"hovernet": [B, 2, H, W], "type": [B, 5, H, W], "inst": [B, 2, H, W]}
144153
154+ # the final activations for each model output
155+ out_activations = {" hovernet" : " tanh" , " type" : " softmax" , " inst" : " softmax" }
156+
157+ # models perform the poorest at the image boundaries, with overlapping patches this
158+ # causes issues which can be overcome by adding smoothing to the prediction boundaries
159+ out_boundary_weights = {" hovernet" : True , " type" : False , " inst" : False }
160+
145161# Sliding window inference for big images using overlapping patches
146162inferer = csmp.inference.SlidingWindowInferer(
147163 model = model,
148164 input_folder = " /path/to/images/" ,
149165 checkpoint_path = " /path/to/model/weights/" ,
150- out_activations = { " hovernet " : " tanh " , " type " : " softmax " , " inst " : " softmax " } ,
151- out_boundary_weights = { " hovernet " : True , " type " : False , " inst " : False }, # smooths boundary effects
166+ out_activations = out_activations ,
167+ out_boundary_weights = out_boundary_weights,
152168 instance_postproc = " hovernet" , # THE POST-PROCESSING METHOD
153169 patch_size = (256 , 256 ),
154170 stride = 128 ,
@@ -157,7 +173,8 @@ inferer = csmp.inference.SlidingWindowInferer(
157173 normalization = " percentile" , # same normalization as in training
158174)
159175
160- inferer.infer() # Run sliding window inference.
176+ # Run sliding window inference.
177+ inferer.infer()
161178
162179inferer.out_masks
163180# {"image1" :{"inst": [H, W], "type": [H, W]}, ..., "imageN" :{"inst": [H, W], "type": [H, W]}}
0 commit comments