Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ Attributes:
- [x] Write to STRDS
- [x] Write to 3D raster
- [x] Write to STR3DS
- [ ] Transpose if dimensions are not in the expected order
- [x] Transpose if dimensions are not in the expected order
- [ ] Support time units for relative time
- [ ] Support `end_time`
- [ ] Accept writing into a specific mapset (GRASS 8.5)
Expand Down
18 changes: 18 additions & 0 deletions src/xarray_grass/to_grass.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,12 @@ def _datarray_to_grass(
# TODO: reshape to match userGRASS expected dims order
try:
if is_raster:
data = self.transpose(data, dims, arr_type="raster")
self.grass_interface.write_raster_map(data, data.name)
elif is_strds:
self._write_stds(data, dims)
elif is_raster_3d:
data = self.transpose(data, dims, arr_type="raster3d")
self.grass_interface.write_raster3d_map(data, data.name)
elif is_str3ds:
self._write_stds(data, dims)
Expand All @@ -320,6 +322,19 @@ def _datarray_to_grass(
# Restore the original region
self.grass_interface.set_region(current_region)

def transpose(
self, da: xr.DataArray, dims, arr_type: str = "raster"
) -> xr.DataArray:
"""Force dimension order to conform with grass expectation."""
if "raster" == arr_type:
return da.transpose(dims["y"], dims["x"])
elif "raster3d" == arr_type:
return da.transpose(dims["z"], dims["y_3d"], dims["x_3d"])
else:
raise ValueError(
f"Unknown array type: {arr_type}. Must be 'raster' or 'raster3d'."
)

def _write_stds(self, data: xr.DataArray, dims: Mapping):
# 1. Determine the temporal coordinate and type
time_coord = data[dims["start_time"]]
Expand All @@ -337,14 +352,17 @@ def _write_stds(self, data: xr.DataArray, dims: Mapping):
# 2.5 determine if 2D or 3D
is_3d = False
stds_type = "strds"
arr_type = "raster"
if len(data.isel({dims["start_time"]: 0}).dims) == 3:
is_3d = True
stds_type = "str3ds"
arr_type = "raster3d"

# 3. Loop through the time dim:
map_list = []
for index, time in enumerate(time_coord):
darray = data.sel({dims["start_time"]: time})
darray = self.transpose(darray, dims, arr_type=arr_type)
nd_array = darray.values
# 3.1 Write each map individually
raster_name = f"{data.name}_{temporal_type}_{index}"
Expand Down
232 changes: 232 additions & 0 deletions tests/test_tograss.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,235 @@ def test_dims_mapping(
)
assert int(info["rows"]) == img_height
assert int(info["cols"]) == img_width

def test_dimension_transposition(
self,
temp_gisdb,
grass_i: GrassInterface,
):
"""Test that to_grass() correctly transposes dimensions to GRASS format.

Creates DataArrays with standard dimension names but in non-standard order,
and verifies they are correctly transposed when written to GRASS.
"""
session_crs_wkt = grass_i.get_crs_wkt_str()
target_mapset_name = temp_gisdb.mapset
mapset_path_obj = (
Path(temp_gisdb.gisdb) / temp_gisdb.project / target_mapset_name
)
mapset_arg = str(mapset_path_obj)

# Define expected dimensions for verification
height_2d, width_2d = 5, 7
depth_3d, height_3d, width_3d = 3, 4, 6
num_times_strds = 2
height_strds, width_strds = 6, 5
num_times_str3ds = 2
depth_str3ds, height_str3ds, width_str3ds = 3, 5, 4

# 1. Test 2D Raster: Create with standard dims but transpose before writing
raster2d_name = "test_transpose_2d"
da_2d = create_sample_dataarray(
dims_spec={
"y": np.arange(height_2d, dtype=float),
"x": np.arange(width_2d, dtype=float),
},
shape=(height_2d, width_2d),
crs_wkt=session_crs_wkt,
name=raster2d_name,
fill_value_generator=lambda s: np.arange(s[0] * s[1])
.reshape(s)
.astype(float),
)
# Transpose to non-standard order (x, y instead of y, x)
da_2d = da_2d.transpose("x", "y")
assert da_2d.dims == ("x", "y"), f"Expected dims ('x', 'y'), got {da_2d.dims}"
assert da_2d.shape == (width_2d, height_2d)

# 2. Test 3D Raster: Create with z,y,x then transpose to x,z,y
raster3d_name = "test_transpose_3d"
res3 = 1000
da_3d = create_sample_dataarray(
dims_spec={
"z": np.arange(depth_3d, dtype=float),
"y": np.linspace(220000, 220000 + (height_3d - 1) * res3, height_3d),
"x": np.linspace(630000, 630000 + (width_3d - 1) * res3, width_3d),
},
shape=(depth_3d, height_3d, width_3d),
crs_wkt=session_crs_wkt,
name=raster3d_name,
fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2])
.reshape(s)
.astype(float),
)
# Transpose to non-standard order (x_3d, z, y_3d)
da_3d = da_3d.transpose("x_3d", "z", "y_3d")
assert da_3d.dims == ("x_3d", "z", "y_3d")
assert da_3d.shape == (width_3d, depth_3d, height_3d)

# 3. Test STRDS: Create with start_time,y,x then transpose to x,y,start_time
strds_name = "test_transpose_strds"
da_strds = create_sample_dataarray(
dims_spec={
"start_time": np.arange(1, num_times_strds + 1),
"y": np.arange(height_strds, dtype=float),
"x": np.arange(width_strds, dtype=float),
},
shape=(num_times_strds, height_strds, width_strds),
crs_wkt=session_crs_wkt,
name=strds_name,
time_dim_type="relative",
fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2])
.reshape(s)
.astype(float),
)
# Transpose to non-standard order (x, y, start_time)
da_strds = da_strds.transpose("x", "y", "start_time")
assert da_strds.dims == ("x", "y", "start_time")
assert da_strds.shape == (width_strds, height_strds, num_times_strds)

