diff --git a/README.md b/README.md index 1114d50..b9c12f4 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,24 @@ async def main(): data_dict = dataset_filtered.as_dict() print(data_dict['data']) + # Or load and select in one call. + western_europe, metadata = await dclimate.select_dataset( + request={ + "dataset": "temperature_2m", + "collection": "era5", + "organization": "ecmwf", + "variant": "finalized", + }, + selection={ + # Bounds are [west, south, east, north]. + "bounds": [-12, 35, 16, 60], + "time_range": { + "start": datetime(2024, 1, 1), + "end": datetime(2024, 1, 7, 23), + }, + }, + ) + # Custom IPFS endpoints (optional) async def main_custom_ipfs(): async with dClimateClient( diff --git a/dclimate_client_py/client.py b/dclimate_client_py/client.py index 7ef61fb..093ea97 100644 --- a/dclimate_client_py/client.py +++ b/dclimate_client_py/client.py @@ -46,6 +46,8 @@ def geo_temporal_query( rectangle_kwargs: dict = None, polygon_kwargs: dict = None, multiple_points_kwargs: dict = None, + bounds=None, + bounds_options: dict = None, spatial_agg_kwargs: dict = None, temporal_agg_kwargs: dict = None, rolling_agg_kwargs: dict = None, @@ -64,9 +66,9 @@ def geo_temporal_query( Return either a numpy array of data values or a NetCDF file. - Only one of point, circle, rectangle, or polygon kwargs may be provided. Only one of - temporal or rolling aggregation kwargs may be provided, although they can be chained - with spatial aggregations if desired. + Only one of point, circle, rectangle, bounds, or polygon kwargs may be provided. Only + one of temporal or rolling aggregation kwargs may be provided, although they can be + chained with spatial aggregations if desired. Args: dataset_name (str): Name used to identify the dataset within the STAC catalog (for IPFS) @@ -82,6 +84,10 @@ def geo_temporal_query( circular query rectangle_kwargs (dict, optional): a dictionary of parameters relevant to a rectangular query + bounds (list | tuple | dict, optional): rectangular bounds in + ``[west, south, east, north]`` order, or a mapping with those keys. + bounds_options (dict, optional): optional coordinate key overrides for bounds + selections, using ``latitude_key`` and ``longitude_key``. polygon_kwargs (dict, optional): a dictionary of parameters relevant to a polygonal query multiple_points_kwargs (dict, optional): Parameters for querying multiple specific points. @@ -118,6 +124,7 @@ def geo_temporal_query( polygon_kwargs, multiple_points_kwargs, point_kwargs, + bounds, ] if kwarg_dict is not None ] @@ -175,6 +182,8 @@ def geo_temporal_query( rectangle_kwargs=rectangle_kwargs, polygon_kwargs=polygon_kwargs, multiple_points_kwargs=multiple_points_kwargs, + bounds=bounds, + bounds_options=bounds_options, spatial_agg_kwargs=spatial_agg_kwargs, temporal_agg_kwargs=temporal_agg_kwargs, rolling_agg_kwargs=rolling_agg_kwargs, diff --git a/dclimate_client_py/dclimate_client.py b/dclimate_client_py/dclimate_client.py index ec3054a..389c1e8 100644 --- a/dclimate_client_py/dclimate_client.py +++ b/dclimate_client_py/dclimate_client.py @@ -6,6 +6,8 @@ """ import typing +from collections.abc import Mapping + import requests import xarray as xr from py_hamt import KuboCAS @@ -326,6 +328,53 @@ async def load_dataset( else: return GeotemporalData(ds, dataset_name=metadata["slug"]), metadata + async def select_dataset( + self, + *, + request: typing.Mapping[str, typing.Any], + selection: typing.Mapping[str, typing.Any], + return_xarray: bool = False, + ) -> typing.Union[ + typing.Tuple[GeotemporalData, DatasetMetadata], + typing.Tuple[xr.Dataset, DatasetMetadata], + ]: + """ + Load a dClimate dataset and apply point, bounds, and/or time selections. + + Parameters + ---------- + request : Mapping[str, Any] + Keyword arguments accepted by :meth:`load_dataset`, such as + ``dataset``, ``collection``, ``variant``, ``organization``, or ``cid``. + selection : Mapping[str, Any] + Selection mapping accepted by :meth:`GeotemporalData.select`. + return_xarray : bool, optional + If True, return the raw xarray dataset without applying selections. + + Returns + ------- + Tuple[Union[GeotemporalData, xr.Dataset], DatasetMetadata] + The selected dataset plus metadata. + """ + if not isinstance(request, Mapping): + raise InvalidSelectionError("request must be a mapping.") + if not isinstance(selection, Mapping): + raise InvalidSelectionError("selection must be a mapping.") + + load_kwargs = dict(request) + load_kwargs.pop("return_xarray", None) + if load_kwargs.get("cid") and "dataset" not in load_kwargs: + load_kwargs["dataset"] = "" + + dataset_obj, metadata = await self.load_dataset( + **load_kwargs, + return_xarray=return_xarray, + ) + if not isinstance(dataset_obj, GeotemporalData): + return dataset_obj, metadata + + return dataset_obj.select(selection), metadata + def list_datasets(self) -> typing.Dict[str, typing.Dict[str, typing.Any]]: """ List all available datasets from the STAC catalog. diff --git a/dclimate_client_py/geotemporal_data.py b/dclimate_client_py/geotemporal_data.py index fc371aa..0de3e30 100644 --- a/dclimate_client_py/geotemporal_data.py +++ b/dclimate_client_py/geotemporal_data.py @@ -1,7 +1,10 @@ import datetime import functools +import math +import numbers import operator import typing +from collections.abc import Mapping import geopandas as gpd import pandas as pd @@ -17,6 +20,8 @@ # Users should not select more than this number of data points and coordinates DEFAULT_POINT_LIMIT = 40 * 40 * 50_000 +BoundsSelection = typing.Union[typing.Sequence[float], Mapping[str, typing.Any]] +GeoSelectionOptions = Mapping[str, typing.Any] class GeotemporalData: @@ -165,7 +170,12 @@ def reindex_forecast(self) -> "GeotemporalData": return self._new(self.data.reindex(time=trange)) def point( - self, latitude: float, longitude: float, snap_to_grid: bool = True + self, + latitude: float, + longitude: float, + snap_to_grid: bool = True, + latitude_key: str = "latitude", + longitude_key: str = "longitude", ) -> "GeotemporalData": """Gets a dataset corresponding to the full time series for a single point @@ -178,24 +188,22 @@ def point( snap_to_grid: bool, optional When ``True``, find nearest point to lat, lon in dataset. When ``False``, error out when exact lat, lon is not on dataset grid. + latitude_key: str, optional + Name of the latitude coordinate in the dataset. + longitude_key: str, optional + Name of the longitude coordinate in the dataset. Returns ------- GeotemporalData New dataset restricted to single point """ + selection = {latitude_key: latitude, longitude_key: longitude} if snap_to_grid: - data = self.data.sel( - latitude=latitude, longitude=longitude, method="nearest" - ) + data = self.data.sel(selection, method="nearest") else: try: - data = self.data.sel( - latitude=latitude, - longitude=longitude, - method="nearest", - tolerance=10e-5, - ) + data = self.data.sel(selection, method="nearest", tolerance=10e-5) except KeyError: raise errors.NoDataFoundError( "User requested not to snap_to_grid, but exact coord not in dataset" @@ -264,6 +272,8 @@ def rectangle( min_lon: float, max_lat: float, max_lon: float, + latitude_key: str = "latitude", + longitude_key: str = "longitude", ) -> "GeotemporalData": """Reduce dataset to points in rectangle @@ -278,17 +288,29 @@ def rectangle( Northern limit of rectangle max_lon: float Eastern limit of rectangle + latitude_key: str, optional + Name of the latitude coordinate in the dataset. + longitude_key: str, optional + Name of the longitude coordinate in the dataset. Returns ------- GeotemporalData New dataset """ + try: + latitudes = self.data[latitude_key] + longitudes = self.data[longitude_key] + except KeyError as exc: + raise errors.InvalidSelectionError( + "Latitude/longitude coordinates were not found in the dataset." + ) from exc + data = self.data.where( - (self.data.latitude >= min_lat) - & (self.data.latitude <= max_lat) - & (self.data.longitude >= min_lon) - & (self.data.longitude <= max_lon), + (latitudes >= min_lat) + & (latitudes <= max_lat) + & (longitudes >= min_lon) + & (longitudes <= max_lon), drop=True, ) return self._new(data) @@ -365,6 +387,50 @@ def time_range( data = self.data.sel(time=slice(start_time, end_time)) return self._new(data) + def select(self, selection: GeoSelectionOptions) -> "GeotemporalData": + """Apply a combined point or bounds selection and optional time range. + + The selection mapping accepts: + + - ``point``: ``{"latitude": float, "longitude": float, "options": {...}}`` + - ``bounds``: ``[west, south, east, north]`` or a mapping with those keys + - ``time_range``/``timeRange``: ``[start, end]`` or + ``{"start": start, "end": end}`` + """ + _ensure_mapping(selection, "Selection") + current = self + + point = _get_selection_value(selection, "point") + bounds = _get_selection_value(selection, "bounds") + if point is not None and bounds is not None: + raise errors.InvalidSelectionError( + "Use either point or bounds selection, not both." + ) + + if point is not None: + current = current.point(**_normalize_point_selection(point)) + + time_range = _get_selection_value(selection, "time_range", "timeRange") + if time_range is not None: + current = current.time_range(*_normalize_time_range_selection(time_range)) + + if bounds is not None: + bounds_options = _get_selection_value( + selection, "bounds_options", "boundsOptions" + ) + min_lat, min_lon, max_lat, max_lon, coordinate_options = ( + _normalize_bounds_selection(bounds, bounds_options) + ) + current = current.rectangle( + min_lat, + min_lon, + max_lat, + max_lon, + **coordinate_options, + ) + + return current + def reduce_polygon_to_point( self, polygons_mask: gpd.array.GeometryArray ) -> "GeotemporalData": @@ -586,6 +652,8 @@ def query( rectangle_kwargs: dict = None, polygon_kwargs: dict = None, multiple_points_kwargs: dict = None, + bounds: BoundsSelection = None, + bounds_options: dict = None, spatial_agg_kwargs: dict = None, temporal_agg_kwargs: dict = None, rolling_agg_kwargs: dict = None, @@ -607,6 +675,17 @@ def query( data = data.circle(lat=lat, lon=lon, radius=radius) elif rectangle_kwargs: data = data.rectangle(**rectangle_kwargs) + elif bounds is not None: + min_lat, min_lon, max_lat, max_lon, coordinate_options = ( + _normalize_bounds_selection(bounds, bounds_options) + ) + data = data.rectangle( + min_lat, + min_lon, + max_lat, + max_lon, + **coordinate_options, + ) elif polygon_kwargs: data = data.polygons(**polygon_kwargs, point_limit=point_limit) elif multiple_points_kwargs: @@ -648,6 +727,155 @@ def _new(self, data): return type(self)(data, dataset_name=self.dataset_name, data_var=self._data_var) +def _ensure_mapping(value: typing.Any, label: str) -> Mapping[str, typing.Any]: + if not isinstance(value, Mapping): + raise errors.InvalidSelectionError(f"{label} must be a mapping.") + return value + + +def _get_selection_value( + selection: Mapping[str, typing.Any], + *keys: str, +) -> typing.Any: + for key in keys: + if key in selection: + return selection[key] + return None + + +def _get_required_mapping_value( + selection: Mapping[str, typing.Any], + key: str, + label: str, +) -> typing.Any: + if key not in selection: + raise errors.InvalidSelectionError(f"{label} must include '{key}'.") + return selection[key] + + +def _coerce_finite_number(value: typing.Any, label: str) -> float: + if ( + isinstance(value, bool) + or not isinstance(value, numbers.Real) + or not math.isfinite(value) + ): + raise errors.InvalidSelectionError( + "Bounds selection must use finite west, south, east, and north numbers." + ) + return float(value) + + +def _normalize_coordinate_options( + options: typing.Optional[Mapping[str, typing.Any]], +) -> dict[str, str]: + if options is None: + return {} + _ensure_mapping(options, "Selection options") + + normalized: dict[str, str] = {} + latitude_key = _get_selection_value(options, "latitude_key", "latitudeKey") + longitude_key = _get_selection_value(options, "longitude_key", "longitudeKey") + + if latitude_key is not None: + if not isinstance(latitude_key, str): + raise errors.InvalidSelectionError("latitude_key must be a string.") + normalized["latitude_key"] = latitude_key + if longitude_key is not None: + if not isinstance(longitude_key, str): + raise errors.InvalidSelectionError("longitude_key must be a string.") + normalized["longitude_key"] = longitude_key + + return normalized + + +def _normalize_bounds_selection( + bounds: BoundsSelection, + fallback_options: typing.Optional[Mapping[str, typing.Any]] = None, +) -> tuple[float, float, float, float, dict[str, str]]: + if isinstance(bounds, Mapping): + west = _get_required_mapping_value(bounds, "west", "Bounds selection") + south = _get_required_mapping_value(bounds, "south", "Bounds selection") + east = _get_required_mapping_value(bounds, "east", "Bounds selection") + north = _get_required_mapping_value(bounds, "north", "Bounds selection") + options = bounds.get("options", fallback_options) + elif isinstance(bounds, (str, bytes)): + raise errors.InvalidSelectionError( + "Bounds selection must be [west, south, east, north]." + ) + else: + try: + west, south, east, north = bounds + except (TypeError, ValueError) as exc: + raise errors.InvalidSelectionError( + "Bounds selection must be [west, south, east, north]." + ) from exc + options = fallback_options + + west = _coerce_finite_number(west, "west") + south = _coerce_finite_number(south, "south") + east = _coerce_finite_number(east, "east") + north = _coerce_finite_number(north, "north") + + if west >= east: + raise errors.InvalidSelectionError( + f"west ({west}) must be less than east ({east})." + ) + if south >= north: + raise errors.InvalidSelectionError( + f"south ({south}) must be less than north ({north})." + ) + + return ( + south, + west, + north, + east, + _normalize_coordinate_options(options), + ) + + +def _normalize_point_selection( + point: Mapping[str, typing.Any], +) -> dict[str, typing.Any]: + point = _ensure_mapping(point, "Point selection") + latitude = _get_required_mapping_value(point, "latitude", "Point selection") + longitude = _get_required_mapping_value(point, "longitude", "Point selection") + options = point.get("options") or {} + coordinate_options = _normalize_coordinate_options(options) + + snap_to_grid = _get_selection_value(options, "snap_to_grid", "snapToGrid") + if snap_to_grid is None: + snap_to_grid = True + + return { + "latitude": latitude, + "longitude": longitude, + "snap_to_grid": snap_to_grid, + **coordinate_options, + } + + +def _normalize_time_range_selection( + time_range: typing.Any, +) -> tuple[typing.Any, typing.Any]: + if isinstance(time_range, Mapping): + return ( + _get_required_mapping_value(time_range, "start", "Time range selection"), + _get_required_mapping_value(time_range, "end", "Time range selection"), + ) + if isinstance(time_range, (str, bytes)): + raise errors.InvalidSelectionError( + "Time range selection must be [start, end] or {'start': ..., 'end': ...}." + ) + try: + start, end = time_range + except (TypeError, ValueError) as exc: + raise errors.InvalidSelectionError( + "Time range selection must be [start, end] or {'start': ..., 'end': ...}." + ) from exc + return start, end + + def _haversine( lat1: typing.Union[np.ndarray, float], lon1: typing.Union[np.ndarray, float], diff --git a/pyproject.toml b/pyproject.toml index 3e037d1..1374ac5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "pdm.backend" [project] name = "dclimate-client-py" -version = "0.5.7" # Set a static version or handle it in versioning strategy +version = "0.5.8" # Set a static version or handle it in versioning strategy description = "Python client library for accessing dClimate weather and climate data" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_geotemporal_data.py b/tests/test_geotemporal_data.py index fafe30f..6ad657f 100644 --- a/tests/test_geotemporal_data.py +++ b/tests/test_geotemporal_data.py @@ -4,6 +4,7 @@ import pytest from dclimate_client_py import dclimate_zarr_errors as errors +from dclimate_client_py.dclimate_client import dClimateClient from dclimate_client_py.geotemporal_data import GeotemporalData @@ -122,6 +123,82 @@ def test_rectangle(dataset): assert np.array_equal(data.data.latitude, (20, 30, 40)) assert np.array_equal(data.data.longitude, (190, 195, 200)) + @staticmethod + def test_rectangle_custom_coordinate_keys(dataset): + renamed_dataset = dataset.rename({"latitude": "lat", "longitude": "lon"}) + data = GeotemporalData(renamed_dataset, dataset_name="fake dataset") + + selected = data.rectangle( + 20, + 190, + 40, + 200, + latitude_key="lat", + longitude_key="lon", + ) + + assert np.array_equal(selected.data.lat, (20, 30, 40)) + assert np.array_equal(selected.data.lon, (190, 195, 200)) + + @staticmethod + def test_select_bounds_and_time_range(dataset): + data = GeotemporalData(dataset, dataset_name="fake dataset") + begin = datetime.datetime(2000, 1, 10) + end = datetime.datetime(2000, 1, 15) + + selected = data.select( + { + "bounds": [190, 20, 200, 40], + "time_range": {"start": begin, "end": end}, + } + ) + + assert np.array_equal(selected.data.latitude, (20, 30, 40)) + assert np.array_equal(selected.data.longitude, (190, 195, 200)) + assert selected.data.sizes["time"] == 6 + + @staticmethod + def test_select_object_bounds_with_coordinate_options(dataset): + renamed_dataset = dataset.rename({"latitude": "lat", "longitude": "lon"}) + data = GeotemporalData(renamed_dataset, dataset_name="fake dataset") + + selected = data.select( + { + "bounds": { + "west": 190, + "south": 20, + "east": 200, + "north": 40, + "options": { + "latitude_key": "lat", + "longitude_key": "lon", + }, + } + } + ) + + assert np.array_equal(selected.data.lat, (20, 30, 40)) + assert np.array_equal(selected.data.lon, (190, 195, 200)) + + @staticmethod + def test_select_rejects_point_and_bounds(dataset): + data = GeotemporalData(dataset, dataset_name="fake dataset") + + with pytest.raises(errors.InvalidSelectionError): + data.select( + { + "point": {"latitude": 20, "longitude": 190}, + "bounds": [190, 20, 200, 40], + } + ) + + @staticmethod + def test_select_rejects_invalid_bounds_order(dataset): + data = GeotemporalData(dataset, dataset_name="fake dataset") + + with pytest.raises(errors.InvalidSelectionError): + data.select({"bounds": [200, 20, 190, 40]}) + @staticmethod def test_time_range(dataset): data = GeotemporalData(dataset, dataset_name="fake dataset") @@ -233,3 +310,68 @@ def test_spatial_aggregation(input_ds): assert float(min_val_rep_pt.data["u100"].values[0]) == pytest.approx( -9.5386962890625 ) + + +@pytest.mark.asyncio +async def test_client_select_dataset_loads_and_selects(monkeypatch, dataset): + dclimate = dClimateClient() + loaded = GeotemporalData(dataset, dataset_name="fake dataset") + metadata = {"slug": "fake dataset", "cid": "bafy-test"} + + async def fake_load_dataset(**kwargs): + assert kwargs == { + "dataset": "temperature_2m", + "collection": "era5", + "organization": "ecmwf", + "variant": "finalized", + "return_xarray": False, + } + return loaded, metadata + + monkeypatch.setattr(dclimate, "load_dataset", fake_load_dataset) + + selected, selected_metadata = await dclimate.select_dataset( + request={ + "dataset": "temperature_2m", + "collection": "era5", + "organization": "ecmwf", + "variant": "finalized", + }, + selection={"bounds": [190, 20, 200, 40]}, + ) + + assert selected_metadata == metadata + assert np.array_equal(selected.data.latitude, (20, 30, 40)) + assert np.array_equal(selected.data.longitude, (190, 195, 200)) + + +@pytest.mark.asyncio +async def test_client_select_dataset_ignores_request_return_xarray( + monkeypatch, dataset +): + dclimate = dClimateClient() + loaded = GeotemporalData(dataset, dataset_name="fake dataset") + metadata = {"slug": "fake dataset", "cid": "bafy-test"} + + async def fake_load_dataset(**kwargs): + assert kwargs == { + "dataset": "temperature_2m", + "collection": "era5", + "return_xarray": False, + } + return loaded, metadata + + monkeypatch.setattr(dclimate, "load_dataset", fake_load_dataset) + + selected, selected_metadata = await dclimate.select_dataset( + request={ + "dataset": "temperature_2m", + "collection": "era5", + "return_xarray": True, + }, + selection={"bounds": [190, 20, 200, 40]}, + ) + + assert selected_metadata == metadata + assert np.array_equal(selected.data.latitude, (20, 30, 40)) + assert np.array_equal(selected.data.longitude, (190, 195, 200))