From b297a4085197653ad8aaafdd7cfd86367b0859a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 2 Jan 2023 15:36:59 +0100 Subject: [PATCH 01/15] added jax backend support for macenko --- torchstain/base/normalizers/macenko.py | 3 + torchstain/jax/__init__.py | 1 + torchstain/jax/normalizers/__init__.py | 1 + torchstain/jax/normalizers/macenko.py | 105 +++++++++++++++++++++++++ torchstain/jax/utils/__init__.py | 0 torchstain/tf/normalizers/macenko.py | 1 - 6 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 torchstain/jax/__init__.py create mode 100644 torchstain/jax/normalizers/__init__.py create mode 100644 torchstain/jax/normalizers/macenko.py create mode 100644 torchstain/jax/utils/__init__.py diff --git a/torchstain/base/normalizers/macenko.py b/torchstain/base/normalizers/macenko.py index 2c0c988..db806e3 100644 --- a/torchstain/base/normalizers/macenko.py +++ b/torchstain/base/normalizers/macenko.py @@ -8,5 +8,8 @@ def MacenkoNormalizer(backend='torch'): elif backend == "tensorflow": from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer return TensorFlowMacenkoNormalizer() + elif backend == "jax": + from torchstain.jax.normalizers.macenko import JaxMacenkoNormalizer + return JaxMacenkoNormalizer() else: raise Exception(f'Unknown backend {backend}') diff --git a/torchstain/jax/__init__.py b/torchstain/jax/__init__.py new file mode 100644 index 0000000..c36ce2a --- /dev/null +++ b/torchstain/jax/__init__.py @@ -0,0 +1 @@ +from torchstain.jax import normalizers, utils diff --git a/torchstain/jax/normalizers/__init__.py b/torchstain/jax/normalizers/__init__.py new file mode 100644 index 0000000..de86fc2 --- /dev/null +++ b/torchstain/jax/normalizers/__init__.py @@ -0,0 +1 @@ +from torchstain.jax.normalizers.macenko import JaxMacenkoNormalizer diff --git a/torchstain/jax/normalizers/macenko.py b/torchstain/jax/normalizers/macenko.py new file mode 100644 index 0000000..308b1a7 --- /dev/null +++ b/torchstain/jax/normalizers/macenko.py @@ -0,0 +1,105 @@ +import jax +from jax import numpy as jnp +from torchstain.base.normalizers import HENormalizer + + +class JaxMacenkoNormalizer(HENormalizer): + def __init__(self): + super().__init__() + + self.HERef = jnp.array([[0.5626, 0.2159], + [0.7201, 0.8012], + [0.4062, 0.5581]]) + self.maxCRef = jnp.array([1.9705, 1.0308]) + + def __convert_rgb2od(self, I, Io=240, beta=0.15): + # calculate optical density + OD = -jnp.log((I.astype(jnp.float32) + 1) / Io) + + # remove transparent pixels + ODhat = OD[~jnp.any(OD < beta, axis=1)] + + return OD, ODhat + + def __find_HE(self, ODhat, eigvecs, alpha): + # project on the plane spanned by the eigenvectors corresponding to the two + # largest eigenvalues + That = ODhat.dot(eigvecs[:, 1:3]) + + phi = jnp.arctan2(That[:, 1], That[:, 0]) + + minPhi = jnp.percentile(phi, alpha) + maxPhi = jnp.percentile(phi, 100 - alpha) + + vMin = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(minPhi), jnp.sin(minPhi))]).T) + vMax = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(maxPhi), jnp.sin(maxPhi))]).T) + + # a heuristic to make the vector corresponding to hematoxylin first and the + # one corresponding to eosin second + if vMin[0] > vMax[0]: + HE = jnp.array((vMin[:, 0], vMax[:, 0])).T + else: + HE = jnp.array((vMax[:, 0], vMin[:, 0])).T + + return HE + + def __find_concentration(self, OD, HE): + # rows correspond to channels (RGB), columns to OD values + Y = jnp.reshape(OD, (-1, 3)).T + + # determine concentrations of the individual stains + C = jnp.linalg.lstsq(HE, Y, rcond=None)[0] + + return C + + def __compute_matrices(self, I, Io, alpha, beta): + I = I.reshape((-1, 3)) + + OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) + + # compute eigenvectors + _, eigvecs = jnp.linalg.eigh(jnp.cov(ODhat.T)) + + HE = self.__find_HE(ODhat, eigvecs, alpha) + + C = self.__find_concentration(OD, HE) + + # normalize stain concentrations + maxC = jnp.array([jnp.percentile(C[0, :], 99), jnp.percentile(C[1, :],99)]) + + return HE, C, maxC + + def fit(self, I, Io=240, alpha=1, beta=0.15): + HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) + + self.HERef = HE + self.maxCRef = maxC + + def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): + h, w, c = I.shape + I = I.reshape((-1,3)) + + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + maxC = jnp.divide(maxC, self.maxCRef) + C2 = jnp.divide(C, maxC[:, jnp.newaxis]) + + # recreate the image using reference mixing matrix + Inorm = jnp.multiply(Io, jnp.exp(-self.HERef.dot(C2))) + Inorm.at[Inorm > 255].set(255) + Inorm = jnp.reshape(Inorm.T, (h, w, c)).astype(jnp.uint8) + + + H, E = None, None + + if stains: + # unmix hematoxylin and eosin + H = jnp.multiply(Io,jnp.exp(jnp.expand_dims(-self.HERef[:, 0], axis=1).dot(jnp.expand_dims(C2[0, :], axis=0)))) + H.at[H > 255].set(255) + H = jnp.reshape(H.T, (h, w, c)).astype(jnp.uint8) + + E = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 1], axis=1).dot(jnp.expand_dims(C2[1, :], axis=0)))) + E.at[E > 255].set(255) + E = jnp.reshape(E.T, (h, w, c)).astype(jnp.uint8) + + return Inorm, H, E diff --git a/torchstain/jax/utils/__init__.py b/torchstain/jax/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchstain/tf/normalizers/macenko.py b/torchstain/tf/normalizers/macenko.py index bf21666..578560a 100644 --- a/torchstain/tf/normalizers/macenko.py +++ b/torchstain/tf/normalizers/macenko.py @@ -1,7 +1,6 @@ import tensorflow as tf from torchstain.base.normalizers.he_normalizer import HENormalizer from torchstain.tf.utils import cov, percentile, solveLS -import numpy as np import tensorflow.keras.backend as K From 07c2c5849d9ee05bef1d855f3440d20f40aab23c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 3 Jan 2023 05:36:09 +0100 Subject: [PATCH 02/15] fixed bug in jax macenko --- torchstain/jax/normalizers/macenko.py | 34 +++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/torchstain/jax/normalizers/macenko.py b/torchstain/jax/normalizers/macenko.py index 308b1a7..0281059 100644 --- a/torchstain/jax/normalizers/macenko.py +++ b/torchstain/jax/normalizers/macenko.py @@ -1,6 +1,8 @@ import jax +from jax import lax from jax import numpy as jnp from torchstain.base.normalizers import HENormalizer +from functools import partial class JaxMacenkoNormalizer(HENormalizer): @@ -17,7 +19,13 @@ def __convert_rgb2od(self, I, Io=240, beta=0.15): OD = -jnp.log((I.astype(jnp.float32) + 1) / Io) # remove transparent pixels - ODhat = OD[~jnp.any(OD < beta, axis=1)] + #ODhat = OD[~jnp.any(OD < beta, axis=1)] + + # jax dont support dynamic shapes, but this: https://stackoverflow.com/a/71694754 + # @FIXME: Not identical to numpy approach above! + mask = ~jnp.any(OD < beta, axis=1) + indices = jnp.where(mask, size=len(mask), fill_value=255) + ODhat = OD.at[indices].get() # mode="fill", fill_value=0) return OD, ODhat @@ -36,10 +44,12 @@ def __find_HE(self, ODhat, eigvecs, alpha): # a heuristic to make the vector corresponding to hematoxylin first and the # one corresponding to eosin second - if vMin[0] > vMax[0]: - HE = jnp.array((vMin[:, 0], vMax[:, 0])).T - else: - HE = jnp.array((vMax[:, 0], vMin[:, 0])).T + HE = lax.cond( + vMin[0, 0] > vMax[0, 0], + lambda x: jnp.array((x[0], x[1])).T, + lambda x: jnp.array((x[0], x[1])).T, + (vMin[:, 0], vMax[:, 0]) + ) return HE @@ -75,9 +85,10 @@ def fit(self, I, Io=240, alpha=1, beta=0.15): self.HERef = HE self.maxCRef = maxC + @partial(jax.jit, static_argnums=(0,)) def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): h, w, c = I.shape - I = I.reshape((-1,3)) + I = I.reshape((-1, 3)) HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) @@ -86,20 +97,19 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): # recreate the image using reference mixing matrix Inorm = jnp.multiply(Io, jnp.exp(-self.HERef.dot(C2))) - Inorm.at[Inorm > 255].set(255) + Inorm = jnp.clip(Inorm, 0, 255) Inorm = jnp.reshape(Inorm.T, (h, w, c)).astype(jnp.uint8) - H, E = None, None - if stains: + if False: # unmix hematoxylin and eosin - H = jnp.multiply(Io,jnp.exp(jnp.expand_dims(-self.HERef[:, 0], axis=1).dot(jnp.expand_dims(C2[0, :], axis=0)))) - H.at[H > 255].set(255) + H = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 0], axis=1).dot(jnp.expand_dims(C2[0, :], axis=0)))) + H = jnp.clip(H, 0, 255) H = jnp.reshape(H.T, (h, w, c)).astype(jnp.uint8) E = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 1], axis=1).dot(jnp.expand_dims(C2[1, :], axis=0)))) - E.at[E > 255].set(255) + E = jnp.clip(E, 0, 255) E = jnp.reshape(E.T, (h, w, c)).astype(jnp.uint8) return Inorm, H, E From a20c69d031983a23472c672bd6047e50cce2a51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 3 Jan 2023 07:01:57 +0100 Subject: [PATCH 03/15] fixed masking bug in jax macenko --- torchstain/jax/normalizers/macenko.py | 60 ++++++++++----------------- 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/torchstain/jax/normalizers/macenko.py b/torchstain/jax/normalizers/macenko.py index 0281059..7c663a2 100644 --- a/torchstain/jax/normalizers/macenko.py +++ b/torchstain/jax/normalizers/macenko.py @@ -14,30 +14,35 @@ def __init__(self): [0.4062, 0.5581]]) self.maxCRef = jnp.array([1.9705, 1.0308]) - def __convert_rgb2od(self, I, Io=240, beta=0.15): + def __find_concentration(self, OD, HE): + # rows correspond to channels (RGB), columns to OD values + Y = jnp.reshape(OD, (-1, 3)).T + + # determine concentrations of the individual stains + C = jnp.linalg.lstsq(HE, Y, rcond=None)[0] + + return C + + def __compute_matrices(self, I, Io, alpha, beta): + I = I.reshape((-1, 3)) + # calculate optical density OD = -jnp.log((I.astype(jnp.float32) + 1) / Io) - # remove transparent pixels - #ODhat = OD[~jnp.any(OD < beta, axis=1)] - - # jax dont support dynamic shapes, but this: https://stackoverflow.com/a/71694754 - # @FIXME: Not identical to numpy approach above! + # compute eigenvectors mask = ~jnp.any(OD < beta, axis=1) - indices = jnp.where(mask, size=len(mask), fill_value=255) - ODhat = OD.at[indices].get() # mode="fill", fill_value=0) + cov = jnp.cov(OD.T, fweights=mask.astype(jnp.int32)) + _, eigvecs = jnp.linalg.eigh(cov) - return OD, ODhat + Th = OD.dot(eigvecs[:, 1:3]) - def __find_HE(self, ODhat, eigvecs, alpha): - # project on the plane spanned by the eigenvectors corresponding to the two - # largest eigenvalues - That = ODhat.dot(eigvecs[:, 1:3]) + phi = jnp.arctan2(Th[:, 1], Th[:, 0]) - phi = jnp.arctan2(That[:, 1], That[:, 0]) + phi = jnp.where(mask, phi, jnp.inf) + pvalid = mask.mean() # proportion that is valid and not masked - minPhi = jnp.percentile(phi, alpha) - maxPhi = jnp.percentile(phi, 100 - alpha) + minPhi = jnp.percentile(phi, alpha * pvalid) + maxPhi = jnp.percentile(phi, (100 - alpha) * pvalid) vMin = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(minPhi), jnp.sin(minPhi))]).T) vMax = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(maxPhi), jnp.sin(maxPhi))]).T) @@ -51,27 +56,6 @@ def __find_HE(self, ODhat, eigvecs, alpha): (vMin[:, 0], vMax[:, 0]) ) - return HE - - def __find_concentration(self, OD, HE): - # rows correspond to channels (RGB), columns to OD values - Y = jnp.reshape(OD, (-1, 3)).T - - # determine concentrations of the individual stains - C = jnp.linalg.lstsq(HE, Y, rcond=None)[0] - - return C - - def __compute_matrices(self, I, Io, alpha, beta): - I = I.reshape((-1, 3)) - - OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) - - # compute eigenvectors - _, eigvecs = jnp.linalg.eigh(jnp.cov(ODhat.T)) - - HE = self.__find_HE(ODhat, eigvecs, alpha) - C = self.__find_concentration(OD, HE) # normalize stain concentrations @@ -102,7 +86,7 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): H, E = None, None - if False: + if stains: # unmix hematoxylin and eosin H = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 0], axis=1).dot(jnp.expand_dims(C2[0, :], axis=0)))) H = jnp.clip(H, 0, 255) From 2366cc55ded2691d0e50ee2c40ba17c71e92ad7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Sun, 22 Jan 2023 22:45:56 +0100 Subject: [PATCH 04/15] made Jax Macenko class a PyTree -> Jax compatible [no ci] --- torchstain/jax/normalizers/macenko.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/torchstain/jax/normalizers/macenko.py b/torchstain/jax/normalizers/macenko.py index 7c663a2..3f9e89a 100644 --- a/torchstain/jax/normalizers/macenko.py +++ b/torchstain/jax/normalizers/macenko.py @@ -2,16 +2,15 @@ from jax import lax from jax import numpy as jnp from torchstain.base.normalizers import HENormalizer -from functools import partial - +from jax import tree_util class JaxMacenkoNormalizer(HENormalizer): def __init__(self): super().__init__() self.HERef = jnp.array([[0.5626, 0.2159], - [0.7201, 0.8012], - [0.4062, 0.5581]]) + [0.7201, 0.8012], + [0.4062, 0.5581]]) self.maxCRef = jnp.array([1.9705, 1.0308]) def __find_concentration(self, OD, HE): @@ -55,7 +54,6 @@ def __compute_matrices(self, I, Io, alpha, beta): lambda x: jnp.array((x[0], x[1])).T, (vMin[:, 0], vMax[:, 0]) ) - C = self.__find_concentration(OD, HE) # normalize stain concentrations @@ -63,13 +61,14 @@ def __compute_matrices(self, I, Io, alpha, beta): return HE, C, maxC + @jax.jit def fit(self, I, Io=240, alpha=1, beta=0.15): HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) self.HERef = HE self.maxCRef = maxC - @partial(jax.jit, static_argnums=(0,)) + @jax.jit def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): h, w, c = I.shape I = I.reshape((-1, 3)) @@ -97,3 +96,16 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): E = jnp.reshape(E.T, (h, w, c)).astype(jnp.uint8) return Inorm, H, E + + def _tree_flatten(self): + children = () # arrays / dynamic values + aux = () # static values + return (), () + + @classmethod + def _tree_unflatten(cls, aux, children): + return cls(*children, *aux) + +tree_util.register_pytree_node( + JaxMacenkoNormalizer, JaxMacenkoNormalizer._tree_flatten,JaxMacenkoNormalizer._tree_unflatten +) From 4eddfc93be96e5bcdc951787444bbdbd41ad8439 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Sun, 22 Jan 2023 23:11:20 +0100 Subject: [PATCH 05/15] minor bug in _tree_flatten [no ci] --- torchstain/jax/normalizers/macenko.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstain/jax/normalizers/macenko.py b/torchstain/jax/normalizers/macenko.py index 3f9e89a..66afced 100644 --- a/torchstain/jax/normalizers/macenko.py +++ b/torchstain/jax/normalizers/macenko.py @@ -100,7 +100,7 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): def _tree_flatten(self): children = () # arrays / dynamic values aux = () # static values - return (), () + return children, aux @classmethod def _tree_unflatten(cls, aux, children): From 4890bdb3a92daef18edc48fab1116d8b1c7a7419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 23 Jan 2023 01:22:28 +0100 Subject: [PATCH 06/15] added macenko jax unit test --- tests/test_jax.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/test_jax.py diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..7366b90 --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,36 @@ +import os +import cv2 +import torchstain +import torchstain.jax +import time +from skimage.metrics import structural_similarity as ssim +import numpy as np +from jax import numpy as jnp + +def test_macenko_jax(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # setup preprocessing and preprocess image to be normalized + T = lambda x: jnp.array(np.moveaxis(x, -1, 0).astype("float32")) + t_to_transform = T(to_transform) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') + normalizer.fit(target) + + jax_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='jax') + jax_normalizer.fit(T(target)) + + # transform + result_numpy, _, _ = normalizer.normalize(I=to_transform) + result_jax, _, _ = jax_normalizer.normalize(I=t_to_transform) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") + result_jax = np.asarray(result_jax).astype("float32") + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_jax.flatten()), 1.0, decimal=4, verbose=True) From d84110f035b9a92af799041dbbb0d13eb22c2ddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 23 Jan 2023 01:38:50 +0100 Subject: [PATCH 07/15] run CI on all branches + added jax tests to CI --- .github/workflows/tests_full.yml | 33 +++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests_full.yml b/.github/workflows/tests_full.yml index e550c88..14afd71 100644 --- a/.github/workflows/tests_full.yml +++ b/.github/workflows/tests_full.yml @@ -3,7 +3,7 @@ name: tests on: push: branches: - - main + - "*" pull_request: branches: - main @@ -102,3 +102,34 @@ jobs: - name: Run tests run: pytest -vs tests/test_torch.py + + test-jax: + needs: build + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ windows-2019, ubuntu-18.04, macos-11 ] + python-version: [ 3.6, 3.7, 3.8, 3.9 ] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Download artifact + uses: actions/download-artifact@master + with: + name: "Python wheel" + + - name: Install dependencies + run: | + pip install jax opencv-python-headless scikit-image + pip install pytest + + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain + + - name: Run tests + run: pytest -vs tests/test_jax.py From 80131085973d1ad5c1de4be9e2a9423e34dffd19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 23 Jan 2023 01:48:04 +0100 Subject: [PATCH 08/15] build full only on main + build quick on all branches --- .github/workflows/build.yaml | 4 +++- .github/workflows/tests_full.yml | 6 ++---- .github/workflows/tests_quick.yml | 26 ++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index e6719d7..b3d9efc 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,7 +1,9 @@ name: Build and upload to PyPI on: - push + push: + branches: + - main jobs: build_wheels: diff --git a/.github/workflows/tests_full.yml b/.github/workflows/tests_full.yml index 14afd71..7fb7d7b 100644 --- a/.github/workflows/tests_full.yml +++ b/.github/workflows/tests_full.yml @@ -3,7 +3,7 @@ name: tests on: push: branches: - - "*" + - main pull_request: branches: - main @@ -124,9 +124,7 @@ jobs: name: "Python wheel" - name: Install dependencies - run: | - pip install jax opencv-python-headless scikit-image - pip install pytest + run: pip install jax jaxlib opencv-python-headless scikit-image pytest - name: Install wheel run: pip install --find-links=${{github.workspace}} torchstain diff --git a/.github/workflows/tests_quick.yml b/.github/workflows/tests_quick.yml index 3e08828..a1f7fd9 100644 --- a/.github/workflows/tests_quick.yml +++ b/.github/workflows/tests_quick.yml @@ -84,3 +84,29 @@ jobs: - name: Run tests run: pytest -vs tests/test_torch.py + + + test-jax: + needs: build + runs-on: ubuntu-18.04 + + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Download artifact + uses: actions/download-artifact@master + with: + name: "Python wheel" + + - name: Install dependencies + run: pip install jax jaxlib opencv-python-headless scikit-image pytest + + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain + + - name: Run tests + run: pytest -vs tests/test_jax.py From cae7f954f5b211931b4b2d2d10a961693cdce584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 23 Jan 2023 01:48:43 +0100 Subject: [PATCH 09/15] renamed name in workflows --- .github/workflows/tests_full.yml | 2 +- .github/workflows/tests_quick.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests_full.yml b/.github/workflows/tests_full.yml index 7fb7d7b..01e55e9 100644 --- a/.github/workflows/tests_full.yml +++ b/.github/workflows/tests_full.yml @@ -1,4 +1,4 @@ -name: tests +name: full test on: push: diff --git a/.github/workflows/tests_quick.yml b/.github/workflows/tests_quick.yml index a1f7fd9..c022dae 100644 --- a/.github/workflows/tests_quick.yml +++ b/.github/workflows/tests_quick.yml @@ -1,4 +1,4 @@ -name: tests +name: quick tests on: push: From 1877664078c2ccb6b781efcbbfb4efd64cef18dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 23 Jan 2023 01:52:11 +0100 Subject: [PATCH 10/15] fixed np.float issue + minor array refactor --- torchstain/numpy/normalizers/macenko.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchstain/numpy/normalizers/macenko.py b/torchstain/numpy/normalizers/macenko.py index 37dac55..968b716 100644 --- a/torchstain/numpy/normalizers/macenko.py +++ b/torchstain/numpy/normalizers/macenko.py @@ -10,13 +10,13 @@ def __init__(self): super().__init__() self.HERef = np.array([[0.5626, 0.2159], - [0.7201, 0.8012], - [0.4062, 0.5581]]) + [0.7201, 0.8012], + [0.4062, 0.5581]]) self.maxCRef = np.array([1.9705, 1.0308]) def __convert_rgb2od(self, I, Io=240, beta=0.15): # calculate optical density - OD = -np.log((I.astype(np.float)+1)/Io) + OD = -np.log((I.astype(np.float32)+1)/Io) # remove transparent pixels ODhat = OD[~np.any(OD < beta, axis=1)] From 79b91aabf9922aaf6c8bdb1cdc74c4b459f2345c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Sun, 29 Jan 2023 12:27:48 +0100 Subject: [PATCH 11/15] fix: setting jax.jit decorators differently yields correct output --- torchstain/jax/normalizers/macenko.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchstain/jax/normalizers/macenko.py b/torchstain/jax/normalizers/macenko.py index 66afced..9cd4379 100644 --- a/torchstain/jax/normalizers/macenko.py +++ b/torchstain/jax/normalizers/macenko.py @@ -10,9 +10,10 @@ def __init__(self): self.HERef = jnp.array([[0.5626, 0.2159], [0.7201, 0.8012], - [0.4062, 0.5581]]) - self.maxCRef = jnp.array([1.9705, 1.0308]) + [0.4062, 0.5581]], dtype=jnp.float32) + self.maxCRef = jnp.array([1.9705, 1.0308], dtype=jnp.float32) + @jax.jit def __find_concentration(self, OD, HE): # rows correspond to channels (RGB), columns to OD values Y = jnp.reshape(OD, (-1, 3)).T @@ -22,14 +23,14 @@ def __find_concentration(self, OD, HE): return C + @jax.jit def __compute_matrices(self, I, Io, alpha, beta): I = I.reshape((-1, 3)) # calculate optical density OD = -jnp.log((I.astype(jnp.float32) + 1) / Io) - - # compute eigenvectors - mask = ~jnp.any(OD < beta, axis=1) + + mask = ~jnp.any(OD < beta, axis=1) # to remove transparent pixels cov = jnp.cov(OD.T, fweights=mask.astype(jnp.int32)) _, eigvecs = jnp.linalg.eigh(cov) @@ -51,7 +52,7 @@ def __compute_matrices(self, I, Io, alpha, beta): HE = lax.cond( vMin[0, 0] > vMax[0, 0], lambda x: jnp.array((x[0], x[1])).T, - lambda x: jnp.array((x[0], x[1])).T, + lambda x: jnp.array((x[1], x[0])).T, (vMin[:, 0], vMax[:, 0]) ) C = self.__find_concentration(OD, HE) @@ -61,14 +62,14 @@ def __compute_matrices(self, I, Io, alpha, beta): return HE, C, maxC - @jax.jit + #@jax.jit def fit(self, I, Io=240, alpha=1, beta=0.15): HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) self.HERef = HE self.maxCRef = maxC - @jax.jit + #@jax.jit def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): h, w, c = I.shape I = I.reshape((-1, 3)) From dd3e8e6dd6214d6156c91b2dd087c98ef965088a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Sun, 29 Jan 2023 12:30:23 +0100 Subject: [PATCH 12/15] fix: removed unwanted preprocessing in jax unit test --- tests/test_jax.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_jax.py b/tests/test_jax.py index 7366b90..e5a033f 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -13,20 +13,16 @@ def test_macenko_jax(): target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) - # setup preprocessing and preprocess image to be normalized - T = lambda x: jnp.array(np.moveaxis(x, -1, 0).astype("float32")) - t_to_transform = T(to_transform) - # initialize normalizers for each backend and fit to target image normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') normalizer.fit(target) jax_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='jax') - jax_normalizer.fit(T(target)) + jax_normalizer.fit(target) # transform result_numpy, _, _ = normalizer.normalize(I=to_transform) - result_jax, _, _ = jax_normalizer.normalize(I=t_to_transform) + result_jax, _, _ = jax_normalizer.normalize(I=to_transform) # convert to numpy and set dtype result_numpy = result_numpy.astype("float32") From 9323c03e8f7019c2553aff76ddae03c1842db416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Sun, 29 Jan 2023 12:37:33 +0100 Subject: [PATCH 13/15] refactor: updated README regarding JAX - match development branch --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 74af142..4c793b2 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![Pip Downloads](https://img.shields.io/pypi/dm/torchstain?label=pip%20downloads&logo=python)](https://pypi.org/project/torchstain/) [![DOI](https://zenodo.org/badge/323590093.svg)](https://zenodo.org/badge/latestdoi/323590093) -GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy. +GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, Numpy, and JAX. Normalization algorithms currently implemented: - Macenko [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python)) @@ -47,11 +47,11 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True) ## Implemented algorithms -| Algorithm | numpy | torch | tensorflow | -|-|-|-|-| -| Macenko | ✓ | ✓ | ✓ | -| Reinhard | ✓ | ✓ | ✓ | -| Modified Reinhard | ✓ | ✓ | ✓ | +| Algorithm | numpy | torch | tensorflow | jax | +|-|-|-|-|-| +| Macenko | ✓ | ✓ | ✓ | ✓ | +| Reinhard | ✓ | ✓ | ✓ | ✗ | +| Modified Reinhard | ✓ | ✓ | ✓ | ✗ | ## Backend comparison From b54c2286239aafe5e370103392d07232d3ff3c21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Sun, 29 Jan 2023 12:39:40 +0100 Subject: [PATCH 14/15] fix: updated setup.py to properly support jax during install --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index d27acf7..cb0c60b 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ extras_require={ "tf": ["tensorflow"], "torch": ["torch"], + "jax": ["jax", "jaxlib"], }, classifiers=[ 'Development Status :: 4 - Beta', From 50f0fc1428be400ef09e080c943a70faee184c93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Sun, 29 Jan 2023 12:43:58 +0100 Subject: [PATCH 15/15] fix: run jax quick test CI on latest ubuntu + python 3.7 --- .github/workflows/tests_quick.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests_quick.yml b/.github/workflows/tests_quick.yml index c022dae..9f82d5c 100644 --- a/.github/workflows/tests_quick.yml +++ b/.github/workflows/tests_quick.yml @@ -88,14 +88,14 @@ jobs: test-jax: needs: build - runs-on: ubuntu-18.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.7 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.7 - name: Download artifact uses: actions/download-artifact@master