Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions ocf_data_sampler/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def select_time_slice_nwp(
if dropout_timedeltas is None:
dropout_timedeltas = []

if len(dropout_timedeltas)>0:
if len(dropout_timedeltas) > 0:
if not all(t < pd.Timedelta(0) for t in dropout_timedeltas):
raise ValueError("dropout timedeltas must be negative")
if len(dropout_timedeltas) < 1:
Expand All @@ -75,31 +75,61 @@ def select_time_slice_nwp(
else:
t0_available = t0

# Get the available and relevant init-times
# Find the window of all possible `init_time`s whose forecast horizons could cover the
# start of the target period. This correctly handles the case where the requested time range
# does not contain an `init_time` itself, but is covered by a forecast from a
# previous `init_time`.
#
# For example, if the last NWP init_time was 12:00 with a 36-hour forecast, and we
# request data for 14:00-18:00, this logic will correctly identify the 12:00 init_time
# as a valid source.
t_min = target_times[0] - da.step.values[-1]
init_times = da.init_time_utc.values
available_init_times = init_times[(t_min<=init_times) & (init_times<=t0_available)]

# Find the most recent available init-times for all target-times
selected_init_times = np.array(
[available_init_times[available_init_times<=t][-1] for t in target_times],
)
available_init_times = init_times[(t_min <= init_times) & (init_times <= t0_available)]

# Check if there are any available init times
if len(available_init_times) == 0:
max_step = da.step.values[-1]
raise ValueError(
f"Cannot get NWP data for target time {target_times[0]}. "
f"The latest available init_time is {init_times[-1]}, but an init_time of at least "
f"{t_min} is required to cover this target time (given a "
f"maximum forecast horizon of {max_step}).",
)

# Use numpy.searchsorted to find the index of the most recent available init-time for each
# target-time. `side="right"` ensures that if a target_time is identical to an init_time,
# we get the index of the init_time itself. Subtracting 1 then gives us the index of the
# latest init_time that is less than or equal to the target_time.
indices = np.searchsorted(available_init_times, target_times.values, side="right")
selected_indices = indices - 1

# Check for indices less than 0, which indicate a target_time was before the first available
if np.any(selected_indices < 0):
first_bad_idx = np.where(selected_indices < 0)[0][0]
first_bad_target = target_times[first_bad_idx]
raise ValueError(
f"Target time {first_bad_target} is before the first available init time"
f" {available_init_times[0]}.",
)

selected_init_times = available_init_times[selected_indices]

# Find the required steps for all target-times
steps = target_times - selected_init_times

# If we are only selecting from one init-time we can construct the slice so its faster
if len(np.unique(selected_init_times))==1:
if len(np.unique(selected_init_times)) == 1:
da_sel = da.sel(init_time_utc=selected_init_times[0], step=slice(steps[0], steps[-1]))

# If we are selecting from multiple init times this more complex and slower
else:
# We want one timestep for each target_time_hourly (obviously!) If we simply do
# We want one timestep for each target_time. If we simply do
# nwp.sel(init_time=init_times, step=steps) then we'll get the *product* of
# init_times and steps, which is not what we want! Instead, we use xarray's
# vectorised-indexing mode via using a DataArray indexer. See the last example here:
# vectorised-indexing mode via using a DataArray indexer. See the last example here:
# https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
coords = {"step": steps}
coords = {"time_utc": target_times}
init_time_indexer = xr.DataArray(selected_init_times, coords=coords)
step_indexer = xr.DataArray(steps, coords=coords)
da_sel = da.sel(init_time_utc=init_time_indexer, step=step_indexer)
Expand Down
102 changes: 102 additions & 0 deletions tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pandas as pd
import pytest
import xarray as xr

from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
from tests.conftest import NWP_FREQ
Expand Down Expand Up @@ -143,3 +145,103 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours):
[t if t < t0_delayed else t0_delayed for t in expected_target_times],
).floor(NWP_FREQ)
assert (expected_init_times == da_slice.init_time_utc.values).all()


