|
| 1 | +from typing import Any, Dict, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from ...decoders import Decoder |
| 6 | +from ...modules.misc_modules import StyleReshape |
| 7 | +from ._base_model import BaseMultiTaskSegModel |
| 8 | +from ._seg_head import SegHead |
| 9 | +from ._timm_encoder import TimmEncoder |
| 10 | + |
| 11 | +__all__ = ["MultiTaskUnet"] |
| 12 | + |
| 13 | + |
| 14 | +class MultiTaskUnet(BaseMultiTaskSegModel): |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + decoders: Tuple[str, ...], |
| 18 | + heads: Dict[str, Dict[str, int]], |
| 19 | + n_layers: Dict[str, Tuple[int, ...]], |
| 20 | + n_blocks: Dict[str, Tuple[Tuple[int, ...], ...]], |
| 21 | + out_channels: Dict[str, Tuple[int, ...]], |
| 22 | + long_skips: Dict[str, str], |
| 23 | + dec_params: Dict[str, Tuple[Dict[str, Any], ...]], |
| 24 | + depth: int = 4, |
| 25 | + style_channels: int = 256, |
| 26 | + enc_name: str = "resnet50", |
| 27 | + enc_pretrain: bool = True, |
| 28 | + enc_freeze: bool = False, |
| 29 | + ) -> None: |
| 30 | + """Create a universal multi-task (2D) unet. |
| 31 | +
|
| 32 | + NOTE: For experimental purposes. |
| 33 | +
|
| 34 | + Parameters |
| 35 | + ---------- |
| 36 | + decoders : Tuple[str, ...] |
| 37 | + Names of the decoder branches of this network. E.g. ("cellpose", "sem") |
| 38 | + heads : Dict[str, Dict[str, int]] |
| 39 | + Names of the decoder branches (has to match `decoders`) mapped to dicts |
| 40 | + of output name - number of output classes. E.g. |
| 41 | + {"cellpose": {"type": 4, "cellpose": 2}, "sem": {"sem": 5}} |
| 42 | + n_layers : Dict[str, Tuple[int, ...]] |
| 43 | + The number of conv layers inside each of the decoder stages. |
| 44 | + n_blocks : Dict[str, Tuple[Tuple[int, ...], ...]] |
| 45 | + The number of blocks inside each conv-layer in each decoder stage. |
| 46 | + out_channels : Tuple[int, ...] |
| 47 | + Out channels for each decoder stage. |
| 48 | + long_skips : Dict[str, str] |
| 49 | + long skip method to be used. One of: "unet", "unetpp", "unet3p", |
| 50 | + "unet3p-lite", None |
| 51 | + dec_params : Dict[str, Tuple[Dict[str, Any], ...]]) |
| 52 | + The keyword args for each of the distinct decoder stages. Incudes the |
| 53 | + parameters for the long skip connections and convolutional layers of the |
| 54 | + decoder itself. See the `DecoderStage` documentation for more info. |
| 55 | + depth : int, default=4 |
| 56 | + The depth of the encoder. I.e. Number of returned feature maps from |
| 57 | + the encoder. Maximum depth = 5. |
| 58 | + style_channels : int, default=256 |
| 59 | + Number of style vector channels. If None, style vectors are ignored. |
| 60 | + enc_name : str, default="resnet50" |
| 61 | + Name of the encoder. See timm docs for more info. |
| 62 | + enc_pretrain : bool, default=True |
| 63 | + Whether to use imagenet pretrained weights in the encoder. |
| 64 | + enc_freeze : bool, default=False |
| 65 | + Freeze encoder weights for training. |
| 66 | + """ |
| 67 | + super().__init__() |
| 68 | + self.enc_freeze = enc_freeze |
| 69 | + use_style = style_channels is not None |
| 70 | + self.heads = heads |
| 71 | + |
| 72 | + # set timm encoder |
| 73 | + self.encoder = TimmEncoder( |
| 74 | + enc_name, |
| 75 | + depth=depth, |
| 76 | + pretrained=enc_pretrain, |
| 77 | + ) |
| 78 | + |
| 79 | + # style |
| 80 | + self.make_style = None |
| 81 | + if use_style: |
| 82 | + self.make_style = StyleReshape(self.encoder.out_channels[0], style_channels) |
| 83 | + |
| 84 | + # set decoders |
| 85 | + for decoder_name in decoders: |
| 86 | + decoder = Decoder( |
| 87 | + enc_channels=list(self.encoder.out_channels), |
| 88 | + style_channels=style_channels, |
| 89 | + out_channels=out_channels[decoder_name], |
| 90 | + long_skip=long_skips[decoder_name], |
| 91 | + n_layers=n_layers[decoder_name], |
| 92 | + n_blocks=n_blocks[decoder_name], |
| 93 | + stage_params=dec_params[decoder_name], |
| 94 | + ) |
| 95 | + self.add_module(f"{decoder_name}_decoder", decoder) |
| 96 | + |
| 97 | + # set heads |
| 98 | + for decoder_name in heads.keys(): |
| 99 | + for output_name, n_classes in heads[decoder_name].items(): |
| 100 | + seg_head = SegHead( |
| 101 | + in_channels=decoder.out_channels, |
| 102 | + out_channels=n_classes, |
| 103 | + kernel_size=1, |
| 104 | + ) |
| 105 | + self.add_module(f"{output_name}_seg_head", seg_head) |
| 106 | + |
| 107 | + self.name = f"MultiTaskUnet-{enc_name}" |
| 108 | + |
| 109 | + # init decoder weights |
| 110 | + self.initialize() |
| 111 | + |
| 112 | + # freeze encoder if specified |
| 113 | + if enc_freeze: |
| 114 | + self.freeze_encoder() |
| 115 | + |
| 116 | + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
| 117 | + """Forward pass of Multi-task U-net.""" |
| 118 | + self._check_input_shape(x) |
| 119 | + |
| 120 | + feats = self.encoder(x) |
| 121 | + |
| 122 | + style = None |
| 123 | + if self.make_style is not None: |
| 124 | + style = self.make_style(feats[0]) |
| 125 | + |
| 126 | + dec_feats = self.forward_dec_features(feats, style) |
| 127 | + |
| 128 | + for decoder_name in self.heads.keys(): |
| 129 | + for head_name in self.heads[decoder_name].keys(): |
| 130 | + k = self.aux_key if head_name not in dec_feats.keys() else head_name |
| 131 | + dec_feats[head_name] = dec_feats[k] |
| 132 | + |
| 133 | + out = self.forward_heads(dec_feats) |
| 134 | + |
| 135 | + return out |
0 commit comments