From bcd1f305175fb4effe900dba78030652c65d8a07 Mon Sep 17 00:00:00 2001 From: Laurent Courty Date: Sat, 1 Nov 2025 14:06:02 -0600 Subject: [PATCH 1/3] basic lazy loading. --- src/xarray_grass/grass_backend_array.py | 67 +++++++++++++++++++++++++ src/xarray_grass/grass_interface.py | 42 +++++++++++++--- src/xarray_grass/xarray_grass.py | 65 +++++++++++++----------- 3 files changed, 137 insertions(+), 37 deletions(-) create mode 100644 src/xarray_grass/grass_backend_array.py diff --git a/src/xarray_grass/grass_backend_array.py b/src/xarray_grass/grass_backend_array.py new file mode 100644 index 0000000..0110c84 --- /dev/null +++ b/src/xarray_grass/grass_backend_array.py @@ -0,0 +1,67 @@ +""" +Copyright (C) 2025 Laurent Courty + +This program is free software; you can redistribute it and/or +modify it under the terms of the GNU General Public License +as published by the Free Software Foundation; either version 2 +of the License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. +""" + +from __future__ import annotations +from typing import TYPE_CHECKING +import threading + +import numpy as np +import xarray as xr + +from xarray.backends import BackendArray + +if TYPE_CHECKING: + from xarray_grass.grass_interface import GrassInterface + + +class GrassBackendArray(BackendArray): + """Lazy loading of grass arrays""" + + def __init__( + self, + shape, + dtype, + # lock, + map_id: str, + map_type: str, + grass_interface: GrassInterface, + ): + self.shape = shape + self.dtype = dtype + self._lock = threading.Lock() + self.map_id = map_id + self.map_type = map_type # "raster" or "raster3d" + self.grass_interface = grass_interface + self._array: np.ndarray = None + + def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike: + """takes in input an index and returns a NumPy array""" + return xr.core.indexing.explicit_indexing_adapter( + key, + self.shape, + xr.core.indexing.IndexingSupport.BASIC, + self._raw_indexing_method, + ) + + def _raw_indexing_method(self, key: tuple): + with self._lock: + if self._array is None: + self._array = self._load_map() + return self._array[key] + + def _load_map(self): + if self.map_type == "raster": + return self.grass_interface.read_raster_map(self.map_id) + else: # 'raster3d' + return self.grass_interface.read_raster3d_map(self.map_id) diff --git a/src/xarray_grass/grass_interface.py b/src/xarray_grass/grass_interface.py index 27b5cdc..b1f005a 100644 --- a/src/xarray_grass/grass_interface.py +++ b/src/xarray_grass/grass_interface.py @@ -1,4 +1,3 @@ -# coding=utf8 """ Copyright (C) 2025 Laurent Courty @@ -29,6 +28,7 @@ from grass.script import array as garray import grass.pygrass.utils as gutils from grass.pygrass import raster as graster +from grass.pygrass.raster.abstract import Info import grass.temporal as tgis from xarray_grass.coord_utils import ( @@ -48,7 +48,8 @@ class GrassConfig: strds_cols = ["id", "start_time", "end_time"] -MapData = namedtuple("MapData", strds_cols) +MapData = namedtuple("MapData", strds_cols + ["dtype"]) + strds_infos = [ "id", "title", @@ -245,6 +246,8 @@ def get_crs_wkt_str() -> str: return gs.read_command("g.proj", flags="wf").replace("\n", "") def grass_dtype(self, dtype: str) -> str: + """Takes a numpy-style data-type description string, + and return a GRASS data type string.""" if dtype in self.dtype_conv["DCELL"]: mtype = "DCELL" elif dtype in self.dtype_conv["CELL"]: @@ -252,9 +255,21 @@ def grass_dtype(self, dtype: str) -> str: elif dtype in self.dtype_conv["FCELL"]: mtype = "FCELL" else: - raise ValueError("datatype incompatible with GRASS!") + raise ValueError(f"datatype '{dtype}' incompatible with GRASS!") return mtype + @staticmethod + def numpy_dtype(mtype: str) -> np.dtype: + if mtype == "CELL": + dtype = np.int64 + elif mtype == "FCELL": + dtype = np.float32 + elif mtype == "DCELL": + dtype = np.float64 + else: + raise ValueError(f"Unknown GRASS data type: {mtype}") + return dtype + @staticmethod def has_mask() -> bool: """Return True if the mapset has a mask, False otherwise.""" @@ -293,12 +308,14 @@ def list_grass_objects(self, mapset: str = None) -> dict[list[str]]: return objects_dict @staticmethod - def get_raster_info(raster_id): - return gs.parse_command("r.info", map=raster_id, flags="e") + def get_raster_info(raster_id: str) -> Info: + return gs.parse_command("r.info", map=raster_id, flags="ge") @staticmethod def get_raster3d_info(raster3d_id): - return gs.parse_command("r3.info", map=raster3d_id, flags="gh") + result = gs.parse_command("r3.info", map=raster3d_id, flags="gh") + # Strip quotes from string values (r3.info -gh returns quoted strings) + return {k: v.strip('"') if isinstance(v, str) else v for k, v in result.items()} def get_stds_infos(self, strds_name, stds_type) -> STRDSInfos: strds_id = self.get_id_from_name(strds_name) @@ -345,7 +362,12 @@ def list_maps_in_str3ds(self, strds_name: str) -> list[MapData]: err_msg = "STR3DS <{}>: Can't find following maps: {}" str_lst = ",".join(maps_not_found) raise RuntimeError(err_msg.format(strds_name, str_lst)) - return [MapData(*i) for i in maplist] + tuple_list = [] + for i in maplist: + mtype = self.get_raster3d_info(i[0])["datatype"] + dtype = self.numpy_dtype(mtype) + tuple_list.append(MapData(*i, dtype=dtype)) + return tuple_list def list_maps_in_strds(self, strds_name: str) -> list[MapData]: strds = tgis.open_stds.open_old_stds(strds_name, "strds") @@ -358,7 +380,11 @@ def list_maps_in_strds(self, strds_name: str) -> list[MapData]: err_msg = "STRDS <{}>: Can't find following maps: {}" str_lst = ",".join(maps_not_found) raise RuntimeError(err_msg.format(strds_name, str_lst)) - return [MapData(*i) for i in maplist] + tuple_list = [] + for i in maplist: + dtype = self.numpy_dtype(Info(i[0]).mtype) + tuple_list.append(MapData(*i, dtype=dtype)) + return tuple_list @staticmethod def read_raster_map(rast_name: str) -> np.ndarray: diff --git a/src/xarray_grass/xarray_grass.py b/src/xarray_grass/xarray_grass.py index 3b148f7..be181db 100644 --- a/src/xarray_grass/xarray_grass.py +++ b/src/xarray_grass/xarray_grass.py @@ -1,4 +1,3 @@ -# coding=utf8 """ Copyright (C) 2025 Laurent Courty @@ -20,11 +19,12 @@ import numpy as np from xarray.backends import BackendEntrypoint -from xarray.backends import BackendArray import xarray as xr import grass_session # noqa: F401 + import xarray_grass from xarray_grass.grass_interface import GrassInterface +from xarray_grass.grass_backend_array import GrassBackendArray class GrassBackendEntrypoint(BackendEntrypoint): @@ -365,15 +365,30 @@ def open_grass_strds(strds_name: str, grass_i: GrassInterface) -> xr.DataArray: coordinates["x"] = x_coords coordinates["y"] = y_coords map_list = grass_i.list_maps_in_strds(strds_id) + region = grass_i.get_region() array_list = [] for map_data in map_list: + # Lazy load the array + backend_array = GrassBackendArray( + shape=(region.rows, region.cols), + dtype=map_data.dtype, + map_id=map_data.id, + map_type="raster", + grass_interface=grass_i, + ) + lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array) + # add time dimension at the beginning + lazy_array_with_time = np.expand_dims(lazy_array, axis=0) + + # ndarray = grass_i.read_raster_map(map_data.id) + # # add time dimension at the beginning + # ndarray = np.expand_dims(ndarray, axis=0) + coordinates[start_time_dim] = [map_data.start_time] coordinates[end_time_dim] = (start_time_dim, [map_data.end_time]) - ndarray = grass_i.read_raster_map(map_data.id) - # add time dimension at the beginning - ndarray = np.expand_dims(ndarray, axis=0) + data_array = xr.DataArray( - ndarray, + lazy_array_with_time, coords=coordinates, dims=dims, name=strds_name, @@ -417,15 +432,26 @@ def open_grass_str3ds(str3ds_name: str, grass_i: GrassInterface) -> xr.DataArray coordinates["y_3d"] = y_coords coordinates["z"] = z_coords map_list = grass_i.list_maps_in_str3ds(str3ds_id) + region = grass_i.get_region() array_list = [] for map_data in map_list: + # Lazy load the map + backend_array = GrassBackendArray( + shape=(region.depths, region.rows3, region.cols3), + dtype=map_data.dtype, + map_id=map_data.id, + map_type="raster3d", + grass_interface=grass_i, + ) + lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array) + # add time dimension at the beginning + lazy_array_with_time = np.expand_dims(lazy_array, axis=0) + coordinates[start_time_dim] = [map_data.start_time] coordinates[end_time_dim] = (start_time_dim, [map_data.end_time]) - ndarray = grass_i.read_raster3d_map(map_data.id) - # add time dimension at the beginning - ndarray = np.expand_dims(ndarray, axis=0) + data_array = xr.DataArray( - ndarray, + lazy_array_with_time, coords=coordinates, dims=dims, name=str3ds_name, @@ -497,22 +523,3 @@ def set_cf_coordinates( da[x_coord].attrs["units"] = spatial_unit da[y_coord].attrs["units"] = spatial_unit return da - - -class GrassBackendArray(BackendArray): - """For lazy loading""" - - def __init__( - self, - shape, - dtype, - lock, - # other backend specific keyword arguments - ): - self.shape = shape - self.dtype = dtype - self.lock = lock - - def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike: - """takes in input an index and returns a NumPy array""" - pass From 802b28bf00715311e3147c09502100a5f85fb3d2 Mon Sep 17 00:00:00 2001 From: Laurent Courty Date: Sat, 1 Nov 2025 18:01:22 -0600 Subject: [PATCH 2/3] get lazy loading actually working for strds --- README.md | 2 +- src/xarray_grass/__init__.py | 4 +- src/xarray_grass/grass_backend_array.py | 58 +++++-- src/xarray_grass/grass_interface.py | 16 +- src/xarray_grass/xarray_grass.py | 147 +++++++++-------- tests/test_lazy_loading.py | 208 ++++++++++++++++++++++++ 6 files changed, 340 insertions(+), 95 deletions(-) create mode 100644 tests/test_lazy_loading.py diff --git a/README.md b/README.md index a0abac3..e89f5d2 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,7 @@ Attributes: - [ ] Support `end_time` - [ ] Accept writing into a specific mapset (GRASS 8.5) - [ ] Accept non homogeneous 3D resolution in NS and EW dimensions (GRASS 8.5) -- [ ] Lazy loading of all raster types +- [x] Lazy loading of STDS on the time dimension - [ ] Properly test with lat-lon location ### Stretch goals diff --git a/src/xarray_grass/__init__.py b/src/xarray_grass/__init__.py index 733b680..7a94871 100644 --- a/src/xarray_grass/__init__.py +++ b/src/xarray_grass/__init__.py @@ -1,7 +1,9 @@ from xarray_grass.grass_interface import GrassConfig as GrassConfig from xarray_grass.grass_interface import GrassInterface as GrassInterface from xarray_grass.xarray_grass import GrassBackendEntrypoint as GrassBackendEntrypoint -from xarray_grass.xarray_grass import GrassBackendArray as GrassBackendArray +from xarray_grass.grass_backend_array import ( + GrassSTDSBackendArray as GrassSTDSBackendArray, +) from xarray_grass.to_grass import to_grass as to_grass from xarray_grass.coord_utils import RegionData as RegionData diff --git a/src/xarray_grass/grass_backend_array.py b/src/xarray_grass/grass_backend_array.py index 0110c84..27b25e2 100644 --- a/src/xarray_grass/grass_backend_array.py +++ b/src/xarray_grass/grass_backend_array.py @@ -25,25 +25,24 @@ from xarray_grass.grass_interface import GrassInterface -class GrassBackendArray(BackendArray): - """Lazy loading of grass arrays""" +class GrassSTDSBackendArray(BackendArray): + """Lazy loading of grass Space-Time DataSets (multiple maps in time series)""" def __init__( self, shape, dtype, - # lock, - map_id: str, + map_list: list, # List of map metadata objects map_type: str, grass_interface: GrassInterface, ): self.shape = shape self.dtype = dtype self._lock = threading.Lock() - self.map_id = map_id + self.map_list = map_list # List with .id attribute self.map_type = map_type # "raster" or "raster3d" self.grass_interface = grass_interface - self._array: np.ndarray = None + self._cached_maps = {} # Cache loaded maps by index def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike: """takes in input an index and returns a NumPy array""" @@ -55,13 +54,44 @@ def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayL ) def _raw_indexing_method(self, key: tuple): + """Load only the maps needed for the requested slice""" with self._lock: - if self._array is None: - self._array = self._load_map() - return self._array[key] + # key is a tuple of slices/indices for each dimension + # First dimension is time + time_key = key[0] if key else slice(None) + spatial_key = key[1:] if len(key) > 1 else () - def _load_map(self): - if self.map_type == "raster": - return self.grass_interface.read_raster_map(self.map_id) - else: # 'raster3d' - return self.grass_interface.read_raster3d_map(self.map_id) + # Determine which time indices are needed + if isinstance(time_key, slice): + time_indices = range(*time_key.indices(self.shape[0])) + elif isinstance(time_key, int): + time_indices = [time_key] + else: + time_indices = list(time_key) + + # Load only the needed maps + result_list = [] + for t_idx in time_indices: + if t_idx not in self._cached_maps: + map_data = self.map_list[t_idx] + if self.map_type == "raster": + self._cached_maps[t_idx] = self.grass_interface.read_raster_map( + map_data.id + ) + else: # 'raster3d' + self._cached_maps[t_idx] = ( + self.grass_interface.read_raster3d_map(map_data.id) + ) + + # Apply spatial indexing + if spatial_key: + result_list.append(self._cached_maps[t_idx][spatial_key]) + else: + result_list.append(self._cached_maps[t_idx]) + + # Stack results along time dimension + if len(result_list) == 1 and isinstance(time_key, int): + # Single time slice requested as integer index + return result_list[0] + else: + return np.stack(result_list, axis=0) diff --git a/src/xarray_grass/grass_interface.py b/src/xarray_grass/grass_interface.py index b1f005a..654bb2c 100644 --- a/src/xarray_grass/grass_interface.py +++ b/src/xarray_grass/grass_interface.py @@ -28,7 +28,7 @@ from grass.script import array as garray import grass.pygrass.utils as gutils from grass.pygrass import raster as graster -from grass.pygrass.raster.abstract import Info +from grass.pygrass.raster.abstract import Info, RasterAbstractBase import grass.temporal as tgis from xarray_grass.coord_utils import ( @@ -230,11 +230,11 @@ def name_is_str3ds(self, name: str) -> bool: return bool(tgis.SpaceTimeRaster3DDataset(str3ds_id).is_in_db()) def name_is_raster(self, raster_name: str) -> bool: - """return True if the given name is a map in the grass database - False if not - """ + """return True if the given name is a raster map in the grass database.""" + # Using pygrass instead of gscript is at least 40x faster map_id = self.get_id_from_name(raster_name) - return bool(gs.find_file(name=map_id, element="raster").get("file")) + map_object = RasterAbstractBase(map_id) + return map_object.exist() def name_is_raster_3d(self, raster3d_name: str) -> bool: """return True if the given name is a 3D raster in the grass database.""" @@ -261,11 +261,11 @@ def grass_dtype(self, dtype: str) -> str: @staticmethod def numpy_dtype(mtype: str) -> np.dtype: if mtype == "CELL": - dtype = np.int64 + dtype = np.dtype("int64") elif mtype == "FCELL": - dtype = np.float32 + dtype = np.dtype("float32") elif mtype == "DCELL": - dtype = np.float64 + dtype = np.dtype("float64") else: raise ValueError(f"Unknown GRASS data type: {mtype}") return dtype diff --git a/src/xarray_grass/xarray_grass.py b/src/xarray_grass/xarray_grass.py index be181db..9dbd7f5 100644 --- a/src/xarray_grass/xarray_grass.py +++ b/src/xarray_grass/xarray_grass.py @@ -24,7 +24,7 @@ import xarray_grass from xarray_grass.grass_interface import GrassInterface -from xarray_grass.grass_backend_array import GrassBackendArray +from xarray_grass.grass_backend_array import GrassSTDSBackendArray class GrassBackendEntrypoint(BackendEntrypoint): @@ -265,6 +265,8 @@ def open_grass_maps( data_array_list.append(data_array) if raise_on_not_found and any(not_found.values()): raise ValueError(f"Objects not found: {not_found}") + + crs_wkt = gi.get_crs_wkt_str() finally: if session is not None: session.__exit__(None, None, None) @@ -277,7 +279,7 @@ def open_grass_maps( data_array_dict = {da.name: da for da in data_array_list} attrs = { - "crs_wkt": gi.get_crs_wkt_str(), + "crs_wkt": crs_wkt, "Conventions": "CF-1.13-draft", # "title": "", "history": f"{datetime.now(timezone.utc)}: Created with xarray-grass version {xarray_grass.__version__}", @@ -347,9 +349,7 @@ def open_grass_raster_3d(raster_3d_name: str, grass_i: GrassInterface) -> xr.Dat def open_grass_strds(strds_name: str, grass_i: GrassInterface) -> xr.DataArray: - """must be called from within a grass session - TODO: lazy loading - """ + """Open a STRDS with lazy loading - data is only loaded when accessed""" strds_id = grass_i.get_id_from_name(strds_name) strds_name = grass_i.get_name_from_id(strds_id) x_coords, y_coords, _ = get_coordinates(grass_i, raster_3d=False).values() @@ -360,45 +360,46 @@ def open_grass_strds(strds_name: str, grass_i: GrassInterface) -> xr.DataArray: time_unit = strds_infos.time_unit start_time_dim = f"start_time_{strds_name}" end_time_dim = f"end_time_{strds_name}" - dims = [start_time_dim, "y", "x"] - coordinates = dict.fromkeys(dims) - coordinates["x"] = x_coords - coordinates["y"] = y_coords + map_list = grass_i.list_maps_in_strds(strds_id) region = grass_i.get_region() - array_list = [] - for map_data in map_list: - # Lazy load the array - backend_array = GrassBackendArray( - shape=(region.rows, region.cols), - dtype=map_data.dtype, - map_id=map_data.id, - map_type="raster", - grass_interface=grass_i, - ) - lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array) - # add time dimension at the beginning - lazy_array_with_time = np.expand_dims(lazy_array, axis=0) - - # ndarray = grass_i.read_raster_map(map_data.id) - # # add time dimension at the beginning - # ndarray = np.expand_dims(ndarray, axis=0) - - coordinates[start_time_dim] = [map_data.start_time] - coordinates[end_time_dim] = (start_time_dim, [map_data.end_time]) - - data_array = xr.DataArray( - lazy_array_with_time, - coords=coordinates, - dims=dims, - name=strds_name, - ) - array_list.append(data_array) - da_concat = xr.concat(array_list, dim=start_time_dim) + + # Create a single backend array for the entire STRDS + backend_array = GrassSTDSBackendArray( + shape=(len(map_list), region.rows, region.cols), + dtype=map_list[0].dtype, + map_list=map_list, + map_type="raster", + grass_interface=grass_i, + ) + lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array) + + # Create Variable with lazy array + var = xr.Variable(dims=[start_time_dim, "y", "x"], data=lazy_array) + + # Extract time coordinates + start_times = [map_data.start_time for map_data in map_list] + end_times = [map_data.end_time for map_data in map_list] + + # Create coordinates + coordinates = { + "x": x_coords, + "y": y_coords, + start_time_dim: start_times, + end_time_dim: (start_time_dim, end_times), + } + + # Convert to DataArray + data_array = xr.DataArray( + var, + coords=coordinates, + name=strds_name, + ) + # Add CF attributes r_infos = grass_i.get_raster_info(map_list[0].id) da_with_attrs = set_cf_coordinates( - da_concat, + data_array, gi=grass_i, is_3d=False, time_dims=[start_time_dim, end_time_dim], @@ -414,7 +415,7 @@ def open_grass_strds(strds_name: str, grass_i: GrassInterface) -> xr.DataArray: def open_grass_str3ds(str3ds_name: str, grass_i: GrassInterface) -> xr.DataArray: - """Open a series of 3D raster maps. + """Open a STR3DS with lazy loading - data is only loaded when accessed TODO: Figure out what to do when the z value of the maps is time.""" str3ds_id = grass_i.get_id_from_name(str3ds_name) str3ds_name = grass_i.get_name_from_id(str3ds_id) @@ -426,43 +427,47 @@ def open_grass_str3ds(str3ds_name: str, grass_i: GrassInterface) -> xr.DataArray time_unit = strds_infos.time_unit start_time_dim = f"start_time_{str3ds_name}" end_time_dim = f"end_time_{str3ds_name}" - dims = [start_time_dim, "z", "y_3d", "x_3d"] - coordinates = dict.fromkeys(dims) - coordinates["x_3d"] = x_coords - coordinates["y_3d"] = y_coords - coordinates["z"] = z_coords + map_list = grass_i.list_maps_in_str3ds(str3ds_id) region = grass_i.get_region() - array_list = [] - for map_data in map_list: - # Lazy load the map - backend_array = GrassBackendArray( - shape=(region.depths, region.rows3, region.cols3), - dtype=map_data.dtype, - map_id=map_data.id, - map_type="raster3d", - grass_interface=grass_i, - ) - lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array) - # add time dimension at the beginning - lazy_array_with_time = np.expand_dims(lazy_array, axis=0) - - coordinates[start_time_dim] = [map_data.start_time] - coordinates[end_time_dim] = (start_time_dim, [map_data.end_time]) - - data_array = xr.DataArray( - lazy_array_with_time, - coords=coordinates, - dims=dims, - name=str3ds_name, - ) - array_list.append(data_array) - da_concat = xr.concat(array_list, dim=start_time_dim) + # Create a single backend array for the entire STR3DS + backend_array = GrassSTDSBackendArray( + shape=(len(map_list), region.depths, region.rows3, region.cols3), + dtype=map_list[0].dtype, + map_list=map_list, + map_type="raster3d", + grass_interface=grass_i, + ) + lazy_array = xr.core.indexing.LazilyIndexedArray(backend_array) + + # Create Variable with lazy array + var = xr.Variable(dims=[start_time_dim, "z", "y_3d", "x_3d"], data=lazy_array) + + # Extract time coordinates + start_times = [map_data.start_time for map_data in map_list] + end_times = [map_data.end_time for map_data in map_list] + + # Create coordinates + coordinates = { + "x_3d": x_coords, + "y_3d": y_coords, + "z": z_coords, + start_time_dim: start_times, + end_time_dim: (start_time_dim, end_times), + } + + # Convert to DataArray + data_array = xr.DataArray( + var, + coords=coordinates, + name=str3ds_name, + ) + # Add CF attributes r3_infos = grass_i.get_raster3d_info(map_list[0].id) da_with_attrs = set_cf_coordinates( - da_concat, + data_array, gi=grass_i, is_3d=True, z_unit=r3_infos["vertical_units"], diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py new file mode 100644 index 0000000..28039c0 --- /dev/null +++ b/tests/test_lazy_loading.py @@ -0,0 +1,208 @@ +from pathlib import Path + +import pytest +import xarray as xr + + +ACTUAL_STRDS = "LST_Day_monthly@modis_lst" + + +@pytest.mark.usefixtures("grass_session_fixture", "grass_test_region") +class TestXarrayGrass: + def test_strds_lazy_loading_detailed(self, grass_i, temp_gisdb) -> None: + """Detailed diagnostic to identify exactly where lazy loading breaks. + + This test instruments the code to find the exact point where + LazilyIndexedArray gets converted to numpy array. + """ + from unittest.mock import patch, wraps + from xarray_grass.grass_interface import GrassInterface + + mapset_path = ( + Path(temp_gisdb.gisdb) / Path(temp_gisdb.project) / Path(temp_gisdb.mapset) + ) + + # Track read calls + read_calls = [] + original_read = GrassInterface.read_raster_map + + def tracked_read(map_id): + import traceback + + read_calls.append( + { + "map_id": map_id, + "stack": "".join( + traceback.format_stack()[-8:-1] + ), # More stack frames + } + ) + return original_read(map_id) + + # Track concat calls more carefully + concat_details = [] + original_concat = xr.concat + + @wraps(original_concat) + def tracked_concat(objs, dim, **kwargs): + if objs: + first_var = objs[0]._variable if hasattr(objs[0], "_variable") else None + concat_details.append( + { + "num_objs": len(objs), + "dim": dim, + "first_obj_data_type": type(objs[0].data).__name__, + "first_obj_is_lazy": isinstance( + objs[0].data, xr.core.indexing.LazilyIndexedArray + ), + "first_var_data_type": type(first_var._data).__name__ + if first_var is not None + else None, + "first_var_is_lazy": isinstance( + first_var._data, xr.core.indexing.LazilyIndexedArray + ) + if first_var is not None + else None, + } + ) + result = original_concat(objs, dim, **kwargs) + if concat_details: + concat_details[-1]["result_data_type"] = type(result.data).__name__ + concat_details[-1]["result_is_lazy"] = isinstance( + result.data, xr.core.indexing.LazilyIndexedArray + ) + return result + + # Patch at module level to capture all GrassInterface instances + with ( + patch("xarray_grass.xarray_grass.GrassInterface") as MockGI, + patch("xarray.concat", side_effect=tracked_concat), + ): + # Make mock return real instances with tracked read method + def create_tracked_instance(*args, **kwargs): + real_instance = GrassInterface(*args, **kwargs) + # Replace the static method with our tracking wrapper + real_instance.read_raster_map = staticmethod(tracked_read) + return real_instance + + MockGI.side_effect = create_tracked_instance + + test_dataset = xr.open_dataset(mapset_path, strds=ACTUAL_STRDS) + + print(f"\n=== Total read_raster_map calls: {len(read_calls)}") + print(f"=== Total concat calls: {len(concat_details)}") + + if concat_details: + print("\n=== Concat details:") + for i, call in enumerate(concat_details): + print( + f" Call {i + 1}: Concatenating {call['num_objs']} objects on dim '{call['dim']}'" + ) + print( + f" First obj data: {call['first_obj_data_type']} (lazy={call['first_obj_is_lazy']})" + ) + if call["first_var_data_type"]: + print( + f" First obj Variable._data: {call['first_var_data_type']} (lazy={call['first_var_is_lazy']})" + ) + print( + f" Result data: {call['result_data_type']} (lazy={call['result_is_lazy']})" + ) + + strds_name = grass_i.get_name_from_id(ACTUAL_STRDS) + da = test_dataset[strds_name] + + print("\n=== Final DataArray:") + print(f" data type: {type(da.data).__name__}") + print(f" Variable._data type: {type(da._variable._data).__name__}") + print( + f" Is LazilyIndexedArray: {isinstance(da.data, xr.core.indexing.LazilyIndexedArray)}" + ) + + def test_strds_lazy_loading(self, grass_i, temp_gisdb) -> None: + """Test that STRDS data is actually loaded lazily and not eagerly. + + This test verifies that: + 1. Opening a STRDS doesn't immediately load all data into memory + 2. Data is only loaded when actually accessed via .values + 3. Individual time slices are loaded on demand, not all at once + """ + from unittest.mock import patch + from xarray_grass.grass_interface import GrassInterface + + mapset_path = ( + Path(temp_gisdb.gisdb) / Path(temp_gisdb.project) / Path(temp_gisdb.mapset) + ) + + # Track read_raster_map calls across ALL GrassInterface instances + read_calls = [] + original_read = GrassInterface.read_raster_map + + def tracked_read_raster_map(map_id): + """Wrapper that tracks calls while preserving functionality""" + read_calls.append(map_id) + return original_read(map_id) + + # Patch at module level where GrassInterface is used to create instances + # This ensures ALL instances created during xr.open_dataset() are tracked + with patch("xarray_grass.xarray_grass.GrassInterface") as MockGI: + # Make the mock return real GrassInterface instances with tracked read method + def create_tracked_instance(*args, **kwargs): + real_instance = GrassInterface(*args, **kwargs) + # Replace the static method with our tracking wrapper + real_instance.read_raster_map = staticmethod(tracked_read_raster_map) + return real_instance + + MockGI.side_effect = create_tracked_instance + + # Open the dataset - this should NOT load any raster data + test_dataset = xr.open_dataset(mapset_path, strds=ACTUAL_STRDS) + strds_name = grass_i.get_name_from_id(ACTUAL_STRDS) + + print(f"\n=== After opening: {len(read_calls)} maps read") + + # The key assertion: opening should NOT trigger data loading + assert len(read_calls) == 0, ( + f"Expected 0 raster reads during dataset opening, " + f"but {len(read_calls)} maps were read. " + f"Lazy loading is NOT working!" + ) + + # Get the DataArray - still shouldn't load data + da = test_dataset[strds_name] + assert len(read_calls) == 0, ( + f"Getting DataArray triggered {len(read_calls)} reads" + ) + + # Check it's a lazy array (MemoryCachedArray wrapping LazilyIndexedArray) + print( + f"=== DataArray._variable._data type: {type(da._variable._data).__name__}" + ) + + # Now access a single time slice via .values - this SHOULD trigger loading + read_calls.clear() + _ = da.isel({f"start_time_{strds_name}": 0}).values + + print(f"=== After accessing first slice: {len(read_calls)} maps read") + + # Should have loaded exactly 1 map (the first time slice) + assert len(read_calls) == 1, ( + f"Expected 1 raster read when accessing first time slice, " + f"got {len(read_calls)}" + ) + + # Access another slice - should load one more + read_calls.clear() + _ = da.isel({f"start_time_{strds_name}": 1}).values + + print(f"=== After accessing second slice: {len(read_calls)} maps read") + + assert len(read_calls) == 1, ( + f"Expected 1 raster read when accessing second time slice, " + f"got {len(read_calls)}" + ) + + print("\n✅ Lazy loading is working correctly!") + print(" - Opening STRDS: 0 maps loaded") + print(" - Accessing slice 0: 1 map loaded") + print(" - Accessing slice 1: 1 map loaded") From 69b174046d3521a3cf7bf69df2939039940d8178 Mon Sep 17 00:00:00 2001 From: Laurent Courty Date: Sat, 1 Nov 2025 18:31:28 -0600 Subject: [PATCH 3/3] update comments --- src/xarray_grass/to_grass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/xarray_grass/to_grass.py b/src/xarray_grass/to_grass.py index 4bce00d..46d4f39 100644 --- a/src/xarray_grass/to_grass.py +++ b/src/xarray_grass/to_grass.py @@ -301,7 +301,6 @@ def _datarray_to_grass( current_region = self.grass_interface.get_region() temp_region = get_region_from_xarray(data, dims) self.grass_interface.set_region(temp_region) - # TODO: reshape to match userGRASS expected dims order try: if is_raster: data = self.transpose(data, dims, arr_type="raster")