Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions config/streams/eerie_gridded/eerie_atmo.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
EERIE_ATMO:
type: mesh
stream_id: 2136
loss_weight: 1.0
loss_weight: 0.8
filenames:
- /work/ab0995/a270225/data_transformation/eerie_vzarr_production/eerie_atmo_3d_gridded_instant_1971_2050.parq
source: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000']
target: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000']
- /work/ab0995/a270225/data_transformation/eerie_vzarr_production/eerie_atmo_gridded_instant_1971_2050.parq
source: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', 'v10', 'u10', 'd2m', 't2m', 'msl', 'sp', 'skt']
target: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', 'v10', 'u10', 'd2m', 't2m', 'msl', 'sp']
location_weight : cosine_latitude
sampling_mode: 'global_sparse' # loads given number of random points globally for each time step.
sample_points: 65536 # list of points to sample, if sampling_mode is not null. Format: [[lat1, lon1], [lat2, lon2], ...]
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 32
Expand Down
12 changes: 7 additions & 5 deletions config/streams/eerie_gridded/eerie_ocean_elem.yml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
EERIE_OCEAN_ELEM:
type: mesh
stream_id: 2137
loss_weight: 1.0
loss_weight: 0.1
filenames: # In case of gridded data there's no separate node and element files. Not every vertial level is used for training, so the 3D gridded file is used instead of the 3D element file.
- /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_3d_gridded_daily.parq
source: # Not every vertical level is present in gridded data.
- avg_uoe_2.5m # ocean u velocity at 2.5m depth
- avg_voe_2.5m # ocean v velocity at 2.5m depth
- avg_uoe_surf # ocean u velocity at 2.5m depth
- avg_von_surf # ocean v velocity at 2.5m depth
target:
- avg_uoe_2.5m
- avg_voe_2.5m
- avg_uoe_surf
- avg_von_surf
location_weight : cosine_latitude
sampling_mode: 'global_sparse' # loads given number of random points globally for each time step.
sample_points: 32768 # list of points to sample, if sampling_mode is not null. Format: [[lat1, lon1], [lat2, lon2], ...]
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 32
Expand Down
4 changes: 3 additions & 1 deletion config/streams/eerie_gridded/eerie_ocean_node.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
EERIE_OCEAN_NODE:
type: mesh
stream_id: 2138
loss_weight: 1.0
loss_weight: 0.1
filenames:
- /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_2d_gridded_daily.parq
source:
Expand All @@ -13,6 +13,8 @@ EERIE_OCEAN_NODE:
- avg_sos
- avg_zos
location_weight : cosine_latitude
sampling_mode: 'global_sparse' # loads given number of random points globally for each time step.
sample_points: 32768
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TimeWindowHandler,
TIndex,
)
from weathergen.train.utils import Stage

