From 3bbd81cbc79c0dda23647e50495b27ff440de046 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 20 Feb 2026 10:25:34 +0100 Subject: [PATCH 1/7] test: transforming points with multiple partitions --- tests/core/operations/test_transform.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 1bb494fb..41b29722 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -586,6 +586,21 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa _ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning) +def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + full_sdata["points_0"] = PointsModel.parse( + full_sdata["points_0"].repartition(npartitions=4), + transformations={"global": get_transformation(full_sdata["points_0"])}, + ) + + full_sdata.write(tmpdir) + + full_sdata = SpatialData.read(tmpdir) + + # This just needs to run without error + transform(full_sdata["points_0"], to_coordinate_system="global") + + @pytest.mark.parametrize("maintain_positioning", [True, False]) def test_transform_elements_and_entire_spatial_data_object_multi_hop( full_sdata: SpatialData, maintain_positioning: bool From f6496f366d672352dc08f107ac44ca11169fd235 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 20 Feb 2026 10:26:42 +0100 Subject: [PATCH 2/7] fix: transform points data with multiple partitions. --- src/spatialdata/__init__.py | 3 ++- src/spatialdata/_core/operations/transform.py | 13 +++++++++++-- src/spatialdata/_utils.py | 13 +++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index bb24f04e..22d8f491 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -9,6 +9,7 @@ "transformations", "datasets", "dataloader", + "disable_dask_tune_optimization", "concatenate", "rasterize", "rasterize_bins", @@ -72,5 +73,5 @@ from spatialdata._io._utils import get_dask_backing_files from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_zarr import read_zarr -from spatialdata._utils import get_pyramid_levels, unpad_raster +from spatialdata._utils import disable_dask_tune_optimization, get_pyramid_levels, unpad_raster from spatialdata.config import settings diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index e821edcf..d29af709 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import itertools import warnings from functools import singledispatch @@ -17,6 +18,7 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike +from spatialdata._utils import disable_dask_tune_optimization from spatialdata.models import SpatialElement, get_axes_names, get_model from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channel_names from spatialdata.transformations._utils import _get_scale, compute_coordinates, scale_radii @@ -439,8 +441,15 @@ def _( ) axes = get_axes_names(data) arrays = [] - for ax in axes: - arrays.append(data[ax].to_dask_array(lengths=True).reshape(-1, 1)) + + # Workaround to prevent partition collaps and missing dependency problem for now. + with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext(): + for ax in axes: + # TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization + # leads to collaps of the partitions. However this causes a missing dependency problem, which for now is + # prevented by setting the optimization to False when performing this operation. + arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1)) + xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)}) xtransformed = transformation._transform_coordinates(xdata) transformed = data.drop(columns=list(axes)).copy() diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 64dd7638..64fb43e7 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -2,6 +2,7 @@ import re import warnings from collections.abc import Callable, Generator +from contextlib import contextmanager from itertools import islice from typing import Any, TypeVar @@ -9,6 +10,7 @@ import pandas as pd from anndata import AnnData from dask import array as da +from dask import config from dask.array import Array as DaskArray from xarray import DataArray, Dataset, DataTree @@ -20,6 +22,17 @@ RT = TypeVar("RT") +@contextmanager +def disable_dask_tune_optimization() -> Generator[None, None, None]: + """Prevent dask graph optimization when performing operations on dask dataframes with npartition > 1.""" + old_setting = config.config["optimization"]["tune"]["active"] + config.set({"optimization.tune.active": False}) + try: + yield + finally: + config.set({"optimization.tune.active": old_setting}) + + def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike: if isinstance(array, list): array = np.array(array) From 164ebd23f5ee97c03ab3260713a9e6d0da49a502 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 20 Feb 2026 13:23:42 +0100 Subject: [PATCH 3/7] test: add test for dask_tune contextmanager --- tests/core/operations/test_transform.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 41b29722..a1c93c89 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -1,16 +1,18 @@ +import contextlib import math import tempfile from pathlib import Path import numpy as np import pytest +from dask import config from geopandas.testing import geom_almost_equals from xarray import DataArray, DataTree from spatialdata import transform from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.spatialdata import SpatialData -from spatialdata._utils import unpad_raster +from spatialdata._utils import disable_dask_tune_optimization, unpad_raster from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names from spatialdata.transformations.operations import ( align_elements_using_landmarks, @@ -601,6 +603,24 @@ def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_ transform(full_sdata["points_0"], to_coordinate_system="global") +@pytest.mark.parametrize( + "tune,partition", + [ + (True, None), + (False, 4), + ], +) +def test_dask_tune_contextmanager(full_sdata: SpatialData, partition: int | None, tune: bool): + if partition: + full_sdata["points_0"] = PointsModel.parse( + full_sdata["points_0"].repartition(npartitions=4), + transformations={"global": get_transformation(full_sdata["points_0"])}, + ) + + with disable_dask_tune_optimization() if full_sdata["points_0"].npartitions > 1 else contextlib.nullcontext(): + assert config.config["optimization"]["tune"]["active"] is tune + + @pytest.mark.parametrize("maintain_positioning", [True, False]) def test_transform_elements_and_entire_spatial_data_object_multi_hop( full_sdata: SpatialData, maintain_positioning: bool From c980ecb5ffe03aa35575a4ea9eaaa5b0b8ac376b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 20 Feb 2026 10:41:48 +0100 Subject: [PATCH 4/7] docs: add note explaining the workaround This is a note explaining the workaround in case people run into the partition collaps problem due to dask graph optimization. --- docs/index.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/index.md b/docs/index.md index 73cd1b5e..96f61790 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,6 +14,22 @@ SpatialData is a data framework that comprises a FAIR storage format and a colle Please see our publication {cite}`marconatoSpatialDataOpenUniversal2024` for citation and to learn more. +:::{note} +With dask >= 2025.2.0, users can get an error as described in [#1077](https://github.com/scverse/spatialdata/issues/1064). While we tried implementing fixes in SpatialData, it can be that +users perform operations on the `Points` data themselves and get this error. In order to prevent it, users can use a context manager we created. + +```python +from spatialdata import disable_dask_tune_optimization +import contextlib +... + +with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext(): + +``` + +This will disable dask graph optimization if the dataframe has more than 1 partition and otherwise keep it enabled. +::: + [//]: # "numfocus-fiscal-sponsor-attribution" spatialdata is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/). From 66384c2e917f825c79682ef0130199441fe58223 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 20 Feb 2026 22:03:53 +0100 Subject: [PATCH 5/7] chore: push lower bound dask version to 2025.11.0 Reason for this is that this dask v2025.2.0 does not allow for disabling graph optimization, but neither keeps partition size consistent. Turning optimization off was introduced in dask 2025.12.0. --- .github/workflows/test.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index efd31ef3..1635bdd2 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: include: - - {os: windows-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"} + - {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "Dask 2025.12.0"} - {os: windows-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} diff --git a/pyproject.toml b/pyproject.toml index fb06b861..b14cf0d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "annsel>=0.1.2", "click", "dask-image", - "dask>=2025.2.0,<2026.1.2", + "dask>=2025.12.0,<2026.1.2", "distributed<2026.1.2", "datashader", "fsspec[s3,http]", From 673b39225ebde829acca30290af3d9ecd592c6e6 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 21 Feb 2026 13:23:26 +0100 Subject: [PATCH 6/7] docs: add dask issue --- docs/index.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 96f61790..8e614fa0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -27,7 +27,8 @@ with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nu ``` -This will disable dask graph optimization if the dataframe has more than 1 partition and otherwise keep it enabled. +This will disable dask graph optimization if the dataframe has more than 1 partition and otherwise keep it enabled. This solves +the problem discussed in this [dask issue](https://github.com/dask/dask/issues/12193). We are looking into an upstream fix. ::: [//]: # "numfocus-fiscal-sponsor-attribution" From abf3a40735463800a67fa0248e92dc31841f7792 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 21 Feb 2026 13:48:24 +0100 Subject: [PATCH 7/7] test: test compute can be performed --- tests/core/operations/test_transform.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index a1c93c89..c216703f 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -590,17 +590,22 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" + points_memory = full_sdata["points_0"].compute() full_sdata["points_0"] = PointsModel.parse( full_sdata["points_0"].repartition(npartitions=4), transformations={"global": get_transformation(full_sdata["points_0"])}, ) + assert points_memory.equals(full_sdata["points_0"].compute()) full_sdata.write(tmpdir) full_sdata = SpatialData.read(tmpdir) # This just needs to run without error - transform(full_sdata["points_0"], to_coordinate_system="global") + data = transform(full_sdata["points_0"], to_coordinate_system="global") + + # test that data still can be computed + data.compute() @pytest.mark.parametrize(