Skip to content

Commit f137509

Browse files
committed
feat: Add xarray.open_dataset fsspec support
1 parent b1f2664 commit f137509

2 files changed

Lines changed: 88 additions & 3 deletions

File tree

python/omfiles/xarray.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,15 @@
3232
DIMENSION_KEY = "_ARRAY_DIMENSIONS"
3333

3434

35+
def _is_remote_uri(path: str) -> bool:
36+
return "://" in path
37+
38+
3539
class OmXarrayEntrypoint(BackendEntrypoint):
3640
def guess_can_open(self, filename_or_obj):
41+
if isinstance(filename_or_obj, tuple):
42+
_, path = filename_or_obj
43+
return isinstance(path, str) and path.endswith(".om")
3744
return isinstance(filename_or_obj, str) and filename_or_obj.endswith(".om")
3845

3946
def open_dataset(
@@ -42,8 +49,19 @@ def open_dataset(
4249
*,
4350
drop_variables=None,
4451
) -> Dataset:
45-
filename_or_obj = _normalize_path(filename_or_obj)
46-
with OmFileReader(filename_or_obj) as root_variable:
52+
if isinstance(filename_or_obj, tuple):
53+
fs, path = filename_or_obj
54+
reader = OmFileReader.from_fsspec(fs, path)
55+
elif isinstance(filename_or_obj, str) and _is_remote_uri(filename_or_obj):
56+
import fsspec
57+
58+
fs, _, paths = fsspec.get_fs_token_paths(filename_or_obj)
59+
reader = OmFileReader.from_fsspec(fs, paths[0])
60+
else:
61+
filename_or_obj = _normalize_path(filename_or_obj)
62+
reader = OmFileReader(filename_or_obj)
63+
64+
with reader as root_variable:
4765
store = OmDataStore(root_variable)
4866
store_entrypoint = StoreBackendEntrypoint()
4967
ds = store_entrypoint.open_dataset(
@@ -56,7 +74,6 @@ def open_dataset(
5674
ds = ds.set_coords(coord_names)
5775
ds.attrs = {k: v for k, v in ds.attrs.items() if k != coord_attr}
5876
return ds
59-
raise ValueError("Failed to open dataset")
6077

6178
description = "Use .om files in Xarray"
6279

tests/test_fsspec.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,71 @@ def test_write_and_read_dataset_fsspec_roundtrip(memory_fs):
394394
np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values)
395395
np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values)
396396
assert ds2.attrs["description"] == "full fsspec roundtrip"
397+
398+
399+
# --- open_dataset fsspec tests ---
400+
401+
402+
@filter_numpy_size_warning
403+
def test_open_dataset_fsspec_tuple(memory_fs):
404+
"""open_dataset accepts an (fs, path) tuple to read via fsspec."""
405+
temperature_data = np.random.rand(5, 5).astype(np.float32)
406+
ds = xr.Dataset(
407+
{"temperature": (["lat", "lon"], temperature_data)},
408+
coords={
409+
"lat": np.arange(5, dtype=np.float32),
410+
"lon": np.arange(5, dtype=np.float32),
411+
},
412+
attrs={"description": "tuple test"},
413+
)
414+
path = "tuple_open_test.om"
415+
write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0)
416+
417+
ds2 = xr.open_dataset((memory_fs, path), engine="om")
418+
np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4)
419+
np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values)
420+
np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values)
421+
assert ds2.attrs["description"] == "tuple test"
422+
423+
424+
@filter_numpy_size_warning
425+
def test_open_dataset_fsspec_tuple_local(local_fs):
426+
"""open_dataset with a local fsspec (fs, path) tuple."""
427+
ds = xr.Dataset(
428+
{"temperature": (["lat", "lon"], np.random.rand(8, 8).astype(np.float32))},
429+
coords={
430+
"lat": np.arange(8, dtype=np.float32),
431+
"lon": np.arange(8, dtype=np.float32),
432+
},
433+
)
434+
with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp:
435+
tmp_path = tmp.name
436+
try:
437+
write_dataset(ds, tmp_path, fs=local_fs, scale_factor=100000.0)
438+
ds2 = xr.open_dataset((local_fs, tmp_path), engine="om")
439+
np.testing.assert_array_almost_equal(ds2["temperature"].values, ds["temperature"].values, decimal=4)
440+
finally:
441+
os.unlink(tmp_path)
442+
443+
444+
@filter_numpy_size_warning
445+
def test_open_dataset_fsspec_full_roundtrip(memory_fs):
446+
"""Full roundtrip: write_dataset with fs=, read back with open_dataset (fs, path) tuple."""
447+
temperature_data = np.random.rand(5, 5).astype(np.float32)
448+
ds = xr.Dataset(
449+
{"temperature": (["lat", "lon"], temperature_data)},
450+
coords={
451+
"lat": np.arange(5, dtype=np.float32),
452+
"lon": np.arange(5, dtype=np.float32),
453+
},
454+
attrs={"description": "full roundtrip"},
455+
)
456+
path = "full_roundtrip_open.om"
457+
write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0)
458+
459+
# Read back via (fs, path) tuple — no temp file needed
460+
ds2 = xr.open_dataset((memory_fs, path), engine="om")
461+
np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4)
462+
np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values)
463+
np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values)
464+
assert ds2.attrs["description"] == "full roundtrip"

0 commit comments

Comments
 (0)