Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4ef845a
Add staggered fields prototype (works on gtfn)
SF-N Oct 23, 2025
fea7480
Merge remote-tracking branch 'origin/main' into staggered_fields
tehrengruber Nov 10, 2025
fd5266a
Merge remote-tracking branch 'origin/main' into staggered_fields
tehrengruber Nov 13, 2025
80c4149
No debug
tehrengruber Nov 24, 2025
c406c5e
Merge remote-tracking branch 'origin/main' into staggered_fields
tehrengruber Nov 27, 2025
ac1c761
Merge remote-tracking branch 'origin/main' into staggered_fields
tehrengruber Nov 27, 2025
491d1c2
Use cartesian shift syntax in all tests
tehrengruber Dec 4, 2025
b69e966
Remove default offset provider and offsets from default tests
tehrengruber Dec 4, 2025
2f28810
Flip staggered dims in domains and CartesianOffsets
SF-N Dec 9, 2025
8e59bbf
Update staggered tests
SF-N Dec 9, 2025
b691a70
Update staggered tests
SF-N Dec 9, 2025
25b5417
Fix connectivity_for_cartesian_shift
SF-N Dec 12, 2025
d03d1c0
Merge remote-tracking branch 'origin_sf_n/staggered_fields' into stag…
tehrengruber Dec 17, 2025
0f2cb33
Merge remote-tracking branch 'origin/main' into staggered_fields
tehrengruber Jan 9, 2026
c853508
Merge origin_tehrengruber/plus_minus_cart_shift
tehrengruber Jan 9, 2026
726dadf
Fix nanobind segfault
tehrengruber Jan 10, 2026
39e9c37
Cleanup
tehrengruber Jan 10, 2026
d53b4b7
Merge branch 'fix_nanobind_segfault' into staggered_fields
tehrengruber Jan 12, 2026
89c5287
Cleanup
tehrengruber Jan 13, 2026
c6452b0
Small fix in tests
tehrengruber Jan 28, 2026
7b131fc
Cleanup & small fix
tehrengruber Jan 28, 2026
0e0db05
Fix dace tests
tehrengruber Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 173 additions & 129 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,14 +910,22 @@ def remapping(cls) -> ConnectivityKind:
class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import
domain: tuple[Dimension, ...]
codomain: Dimension
skip_value: Optional[core_defs.IntegralScalar]
skip_value: Optional[
core_defs.IntegralScalar
] # TODO(tehrengruber): isn't this a value of the `NeighborConnectivityType` only
dtype: core_defs.DType

@property
def has_skip_values(self) -> bool:
return self.skip_value is not None


@dataclasses.dataclass(frozen=True)
class CartesianConnectivityType(ConnectivityType):
domain: tuple[Dimension]
offset: int


@dataclasses.dataclass(frozen=True)
class NeighborConnectivityType(ConnectivityType):
# TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain
Expand All @@ -932,8 +940,7 @@ def neighbor_dim(self) -> Dimension:
return self.domain[1]


