Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Attributes:
- [ ] Support `end_time`
- [ ] Accept writing into a specific mapset (GRASS 8.5)
- [ ] Accept non homogeneous 3D resolution in NS and EW dimensions (GRASS 8.5)
- [ ] Lazy loading of all raster types
- [x] Lazy loading of STDS on the time dimension
- [ ] Properly test with lat-lon location

### Stretch goals
Expand Down
4 changes: 3 additions & 1 deletion src/xarray_grass/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from xarray_grass.grass_interface import GrassConfig as GrassConfig
from xarray_grass.grass_interface import GrassInterface as GrassInterface
from xarray_grass.xarray_grass import GrassBackendEntrypoint as GrassBackendEntrypoint
from xarray_grass.xarray_grass import GrassBackendArray as GrassBackendArray
from xarray_grass.grass_backend_array import (
GrassSTDSBackendArray as GrassSTDSBackendArray,
)
from xarray_grass.to_grass import to_grass as to_grass
from xarray_grass.coord_utils import RegionData as RegionData

Expand Down
97 changes: 97 additions & 0 deletions src/xarray_grass/grass_backend_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Copyright (C) 2025 Laurent Courty

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
"""

from __future__ import annotations
from typing import TYPE_CHECKING
import threading

import numpy as np
import xarray as xr

from xarray.backends import BackendArray

if TYPE_CHECKING:
from xarray_grass.grass_interface import GrassInterface


class GrassSTDSBackendArray(BackendArray):
"""Lazy loading of grass Space-Time DataSets (multiple maps in time series)"""

def __init__(
self,
shape,
dtype,
map_list: list, # List of map metadata objects
map_type: str,
grass_interface: GrassInterface,
):
self.shape = shape
self.dtype = dtype
self._lock = threading.Lock()
self.map_list = map_list # List with .id attribute
self.map_type = map_type # "raster" or "raster3d"
self.grass_interface = grass_interface
self._cached_maps = {} # Cache loaded maps by index

def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike:
"""takes in input an index and returns a NumPy array"""
return xr.core.indexing.explicit_indexing_adapter(
key,
self.shape,
xr.core.indexing.IndexingSupport.BASIC,
self._raw_indexing_method,
)

def _raw_indexing_method(self, key: tuple):
"""Load only the maps needed for the requested slice"""
with self._lock:
# key is a tuple of slices/indices for each dimension
# First dimension is time
time_key = key[0] if key else slice(None)
spatial_key = key[1:] if len(key) > 1 else ()

# Determine which time indices are needed
if isinstance(time_key, slice):
time_indices = range(*time_key.indices(self.shape[0]))
elif isinstance(time_key, int):
time_indices = [time_key]
else:
time_indices = list(time_key)

# Load only the needed maps
result_list = []
for t_idx in time_indices:
if t_idx not in self._cached_maps:
map_data = self.map_list[t_idx]
if self.map_type == "raster":
self._cached_maps[t_idx] = self.grass_interface.read_raster_map(
map_data.id
)
else: # 'raster3d'
self._cached_maps[t_idx] = (
self.grass_interface.read_raster3d_map(map_data.id)
)

# Apply spatial indexing
if spatial_key:
result_list.append(self._cached_maps[t_idx][spatial_key])
else:
result_list.append(self._cached_maps[t_idx])

# Stack results along time dimension
if len(result_list) == 1 and isinstance(time_key, int):
# Single time slice requested as integer index
return result_list[0]
else:
return np.stack(result_list, axis=0)
50 changes: 38 additions & 12 deletions src/xarray_grass/grass_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf8
"""
Copyright (C) 2025 Laurent Courty

Expand Down Expand Up @@ -29,6 +28,7 @@
from grass.script import array as garray
import grass.pygrass.utils as gutils
from grass.pygrass import raster as graster
from grass.pygrass.raster.abstract import Info, RasterAbstractBase
import grass.temporal as tgis

