diff --git a/README.md b/README.md index 67338df..a0abac3 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ Attributes: - [x] Write to STRDS - [x] Write to 3D raster - [x] Write to STR3DS - - [ ] Transpose if dimensions are not in the expected order + - [x] Transpose if dimensions are not in the expected order - [ ] Support time units for relative time - [ ] Support `end_time` - [ ] Accept writing into a specific mapset (GRASS 8.5) diff --git a/src/xarray_grass/to_grass.py b/src/xarray_grass/to_grass.py index 8813150..4bce00d 100644 --- a/src/xarray_grass/to_grass.py +++ b/src/xarray_grass/to_grass.py @@ -304,10 +304,12 @@ def _datarray_to_grass( # TODO: reshape to match userGRASS expected dims order try: if is_raster: + data = self.transpose(data, dims, arr_type="raster") self.grass_interface.write_raster_map(data, data.name) elif is_strds: self._write_stds(data, dims) elif is_raster_3d: + data = self.transpose(data, dims, arr_type="raster3d") self.grass_interface.write_raster3d_map(data, data.name) elif is_str3ds: self._write_stds(data, dims) @@ -320,6 +322,19 @@ def _datarray_to_grass( # Restore the original region self.grass_interface.set_region(current_region) + def transpose( + self, da: xr.DataArray, dims, arr_type: str = "raster" + ) -> xr.DataArray: + """Force dimension order to conform with grass expectation.""" + if "raster" == arr_type: + return da.transpose(dims["y"], dims["x"]) + elif "raster3d" == arr_type: + return da.transpose(dims["z"], dims["y_3d"], dims["x_3d"]) + else: + raise ValueError( + f"Unknown array type: {arr_type}. Must be 'raster' or 'raster3d'." + ) + def _write_stds(self, data: xr.DataArray, dims: Mapping): # 1. Determine the temporal coordinate and type time_coord = data[dims["start_time"]] @@ -337,14 +352,17 @@ def _write_stds(self, data: xr.DataArray, dims: Mapping): # 2.5 determine if 2D or 3D is_3d = False stds_type = "strds" + arr_type = "raster" if len(data.isel({dims["start_time"]: 0}).dims) == 3: is_3d = True stds_type = "str3ds" + arr_type = "raster3d" # 3. Loop through the time dim: map_list = [] for index, time in enumerate(time_coord): darray = data.sel({dims["start_time"]: time}) + darray = self.transpose(darray, dims, arr_type=arr_type) nd_array = darray.values # 3.1 Write each map individually raster_name = f"{data.name}_{temporal_type}_{index}" diff --git a/tests/test_tograss.py b/tests/test_tograss.py index c3397b6..534112a 100644 --- a/tests/test_tograss.py +++ b/tests/test_tograss.py @@ -779,3 +779,235 @@ def test_dims_mapping( ) assert int(info["rows"]) == img_height assert int(info["cols"]) == img_width + + def test_dimension_transposition( + self, + temp_gisdb, + grass_i: GrassInterface, + ): + """Test that to_grass() correctly transposes dimensions to GRASS format. + + Creates DataArrays with standard dimension names but in non-standard order, + and verifies they are correctly transposed when written to GRASS. + """ + session_crs_wkt = grass_i.get_crs_wkt_str() + target_mapset_name = temp_gisdb.mapset + mapset_path_obj = ( + Path(temp_gisdb.gisdb) / temp_gisdb.project / target_mapset_name + ) + mapset_arg = str(mapset_path_obj) + + # Define expected dimensions for verification + height_2d, width_2d = 5, 7 + depth_3d, height_3d, width_3d = 3, 4, 6 + num_times_strds = 2 + height_strds, width_strds = 6, 5 + num_times_str3ds = 2 + depth_str3ds, height_str3ds, width_str3ds = 3, 5, 4 + + # 1. Test 2D Raster: Create with standard dims but transpose before writing + raster2d_name = "test_transpose_2d" + da_2d = create_sample_dataarray( + dims_spec={ + "y": np.arange(height_2d, dtype=float), + "x": np.arange(width_2d, dtype=float), + }, + shape=(height_2d, width_2d), + crs_wkt=session_crs_wkt, + name=raster2d_name, + fill_value_generator=lambda s: np.arange(s[0] * s[1]) + .reshape(s) + .astype(float), + ) + # Transpose to non-standard order (x, y instead of y, x) + da_2d = da_2d.transpose("x", "y") + assert da_2d.dims == ("x", "y"), f"Expected dims ('x', 'y'), got {da_2d.dims}" + assert da_2d.shape == (width_2d, height_2d) + + # 2. Test 3D Raster: Create with z,y,x then transpose to x,z,y + raster3d_name = "test_transpose_3d" + res3 = 1000 + da_3d = create_sample_dataarray( + dims_spec={ + "z": np.arange(depth_3d, dtype=float), + "y": np.linspace(220000, 220000 + (height_3d - 1) * res3, height_3d), + "x": np.linspace(630000, 630000 + (width_3d - 1) * res3, width_3d), + }, + shape=(depth_3d, height_3d, width_3d), + crs_wkt=session_crs_wkt, + name=raster3d_name, + fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2]) + .reshape(s) + .astype(float), + ) + # Transpose to non-standard order (x_3d, z, y_3d) + da_3d = da_3d.transpose("x_3d", "z", "y_3d") + assert da_3d.dims == ("x_3d", "z", "y_3d") + assert da_3d.shape == (width_3d, depth_3d, height_3d) + + # 3. Test STRDS: Create with start_time,y,x then transpose to x,y,start_time + strds_name = "test_transpose_strds" + da_strds = create_sample_dataarray( + dims_spec={ + "start_time": np.arange(1, num_times_strds + 1), + "y": np.arange(height_strds, dtype=float), + "x": np.arange(width_strds, dtype=float), + }, + shape=(num_times_strds, height_strds, width_strds), + crs_wkt=session_crs_wkt, + name=strds_name, + time_dim_type="relative", + fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2]) + .reshape(s) + .astype(float), + ) + # Transpose to non-standard order (x, y, start_time) + da_strds = da_strds.transpose("x", "y", "start_time") + assert da_strds.dims == ("x", "y", "start_time") + assert da_strds.shape == (width_strds, height_strds, num_times_strds) + + # 4. Test STR3DS: Create with time,z,y,x then transpose to y_3d,time,x_3d,z + str3ds_name = "test_transpose_str3ds" + res3_str3ds = 1000 + da_str3ds = create_sample_dataarray( + dims_spec={ + "time": np.arange(1, num_times_str3ds + 1), + "z": np.arange(depth_str3ds, dtype=float), + "y": np.linspace( + 220000, 220000 + (height_str3ds - 1) * res3_str3ds, height_str3ds + ), + "x": np.linspace( + 630000, 630000 + (width_str3ds - 1) * res3_str3ds, width_str3ds + ), + }, + shape=(num_times_str3ds, depth_str3ds, height_str3ds, width_str3ds), + crs_wkt=session_crs_wkt, + name=str3ds_name, + time_dim_type="relative", + fill_value_generator=lambda s: np.arange(s[0] * s[1] * s[2] * s[3]) + .reshape(s) + .astype(float), + ) + # Transpose to non-standard order (y_3d, time, x_3d, z) + da_str3ds = da_str3ds.transpose("y_3d", "time", "x_3d", "z") + assert da_str3ds.dims == ("y_3d", "time", "x_3d", "z") + assert da_str3ds.shape == ( + height_str3ds, + num_times_str3ds, + width_str3ds, + depth_str3ds, + ) + + # Write all DataArrays to GRASS + raster2d_id = grass_i.get_id_from_name(raster2d_name) + raster3d_id = grass_i.get_id_from_name(raster3d_name) + strds_id = grass_i.get_id_from_name(strds_name) + str3ds_id = grass_i.get_id_from_name(str3ds_name) + + try: + # Write 2D raster + to_grass(dataset=da_2d, mapset=mapset_arg) + + # Write 3D raster + to_grass(dataset=da_3d, mapset=mapset_arg) + + # Write STRDS + to_grass(dataset=da_strds, mapset=mapset_arg) + + # Write STR3DS + to_grass( + dataset=da_str3ds, + mapset=mapset_arg, + dims={str3ds_name: {"start_time": "time"}}, + ) + + # Verify 2D Raster dimensions + info_2d = gs.parse_command("r.info", map=raster2d_id, flags="g", quiet=True) + assert int(info_2d["rows"]) == height_2d, ( + f"2D Raster rows mismatch: expected {height_2d}, got {info_2d['rows']}" + ) + assert int(info_2d["cols"]) == width_2d, ( + f"2D Raster cols mismatch: expected {width_2d}, got {info_2d['cols']}" + ) + + # Verify 3D Raster dimensions + info_3d = gs.parse_command( + "r3.info", map=raster3d_id, flags="g", quiet=True + ) + assert int(info_3d["depths"]) == depth_3d, ( + f"3D Raster depths mismatch: expected {depth_3d}, got {info_3d['depths']}" + ) + assert int(info_3d["rows"]) == height_3d, ( + f"3D Raster rows mismatch: expected {height_3d}, got {info_3d['rows']}" + ) + assert int(info_3d["cols"]) == width_3d, ( + f"3D Raster cols mismatch: expected {width_3d}, got {info_3d['cols']}" + ) + + # Verify STRDS + strds_maps = grass_i.list_maps_in_strds(strds_id) + assert len(strds_maps) == num_times_strds, ( + f"STRDS map count mismatch: expected {num_times_strds}, got {len(strds_maps)}" + ) + # Check dimensions of first map in STRDS + first_map_id = strds_maps[0].id + info_strds = gs.parse_command( + "r.info", map=first_map_id, flags="g", quiet=True + ) + assert int(info_strds["rows"]) == height_strds, ( + f"STRDS map rows mismatch: expected {height_strds}, got {info_strds['rows']}" + ) + assert int(info_strds["cols"]) == width_strds, ( + f"STRDS map cols mismatch: expected {width_strds}, got {info_strds['cols']}" + ) + + # Verify STR3DS + str3ds_maps = grass_i.list_maps_in_str3ds(str3ds_id) + assert len(str3ds_maps) == num_times_str3ds, ( + f"STR3DS map count mismatch: expected {num_times_str3ds}, got {len(str3ds_maps)}" + ) + # Check dimensions of first map in STR3DS + first_map_3d_id = str3ds_maps[0].id + info_str3ds = gs.parse_command( + "r3.info", map=first_map_3d_id, flags="g", quiet=True + ) + assert int(info_str3ds["depths"]) == depth_str3ds, ( + f"STR3DS map depths mismatch: expected {depth_str3ds}, got {info_str3ds['depths']}" + ) + assert int(info_str3ds["rows"]) == height_str3ds, ( + f"STR3DS map rows mismatch: expected {height_str3ds}, got {info_str3ds['rows']}" + ) + assert int(info_str3ds["cols"]) == width_str3ds, ( + f"STR3DS map cols mismatch: expected {width_str3ds}, got {info_str3ds['cols']}" + ) + + finally: + # Cleanup + try: + gs.run_command( + "g.remove", flags="f", type="raster", name=raster2d_id, quiet=True + ) + except CalledModuleError: + pass + try: + gs.run_command( + "g.remove", + flags="f", + type="raster_3d", + name=raster3d_id, + quiet=True, + ) + except CalledModuleError: + pass + try: + gs.run_command( + "t.remove", inputs=strds_id, type="strds", flags="rfd", quiet=True + ) + except CalledModuleError: + pass + try: + gs.run_command( + "t.remove", inputs=str3ds_id, type="str3ds", flags="rfd", quiet=True + ) + except CalledModuleError: + pass