Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions properties/test_coordinate_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Property tests comparing CoordinateTransformIndex to PandasIndex."""

from collections.abc import Hashable
from typing import Any

import numpy as np
import pytest

pytest.importorskip("hypothesis")

import hypothesis.strategies as st
from hypothesis import given

import xarray as xr
import xarray.testing.strategies as xrst
from xarray.core.coordinate_transform import CoordinateTransform
from xarray.core.indexes import CoordinateTransformIndex
from xarray.testing import assert_identical

DATA_VAR_NAME = "_test_data_"


class IdentityTransform(CoordinateTransform):
"""Identity transform that returns dimension positions as coordinate labels."""

def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]:
return {name: dim_positions[name] for name in self.coord_names}

def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
return {dim: coord_labels[dim] for dim in self.dims}
Comment on lines +26 to +30
Copy link
Collaborator

@keewis keewis Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if I'm misunderstanding something here, but there appears to be something wrong: why do we use coordinate names to index dim_positions (which I think map dimension names to positions), and dimension names to index coord_labels?

(the example at https://xarray-indexes.readthedocs.io/blocks/transform.html#example-astronomy appears to support my mental model of the coordinate transform)


def equals(
self, other: CoordinateTransform, exclude: frozenset[Hashable] | None = None
) -> bool:
if not isinstance(other, IdentityTransform):
return False
return self.dim_size == other.dim_size


def create_transform_da(sizes: dict[str, int]) -> xr.DataArray:
"""Create a DataArray with IdentityTransform CoordinateTransformIndex."""
dims = list(sizes.keys())
shape = tuple(sizes.values())
data = np.arange(np.prod(shape)).reshape(shape)

# Create dataset with transform index for each dimension
ds = xr.Dataset({DATA_VAR_NAME: (dims, data)})
for dim, size in sizes.items():
transform = IdentityTransform([dim], {dim: size}, dtype=np.dtype(np.int64))
index = CoordinateTransformIndex(transform)
ds = ds.assign_coords(xr.Coordinates.from_xindex(index))
Comment on lines +48 to +51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this will make this more readable (it does remove the emulated in-place assignment, at least), but it is possible to collect the coordinate objects and then use functools.reduce(operator.or_, coordinates) to merge them together.


return ds[DATA_VAR_NAME]


def create_pandas_da(sizes: dict[str, int]) -> xr.DataArray:
"""Create a DataArray with standard PandasIndex (range index)."""
shape = tuple(sizes.values())
data = np.arange(np.prod(shape)).reshape(shape)
coords = {dim: np.arange(size) for dim, size in sizes.items()}
return xr.DataArray(
data, dims=list(sizes.keys()), coords=coords, name=DATA_VAR_NAME
)


@given(
st.data(),
xrst.dimension_sizes(min_dims=1, max_dims=3, min_side=1, max_side=5),
)
def test_basic_indexing(data, sizes):
"""Test basic indexing produces identical results for transform and pandas index."""
pandas_da = create_pandas_da(sizes)
transform_da = create_transform_da(sizes)
idxr = data.draw(xrst.basic_indexers(sizes=sizes))
pandas_result = pandas_da.isel(idxr)
transform_result = transform_da.isel(idxr)
assert_identical(pandas_result, transform_result)


@given(
st.data(),
xrst.dimension_sizes(min_dims=1, max_dims=3, min_side=1, max_side=5),
)
def test_outer_indexing(data, sizes):
"""Test outer indexing produces identical results for transform and pandas index."""
pandas_da = create_pandas_da(sizes)
transform_da = create_transform_da(sizes)
idxr = data.draw(xrst.outer_array_indexers(sizes=sizes, min_dims=1))
pandas_result = pandas_da.isel(idxr)
transform_result = transform_da.isel(idxr)
assert_identical(pandas_result, transform_result)


@given(
st.data(),
xrst.dimension_sizes(min_dims=2, max_dims=3, min_side=1, max_side=5),
)
def test_vectorized_indexing(data, sizes):
"""Test vectorized indexing produces identical results for transform and pandas index."""
pandas_da = create_pandas_da(sizes)
transform_da = create_transform_da(sizes)
idxr = data.draw(xrst.vectorized_indexers(sizes=sizes))
pandas_result = pandas_da.isel(idxr)
transform_result = transform_da.isel(idxr)
assert_identical(pandas_result, transform_result)
66 changes: 66 additions & 0 deletions properties/test_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

pytest.importorskip("hypothesis")

import hypothesis.strategies as st
from hypothesis import given

import xarray as xr
import xarray.testing.strategies as xrst


def _slice_size(s: slice, dim_size: int) -> int:
"""Compute the size of a slice applied to a dimension."""
return len(range(*s.indices(dim_size)))


@given(
st.data(),
xrst.variables(dims=xrst.dimension_sizes(min_dims=1, max_dims=4, min_side=1)),
)
def test_basic_indexing(data, var):
"""Test that basic indexers produce expected output shape."""
idxr = data.draw(xrst.basic_indexers(sizes=var.sizes))
result = var.isel(idxr)
expected_shape = tuple(
_slice_size(idxr[d], var.sizes[d]) if d in idxr else var.sizes[d]
for d in result.dims
)
assert result.shape == expected_shape


@given(
st.data(),
xrst.variables(dims=xrst.dimension_sizes(min_dims=1, max_dims=4, min_side=1)),
)
def test_outer_indexing(data, var):
"""Test that outer array indexers produce expected output shape."""
idxr = data.draw(xrst.outer_array_indexers(sizes=var.sizes, min_dims=1))
result = var.isel(idxr)
expected_shape = tuple(
len(idxr[d]) if d in idxr else var.sizes[d] for d in result.dims
)
assert result.shape == expected_shape


@given(
st.data(),
xrst.variables(dims=xrst.dimension_sizes(min_dims=2, max_dims=4, min_side=1)),
)
def test_vectorized_indexing(data, var):
"""Test that vectorized indexers produce expected output shape."""
da = xr.DataArray(var)
idxr = data.draw(xrst.vectorized_indexers(sizes=var.sizes))
result = da.isel(idxr)

# TODO: this logic works because the dims in idxr don't overlap with da.dims
# Compute expected shape from result dims
# Non-indexed dims keep their original size, indexed dims get broadcast size
broadcast_result = xr.broadcast(*idxr.values())
broadcast_sizes = dict(
zip(broadcast_result[0].dims, broadcast_result[0].shape, strict=True)
)
expected_shape = tuple(
var.sizes[d] if d in var.sizes else broadcast_sizes[d] for d in result.dims
)
assert result.shape == expected_shape
8 changes: 6 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2181,7 +2181,7 @@ def _oindex_get(self, indexer: OuterIndexer):
dim_positions = dict(zip(self._dims, positions, strict=False))

result = self._transform.forward(dim_positions)
return np.asarray(result[self._coord_name]).squeeze()
return np.asarray(result[self._coord_name])

def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None:
raise TypeError(
Expand Down Expand Up @@ -2215,7 +2215,11 @@ def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)

# also works with basic indexing
return self._oindex_get(OuterIndexer(indexer.tuple))
res = self._oindex_get(OuterIndexer(indexer.tuple))
squeeze_axes = tuple(
ax for ax, idxr in enumerate(indexer.tuple) if isinstance(idxr, int)
)
return res.squeeze(squeeze_axes) if squeeze_axes else res

def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
raise TypeError(
Expand Down
Loading
Loading