Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Attributes:
- [x] Write to 3D raster
- [x] Write to STR3DS
- [x] Transpose if dimensions are not in the expected order
- [ ] Support time units for relative time
- [x] Support time units for relative time
- [ ] Support `end_time`
- [ ] Accept writing into a specific mapset (GRASS 8.5)
- [ ] Accept non homogeneous 3D resolution in NS and EW dimensions (GRASS 8.5)
Expand Down
23 changes: 16 additions & 7 deletions src/xarray_grass/grass_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Self
from typing import Self, Optional

import numpy as np
import pandas as pd


# Needed to import grass modules
import grass_session # noqa: F401
Expand Down Expand Up @@ -436,10 +438,10 @@ def register_maps_in_stds(
semantic: str,
t_type: str,
stds_type: str,
time_unit: Optional[str] = None,
) -> Self:
"""Create a STDS, create one mapdataset for each map and
register them in the temporal database.
TODO: add support for units other than seconds
"""
# create stds
stds_id = self.get_id_from_name(stds_name)
Expand Down Expand Up @@ -470,9 +472,11 @@ def register_maps_in_stds(
if t_type == "relative":
if not isinstance(map_time, timedelta):
raise TypeError("relative time requires a timedelta object.")
# TODO: support other units
rel_time = map_time.total_seconds()
map_dts.set_relative_time(rel_time, None, "seconds")
if not time_unit:
raise TypeError("relative time requires a time_unit.")
# Convert timedelta to numeric value in the specified unit
rel_time = map_time / pd.Timedelta(1, unit=time_unit)
map_dts.set_relative_time(rel_time, None, time_unit)
elif t_type == "absolute":
if not isinstance(map_time, datetime):
raise TypeError("absolute time requires a datetime object.")
Expand All @@ -484,7 +488,12 @@ def register_maps_in_stds(
# populate the list of MapDataset objects
map_dts_lst.append(map_dts)
# Finally register the maps
t_unit = {"relative": "seconds", "absolute": ""}
# Use provided unit for relative time, empty string for absolute
if t_type == "relative":
t_unit = time_unit
else:
t_unit = ""

map_type = "raster"
if stds_type == "str3ds":
map_type = "raster_3d"
Expand All @@ -493,6 +502,6 @@ def register_maps_in_stds(
map_list=map_dts_lst,
output_stds=stds,
delete_empty=True,
unit=t_unit[t_type],
unit=t_unit,
)
return self
21 changes: 16 additions & 5 deletions src/xarray_grass/to_grass.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,7 @@ def _datarray_to_grass(
data: xr.DataArray,
dims: Mapping[str, str],
) -> None:
"""Convert an xarray DataArray to GRASS maps.

Uses standardized (x, y) dimension naming internally. For datasets with
latitude/longitude dimensions, provide explicit mapping via dims parameter.
"""
"""Convert an xarray DataArray to GRASS maps."""
if len(data.dims) > 4 or len(data.dims) < 2:
raise ValueError(
f"Only DataArray with 2 to 4 dimensions are supported. "
Expand Down Expand Up @@ -338,11 +334,25 @@ def _write_stds(self, data: xr.DataArray, dims: Mapping):
# 1. Determine the temporal coordinate and type
time_coord = data[dims["start_time"]]
time_dtype = time_coord.dtype
time_unit = None # Initialize for absolute time case
if isinstance(time_dtype, np.dtypes.DateTime64DType):
temporal_type = "absolute"
elif np.issubdtype(time_dtype, np.integer):
temporal_type = "relative"
time_unit = time_coord.attrs.get("units", None)
if not time_unit:
raise ValueError(
f"Relative time coordinate '{dims['start_time']}' in DataArray '{data.name}' "
"requires a 'units' attribute. "
"Accepted values: 'days', 'hours', 'minutes', 'seconds'."
)
# Validate that the unit is supported by both pandas and GRASS
supported_units = ["days", "hours", "minutes", "seconds"]
if time_unit not in supported_units:
raise ValueError(
f"Unsupported time unit '{time_unit}' for relative time in DataArray '{data.name}'. "
f"Supported units are: {', '.join(supported_units)}. "
)
else:
raise ValueError(f"Temporal type not supported: {time_dtype}")
# 2. Determine the semantic type
Expand Down Expand Up @@ -390,4 +400,5 @@ def _write_stds(self, data: xr.DataArray, dims: Mapping):
semantic=semantic_type,
t_type=temporal_type,
stds_type=stds_type,
time_unit=time_unit,
)
12 changes: 8 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def gen_str3ds(
"""Generate an synthetic str3ds."""
grass_i = GrassInterface()
if temporal_type == "relative":
time_unit = "months"
time_unit = "days"
str3ds_times = [i + 1 for i in range(str3ds_length)]
elif temporal_type == "absolute":
time_unit = ""
Expand Down Expand Up @@ -325,7 +325,7 @@ def create_sample_dataarray(
shape: tuple,
crs_wkt: str,
name: str = "test_data",
time_dim_type: str = "absolute", # "absolute", "relative", or "none"
time_dim_type: str = "none", # "absolute", "relative", or "none"
fill_value_generator=None,
) -> xr.DataArray:
"""
Expand Down Expand Up @@ -359,10 +359,12 @@ def create_sample_dataarray(
# Determine context for spatial dimension naming (2D or 3D)
is_3d_spatial_context = "z" in dims_spec

for dim_key in dims_spec.keys(): # Iterate in the order provided by dims_spec
coord_values = dims_spec[dim_key]
for dim_key, coord_values in dims_spec.items():
actual_dim_name = dim_key # Default to key

if dim_key not in ["time", "x", "y", "z"]:
raise ValueError(f"Unknown dim_key: {dim_key}")

if dim_key == "time":
if time_dim_type == "absolute":
coords[actual_dim_name] = pd.to_datetime(coord_values)
Expand Down Expand Up @@ -402,6 +404,8 @@ def create_sample_dataarray(
name=name,
)
da.attrs["crs_wkt"] = crs_wkt
if time_dim_type == "relative" and "time" in dims_spec:
da["time"].attrs["units"] = "minutes"
return da


Expand Down
89 changes: 79 additions & 10 deletions tests/test_tograss.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ def test_dataarray_to_strds_conversion(
pytest.fail(f"Unsupported time_dim_type: {time_dim_type}")

dims_spec_for_helper = {
"start_time": time_coords,
"time": time_coords,
"y": np.arange(img_height, dtype=float),
"x": np.arange(img_width, dtype=float),
}
expected_dims_order_in_da = ("start_time", "y", "x")
expected_dims_order_in_da = ("time", "y", "x")

shape = (num_times, img_height, img_width)
session_crs_wkt = grass_i.get_crs_wkt_str()
Expand All @@ -276,6 +276,10 @@ def test_dataarray_to_strds_conversion(
f"DataArray dims {sample_da.dims} do not match expected {expected_dims_order_in_da}"
)

# Set time unit attribute for relative time
if time_dim_type == "relative":
sample_da["time"].attrs["units"] = "days"

target_mapset_name = temp_gisdb.mapset
mapset_path_obj = (
Path(temp_gisdb.gisdb) / temp_gisdb.project / target_mapset_name
Expand All @@ -286,6 +290,7 @@ def test_dataarray_to_strds_conversion(
dataset=sample_da,
mapset=mapset_arg,
create=False,
dims={"test_strds": {"start_time": "time"}},
)
strds_id = f"{sample_da.name}@{target_mapset_name}"

Expand All @@ -303,11 +308,38 @@ def test_dataarray_to_strds_conversion(
f"Expected {num_times} maps in STRDS '{strds_id}', found {len(strds_map_names_in_grass)}."
)

# Check temporal metadata including time units for relative time
t_info = gs.parse_command("t.info", input=strds_id, flags="g", quiet=True)

if time_dim_type == "relative":
# Verify temporal type is relative
assert t_info.get("temporal_type") == "relative", (
f"Expected temporal_type='relative' for STRDS '{strds_id}', "
f"got '{t_info.get('temporal_type')}'"
)
# Verify relative time unit is properly written to the GRASS database
expected_unit = sample_da["time"].attrs.get("units")
assert "unit" in t_info, (
f"Expected 'unit' key in STRDS metadata for '{strds_id}', "
f"but found keys: {list(t_info.keys())}"
)
time_unit = t_info.get("unit")
assert time_unit == expected_unit, (
f"Expected time unit '{expected_unit}' for relative STRDS '{strds_id}', "
f"got '{time_unit}'"
)
elif time_dim_type == "absolute":
# Verify temporal type is absolute
assert t_info.get("temporal_type") == "absolute", (
f"Expected temporal_type='absolute' for STRDS '{strds_id}', "
f"got '{t_info.get('temporal_type')}'"
)

# Check statistics for the first and last time slices
indices_to_check = [0, num_times - 1] if num_times > 0 else []
for idx_in_da_time in indices_to_check:
time_val = sample_da.start_time.values[idx_in_da_time]
da_slice = sample_da.sel(start_time=time_val).astype(
time_val = sample_da.time.values[idx_in_da_time]
da_slice = sample_da.sel(time=time_val).astype(
float
) # Ensure float for comparison

Expand Down Expand Up @@ -390,6 +422,10 @@ def test_dataarray_to_str3ds_conversion(
f"DataArray dims {sample_da.dims} do not match expected {expected_dims_order_in_da}"
)

# Set time unit attribute for relative time
if time_dim_type == "relative":
sample_da["time"].attrs["units"] = "days"

target_mapset_name = temp_gisdb.mapset
mapset_path_obj = (
Path(temp_gisdb.gisdb) / temp_gisdb.project / target_mapset_name
Expand All @@ -415,6 +451,35 @@ def test_dataarray_to_str3ds_conversion(
f"Expected {num_times} maps in STR3DS '{str3ds_id}', found {len(str3ds_maps_in_grass)}."
)

# Check temporal metadata including time units for relative time
t_info = gs.parse_command(
"t.info", type="str3ds", input=str3ds_id, flags="g", quiet=True
)

if time_dim_type == "relative":
# Verify temporal type is relative
assert t_info.get("temporal_type") == "relative", (
f"Expected temporal_type='relative' for STR3DS '{str3ds_id}', "
f"got '{t_info.get('temporal_type')}'"
)
# Verify relative time unit is properly written to the GRASS database
expected_unit = sample_da["time"].attrs.get("units")
assert "unit" in t_info, (
f"Expected 'unit' key in STR3DS metadata for '{str3ds_id}', "
f"but found keys: {list(t_info.keys())}"
)
time_unit = t_info.get("unit")
assert time_unit == expected_unit, (
f"Expected time unit '{expected_unit}' for relative STR3DS '{str3ds_id}', "
f"got '{time_unit}'"
)
elif time_dim_type == "absolute":
# Verify temporal type is absolute
assert t_info.get("temporal_type") == "absolute", (
f"Expected temporal_type='absolute' for STR3DS '{str3ds_id}', "
f"got '{t_info.get('temporal_type')}'"
)

# Check statistics for the first and last time slices
indices_to_check = [0, num_times - 1] if num_times > 0 else []
for idx_in_da_time in indices_to_check:
Expand Down Expand Up @@ -845,11 +910,11 @@ def test_dimension_transposition(
assert da_3d.dims == ("x_3d", "z", "y_3d")
assert da_3d.shape == (width_3d, depth_3d, height_3d)

# 3. Test STRDS: Create with start_time,y,x then transpose to x,y,start_time
# 3. Test STRDS: Create with start_time,y,x then transpose to x,y,time
strds_name = "test_transpose_strds"
da_strds = create_sample_dataarray(
dims_spec={
"start_time": np.arange(1, num_times_strds + 1),
"time": np.arange(1, num_times_strds + 1),
"y": np.arange(height_strds, dtype=float),
"x": np.arange(width_strds, dtype=float),
},
Expand All @@ -861,9 +926,9 @@ def test_dimension_transposition(
.reshape(s)
.astype(float),
)
# Transpose to non-standard order (x, y, start_time)
da_strds = da_strds.transpose("x", "y", "start_time")
assert da_strds.dims == ("x", "y", "start_time")
# Transpose to non-standard order (x, y, time)
da_strds = da_strds.transpose("x", "y", "time")
assert da_strds.dims == ("x", "y", "time")
assert da_strds.shape == (width_strds, height_strds, num_times_strds)

# 4. Test STR3DS: Create with time,z,y,x then transpose to y_3d,time,x_3d,z
Expand Down Expand Up @@ -912,7 +977,11 @@ def test_dimension_transposition(
to_grass(dataset=da_3d, mapset=mapset_arg)

# Write STRDS
to_grass(dataset=da_strds, mapset=mapset_arg)
to_grass(
dataset=da_strds,
mapset=mapset_arg,
dims={strds_name: {"start_time": "time"}},
)

# Write STR3DS
to_grass(
Expand Down