33import itertools
44import json
55import warnings
6- from collections .abc import Hashable , Iterator , Sequence
6+ from collections .abc import Callable , Hashable , Iterator , Sequence
77from operator import itemgetter
8- from typing import Any , Callable , Optional , Union
8+ from typing import Any
99
1010import numpy as np
1111import xarray as xr
@@ -55,10 +55,10 @@ class BatchSchema:
5555
5656 def __init__ (
5757 self ,
58- ds : Union [ xr .Dataset , xr .DataArray ] ,
58+ ds : xr .Dataset | xr .DataArray ,
5959 input_dims : dict [Hashable , int ],
60- input_overlap : Optional [ dict [Hashable , int ]] = None ,
61- batch_dims : Optional [ dict [Hashable , int ]] = None ,
60+ input_overlap : dict [Hashable , int ] | None = None ,
61+ batch_dims : dict [Hashable , int ] | None = None ,
6262 concat_input_bins : bool = True ,
6363 preload_batch : bool = True ,
6464 ):
@@ -91,9 +91,7 @@ def __init__(
9191 )
9292 self .selectors : BatchSelectorSet = self ._gen_batch_selectors (ds )
9393
94- def _gen_batch_selectors (
95- self , ds : Union [xr .DataArray , xr .Dataset ]
96- ) -> BatchSelectorSet :
94+ def _gen_batch_selectors (self , ds : xr .DataArray | xr .Dataset ) -> BatchSelectorSet :
9795 """
9896 Create batch selectors dict, which can be used to create a batch
9997 from an Xarray data object.
@@ -106,9 +104,7 @@ def _gen_batch_selectors(
106104 else : # Each patch gets its own batch
107105 return {ind : [value ] for ind , value in enumerate (patch_selectors )}
108106
109- def _gen_patch_selectors (
110- self , ds : Union [xr .DataArray , xr .Dataset ]
111- ) -> PatchGenerator :
107+ def _gen_patch_selectors (self , ds : xr .DataArray | xr .Dataset ) -> PatchGenerator :
112108 """
113109 Create an iterator that can be used to index an Xarray Dataset/DataArray.
114110 """
@@ -127,7 +123,7 @@ def _gen_patch_selectors(
127123 return all_slices
128124
129125 def _combine_patches_into_batch (
130- self , ds : Union [ xr .DataArray , xr .Dataset ] , patch_selectors : PatchGenerator
126+ self , ds : xr .DataArray | xr .Dataset , patch_selectors : PatchGenerator
131127 ) -> BatchSelectorSet :
132128 """
133129 Combine the patch selectors to form a batch
@@ -169,7 +165,7 @@ def _combine_patches_grouped_by_batch_dims(
169165 return dict (enumerate (batch_selectors ))
170166
171167 def _combine_patches_grouped_by_input_and_batch_dims (
172- self , ds : Union [ xr .DataArray , xr .Dataset ] , patch_selectors : PatchGenerator
168+ self , ds : xr .DataArray | xr .Dataset , patch_selectors : PatchGenerator
173169 ) -> BatchSelectorSet :
174170 """
175171 Combine patches with multiple slices along ``batch_dims`` grouped into
@@ -197,7 +193,7 @@ def _gen_empty_batch_selectors(self) -> BatchSelectorSet:
197193 n_batches = np .prod (list (self ._n_batches_per_dim .values ()))
198194 return {k : [] for k in range (n_batches )}
199195
200- def _gen_patch_numbers (self , ds : Union [ xr .DataArray , xr .Dataset ] ):
196+ def _gen_patch_numbers (self , ds : xr .DataArray | xr .Dataset ):
201197 """
202198 Calculate the number of patches per dimension and the number of patches
203199 in each batch per dimension.
@@ -214,7 +210,7 @@ def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]):
214210 for dim , length in self ._all_sliced_dims .items ()
215211 }
216212
217- def _gen_batch_numbers (self , ds : Union [ xr .DataArray , xr .Dataset ] ):
213+ def _gen_batch_numbers (self , ds : xr .DataArray | xr .Dataset ):
218214 """
219215 Calculate the number of batches per dimension
220216 """
@@ -324,7 +320,7 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[sli
324320
325321
326322def _iterate_through_dimensions (
327- ds : Union [ xr .Dataset , xr .DataArray ] ,
323+ ds : xr .Dataset | xr .DataArray ,
328324 * ,
329325 dims : dict [Hashable , int ],
330326 overlap : dict [Hashable , int ] = {},
@@ -350,10 +346,10 @@ def _iterate_through_dimensions(
350346
351347
352348def _drop_input_dims (
353- ds : Union [ xr .Dataset , xr .DataArray ] ,
349+ ds : xr .Dataset | xr .DataArray ,
354350 input_dims : dict [Hashable , int ],
355351 suffix : str = '_input' ,
356- ) -> Union [ xr .Dataset , xr .DataArray ] :
352+ ) -> xr .Dataset | xr .DataArray :
357353 # remove input_dims coordinates from datasets, rename the dimensions
358354 # then put intput_dims back in as coordinates
359355 out = ds .copy ()
@@ -368,9 +364,9 @@ def _drop_input_dims(
368364
369365
370366def _maybe_stack_batch_dims (
371- ds : Union [ xr .Dataset , xr .DataArray ] ,
367+ ds : xr .Dataset | xr .DataArray ,
372368 input_dims : Sequence [Hashable ],
373- ) -> Union [ xr .Dataset , xr .DataArray ] :
369+ ) -> xr .Dataset | xr .DataArray :
374370 batch_dims = [d for d in ds .sizes if d not in input_dims ]
375371 if len (batch_dims ) < 2 :
376372 return ds
@@ -424,14 +420,14 @@ class BatchGenerator:
424420
425421 def __init__ (
426422 self ,
427- ds : Union [ xr .Dataset , xr .DataArray ] ,
423+ ds : xr .Dataset | xr .DataArray ,
428424 input_dims : dict [Hashable , int ],
429425 input_overlap : dict [Hashable , int ] = {},
430426 batch_dims : dict [Hashable , int ] = {},
431427 concat_input_dims : bool = False ,
432428 preload_batch : bool = True ,
433- cache : Optional [ dict [str , Any ]] = None ,
434- cache_preprocess : Optional [ Callable ] = None ,
429+ cache : dict [str , Any ] | None = None ,
430+ cache_preprocess : Callable | None = None ,
435431 ):
436432 self .ds = ds
437433 self .cache = cache
@@ -466,14 +462,14 @@ def concat_input_dims(self):
466462 def preload_batch (self ):
467463 return self ._batch_selectors .preload_batch
468464
469- def __iter__ (self ) -> Iterator [Union [ xr .DataArray , xr .Dataset ] ]:
465+ def __iter__ (self ) -> Iterator [xr .DataArray | xr .Dataset ]:
470466 for idx in self ._batch_selectors .selectors :
471467 yield self [idx ]
472468
473469 def __len__ (self ) -> int :
474470 return len (self ._batch_selectors .selectors )
475471
476- def __getitem__ (self , idx : int ) -> Union [ xr .Dataset , xr .DataArray ] :
472+ def __getitem__ (self , idx : int ) -> xr .Dataset | xr .DataArray :
477473 if not isinstance (idx , int ):
478474 raise NotImplementedError (
479475 f'{ type (self ).__name__ } .__getitem__ currently requires a single integer key'
@@ -532,7 +528,7 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
532528 def _batch_in_cache (self , idx : int ) -> bool :
533529 return self .cache is not None and f'{ idx } /.zgroup' in self .cache
534530
535- def _cache_batch (self , idx : int , batch : Union [ xr .Dataset , xr .DataArray ] ) -> None :
531+ def _cache_batch (self , idx : int , batch : xr .Dataset | xr .DataArray ) -> None :
536532 batch .to_zarr (self .cache , group = str (idx ), mode = 'a' )
537533
538534 def _get_cached_batch (self , idx : int ) -> xr .Dataset :
0 commit comments