From shallow to deep: exploiting feature-based classifiers for domain adaptation in semantic segmentation
This repository provides a PyTorch implementation of our method from the paper:
@article{article,
author = {Matskevych, Alex and Wolny, Adrian and Pape, Constantin and Kreshuk, Anna},
year = {2022},
month = {03},
pages = {},
title = {From Shallow to Deep: Exploiting Feature-Based Classifiers for Domain Adaptation in Semantic Segmentation},
volume = {4},
journal = {Frontiers in Computer Science},
doi = {10.3389/fcomp.2022.805166}
}
It is a fork of the pytorch-3dunet of Adrian Wolny (can be found here) and adjusted for the following tasks:
- Simple Random Forest training and prediction (for the Prediction Enhancer)
- Prediction Enhancer
- Pseudo-label Net (for further domain adaptation with pseudo-labels), including
- Rectification losses
- Consistency model & losses
This module is designed to be able to work without conflict to the original fork (both should be usable at the same time).
This library can be used with anaconda. After installing anaconda, one can create a new environment with the environment.yml file by running
conda env create -f environment.yamland then activate the environment after installation by running
conda activate shallow2deepOne can then install the shallow2deep library into the environment by running
pip install .in the root folder.
You need to activate the environment everytime you want to use the following commands.
These steps are illustrated for the training dataset of the source dataset.
Run them for every dataset you want to use. (E.g. training and validation).
Note: The code for the creation of filters and RFs is written to work in 2D.
This means the RFs and filters are not taking 3D information into account.
To create filters, you can run
shallow2deep_create_filters \
--source_path /scratch/matskevy/epfl_mitos_2d/training.h5 \
--save_path /scratch/matskevy/epfl_to_vnc/epfl_filter_data/epfl_mitos_training_filtered.h5 \
--raw_internal_path raw \
--label_internal_path labelTo maximize later enhancer performance, it is important that all RFs that are trained later use the same filters. You are ensuring this by running this for every dataset that you need filters for / want to train the Enhancer on.
To generate RFs, you can run
shallow2deep_rf_generation \
--volume_path /scratch/matskevy/epfl_to_vnc/epfl_filter_data/epfl_mitos_training_filtered.h5 \
--rf_save_folder /scratch/matskevy/epfl_to_vnc/rf_data/epfl_training_rfs/ \
--rf_num 200 \
--min_patch_len 100 \
--max_patch_len 600 \
--label_internal_path labelWhile in the paper we used 1000 RFs, a number of 200 or 100 RFs should be sufficient for most cases.
Note: Usually, validation needs a much lesser amount of RFs (--rf_num can be adjusted to a single digit number),
otherwise you have to set validate_iters in the trainer config to avoid too many validation steps.
To predict with the generated RFs, you can run
shallow2deep_rf_prediction \
--volume_path /scratch/matskevy/epfl_to_vnc/epfl_filter_data/epfl_mitos_training_filtered.h5 \
--rf_path /scratch/matskevy/epfl_to_vnc/rf_data/epfl_training_rfs/ \
--save_folder /scratch/matskevy/epfl_to_vnc/rf_data/epfl_training_rf_preds/ \
--raw_internal_path raw \
--label_internal_path labelThis creates the RF predictions and thus the dataset for the Prediction Enhancer training.
Now the label data needs to be added to these predictions. This can be done by using the simple script
shallow2deep_add_volume_to_dataset \
--source_file /scratch/matskevy/epfl_mitos_2d/training.h5
--save_path /scratch/matskevy/epfl_to_vnc/rf_data/epfl_training_rf_preds/ \
--internal_path labelThe PE is a basic U-Net that can be trained from the previously generated RF predictions. It takes the RF predictions as
'raw' and trains to infer the true 'label'. It is trained by the same principles as most U-Nets and uses the original
functionality of the underlying pytorch-3dunet library. Example configs can be found in the resources folder:
shallow2deep_train --config /scratch/matskevy/shallow2deep/resources/2DUnet_epfl_enhancer/epfl_train_config.ymlWhen training in Ilastik, select the right filters (and right dimensions) to train the RF with self-made annotations on the target dataset.
You can then use it to make predictions for the target dataset(s) directly from Ilastik, or save the project and
use the headless mode to predict with a command like
/scratch/matskevy/ilastik/ilastik-1.4.0b20-Linux/run_ilastik.sh \
--headless \
--project=/scratch/matskevy/epfl_to_vnc/ilastik_data/vnc/ilastik_vnc_2d.ilp \
--output_format=hdf5 \
--output_filename_format=/scratch/matskevy/epfl_to_vnc/ilastik_data/vnc_prediction/ilastik_predicted_training.h5 \
--export_source "Probabilities" \
--input_axes=zyx /scratch/matskevy/vnc_2d_mitos/training.h5/rawIn this case /scratch/matskevy/ilastik/ilastik-1.4.0b20-Linux/run_ilastik.sh is the path to Ilastik.
It is important to have the foreground RF prediction probabilities (not just a binary prediction) for the PE
(as it was also trained with probabilities). With the headless prediction from above one can get two channels,
with one containing the foreground and the other the background prediction probabilities.
One can use the helper script
shallow2deep_extract_single_axis \
--source_path /scratch/matskevy/epfl_to_vnc/ilastik_data/vnc_prediction/ilastik_predicted_test.h5 \
--src_internal_path exported_data \
--tgt_internal_path fg_pred \
--axis -1 \
--channel_num 0to create foreground-prediction volumes in the prediction file with the key fg_pred.
Note: Normally the probabilities in this Ilastik binary prediction are aligned as (fg, bg) on the last axis,
so the default value for axis is -1 and channel_num is 0.
The PE can be now used on the Ilastik RF predictions on the target data to do inference for enhanced predictions.
This works like the usual NN prediction in the pytorch-3dunet library and an example config can be found in the resources folder:
shallow2deep_predict --config /scratch/matskevy/shallow2deep/resources/2DUnet_epfl_enhancer/epfl_test_config.ymlTo train a full segmentor with the enhanced predictions as pseudo-labels, one needs to prepare the enhanced predictions to serve as training files.
- Predictions have an additional channel dimension. To transform them to be the same dimension as the raw data, one can again use
shallow2deep_extract_single_axis \
--source_path /scratch/matskevy/epfl_to_vnc/nn_data/enhancer/epfl_to_vnc_mito_predictions/ilastik_predicted_training_predictions.h5 \
--src_internal_path predictions \
--tgt_internal_path pseudo_label \
--axis 0 \
--channel_num 0to create a volume without the channel dimension.
- One needs to add the
rawvolume to the enhanced predictions, e.g. by running
shallow2deep_add_volume_to_dataset \
--source_file /scratch/matskevy/vnc_2d_mitos/training.h5 \
--save_path /scratch/matskevy/epfl_to_vnc/nn_data/enhancer/epfl_to_vnc_mito_predictions/ilastik_predicted_training_predictions.h5 \
--internal_path rawWith the datasets both containing the necessary volumes raw and pseudo_label, one can start the Pseudo-label Net training.
Example configs for training can be found in the resources folder. The different training scenarios include
- Pseudo-label Net training
- Pseudo-label Net training with consistency (as described in the paper)
- Pseudo-label Net training with label rectification (as described in the paper)
- Pseudo-label Net training with consistency & label rectification
shallow2deep_train --config /scratch/matskevy/shallow2deep/resources/2DUnet_vnc_pseudo_label/train/train_config_consistency_rectification.ymlInference can be done as usual with the pytorch-3dunet library, as can be seen in the example config in the resources folder:
shallow2deep_predict --config /scratch/matskevy/shallow2deep/resources/2DUnet_vnc_pseudo_label/test/test_config_consistency_rectification.yml- Set the
--unequal_filter_orderflag in the RF prediction if the RFs have different sources and as filters could have different orders. - During filter creation, set the
--compute_boundariesflag to compute boundaries as labels in case the original labels are a segmentation. - Most validation metrics expect the labels to be binary, so they binarize the target before computing the metric.
This can lead to problems during Pseudo-label Net training. To combat this, the target is rounded to the nearest
integer before passed to a validation metric. If this procedure is not wanted,
one can set the
--do_not_round_targetflag in theeval_metricconfig. - If consistency (as defined in the paper) is used during training, the
final_activationis crucial during the training of the Pseudo-label Net. The final activation is applied to the EMA Net prediction to normalize it before computing the consistency loss.
