Skip to content

Commit 6a5991c

Browse files
committed
Cover array and aggregation functions
1 parent c0ef9a6 commit 6a5991c

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

python/datafusion/functions.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2759,6 +2759,16 @@ def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr:
27592759
... dfn.functions.array_position(dfn.col("a"), dfn.lit(20)).alias("result"))
27602760
>>> result.collect_column("result")[0].as_py()
27612761
2
2762+
2763+
Use ``index`` to start searching from a given position:
2764+
2765+
>>> df = ctx.from_pydict({"a": [[10, 20, 10, 20]]})
2766+
>>> result = df.select(
2767+
... dfn.functions.array_position(
2768+
... dfn.col("a"), dfn.lit(20), index=3,
2769+
... ).alias("result"))
2770+
>>> result.collect_column("result")[0].as_py()
2771+
4
27622772
"""
27632773
return Expr(f.array_position(array.expr, element.expr, index))
27642774

@@ -3091,6 +3101,14 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False)
30913101
>>> result = df.select(dfn.functions.array_sort(dfn.col("a")).alias("result"))
30923102
>>> result.collect_column("result")[0].as_py()
30933103
[1, 2, 3]
3104+
3105+
>>> df = ctx.from_pydict({"a": [[3, None, 1]]})
3106+
>>> result = df.select(
3107+
... dfn.functions.array_sort(
3108+
... dfn.col("a"), descending=True, null_first=True,
3109+
... ).alias("result"))
3110+
>>> result.collect_column("result")[0].as_py()
3111+
[None, 3, 1]
30943112
"""
30953113
desc = "DESC" if descending else "ASC"
30963114
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
@@ -3125,6 +3143,16 @@ def array_slice(
31253143
... dfn.lit(3)).alias("result"))
31263144
>>> result.collect_column("result")[0].as_py()
31273145
[2, 3]
3146+
3147+
Use ``stride`` to skip elements:
3148+
3149+
>>> result = df.select(
3150+
... dfn.functions.array_slice(
3151+
... dfn.col("a"), dfn.lit(1), dfn.lit(4),
3152+
... stride=dfn.lit(2),
3153+
... ).alias("result"))
3154+
>>> result.collect_column("result")[0].as_py()
3155+
[1, 3]
31283156
"""
31293157
if stride is not None:
31303158
stride = stride.expr
@@ -3396,6 +3424,15 @@ def approx_percentile_cont(
33963424
... ).alias("v")])
33973425
>>> result.collect_column("v")[0].as_py()
33983426
3.0
3427+
3428+
>>> result = df.aggregate(
3429+
... [], [dfn.functions.approx_percentile_cont(
3430+
... dfn.col("a"), 0.5,
3431+
... num_centroids=10,
3432+
... filter=dfn.col("a") > dfn.lit(1.0),
3433+
... ).alias("v")])
3434+
>>> result.collect_column("v")[0].as_py()
3435+
3.5
33993436
"""
34003437
sort_expr_raw = sort_or_default(sort_expression)
34013438
filter_raw = filter.expr if filter is not None else None
@@ -3436,6 +3473,15 @@ def approx_percentile_cont_with_weight(
34363473
... dfn.col("w"), 0.5).alias("v")])
34373474
>>> result.collect_column("v")[0].as_py()
34383475
2.0
3476+
3477+
>>> result = df.aggregate(
3478+
... [], [dfn.functions.approx_percentile_cont_with_weight(
3479+
... dfn.col("a"), dfn.col("w"), 0.5,
3480+
... num_centroids=10,
3481+
... filter=dfn.col("a") > dfn.lit(1.0),
3482+
... ).alias("v")])
3483+
>>> result.collect_column("v")[0].as_py()
3484+
2.5
34393485
"""
34403486
sort_expr_raw = sort_or_default(sort_expression)
34413487
filter_raw = filter.expr if filter is not None else None
@@ -3478,6 +3524,23 @@ def array_agg(
34783524
... [], [dfn.functions.array_agg(dfn.col("a")).alias("v")])
34793525
>>> result.collect_column("v")[0].as_py()
34803526
[1, 2, 3]
3527+
3528+
>>> df = ctx.from_pydict({"a": [3, 1, 2, 1]})
3529+
>>> result = df.aggregate(
3530+
... [], [dfn.functions.array_agg(
3531+
... dfn.col("a"), distinct=True,
3532+
... ).alias("v")])
3533+
>>> sorted(result.collect_column("v")[0].as_py())
3534+
[1, 2, 3]
3535+
3536+
>>> result = df.aggregate(
3537+
... [], [dfn.functions.array_agg(
3538+
... dfn.col("a"),
3539+
... filter=dfn.col("a") > dfn.lit(1),
3540+
... order_by="a",
3541+
... ).alias("v")])
3542+
>>> result.collect_column("v")[0].as_py()
3543+
[2, 3]
34813544
"""
34823545
order_by_raw = sort_list_to_raw_sort_list(order_by)
34833546
filter_raw = filter.expr if filter is not None else None
@@ -3579,6 +3642,15 @@ def count(
35793642
>>> result = df.aggregate([], [dfn.functions.count(dfn.col("a")).alias("v")])
35803643
>>> result.collect_column("v")[0].as_py()
35813644
3
3645+
3646+
>>> df = ctx.from_pydict({"a": [1, 1, 2, 3]})
3647+
>>> result = df.aggregate(
3648+
... [], [dfn.functions.count(
3649+
... dfn.col("a"), distinct=True,
3650+
... filter=dfn.col("a") > dfn.lit(1),
3651+
... ).alias("v")])
3652+
>>> result.collect_column("v")[0].as_py()
3653+
2
35823654
"""
35833655
filter_raw = filter.expr if filter is not None else None
35843656

@@ -3735,6 +3807,15 @@ def median(
37353807
>>> result = df.aggregate([], [dfn.functions.median(dfn.col("a")).alias("v")])
37363808
>>> result.collect_column("v")[0].as_py()
37373809
2.0
3810+
3811+
>>> df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]})
3812+
>>> result = df.aggregate(
3813+
... [], [dfn.functions.median(
3814+
... dfn.col("a"), distinct=True,
3815+
... filter=dfn.col("a") > dfn.lit(0.0),
3816+
... ).alias("v")])
3817+
>>> result.collect_column("v")[0].as_py()
3818+
2.0
37383819
"""
37393820
filter_raw = filter.expr if filter is not None else None
37403821
return Expr(f.median(expression.expr, distinct=distinct, filter=filter_raw))
@@ -4551,6 +4632,15 @@ def bit_xor(
45514632
>>> result = df.aggregate([], [dfn.functions.bit_xor(dfn.col("a")).alias("v")])
45524633
>>> result.collect_column("v")[0].as_py()
45534634
6
4635+
4636+
>>> df = ctx.from_pydict({"a": [5, 5, 3]})
4637+
>>> result = df.aggregate(
4638+
... [], [dfn.functions.bit_xor(
4639+
... dfn.col("a"), distinct=True,
4640+
... filter=dfn.col("a") > dfn.lit(0),
4641+
... ).alias("v")])
4642+
>>> result.collect_column("v")[0].as_py()
4643+
6
45544644
"""
45554645
filter_raw = filter.expr if filter is not None else None
45564646
return Expr(f.bit_xor(expression.expr, distinct=distinct, filter=filter_raw))

0 commit comments

Comments
 (0)