from xarray_grass.coord_utils import (
Expand All @@ -48,7 +48,8 @@ class GrassConfig:


strds_cols = ["id", "start_time", "end_time"]
MapData = namedtuple("MapData", strds_cols)
MapData = namedtuple("MapData", strds_cols + ["dtype"])

strds_infos = [
"id",
"title",
Expand Down Expand Up @@ -229,11 +230,11 @@ def name_is_str3ds(self, name: str) -> bool:
return bool(tgis.SpaceTimeRaster3DDataset(str3ds_id).is_in_db())

def name_is_raster(self, raster_name: str) -> bool:
"""return True if the given name is a map in the grass database
False if not
"""
"""return True if the given name is a raster map in the grass database."""
# Using pygrass instead of gscript is at least 40x faster
map_id = self.get_id_from_name(raster_name)
return bool(gs.find_file(name=map_id, element="raster").get("file"))
map_object = RasterAbstractBase(map_id)
return map_object.exist()

def name_is_raster_3d(self, raster3d_name: str) -> bool:
"""return True if the given name is a 3D raster in the grass database."""
Expand All @@ -245,16 +246,30 @@ def get_crs_wkt_str() -> str:
return gs.read_command("g.proj", flags="wf").replace("\n", "")

def grass_dtype(self, dtype: str) -> str:
"""Takes a numpy-style data-type description string,
and return a GRASS data type string."""
if dtype in self.dtype_conv["DCELL"]:
mtype = "DCELL"
elif dtype in self.dtype_conv["CELL"]:
mtype = "CELL"
elif dtype in self.dtype_conv["FCELL"]:
mtype = "FCELL"
else:
raise ValueError("datatype incompatible with GRASS!")
raise ValueError(f"datatype '{dtype}' incompatible with GRASS!")
return mtype

@staticmethod
def numpy_dtype(mtype: str) -> np.dtype:
if mtype == "CELL":
dtype = np.dtype("int64")
elif mtype == "FCELL":
dtype = np.dtype("float32")
elif mtype == "DCELL":
dtype = np.dtype("float64")
else:
raise ValueError(f"Unknown GRASS data type: {mtype}")
return dtype

@staticmethod
def has_mask() -> bool:
"""Return True if the mapset has a mask, False otherwise."""
Expand Down Expand Up @@ -293,12 +308,14 @@ def list_grass_objects(self, mapset: str = None) -> dict[list[str]]:
return objects_dict

@staticmethod
def get_raster_info(raster_id):
return gs.parse_command("r.info", map=raster_id, flags="e")
def get_raster_info(raster_id: str) -> Info:
return gs.parse_command("r.info", map=raster_id, flags="ge")

@staticmethod
def get_raster3d_info(raster3d_id):
return gs.parse_command("r3.info", map=raster3d_id, flags="gh")
result = gs.parse_command("r3.info", map=raster3d_id, flags="gh")
# Strip quotes from string values (r3.info -gh returns quoted strings)
return {k: v.strip('"') if isinstance(v, str) else v for k, v in result.items()}

def get_stds_infos(self, strds_name, stds_type) -> STRDSInfos:
strds_id = self.get_id_from_name(strds_name)
Expand Down Expand Up @@ -345,7 +362,12 @@ def list_maps_in_str3ds(self, strds_name: str) -> list[MapData]:
err_msg = "STR3DS <{}>: Can't find following maps: {}"
str_lst = ",".join(maps_not_found)
raise RuntimeError(err_msg.format(strds_name, str_lst))
return [MapData(*i) for i in maplist]
tuple_list = []
for i in maplist:
mtype = self.get_raster3d_info(i[0])["datatype"]
dtype = self.numpy_dtype(mtype)
tuple_list.append(MapData(*i, dtype=dtype))
return tuple_list

def list_maps_in_strds(self, strds_name: str) -> list[MapData]:
strds = tgis.open_stds.open_old_stds(strds_name, "strds")
Expand All @@ -358,7 +380,11 @@ def list_maps_in_strds(self, strds_name: str) -> list[MapData]:
err_msg = "STRDS <{}>: Can't find following maps: {}"
str_lst = ",".join(maps_not_found)
raise RuntimeError(err_msg.format(strds_name, str_lst))
return [MapData(*i) for i in maplist]
tuple_list = []
for i in maplist:
dtype = self.numpy_dtype(Info(i[0]).mtype)
tuple_list.append(MapData(*i, dtype=dtype))
return tuple_list

@staticmethod
def read_raster_map(rast_name: str) -> np.ndarray:
Expand Down
1 change: 0 additions & 1 deletion src/xarray_grass/to_grass.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def _datarray_to_grass(
current_region = self.grass_interface.get_region()
temp_region = get_region_from_xarray(data, dims)
self.grass_interface.set_region(temp_region)
# TODO: reshape to match userGRASS expected dims order
try:
if is_raster:
data = self.transpose(data, dims, arr_type="raster")
Expand Down
Loading