diff --git a/tests/test_20_open_dataset.py b/tests/test_20_open_dataset.py index f1f80ce..ee4a669 100644 --- a/tests/test_20_open_dataset.py +++ b/tests/test_20_open_dataset.py @@ -1,4 +1,6 @@ +from collections.abc import Hashable from pathlib import Path +from typing import Any import pytest import xarray as xr @@ -103,3 +105,34 @@ def test_combine_coords(tmp_path: Path, index_node: str) -> None: ) assert set(ds.coords) == {"areacella", "lat", "lon", "experiment_id", "orog"} assert not ds.data_vars + + +@pytest.mark.parametrize( + "sel,expected_size", + [ + ({}, 12), + ({"time": "2019-01"}, 1), + ({"time": {"slice": ["2019-01", "2019-02"]}}, 2), + ], +) +def test_time_selection( + tmp_path: Path, + index_node: str, + sel: dict[Hashable, Any], + expected_size: int, +) -> None: + esgpull_path = tmp_path / "esgpull" + selection = { + "query": [ + '"tas_Amon_EC-Earth3-CC_ssp245_r1i1p1f1_gr_201901-201912.nc"', + ] + } + ds = xr.open_dataset( + selection, # type: ignore[arg-type] + esgpull_path=esgpull_path, + engine="esgf", + index_node=index_node, + chunks={}, + sel=sel, + ) + assert ds.sizes["time"] == expected_size diff --git a/xarray_esgf/client.py b/xarray_esgf/client.py index 05a13af..4970e37 100644 --- a/xarray_esgf/client.py +++ b/xarray_esgf/client.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Hashable, Iterable from functools import cached_property from pathlib import Path -from typing import Literal, get_args +from typing import Any, Literal, get_args import tqdm import xarray as xr @@ -117,10 +117,15 @@ def download(self) -> list[File]: def _open_datasets( self, concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None, - drop_variables: str | Iterable[str] | None = None, - download: bool = False, - show_progress: bool = True, + drop_variables: str | Iterable[str] | None, + download: bool, + show_progress: bool, + sel: dict[Hashable, Any], ) -> dict[str, Dataset]: + sel = { + k: slice(*v["slice"]) if isinstance(v, dict) else v for k, v in sel.items() + } + if isinstance(concat_dims, str): concat_dims = [concat_dims] concat_dims = concat_dims or [] @@ -139,6 +144,7 @@ def _open_datasets( drop_variables=drop_variables, storage_options={"ssl": self.verify_ssl}, ) + ds = ds.sel({k: v for k, v in sel.items() if k in ds.dims}) grouped_objects[file.dataset_id].append(ds.drop_encoding()) combined_datasets = {} @@ -173,9 +179,14 @@ def open_dataset( drop_variables: str | Iterable[str] | None = None, download: bool = False, show_progress: bool = True, + sel: dict[Hashable, Any] | None = None, ) -> Dataset: combined_datasets = self._open_datasets( - concat_dims, drop_variables, download, show_progress + concat_dims=concat_dims, + drop_variables=drop_variables, + download=download, + show_progress=show_progress, + sel=sel or {}, ) obj = xr.combine_by_coords( diff --git a/xarray_esgf/engine.py b/xarray_esgf/engine.py index a00d585..c24d1d9 100644 --- a/xarray_esgf/engine.py +++ b/xarray_esgf/engine.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable +from collections.abc import Hashable, Iterable from pathlib import Path from typing import Any @@ -22,6 +22,7 @@ def open_dataset( # type: ignore[override] concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None = None, download: bool = False, show_progress: bool = True, + sel: dict[Hashable, Any] | None = None, ) -> Dataset: client = Client( selection=filename_or_obj, @@ -36,6 +37,7 @@ def open_dataset( # type: ignore[override] drop_variables=drop_variables, download=download, show_progress=show_progress, + sel=sel, ) open_dataset_parameters = (