logging.getLogger("fsspec").setLevel(logging.WARNING)
logging.getLogger("fsspec.implementations.reference").setLevel(logging.WARNING)
Expand All @@ -42,6 +43,7 @@ def __init__(
tw_handler: TimeWindowHandler,
filename: Path,
stream_info: dict,
stage: Stage | None = None,
) -> None:
self.filename_source = Path(filename)
if "target_file" in stream_info:
Expand All @@ -59,7 +61,9 @@ def __init__(
self._dask_arrays_trg = {}

self.sampling_mode = stream_info.get("sampling_mode", "patch")
self.sampling_step = stream_info.get("sampling_step", 1)
self.patch_stability_window = stream_info.get("patch_stability_window", 1)
self.filler_values = stream_info.get("filler_values", [])

# Auto-enable staircase mode if window is defined and we are in patch mode
auto_use_counter = self.sampling_mode == "patch" and "patch_stability_window" in stream_info
Expand Down Expand Up @@ -98,6 +102,7 @@ def __init__(
self.lons_src = meta_src["lons"]
self.spatial_indices_src = meta_src["indices"]
self.coords_src = meta_src["coords"]
self.grid_dims_src = meta_src["grid_dims"]

# 2. Probe Target
if self.filename_target != self.filename_source:
Expand All @@ -108,21 +113,32 @@ def __init__(
self.lons_trg = meta_trg["lons"]
self.spatial_indices_trg = meta_trg["indices"]
self.coords_trg = meta_trg["coords"]
self.grid_dims_trg = meta_trg["grid_dims"]
else:
self.lats_trg = self.lats_src
self.lons_trg = self.lons_src
self.spatial_indices_trg = self.spatial_indices_src
self.coords_trg = self.coords_src
self.grid_dims_trg = self.grid_dims_src

ds_time_values = meta_src["time"]
self._len_cached = len(ds_time_values)
self._time_values_cached = ds_time_values

data_start_time = np.datetime64(ds_time_values[0], "ns")
if len(ds_time_values) > 1:
period = np.datetime64(ds_time_values[1], "ns") - data_start_time
native_period = np.datetime64(ds_time_values[1], "ns") - data_start_time
else:
period = np.timedelta64(24, "h")
native_period = np.timedelta64(24, "h")

self.native_period = native_period

if "frequency" in stream_info:
from weathergen.readers_extra.data_reader_grep import _str_to_timedelta

period = _str_to_timedelta(stream_info["frequency"])
else:
period = native_period

data_end_time = np.datetime64(ds_time_values[-1], "ns")

Expand Down Expand Up @@ -209,6 +225,13 @@ def _probe_file(self, filepath, is_source=True):
meta["indices"] = spatial_indices
meta["coords"] = np.stack([lats, lons], axis=1)

# Detect grid structure for 2D regular sampling
meta["grid_dims"] = None
lat_dims = [d for d in ds.sizes if d.lower() in ["lat", "latitude"]]
lon_dims = [d for d in ds.sizes if d.lower() in ["lon", "longitude"]]
if lat_dims and lon_dims:
meta["grid_dims"] = (ds.sizes[lat_dims[0]], ds.sizes[lon_dims[0]])

return meta
except Exception as e:
_logger.error(f"Failed to probe {filepath}: {e}")
Expand Down Expand Up @@ -258,16 +281,25 @@ def _get_persistent_time_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]:
if dtr.end < self.data_start_time or dtr.start > self.data_end_time:
return (np.array([], dtype=np.int64), dtr)

delta_start = dtr.start - self.data_start_time
start_idx = int(delta_start / self.period)
start_idx = np.searchsorted(self._time_values_cached, dtr.start, side="left")
end_idx = np.searchsorted(self._time_values_cached, dtr.end - t_epsilon, side="right") - 1

delta_end = dtr.end - self.data_start_time - t_epsilon
end_idx = int(delta_end / self.period)
stride = 1
if self.period > self.native_period:
stride = int(self.period / self.native_period)

start_idx = max(0, start_idx)
end_idx = min(len(self._time_values_cached) - 1, end_idx)
if start_idx > end_idx:
# Persistent: find last before window
last_before = start_idx - 1
if last_before >= 0:
return (np.array([last_before], dtype=np.int64), dtr)
else:
return (np.array([], dtype=np.int64), dtr)

return (np.arange(start_idx, end_idx + 1, dtype=np.int64), dtr)
# Generate indices and then subsample by stride
idxs = np.arange(start_idx, end_idx + 1, dtype=np.int64)[::stride]

return (idxs, dtr)

@override
def get_source(self, idx: TIndex) -> ReaderData:
Expand All @@ -288,6 +320,12 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
start_t, end_t = t_idxs[0], t_idxs[-1] + 1
n_steps = len(t_idxs)

stride = 1
if self.period > self.native_period:
stride = int(self.period / self.native_period)
# extend end_t to cover the final step when striding
end_t = t_idxs[-1] + stride

spatial_indices_ref = self.spatial_indices_src if is_source else self.spatial_indices_trg
coords_ref = self.coords_src if is_source else self.coords_trg
ds_ref = self.ds_source if is_source else self.ds_target
Expand Down Expand Up @@ -318,6 +356,23 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
final_disk_indices = spatial_indices_ref[indices_local]
use_contiguous_read = False

elif self.sampling_mode == "regular":
grid_dims = self.grid_dims_src if is_source else self.grid_dims_trg
if grid_dims:
h, w = grid_dims
n = self.sampling_step
rows = spatial_indices_ref // w
cols = spatial_indices_ref % w
mask = (rows % n == 0) & (cols % n == 0)
indices_local = np.where(mask)[0]
else:
total_points = len(spatial_indices_ref)
indices_local = np.arange(0, total_points, self.sampling_step)

patch_coords_base = coords_ref[indices_local]
final_disk_indices = spatial_indices_ref[indices_local]
use_contiguous_read = False

elif self.patch_size_deg:
lat_range = max(0.0, (self.roi_max_lat - self.roi_min_lat) - self.patch_size_deg)
lon_range = max(0.0, (self.roi_max_lon - self.roi_min_lon) - self.patch_size_deg)
Expand Down Expand Up @@ -374,6 +429,7 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
channel_indices,
start_t,
end_t,
stride,
n_steps,
slice(disk_start, disk_stop),
rel_indices,
Expand All @@ -385,6 +441,7 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
channel_indices,
start_t,
end_t,
stride,
n_steps,
final_disk_indices,
None,
Expand All @@ -396,9 +453,28 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
data_block[np.abs(data_block) > 1e10] = np.nan

coords_flat = np.tile(patch_coords_base, (n_steps, 1))
dt_values = self._time_values_cached[start_t:end_t]
dt_values = self._time_values_cached[start_t:end_t:stride]
dt_flat = np.repeat(dt_values, patch_coords_base.shape[0])

if data_block.size > 0:
# Check for NaNs across any channel
valid_mask = ~np.isnan(data_block).any(axis=1)

# Check for filler values across any channel
if self.filler_values:
valid_mask &= ~np.isin(data_block, self.filler_values).any(axis=1)

data_block = data_block[valid_mask]
coords_flat = coords_flat[valid_mask]
dt_flat = dt_flat[valid_mask]

if data_block.size == 0:
_logger.warning(
f"[Stream {self._stream_info.get('name')}] "
"All points were filtered out (NaNs or filler values). Skipping."
)
return ReaderData.empty(len(channels), n_steps)

rdata = ReaderData(
coords=coords_flat,
geoinfos=np.zeros((len(data_block), 0), dtype=np.float32),
Expand All @@ -408,7 +484,7 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read
return rdata

def _load_block_from_ds(
self, ds, arr_cache, indices, start_t, end_t, n_steps, disk_indices, rel_indices
self, ds, arr_cache, indices, start_t, end_t, stride, n_steps, disk_indices, rel_indices
) -> np.typing.NDArray:
if rel_indices is not None:
num_points = len(rel_indices)
Expand Down Expand Up @@ -444,7 +520,7 @@ def _load_block_from_ds(

# 2. Slice Time (keeps memory small before we flatten)
if "time" in dims:
sliced = sliced[start_t:end_t]
sliced = sliced[start_t:end_t:stride]

# 3. Compute the block into memory
chunk = sliced.compute().astype(np.float32)
Expand Down
Loading