@runtime_checkable
class Connectivity(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT_co]):
class Connectivity(Field[DimsT, core_defs.IntegralScalar], Generic[DimsT, DimT_co]):
@property
@abc.abstractmethod
def codomain(self) -> DimT_co:
Expand All @@ -947,22 +954,8 @@ def codomain(self) -> DimT_co:
Currently, this would just complicate implementation as we do not use this information.
"""

def __gt_type__(self) -> ConnectivityType:
if is_neighbor_connectivity(self):
return NeighborConnectivityType(
domain=self.domain.dims,
codomain=self.codomain,
dtype=self.dtype,
skip_value=self.skip_value,
max_neighbors=self.ndarray.shape[1],
)
else:
return ConnectivityType(
domain=self.domain.dims,
codomain=self.codomain,
dtype=self.dtype,
skip_value=self.skip_value,
)
@abc.abstractmethod
def __gt_type__(self) -> ConnectivityType: ...

@property
def kind(self) -> ConnectivityKind:
Expand Down Expand Up @@ -1034,6 +1027,115 @@ def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never:
raise TypeError("'Connectivity' does not support this operation.")


DomainDimT = TypeVar("DomainDimT", bound="Dimension")


@dataclasses.dataclass(frozen=True, eq=False)
class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]):
domain_dim: DomainDimT
codomain: DimT
offset: int = 0

def __init__(
self, domain_dim: DomainDimT, offset: int = 0, *, codomain: Optional[DimT] = None
) -> None:
object.__setattr__(self, "domain_dim", domain_dim)
object.__setattr__(self, "codomain", codomain if codomain is not None else domain_dim)
object.__setattr__(self, "offset", offset)

@classmethod
def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override]
raise NotImplementedError()

@property
def ndarray(self) -> Never:
raise NotImplementedError()

def asnumpy(self) -> Never:
raise NotImplementedError()

def as_scalar(self) -> Never:
raise NotImplementedError()

@functools.cached_property
def domain(self) -> Domain:
return Domain(dims=(self.domain_dim,), ranges=(UnitRange.infinite(),))

@property
def __gt_origin__(self) -> Never:
raise TypeError("'CartesianConnectivity' does not support this operation.")

def __gt_type__(self) -> CartesianConnectivityType:
assert len(self.domain.dims) == 1
return CartesianConnectivityType(
domain=self.domain.dims,
codomain=self.codomain,
dtype=self.dtype,
skip_value=self.skip_value,
offset=self.offset,
)

@property
def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]:
return core_defs.Int32DType() # type: ignore[return-value]

# This is a workaround to make this class concrete, since `codomain` is an
# abstract property of the `Connectivity` Protocol.
if not TYPE_CHECKING:

@functools.cached_property
def codomain(self) -> DimT:
raise RuntimeError("This property should be always set in the constructor.")

@property
def skip_value(self) -> None:
return None

@functools.cached_property
def kind(self) -> ConnectivityKind:
return (
ConnectivityKind.translation()
if self.domain_dim == self.codomain
else ConnectivityKind.relocation()
)

@classmethod
def for_translation(
cls, dimension: DomainDimT, offset: int
) -> CartesianConnectivity[DomainDimT, DomainDimT]:
return cast(CartesianConnectivity[DomainDimT, DomainDimT], cls(dimension, offset))

@classmethod
def for_relocation(cls, old: DimT, new: DomainDimT) -> CartesianConnectivity[DomainDimT, DimT]:
return cls(new, codomain=old)

def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]:
if not isinstance(image_range, UnitRange):
if image_range.dim != self.codomain:
raise ValueError(
f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'."
)

image_range = image_range.unit_range

assert isinstance(image_range, UnitRange)
return (named_range((self.domain_dim, image_range - self.offset)),)

def premap(
self,
index_field: Connectivity | fbuiltins.FieldOffset,
*args: Connectivity | fbuiltins.FieldOffset,
) -> Connectivity:
raise NotImplementedError()

__call__ = premap

def restrict(self, index: AnyIndexSpec) -> Never:
raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case

__getitem__ = restrict


# Utility function to construct a `Field` from different buffer representations.
# Consider removing this function and using `Field` constructor directly. See also `_connectivity`.
@functools.singledispatch
Expand Down Expand Up @@ -1061,12 +1163,19 @@ def _connectivity(
raise NotImplementedError


class NeighborConnectivity(Connectivity, Protocol):
# TODO(havogt): work towards encoding this properly in the type
def __gt_type__(self) -> NeighborConnectivityType: ...
class NeighborConnectivity(Connectivity[DimsT, DimT_co]):
def __gt_type__(self) -> NeighborConnectivityType:
return NeighborConnectivityType(
domain=self.domain.dims,
codomain=self.codomain,
dtype=self.dtype,
skip_value=self.skip_value,
max_neighbors=self.ndarray.shape[1],
)


def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]:
# TODO: reevaluate
if not isinstance(obj, Connectivity):
return False
domain_dims = obj.domain.dims
Expand All @@ -1078,7 +1187,7 @@ def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]:


class NeighborTable(
NeighborConnectivity, Protocol
NeighborConnectivity
): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`)
@property
def ndarray(self) -> core_defs.NDArrayObject:
Expand All @@ -1088,12 +1197,15 @@ def ndarray(self) -> core_defs.NDArrayObject:
...


def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]:
return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray")
# TODO: delete. A protocol and duck typing in it's current form is not enough since we use the
# type of the connectivity to propagate structural information, e.g. that we have a cartesian
# and not a neighbor connectivity. We would need to extend the protocol for this
# def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]:
# return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray")


OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity
OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType
OffsetProviderElem: TypeAlias = CartesianConnectivity | NeighborConnectivity
OffsetProviderTypeElem: TypeAlias = CartesianConnectivityType | NeighborConnectivityType
# Note: `OffsetProvider` and `OffsetProviderType` should not be accessed directly,
# use the `get_offset` and `get_offset_type` functions instead.
OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem]
Expand Down Expand Up @@ -1133,8 +1245,6 @@ def get_offset(offset_provider: OffsetProvider, offset_tag: str) -> OffsetProvid
`OffsetProviderType` should go through this function.
"""
# TODO(havogt): Once we have a custom class for `OffsetProvider`, we can absorb this functionality into it.
if offset_tag.startswith(_IMPLICIT_OFFSET_PREFIX):
return Dimension(value=_get_dimension_name_from_implicit_offset(offset_tag))
if offset_tag not in offset_provider:
raise KeyError(f"Offset '{offset_tag}' not found in offset provider.")
return offset_provider[offset_tag] # TODO return a valid dimension
Expand Down Expand Up @@ -1165,105 +1275,6 @@ def hash_offset_provider_items_by_id(offset_provider: OffsetProvider) -> int:
return hash(tuple((k, id(v)) for k, v in offset_provider.items()))


