Skip to content

Commit 5547c46

Browse files
committed
Improvements to reindex
1 parent 97fcfd4 commit 5547c46

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

flox/core.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,23 @@ def rechunk_for_blockwise(array, axis, labels):
323323
return array.rechunk({axis: newchunks})
324324

325325

326-
def reindex_(array: np.ndarray, from_, to, fill_value=None, axis: int = -1) -> np.ndarray:
326+
def reindex_multiple(array, groups, expected_groups, fill_value, promote=False):
327+
for ax, (group, expect) in enumerate(zip(groups, expected_groups)):
328+
array = reindex_(
329+
array, group, expect, fill_value=fill_value, axis=ax - len(groups), promote=promote
330+
)
331+
return array
332+
327333

328-
assert isinstance(to, pd.Index)
329-
assert axis in (0, -1)
334+
def reindex_(
335+
array: np.ndarray, from_, to, fill_value=None, axis: int = -1, promote: bool = False
336+
) -> np.ndarray:
337+
338+
if not isinstance(to, pd.Index):
339+
if promote:
340+
to = pd.Index(to)
341+
else:
342+
raise ValueError("reindex requires a pandas.Index or promote=True")
330343

331344
if to.ndim > 1:
332345
raise ValueError(f"Cannot reindex to a multidimensional array: {to}")
@@ -353,15 +366,12 @@ def reindex_(array: np.ndarray, from_, to, fill_value=None, axis: int = -1) -> n
353366
if any(idx == -1):
354367
if fill_value is None:
355368
raise ValueError("Filling is required. fill_value cannot be None.")
356-
if axis == 0:
357-
loc = (idx == -1, ...)
358-
else:
359-
loc = (..., idx == -1)
369+
indexer[axis] = idx == -1
360370
# This allows us to match xarray's type promotion rules
361371
if fill_value is xrdtypes.NA or np.isnan(fill_value):
362372
new_dtype, fill_value = xrdtypes.maybe_promote(reindexed.dtype)
363373
reindexed = reindexed.astype(new_dtype, copy=False)
364-
reindexed[loc] = fill_value
374+
reindexed[tuple(indexer)] = fill_value
365375
return reindexed
366376

367377

tests/test_core.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,24 @@ def test_groupby_all_nan_blocks(engine):
487487
assert_equal(actual, expected)
488488

489489

490-
def test_reindex():
491-
array = np.array([1, 2])
492-
groups = np.array(["a", "b"])
493-
expected_groups = ["a", "b", "c"]
490+
@pytest.mark.parametrize("axis", (0, 1, 2, -1))
491+
def test_reindex(axis):
492+
shape = [2, 2, 2]
494493
fill_value = 0
495-
result = reindex_(array, groups, pd.Index(expected_groups), fill_value, axis=-1)
496-
assert_equal(result, np.array([1, 2, 0]))
494+
495+
array = np.broadcast_to(np.array([1, 2]), shape)
496+
groups = np.array(["a", "b"])
497+
expected_groups = pd.Index(["a", "b", "c"])
498+
actual = reindex_(array, groups, expected_groups, fill_value=fill_value, axis=axis)
499+
500+
if axis < 0:
501+
axis = array.ndim + axis
502+
result_shape = tuple(len(expected_groups) if ax == axis else s for ax, s in enumerate(shape))
503+
slicer = tuple(slice(None, s) for s in shape)
504+
expected = np.full(result_shape, fill_value)
505+
expected[slicer] = array
506+
507+
assert_equal(actual, expected)
497508

498509

499510
@pytest.mark.xfail

0 commit comments

Comments
 (0)