Skip to content

Commit 2a99eb0

Browse files
committed
Avoid hard xarray dependency
Closes #16
1 parent a13ca36 commit 2a99eb0

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

dask_groupby/aggregations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
2-
from xarray.core import dtypes, utils
2+
3+
from . import xrdtypes as dtypes, xrutils
34

45

56
def _get_fill_value(dtype, fill_value):
@@ -17,7 +18,7 @@ def _get_fill_value(dtype, fill_value):
1718

1819

1920
def _atleast_1d(inp):
20-
if utils.is_scalar(inp):
21+
if xrutils.is_scalar(inp):
2122
inp = (inp,)
2223
return inp
2324

dask_groupby/xrutils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
# defined in xarray
33

44

5-
from typing import Any
5+
from typing import Any, Iterable
66

7+
import dask.array
78
import numpy as np
9+
import pandas as pd
810

911

1012
def is_duck_array(value: Any) -> bool:
@@ -32,3 +34,47 @@ def is_dask_collection(x):
3234

3335
def is_duck_dask_array(x):
3436
return is_duck_array(x) and is_dask_collection(x)
37+
38+
39+
class ReprObject:
40+
"""Object that prints as the given value, for use with sentinel values."""
41+
42+
__slots__ = ("_value",)
43+
44+
def __init__(self, value: str):
45+
self._value = value
46+
47+
def __repr__(self) -> str:
48+
return self._value
49+
50+
def __eq__(self, other) -> bool:
51+
if isinstance(other, ReprObject):
52+
return self._value == other._value
53+
return False
54+
55+
def __hash__(self) -> int:
56+
return hash((type(self), self._value))
57+
58+
def __dask_tokenize__(self):
59+
from dask.base import normalize_token
60+
61+
return normalize_token((type(self), self._value))
62+
63+
64+
def is_scalar(value: Any, include_0d: bool = True) -> bool:
65+
"""Whether to treat a value as a scalar.
66+
67+
Any non-iterable, string, or 0-D array
68+
"""
69+
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (dask.array.Array, pd.Index)
70+
71+
if include_0d:
72+
include_0d = getattr(value, "ndim", None) == 0
73+
return (
74+
include_0d
75+
or isinstance(value, (str, bytes))
76+
or not (
77+
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
78+
or hasattr(value, "__array_function__")
79+
)
80+
)

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@ dask
22
numpy_groupies
33
netcdf4
44
toolz
5-
xarray

0 commit comments

Comments
 (0)