DomainDimT = TypeVar("DomainDimT", bound="Dimension")


@dataclasses.dataclass(frozen=True, eq=False)
class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]):
domain_dim: DomainDimT
codomain: DimT
offset: int = 0

def __init__(
self, domain_dim: DomainDimT, offset: int = 0, *, codomain: Optional[DimT] = None
) -> None:
object.__setattr__(self, "domain_dim", domain_dim)
object.__setattr__(self, "codomain", codomain if codomain is not None else domain_dim)
object.__setattr__(self, "offset", offset)

@classmethod
def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override]
raise NotImplementedError()

@property
def ndarray(self) -> Never:
raise NotImplementedError()

def asnumpy(self) -> Never:
raise NotImplementedError()

def as_scalar(self) -> Never:
raise NotImplementedError()

@functools.cached_property
def domain(self) -> Domain:
return Domain(dims=(self.domain_dim,), ranges=(UnitRange.infinite(),))

@property
def __gt_origin__(self) -> Never:
raise TypeError("'CartesianConnectivity' does not support this operation.")

@property
def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]:
return core_defs.Int32DType() # type: ignore[return-value]

# This is a workaround to make this class concrete, since `codomain` is an
# abstract property of the `Connectivity` Protocol.
if not TYPE_CHECKING:

@functools.cached_property
def codomain(self) -> DimT:
raise RuntimeError("This property should be always set in the constructor.")

@property
def skip_value(self) -> None:
return None

@functools.cached_property
def kind(self) -> ConnectivityKind:
return (
ConnectivityKind.translation()
if self.domain_dim == self.codomain
else ConnectivityKind.relocation()
)

@classmethod
def for_translation(
cls, dimension: DomainDimT, offset: int
) -> CartesianConnectivity[DomainDimT, DomainDimT]:
return cast(CartesianConnectivity[DomainDimT, DomainDimT], cls(dimension, offset))

@classmethod
def for_relocation(cls, old: DimT, new: DomainDimT) -> CartesianConnectivity[DomainDimT, DimT]:
return cls(new, codomain=old)

def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]:
if not isinstance(image_range, UnitRange):
if image_range.dim != self.codomain:
raise ValueError(
f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'."
)

image_range = image_range.unit_range

assert isinstance(image_range, UnitRange)
return (named_range((self.domain_dim, image_range - self.offset)),)

def premap(
self,
index_field: Connectivity | fbuiltins.FieldOffset,
*args: Connectivity | fbuiltins.FieldOffset,
) -> Connectivity:
raise NotImplementedError()

__call__ = premap

def restrict(self, index: AnyIndexSpec) -> Never:
raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case

__getitem__ = restrict


@enum.unique
class GridType(StrEnum):
CARTESIAN = "cartesian"
Expand All @@ -1274,7 +1285,13 @@ def order_dimensions(dims: Iterable[Dimension]) -> list[Dimension]:
"""Find the canonical ordering of the dimensions in `dims`."""
if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1:
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")
return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value))
return sorted(
dims,
key=lambda dim: (
_DIM_KIND_ORDER[dim.kind],
flip_staggered(dim).value if is_staggered(dim) else dim.value,
),
)


def check_dims(dims: Sequence[Dimension]) -> None:
Expand Down Expand Up @@ -1362,3 +1379,30 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call
#: Equivalent to the `_FillValue` attribute in the UGRID Conventions
#: (see: http://ugrid-conventions.github.io/ugrid-conventions/).
_DEFAULT_SKIP_VALUE: Final[int] = -1
_STAGGERED_PREFIX = "_Staggered"


def is_staggered(dim: Dimension) -> bool:
return dim.value.startswith(_STAGGERED_PREFIX)


def flip_staggered(dim: Dimension) -> Dimension:
if is_staggered(dim):
return Dimension(dim.value[len(_STAGGERED_PREFIX) :], dim.kind)
else:
return Dimension(f"{_STAGGERED_PREFIX}{dim.value}", dim.kind)

def as_non_staggered(dim: Dimension) -> Dimension:
if is_staggered(dim):
return flip_staggered(dim)
return dim

def connectivity_for_cartesian_shift(dim: Dimension, offset: int | float) -> CartesianConnectivity:
if isinstance(offset, float):
integral_offset, half = divmod(offset, 1)
assert half == 0.5
if not dim.value.startswith(_STAGGERED_PREFIX):
integral_offset += 1
return CartesianConnectivity(dim, int(integral_offset), codomain=flip_staggered(dim))
else:
return CartesianConnectivity(dim, offset, codomain=dim)
Loading