|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | +import xarray as xr |
| 4 | + |
| 5 | +import xbatcher # noqa: F401 |
| 6 | +from xbatcher import BatchGenerator |
| 7 | + |
| 8 | + |
| 9 | +@pytest.fixture(scope='module') |
| 10 | +def sample_ds_3d(): |
| 11 | + shape = (10, 50, 100) |
| 12 | + ds = xr.Dataset( |
| 13 | + { |
| 14 | + 'foo': (['time', 'y', 'x'], np.random.rand(*shape)), |
| 15 | + 'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape)), |
| 16 | + }, |
| 17 | + { |
| 18 | + 'x': (['x'], np.arange(shape[-1])), |
| 19 | + 'y': (['y'], np.arange(shape[-2])), |
| 20 | + }, |
| 21 | + ) |
| 22 | + return ds |
| 23 | + |
| 24 | + |
| 25 | +def test_batch_accessor_ds(sample_ds_3d): |
| 26 | + bg_class = BatchGenerator(sample_ds_3d, input_dims={'x': 5}) |
| 27 | + bg_acc = sample_ds_3d.batch.generator(input_dims={'x': 5}) |
| 28 | + assert isinstance(bg_acc, BatchGenerator) |
| 29 | + for batch_class, batch_acc in zip(bg_class, bg_acc): |
| 30 | + assert isinstance(batch_acc, xr.Dataset) |
| 31 | + assert batch_class.equals(batch_acc) |
| 32 | + |
| 33 | + |
| 34 | +def test_batch_accessor_da(sample_ds_3d): |
| 35 | + sample_da = sample_ds_3d['foo'] |
| 36 | + bg_class = BatchGenerator(sample_da, input_dims={'x': 5}) |
| 37 | + bg_acc = sample_da.batch.generator(input_dims={'x': 5}) |
| 38 | + assert isinstance(bg_acc, BatchGenerator) |
| 39 | + for batch_class, batch_acc in zip(bg_class, bg_acc): |
| 40 | + assert batch_class.equals(batch_acc) |
0 commit comments