diff --git a/python/omfiles/xarray.py b/python/omfiles/xarray.py index 4bcfb43..3ac2061 100644 --- a/python/omfiles/xarray.py +++ b/python/omfiles/xarray.py @@ -1,10 +1,15 @@ """OmFileReader backend for Xarray.""" # ruff: noqa: D101, D102, D105, D107 -from __future__ import annotations +import itertools +import os +import warnings +from typing import Any, Generator import numpy as np +from omfiles.dask import _validate_chunk_alignment + try: from xarray.core import indexing except ImportError: @@ -21,7 +26,7 @@ from xarray.core.utils import FrozenDict from xarray.core.variable import Variable -from ._rust import OmFileReader, OmVariable +from ._rust import OmFileReader, OmFileWriter, OmVariable # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -41,10 +46,16 @@ def open_dataset( with OmFileReader(filename_or_obj) as root_variable: store = OmDataStore(root_variable) store_entrypoint = StoreBackendEntrypoint() - return store_entrypoint.open_dataset( + ds = store_entrypoint.open_dataset( store, drop_variables=drop_variables, ) + coord_attr = "_COORDINATE_VARIABLES" + if coord_attr in ds.attrs: + coord_names = [c for c in ds.attrs[coord_attr].split(",") if c in ds] + ds = ds.set_coords(coord_names) + ds.attrs = {k: v for k, v in ds.attrs.items() if k != coord_attr} + return ds raise ValueError("Failed to open dataset") description = "Use .om files in Xarray" @@ -76,6 +87,11 @@ def _get_attributes_for_variable(self, reader: OmFileReader, path: str): for k, variable in direct_children.items(): child_reader = reader._init_from_variable(variable) if child_reader.is_scalar: + # Skip scalars that have _ARRAY_DIMENSIONS — they are 0-d + # coordinate variables, not plain attributes. + dim_key = path + "/" + k + "/" + DIMENSION_KEY + if dim_key in self.variables_store: + continue attrs[k] = child_reader.read_scalar() return attrs @@ -153,6 +169,31 @@ def _get_datasets(self, reader: OmFileReader): data = indexing.LazilyIndexedArray(backend_array) datasets[var_key] = Variable(dims=dim_names, data=data, attrs=attrs_for_var, encoding=None, fastpath=True) + + # Handle 0-d (scalar) variables that have _ARRAY_DIMENSIONS metadata. + # These are scalar coordinates written by write_dataset. + for var_key, var in self.variables_store.items(): + if var_key in datasets: + continue + child_reader = reader._init_from_variable(var) + if not child_reader.is_scalar: + continue + dim_path = var_key + "/" + DIMENSION_KEY + if dim_path not in self.variables_store: + continue + dim_reader = reader._init_from_variable(self.variables_store[dim_path]) + dim_names_str = dim_reader.read_scalar() + if isinstance(dim_names_str, str) and dim_names_str == "": + dim_names = () + elif isinstance(dim_names_str, str): + dim_names = tuple(dim_names_str.split(",")) + else: + dim_names = () + scalar_value = child_reader.read_scalar() + attrs = self._get_attributes_for_variable(child_reader, var_key) + attrs_for_var = {k: v for k, v in attrs.items() if k != DIMENSION_KEY} + datasets[var_key] = Variable(dims=dim_names, data=np.array(scalar_value)) + return datasets def close(self): @@ -181,3 +222,201 @@ def __getitem__(self, key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: indexing.IndexingSupport.BASIC, self.reader.__getitem__, ) + + +def _write_scalar_safe(writer: OmFileWriter, value: Any, name: str) -> OmVariable | None: + """Write a scalar, returning None and warning if the type is unsupported.""" + try: + return writer.write_scalar(value, name=name) + except (ValueError, TypeError) as e: + warnings.warn( + f"Skipping attribute '{name}' with value {value!r}: {e}", + UserWarning, + stacklevel=3, + ) + return None + + +def _chunked_block_iterator(data: Any) -> Generator[np.ndarray, None, None]: + """ + Yield numpy arrays from a chunked array in C-order block traversal. + + Works with any array that exposes ``.numblocks``, ``.blocks[idx]``, + and ``.compute()`` (e.g. dask arrays). No dask import required. + """ + block_index_ranges = [range(n) for n in data.numblocks] + for block_indices in itertools.product(*block_index_ranges): + block = data.blocks[block_indices] + if hasattr(block, "compute"): + yield block.compute() + else: + yield np.asarray(block) + + +def _resolve_chunks_for_variable( + var_name: str, + var: Variable, + encoding: dict[str, dict[str, Any]] | None, + global_chunks: dict[str, int] | None, + data_chunks: tuple | None = None, +) -> list[int]: + """Resolve chunk sizes for a variable using the priority chain.""" + if encoding and var_name in encoding and "chunks" in encoding[var_name]: + return list(encoding[var_name]["chunks"]) + + if global_chunks is not None: + return [global_chunks.get(dim, min(size, 512)) for dim, size in zip(var.dims, var.shape)] + + if data_chunks is not None: + return [int(c[0]) for c in data_chunks] + + return [min(size, 512) for size in var.shape] + + +def _resolve_encoding_for_variable( + var_name: str, + encoding: dict[str, dict[str, Any]] | None, + global_scale_factor: float, + global_add_offset: float, + global_compression: str, +) -> tuple[float, float, str]: + """Resolve compression parameters for a variable.""" + var_enc = (encoding or {}).get(var_name, {}) + sf = var_enc.get("scale_factor", global_scale_factor) + ao = var_enc.get("add_offset", global_add_offset) + comp = var_enc.get("compression", global_compression) + return sf, ao, comp + + +def write_dataset( + ds: Dataset, + path: str | os.PathLike, + *, + fs: Any | None = None, + encoding: dict[str, dict[str, Any]] | None = None, + chunks: dict[str, int] | None = None, + scale_factor: float = 1.0, + add_offset: float = 0.0, + compression: str = "pfor_delta_2d", +) -> None: + """ + Write an xarray Dataset to an OM file. + + The resulting file can be read back with ``xr.open_dataset(path, engine="om")``. + + Args: + ds: The xarray Dataset to write. + path: Output file path (local path or path within the fsspec filesystem). + fs: Optional fsspec filesystem object. When provided, the file is written + via ``OmFileWriter.from_fsspec(fs, path)`` instead of the default + local-file writer. + encoding: Per-variable overrides. Keys per variable: ``"chunks"``, + ``"scale_factor"``, ``"add_offset"``, ``"compression"``. + chunks: Global default chunk sizes as ``{dim_name: chunk_size}``. + scale_factor: Global default scale factor for float compression. + add_offset: Global default offset for float compression. + compression: Global default compression algorithm. + """ + path = str(path) + if fs is not None: + writer = OmFileWriter.from_fsspec(fs, path) + else: + writer = OmFileWriter(path) + all_children: list[OmVariable] = [] + + def _write_variable(name: str, var: Variable, is_dim_coord: bool) -> None: + """Write a single variable (data var or non-dimension coordinate).""" + if np.issubdtype(var.dtype, np.datetime64) or np.issubdtype(var.dtype, np.timedelta64): + raise TypeError( + f"Variable '{name}' has dtype {var.dtype}. " + "OM files do not support datetime64/timedelta64 natively. " + "Convert to a numeric type before writing." + ) + + var_children: list[OmVariable] = [] + + if not is_dim_coord: + dim_str = ",".join(var.dims) + dim_var = writer.write_scalar(dim_str, name=DIMENSION_KEY) + var_children.append(dim_var) + + for attr_name, attr_value in var.attrs.items(): + scalar = _write_scalar_safe(writer, attr_value, attr_name) + if scalar is not None: + var_children.append(scalar) + + if var.ndim == 0: + om_var = writer.write_scalar( + var.values[()], + name=name, + children=var_children if var_children else None, + ) + all_children.append(om_var) + return + + data = var.data + is_chunked = not is_dim_coord and hasattr(data, "chunks") and data.chunks is not None + + if is_dim_coord: + resolved_chunks = [var.shape[0]] + else: + resolved_chunks = _resolve_chunks_for_variable( + name, + var, + encoding, + chunks, + data_chunks=data.chunks if is_chunked else None, + ) + + sf, ao, comp = _resolve_encoding_for_variable(name, encoding, scale_factor, add_offset, compression) + + if is_chunked: + _validate_chunk_alignment(data.chunks, resolved_chunks, var.shape) + om_var = writer.write_array_streaming( + dimensions=[int(d) for d in var.shape], + chunks=[int(c) for c in resolved_chunks], + chunk_iterator=_chunked_block_iterator(data), + dtype=var.dtype, + scale_factor=sf, + add_offset=ao, + compression=comp, + name=name, + children=var_children if var_children else None, + ) + else: + om_var = writer.write_array( + var.values, + chunks=resolved_chunks, + scale_factor=sf, + add_offset=ao, + compression=comp, + name=name, + children=var_children if var_children else None, + ) + all_children.append(om_var) + + for var_name in ds.data_vars: + _write_variable(var_name, ds[var_name].variable, is_dim_coord=False) + + non_dim_coords: list[str] = [] + for coord_name in ds.coords: + if coord_name in ds.data_vars: + continue + coord = ds.coords[coord_name] + is_dim_coord = coord.ndim == 1 and coord.dims[0] == coord_name + if not is_dim_coord: + non_dim_coords.append(coord_name) + _write_variable(coord_name, coord.variable, is_dim_coord=is_dim_coord) + + # Write list of non-dimension coordinates so the reader can restore them + if non_dim_coords: + coord_list_var = writer.write_scalar(",".join(non_dim_coords), name="_COORDINATE_VARIABLES") + all_children.append(coord_list_var) + + for attr_name, attr_value in ds.attrs.items(): + scalar = _write_scalar_safe(writer, attr_value, attr_name) + if scalar is not None: + all_children.append(scalar) + + root_var = writer.write_group(name="", children=all_children) + writer.close(root_var) diff --git a/tests/test_fsspec.py b/tests/test_fsspec.py index 31edebf..00ff5fc 100644 --- a/tests/test_fsspec.py +++ b/tests/test_fsspec.py @@ -11,6 +11,7 @@ import xarray as xr from fsspec.implementations.local import LocalFileSystem from fsspec.implementations.memory import MemoryFileSystem +from omfiles.xarray import write_dataset from s3fs import S3FileSystem from .test_utils import filter_numpy_size_warning, find_chunk_for_timestamp @@ -277,3 +278,117 @@ def read_slice(idx, start): assert len(results) == num_threads for i, arr in enumerate(results): np.testing.assert_array_equal(arr, data[i * slice_size : (i + 1) * slice_size, :]) + + +# --- write_dataset fsspec tests --- + + +@filter_numpy_size_warning +def test_write_dataset_memory_fsspec(memory_fs): + """write_dataset with fs= writes to a memory filesystem and reads back.""" + ds = xr.Dataset( + {"temperature": (["lat", "lon"], np.random.rand(5, 5).astype(np.float32))}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "Test dataset"}, + ) + write_dataset(ds, "dataset_test.om", fs=memory_fs, scale_factor=100000.0) + assert_file_exists(memory_fs, "dataset_test.om") + + reader = omfiles.OmFileReader.from_fsspec(memory_fs, "dataset_test.om") + assert reader.num_children > 0 + reader.close() + + +@filter_numpy_size_warning +def test_write_dataset_memory_fsspec_roundtrip(memory_fs): + """Full roundtrip: write_dataset via memory fs, read back with xarray.""" + temperature_data = np.random.rand(5, 5).astype(np.float32) + ds = xr.Dataset( + {"temperature": (["lat", "lon"], temperature_data)}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "fsspec roundtrip test"}, + ) + path = "roundtrip_dataset.om" + write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0) + + # Dump from memory fs to a temp file so xarray can read it back + with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: + tmp.write(memory_fs.cat(path)) + tmp_path = tmp.name + try: + ds2 = xr.open_dataset(tmp_path, engine="om") + np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + assert ds2.attrs["description"] == "fsspec roundtrip test" + finally: + os.unlink(tmp_path) + + +@filter_numpy_size_warning +def test_write_dataset_local_fsspec(local_fs): + """write_dataset with a local fsspec filesystem produces a valid file.""" + ds = xr.Dataset( + {"temperature": (["lat", "lon"], np.random.rand(8, 8).astype(np.float32))}, + coords={ + "lat": np.arange(8, dtype=np.float32), + "lon": np.arange(8, dtype=np.float32), + }, + ) + with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: + tmp_path = tmp.name + try: + write_dataset(ds, tmp_path, fs=local_fs, scale_factor=100000.0) + assert os.path.exists(tmp_path) + assert os.path.getsize(tmp_path) > 0 + + ds2 = xr.open_dataset(tmp_path, engine="om") + np.testing.assert_array_almost_equal(ds2["temperature"].values, ds["temperature"].values, decimal=4) + finally: + os.unlink(tmp_path) + + +@filter_numpy_size_warning +def test_write_dataset_fs_none_backward_compatible(): + """Passing fs=None behaves identically to the default (local path).""" + ds = xr.Dataset( + {"data": (["x"], np.arange(5, dtype=np.float32))}, + ) + with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: + tmp_path = tmp.name + try: + write_dataset(ds, tmp_path, fs=None) + ds2 = xr.open_dataset(tmp_path, engine="om") + np.testing.assert_array_equal(ds2["data"].values, ds["data"].values) + finally: + os.unlink(tmp_path) + + +@filter_numpy_size_warning +def test_write_and_read_dataset_fsspec_roundtrip(memory_fs): + """Full fsspec roundtrip: write_dataset via fs, read back via fsspec file object.""" + temperature_data = np.random.rand(5, 5).astype(np.float32) + ds = xr.Dataset( + {"temperature": (["lat", "lon"], temperature_data)}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "full fsspec roundtrip"}, + ) + path = "fsspec_full_roundtrip.om" + write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0) + + # Read back via fsspec.core.OpenFile, which xr.open_dataset supports + backend = fsspec.core.OpenFile(memory_fs, path, mode="rb") + ds2 = xr.open_dataset(backend, engine="om") + np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + assert ds2.attrs["description"] == "full fsspec roundtrip" diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 85f66fd..273e11c 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -3,6 +3,7 @@ import pytest import xarray as xr from omfiles import OmFileReader, OmFileWriter +from omfiles.xarray import write_dataset from xarray.core import indexing from .test_utils import create_test_om_file, filter_numpy_size_warning @@ -150,3 +151,311 @@ def test_xarray_hierarchical_file(empty_temp_om_file): mean_temp = ds["temperature"].mean(dim="TIME") assert mean_temp.shape == (5, 5, 5) assert mean_temp.dims == ("LATITUDE", "LONGITUDE", "ALTITUDE") + + +@filter_numpy_size_warning +def test_write_dataset_basic_roundtrip(empty_temp_om_file): + ds = xr.Dataset( + {"temperature": (["lat", "lon"], np.random.rand(5, 5).astype(np.float32))}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "Test dataset"}, + ) + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, ds["temperature"].values, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + assert ds2.attrs["description"] == "Test dataset" + + +@filter_numpy_size_warning +def test_write_dataset_hierarchical_roundtrip(empty_temp_om_file): + """Mirrors test_xarray_hierarchical_file but uses write_dataset.""" + temperature_data = np.random.rand(5, 5, 5, 10).astype(np.float32) + precipitation_data = np.random.rand(5, 5, 10).astype(np.float32) + + ds = xr.Dataset( + { + "temperature": ( + ["LATITUDE", "LONGITUDE", "ALTITUDE", "TIME"], + temperature_data, + {"units": "celsius", "description": "Surface temperature"}, + ), + "precipitation": ( + ["LATITUDE", "LONGITUDE", "TIME"], + precipitation_data, + {"units": "mm", "description": "Precipitation"}, + ), + }, + coords={ + "LATITUDE": np.arange(5, dtype=np.float32), + "LONGITUDE": np.arange(5, dtype=np.float32), + "ALTITUDE": np.arange(5, dtype=np.float32), + "TIME": np.arange(10, dtype=np.float32), + }, + attrs={"description": "This is a hierarchical OM File"}, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + assert ds2.attrs["description"] == "This is a hierarchical OM File" + assert set(ds2.data_vars) == {"temperature", "precipitation"} + + np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) + assert ds2["temperature"].dims == ("LATITUDE", "LONGITUDE", "ALTITUDE", "TIME") + assert ds2["temperature"].attrs["units"] == "celsius" + assert ds2["temperature"].attrs["description"] == "Surface temperature" + + np.testing.assert_array_almost_equal(ds2["precipitation"].values, precipitation_data, decimal=4) + assert ds2["precipitation"].dims == ("LATITUDE", "LONGITUDE", "TIME") + assert ds2["precipitation"].attrs["units"] == "mm" + + assert ds2["LATITUDE"].dims == ("LATITUDE",) + assert ds2["LONGITUDE"].dims == ("LONGITUDE",) + assert ds2["ALTITUDE"].dims == ("ALTITUDE",) + assert ds2["TIME"].dims == ("TIME",) + + +@filter_numpy_size_warning +def test_write_dataset_per_variable_encoding(empty_temp_om_file): + ds = xr.Dataset( + { + "high_res": (["x", "y"], np.random.rand(10, 10).astype(np.float32)), + "low_res": (["x", "y"], np.random.rand(10, 10).astype(np.float32)), + }, + coords={ + "x": np.arange(10, dtype=np.float32), + "y": np.arange(10, dtype=np.float32), + }, + ) + + write_dataset( + ds, + empty_temp_om_file, + scale_factor=1000.0, + encoding={ + "high_res": {"scale_factor": 100000.0, "chunks": [5, 5]}, + "low_res": {"chunks": [10, 10]}, + }, + ) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["high_res"].values, ds["high_res"].values, decimal=4) + np.testing.assert_array_almost_equal(ds2["low_res"].values, ds["low_res"].values, decimal=2) + + +@filter_numpy_size_warning +@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32, np.uint64]) +def test_write_dataset_integer_dtypes(dtype, empty_temp_om_file): + data = np.arange(25, dtype=dtype).reshape(5, 5) + ds = xr.Dataset({"values": (["x", "y"], data)}) + + write_dataset(ds, empty_temp_om_file) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_equal(ds2["values"].values, data) + assert ds2["values"].dtype == dtype + + +@filter_numpy_size_warning +def test_write_dataset_unsupported_attrs_warning(empty_temp_om_file): + ds = xr.Dataset( + {"data": (["x"], np.arange(5, dtype=np.float32))}, + attrs={"valid": "hello", "invalid": [1, 2, 3]}, + ) + + with pytest.warns(UserWarning, match="Skipping attribute"): + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + assert ds2.attrs["valid"] == "hello" + assert "invalid" not in ds2.attrs + + +def test_write_dataset_datetime_raises(empty_temp_om_file): + time_values = np.array( + ["2020-01-01", "2020-01-02", "2020-01-03", "2020-01-04", "2020-01-05"], dtype="datetime64[ns]" + ) + ds = xr.Dataset( + {"data": (["time"], np.arange(5, dtype=np.float32))}, + coords={"time": time_values}, + ) + + with pytest.raises(TypeError, match="datetime64"): + write_dataset(ds, empty_temp_om_file) + + +@filter_numpy_size_warning +def test_write_dataset_scalar_coordinate(empty_temp_om_file): + """Writing a dataset with a scalar (0-d) coordinate should not segfault.""" + temperature_data = np.random.rand(5, 5).astype(np.float32) + ds = xr.Dataset( + {"temperature": (["lat", "lon"], temperature_data)}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + "time": np.float32(42.0), + }, + ) + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + loaded = xr.open_dataset(empty_temp_om_file, engine="om") + + assert "time" in loaded.coords + assert "time" not in loaded.data_vars + assert loaded.coords["time"].ndim == 0 + np.testing.assert_almost_equal(float(loaded.coords["time"]), 42.0) + + np.testing.assert_array_almost_equal(loaded["temperature"].values, temperature_data, decimal=4) + np.testing.assert_array_equal(loaded.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(loaded.coords["lon"].values, ds.coords["lon"].values) + + +@filter_numpy_size_warning +def test_write_dataset_non_dimension_coordinate(empty_temp_om_file): + """Non-dimension coordinates should preserve their dimensions and coordinate status.""" + valid_time_data = np.arange(6, dtype=np.float32) + ds = xr.Dataset( + {"t2m": (("step", "lat"), np.zeros((6, 10), dtype=np.float32))}, + coords={"valid_time": ("step", valid_time_data)}, + ) + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + loaded = xr.open_dataset(empty_temp_om_file, engine="om") + + assert loaded["valid_time"].dims == ("step",) + assert "valid_time" in loaded.coords + assert "valid_time" not in loaded.data_vars + np.testing.assert_array_equal(loaded["valid_time"].values, valid_time_data) + + +@filter_numpy_size_warning +def test_write_dataset_dask_roundtrip(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.random.rand(10, 20).astype(np.float32) + dask_data = da.from_array(np_data, chunks=(5, 10)) + + ds = xr.Dataset( + {"temperature": (["lat", "lon"], dask_data)}, + coords={ + "lat": np.arange(10, dtype=np.float32), + "lon": np.arange(20, dtype=np.float32), + }, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, np_data, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + + +@filter_numpy_size_warning +def test_write_dataset_dask_mixed_variables(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_temp = np.random.rand(10, 20).astype(np.float32) + dask_temp = da.from_array(np_temp, chunks=(5, 10)) + np_precip = np.random.rand(10, 20).astype(np.float32) + + ds = xr.Dataset( + { + "temperature": (["lat", "lon"], dask_temp), + "precipitation": (["lat", "lon"], np_precip), + }, + coords={ + "lat": np.arange(10, dtype=np.float32), + "lon": np.arange(20, dtype=np.float32), + }, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, np_temp, decimal=4) + np.testing.assert_array_almost_equal(ds2["precipitation"].values, np_precip, decimal=4) + + +@filter_numpy_size_warning +def test_write_dataset_dask_boundary_chunks(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.arange(91, dtype=np.float32).reshape(7, 13) + dask_data = da.from_array(np_data, chunks=(4, 5)) + + ds = xr.Dataset({"data": (["x", "y"], dask_data)}) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["data"].values, np_data, decimal=4) + + +@filter_numpy_size_warning +def test_write_dataset_dask_with_attributes(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.random.rand(5, 5).astype(np.float32) + dask_data = da.from_array(np_data, chunks=(5, 5)) + + ds = xr.Dataset( + {"temp": (["x", "y"], dask_data, {"units": "K", "long_name": "temperature"})}, + attrs={"source": "test"}, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temp"].values, np_data, decimal=4) + assert ds2["temp"].attrs["units"] == "K" + assert ds2["temp"].attrs["long_name"] == "temperature" + assert ds2.attrs["source"] == "test" + + +@filter_numpy_size_warning +@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32]) +def test_write_dataset_dask_integer_dtypes(dtype, empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.arange(25, dtype=dtype).reshape(5, 5) + dask_data = da.from_array(np_data, chunks=(5, 5)) + + ds = xr.Dataset({"values": (["x", "y"], dask_data)}) + + write_dataset(ds, empty_temp_om_file) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_equal(ds2["values"].values, np_data) + assert ds2["values"].dtype == dtype + + +@filter_numpy_size_warning +def test_write_dataset_dask_larger_chunks_than_om(empty_temp_om_file): + """Dask blocks larger than OM chunks with explicit smaller OM chunk sizes.""" + da = pytest.importorskip("dask.array") + + np_data = np.random.rand(10, 20).astype(np.float32) + dask_data = da.from_array(np_data, chunks=(10, 20)) + + ds = xr.Dataset( + {"temperature": (["lat", "lon"], dask_data)}, + coords={ + "lat": np.arange(10, dtype=np.float32), + "lon": np.arange(20, dtype=np.float32), + }, + ) + + write_dataset( + ds, + empty_temp_om_file, + chunks={"lat": 5, "lon": 10}, + scale_factor=100000.0, + ) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, np_data, decimal=4)