@pytest.fixture
def nwp_data_array_non_hourly():
"""NWP data array with non-hourly init times"""
# Create a dummy NWP data array
data = np.random.rand(2, 3, 2, 4, 5).astype(np.float32)
init_time = pd.to_datetime(["2023-01-01 12:30", "2023-01-01 13:30"])
step = pd.to_timedelta([0, 0.5, 1], unit="h")
channel = ["t", "dswrf"]
x = np.arange(4)
y = np.arange(5)
da = xr.DataArray(
data,
coords=[init_time, step, channel, x, y],
dims=["init_time_utc", "step", "channel", "x_osgb", "y_osgb"],
)
return da


def test_select_time_slice_nwp_raises_error_for_early_target_time(nwp_data_array_non_hourly):
"""Test that select_time_slice_nwp raises a descriptive error for a too-early target time."""
with pytest.raises(ValueError, match=r"Target time .* is before the first available init time"):
select_time_slice_nwp(
da=nwp_data_array_non_hourly,
t0=pd.Timestamp("2023-01-01 13:00"),
interval_start=pd.Timedelta("-1h"),
interval_end=pd.Timedelta("0h"),
time_resolution=pd.Timedelta("1h"),
)


def test_select_time_slice_nwp_success(nwp_data_array_non_hourly):
"""Test that select_time_slice_nwp works for a valid t0"""
da_sel = select_time_slice_nwp(
da=nwp_data_array_non_hourly,
t0=pd.Timestamp("2023-01-01 14:00"),
interval_start=pd.Timedelta("-1h"),
interval_end=pd.Timedelta("0h"),
time_resolution=pd.Timedelta("1h"),
)

# Expected target times are 13:00 and 14:00.
# For 13:00, init_time is 12:30 (step 30 mins).
# For 14:00, init_time is 13:30 (step 30 mins).
expected_steps = pd.TimedeltaIndex(["30min", "30min"])
expected_init_times = pd.to_datetime(["2023-01-01 12:30", "2023-01-01 13:30"])
assert da_sel.dims == ("time_utc", "channel", "x_osgb", "y_osgb")
assert len(da_sel.step) == 2
assert np.all(da_sel.step.values == expected_steps.values)
assert np.all(da_sel.init_time_utc.values == expected_init_times.values)


def test_select_time_slice_nwp_no_available_init_times(nwp_data_array_non_hourly):
"""Test that select_time_slice_nwp raises error if no init times can cover the target period"""
with pytest.raises(ValueError, match="Cannot get NWP data for target time"):
select_time_slice_nwp(
da=nwp_data_array_non_hourly,
t0=pd.Timestamp(
"2023-01-01 10:00",
), # t0 is very early, no init_time can cover it
interval_start=pd.Timedelta("0h"),
interval_end=pd.Timedelta("1h"),
time_resolution=pd.Timedelta("1h"),
)


def test_select_time_slice_nwp_handles_extended_forecast_coverage(da_nwp_like):
"""
Test that select_time_slice_nwp correctly selects an init_time that is much
earlier than the requested time range, but whose forecast covers the range.
"""
# Request a time range (14:00-15:00) that does not contain an init_time itself.
# The latest init_time before this range is 12:00.
t0 = pd.Timestamp("2024-01-02 13:00")
start_interval = pd.Timedelta("1h") # Start at 14:00
end_interval = pd.Timedelta("2h") # End at 15:00
freq = pd.Timedelta("1h")

# Make the selection
da_slice = select_time_slice_nwp(
da_nwp_like,
t0,
time_resolution=freq,
interval_start=start_interval,
interval_end=end_interval,
)

# Check that the init_time for both target times (14:00, 15:00) is 12:00
expected_init_time = pd.Timestamp("2024-01-02 12:00")
assert (da_slice.init_time_utc == expected_init_time).all()

# Check that the steps are correct (2h and 3h from 12:00)
expected_steps = pd.to_timedelta(["2h", "3h"])
assert (da_slice.step.values == expected_steps).all()

# Verify the valid times are correct
expected_valid_times = pd.to_datetime(["2024-01-02 14:00", "2024-01-02 15:00"])
valid_times = da_slice.init_time_utc + da_slice.step
assert (valid_times.values == expected_valid_times).all()
Loading