diff --git a/config/streams/eerie_gridded/eerie_atmo.yml b/config/streams/eerie_gridded/eerie_atmo.yml index 57354da54..3d7b41232 100644 --- a/config/streams/eerie_gridded/eerie_atmo.yml +++ b/config/streams/eerie_gridded/eerie_atmo.yml @@ -1,12 +1,14 @@ EERIE_ATMO: type: mesh stream_id: 2136 - loss_weight: 1.0 + loss_weight: 0.8 filenames: - - /work/ab0995/a270225/data_transformation/eerie_vzarr_production/eerie_atmo_3d_gridded_instant_1971_2050.parq - source: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000'] - target: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000'] + - /work/ab0995/a270225/data_transformation/eerie_vzarr_production/eerie_atmo_gridded_instant_1971_2050.parq + source: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', 'v10', 'u10', 'd2m', 't2m', 'msl', 'sp', 'skt'] + target: ['q_50', 'q_100', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', 'v10', 'u10', 'd2m', 't2m', 'msl', 'sp'] location_weight : cosine_latitude + sampling_mode: 'global_sparse' # loads given number of random points globally for each time step. + sample_points: 65536 # list of points to sample, if sampling_mode is not null. Format: [[lat1, lon1], [lat2, lon2], ...] masking_rate : 0.6 masking_rate_none : 0.05 token_size : 32 diff --git a/config/streams/eerie_gridded/eerie_ocean_elem.yml b/config/streams/eerie_gridded/eerie_ocean_elem.yml index 2880c1cd0..e48a061ff 100644 --- a/config/streams/eerie_gridded/eerie_ocean_elem.yml +++ b/config/streams/eerie_gridded/eerie_ocean_elem.yml @@ -1,16 +1,18 @@ EERIE_OCEAN_ELEM: type: mesh stream_id: 2137 - loss_weight: 1.0 + loss_weight: 0.1 filenames: # In case of gridded data there's no separate node and element files. Not every vertial level is used for training, so the 3D gridded file is used instead of the 3D element file. - /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_3d_gridded_daily.parq source: # Not every vertical level is present in gridded data. - - avg_uoe_2.5m # ocean u velocity at 2.5m depth - - avg_voe_2.5m # ocean v velocity at 2.5m depth + - avg_uoe_surf # ocean u velocity at 2.5m depth + - avg_von_surf # ocean v velocity at 2.5m depth target: - - avg_uoe_2.5m - - avg_voe_2.5m + - avg_uoe_surf + - avg_von_surf location_weight : cosine_latitude + sampling_mode: 'global_sparse' # loads given number of random points globally for each time step. + sample_points: 32768 # list of points to sample, if sampling_mode is not null. Format: [[lat1, lon1], [lat2, lon2], ...] masking_rate : 0.6 masking_rate_none : 0.05 token_size : 32 diff --git a/config/streams/eerie_gridded/eerie_ocean_node.yml b/config/streams/eerie_gridded/eerie_ocean_node.yml index 8d270dfda..a036b6723 100644 --- a/config/streams/eerie_gridded/eerie_ocean_node.yml +++ b/config/streams/eerie_gridded/eerie_ocean_node.yml @@ -1,7 +1,7 @@ EERIE_OCEAN_NODE: type: mesh stream_id: 2138 - loss_weight: 1.0 + loss_weight: 0.1 filenames: - /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_2d_gridded_daily.parq source: @@ -13,6 +13,8 @@ EERIE_OCEAN_NODE: - avg_sos - avg_zos location_weight : cosine_latitude + sampling_mode: 'global_sparse' # loads given number of random points globally for each time step. + sample_points: 32768 masking_rate : 0.6 masking_rate_none : 0.05 token_size : 32 diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py index 5d480f2f1..e955671ea 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py @@ -17,6 +17,7 @@ TimeWindowHandler, TIndex, ) +from weathergen.train.utils import Stage logging.getLogger("fsspec").setLevel(logging.WARNING) logging.getLogger("fsspec.implementations.reference").setLevel(logging.WARNING) @@ -42,6 +43,7 @@ def __init__( tw_handler: TimeWindowHandler, filename: Path, stream_info: dict, + stage: Stage | None = None, ) -> None: self.filename_source = Path(filename) if "target_file" in stream_info: @@ -59,7 +61,9 @@ def __init__( self._dask_arrays_trg = {} self.sampling_mode = stream_info.get("sampling_mode", "patch") + self.sampling_step = stream_info.get("sampling_step", 1) self.patch_stability_window = stream_info.get("patch_stability_window", 1) + self.filler_values = stream_info.get("filler_values", []) # Auto-enable staircase mode if window is defined and we are in patch mode auto_use_counter = self.sampling_mode == "patch" and "patch_stability_window" in stream_info @@ -98,6 +102,7 @@ def __init__( self.lons_src = meta_src["lons"] self.spatial_indices_src = meta_src["indices"] self.coords_src = meta_src["coords"] + self.grid_dims_src = meta_src["grid_dims"] # 2. Probe Target if self.filename_target != self.filename_source: @@ -108,11 +113,13 @@ def __init__( self.lons_trg = meta_trg["lons"] self.spatial_indices_trg = meta_trg["indices"] self.coords_trg = meta_trg["coords"] + self.grid_dims_trg = meta_trg["grid_dims"] else: self.lats_trg = self.lats_src self.lons_trg = self.lons_src self.spatial_indices_trg = self.spatial_indices_src self.coords_trg = self.coords_src + self.grid_dims_trg = self.grid_dims_src ds_time_values = meta_src["time"] self._len_cached = len(ds_time_values) @@ -120,9 +127,18 @@ def __init__( data_start_time = np.datetime64(ds_time_values[0], "ns") if len(ds_time_values) > 1: - period = np.datetime64(ds_time_values[1], "ns") - data_start_time + native_period = np.datetime64(ds_time_values[1], "ns") - data_start_time else: - period = np.timedelta64(24, "h") + native_period = np.timedelta64(24, "h") + + self.native_period = native_period + + if "frequency" in stream_info: + from weathergen.readers_extra.data_reader_grep import _str_to_timedelta + + period = _str_to_timedelta(stream_info["frequency"]) + else: + period = native_period data_end_time = np.datetime64(ds_time_values[-1], "ns") @@ -209,6 +225,13 @@ def _probe_file(self, filepath, is_source=True): meta["indices"] = spatial_indices meta["coords"] = np.stack([lats, lons], axis=1) + # Detect grid structure for 2D regular sampling + meta["grid_dims"] = None + lat_dims = [d for d in ds.sizes if d.lower() in ["lat", "latitude"]] + lon_dims = [d for d in ds.sizes if d.lower() in ["lon", "longitude"]] + if lat_dims and lon_dims: + meta["grid_dims"] = (ds.sizes[lat_dims[0]], ds.sizes[lon_dims[0]]) + return meta except Exception as e: _logger.error(f"Failed to probe {filepath}: {e}") @@ -258,16 +281,25 @@ def _get_persistent_time_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: if dtr.end < self.data_start_time or dtr.start > self.data_end_time: return (np.array([], dtype=np.int64), dtr) - delta_start = dtr.start - self.data_start_time - start_idx = int(delta_start / self.period) + start_idx = np.searchsorted(self._time_values_cached, dtr.start, side="left") + end_idx = np.searchsorted(self._time_values_cached, dtr.end - t_epsilon, side="right") - 1 - delta_end = dtr.end - self.data_start_time - t_epsilon - end_idx = int(delta_end / self.period) + stride = 1 + if self.period > self.native_period: + stride = int(self.period / self.native_period) - start_idx = max(0, start_idx) - end_idx = min(len(self._time_values_cached) - 1, end_idx) + if start_idx > end_idx: + # Persistent: find last before window + last_before = start_idx - 1 + if last_before >= 0: + return (np.array([last_before], dtype=np.int64), dtr) + else: + return (np.array([], dtype=np.int64), dtr) - return (np.arange(start_idx, end_idx + 1, dtype=np.int64), dtr) + # Generate indices and then subsample by stride + idxs = np.arange(start_idx, end_idx + 1, dtype=np.int64)[::stride] + + return (idxs, dtr) @override def get_source(self, idx: TIndex) -> ReaderData: @@ -288,6 +320,12 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read start_t, end_t = t_idxs[0], t_idxs[-1] + 1 n_steps = len(t_idxs) + stride = 1 + if self.period > self.native_period: + stride = int(self.period / self.native_period) + # extend end_t to cover the final step when striding + end_t = t_idxs[-1] + stride + spatial_indices_ref = self.spatial_indices_src if is_source else self.spatial_indices_trg coords_ref = self.coords_src if is_source else self.coords_trg ds_ref = self.ds_source if is_source else self.ds_target @@ -318,6 +356,23 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read final_disk_indices = spatial_indices_ref[indices_local] use_contiguous_read = False + elif self.sampling_mode == "regular": + grid_dims = self.grid_dims_src if is_source else self.grid_dims_trg + if grid_dims: + h, w = grid_dims + n = self.sampling_step + rows = spatial_indices_ref // w + cols = spatial_indices_ref % w + mask = (rows % n == 0) & (cols % n == 0) + indices_local = np.where(mask)[0] + else: + total_points = len(spatial_indices_ref) + indices_local = np.arange(0, total_points, self.sampling_step) + + patch_coords_base = coords_ref[indices_local] + final_disk_indices = spatial_indices_ref[indices_local] + use_contiguous_read = False + elif self.patch_size_deg: lat_range = max(0.0, (self.roi_max_lat - self.roi_min_lat) - self.patch_size_deg) lon_range = max(0.0, (self.roi_max_lon - self.roi_min_lon) - self.patch_size_deg) @@ -374,6 +429,7 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read channel_indices, start_t, end_t, + stride, n_steps, slice(disk_start, disk_stop), rel_indices, @@ -385,6 +441,7 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read channel_indices, start_t, end_t, + stride, n_steps, final_disk_indices, None, @@ -396,9 +453,28 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read data_block[np.abs(data_block) > 1e10] = np.nan coords_flat = np.tile(patch_coords_base, (n_steps, 1)) - dt_values = self._time_values_cached[start_t:end_t] + dt_values = self._time_values_cached[start_t:end_t:stride] dt_flat = np.repeat(dt_values, patch_coords_base.shape[0]) + if data_block.size > 0: + # Check for NaNs across any channel + valid_mask = ~np.isnan(data_block).any(axis=1) + + # Check for filler values across any channel + if self.filler_values: + valid_mask &= ~np.isin(data_block, self.filler_values).any(axis=1) + + data_block = data_block[valid_mask] + coords_flat = coords_flat[valid_mask] + dt_flat = dt_flat[valid_mask] + + if data_block.size == 0: + _logger.warning( + f"[Stream {self._stream_info.get('name')}] " + "All points were filtered out (NaNs or filler values). Skipping." + ) + return ReaderData.empty(len(channels), n_steps) + rdata = ReaderData( coords=coords_flat, geoinfos=np.zeros((len(data_block), 0), dtype=np.float32), @@ -408,7 +484,7 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read return rdata def _load_block_from_ds( - self, ds, arr_cache, indices, start_t, end_t, n_steps, disk_indices, rel_indices + self, ds, arr_cache, indices, start_t, end_t, stride, n_steps, disk_indices, rel_indices ) -> np.typing.NDArray: if rel_indices is not None: num_points = len(rel_indices) @@ -444,7 +520,7 @@ def _load_block_from_ds( # 2. Slice Time (keeps memory small before we flatten) if "time" in dims: - sliced = sliced[start_t:end_t] + sliced = sliced[start_t:end_t:stride] # 3. Compute the block into memory chunk = sliced.compute().astype(np.float32)