# 4. Test STR3DS: Create with time,z,y,x then transpose to y_3d,time,x_3d,z
str3ds_name = "test_transpose_str3ds"
res3_str3ds = 1000
da_str3ds = create_sample_dataarray(
dims_spec={
"time": np.arange(1, num_times_str3ds + 1),
"z": np.arange(depth_str3ds, dtype=float),
"y": np.linspace(
220000, 220000 + (height_str3ds - 1) * res3_str3ds, height_str3ds
),
"x": np.linspace(
630000, 630000 + (width_str3ds - 1) * res3_str3ds, width_str3ds
),
},
shape=(num_times_str3ds, depth_str3ds, height_str3ds, width_str3ds),
crs_wkt=session_crs_wkt,
name=str3ds_name,
time_dim_type="relative",
fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2] * s[3])
.reshape(s)
.astype(float),
)
# Transpose to non-standard order (y_3d, time, x_3d, z)
da_str3ds = da_str3ds.transpose("y_3d", "time", "x_3d", "z")
assert da_str3ds.dims == ("y_3d", "time", "x_3d", "z")
assert da_str3ds.shape == (
height_str3ds,
num_times_str3ds,
width_str3ds,
depth_str3ds,
)

# Write all DataArrays to GRASS
raster2d_id = grass_i.get_id_from_name(raster2d_name)
raster3d_id = grass_i.get_id_from_name(raster3d_name)
strds_id = grass_i.get_id_from_name(strds_name)
str3ds_id = grass_i.get_id_from_name(str3ds_name)

try:
# Write 2D raster
to_grass(dataset=da_2d, mapset=mapset_arg)

# Write 3D raster
to_grass(dataset=da_3d, mapset=mapset_arg)

# Write STRDS
to_grass(dataset=da_strds, mapset=mapset_arg)

# Write STR3DS
to_grass(
dataset=da_str3ds,
mapset=mapset_arg,
dims={str3ds_name: {"start_time": "time"}},
)

# Verify 2D Raster dimensions
info_2d = gs.parse_command("r.info", map=raster2d_id, flags="g", quiet=True)
assert int(info_2d["rows"]) == height_2d, (
f"2D Raster rows mismatch: expected {height_2d}, got {info_2d['rows']}"
)
assert int(info_2d["cols"]) == width_2d, (
f"2D Raster cols mismatch: expected {width_2d}, got {info_2d['cols']}"
)

# Verify 3D Raster dimensions
info_3d = gs.parse_command(
"r3.info", map=raster3d_id, flags="g", quiet=True
)
assert int(info_3d["depths"]) == depth_3d, (
f"3D Raster depths mismatch: expected {depth_3d}, got {info_3d['depths']}"
)
assert int(info_3d["rows"]) == height_3d, (
f"3D Raster rows mismatch: expected {height_3d}, got {info_3d['rows']}"
)
assert int(info_3d["cols"]) == width_3d, (
f"3D Raster cols mismatch: expected {width_3d}, got {info_3d['cols']}"
)

# Verify STRDS
strds_maps = grass_i.list_maps_in_strds(strds_id)
assert len(strds_maps) == num_times_strds, (
f"STRDS map count mismatch: expected {num_times_strds}, got {len(strds_maps)}"
)
# Check dimensions of first map in STRDS
first_map_id = strds_maps[0].id
info_strds = gs.parse_command(
"r.info", map=first_map_id, flags="g", quiet=True
)
assert int(info_strds["rows"]) == height_strds, (
f"STRDS map rows mismatch: expected {height_strds}, got {info_strds['rows']}"
)
assert int(info_strds["cols"]) == width_strds, (
f"STRDS map cols mismatch: expected {width_strds}, got {info_strds['cols']}"
)

# Verify STR3DS
str3ds_maps = grass_i.list_maps_in_str3ds(str3ds_id)
assert len(str3ds_maps) == num_times_str3ds, (
f"STR3DS map count mismatch: expected {num_times_str3ds}, got {len(str3ds_maps)}"
)
# Check dimensions of first map in STR3DS
first_map_3d_id = str3ds_maps[0].id
info_str3ds = gs.parse_command(
"r3.info", map=first_map_3d_id, flags="g", quiet=True
)
assert int(info_str3ds["depths"]) == depth_str3ds, (
f"STR3DS map depths mismatch: expected {depth_str3ds}, got {info_str3ds['depths']}"
)
assert int(info_str3ds["rows"]) == height_str3ds, (
f"STR3DS map rows mismatch: expected {height_str3ds}, got {info_str3ds['rows']}"
)
assert int(info_str3ds["cols"]) == width_str3ds, (
f"STR3DS map cols mismatch: expected {width_str3ds}, got {info_str3ds['cols']}"
)

finally:
# Cleanup
try:
gs.run_command(
"g.remove", flags="f", type="raster", name=raster2d_id, quiet=True
)
except CalledModuleError:
pass
try:
gs.run_command(
"g.remove",
flags="f",
type="raster_3d",
name=raster3d_id,
quiet=True,
)
except CalledModuleError:
pass
try:
gs.run_command(
"t.remove", inputs=strds_id, type="strds", flags="rfd", quiet=True
)
except CalledModuleError:
pass
try:
gs.run_command(
"t.remove", inputs=str3ds_id, type="str3ds", flags="rfd", quiet=True
)
except CalledModuleError:
pass