Skip to content

Commit eacf226

Browse files
authored
Merge branch 'main' into chore/upgrade-rust-2024-group-5
2 parents 752818f + 83736ef commit eacf226

File tree

3 files changed

+421
-85
lines changed

3 files changed

+421
-85
lines changed

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,14 +579,32 @@ pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ {
579579
pub fn snapshot_physical_expr(
580580
expr: Arc<dyn PhysicalExpr>,
581581
) -> Result<Arc<dyn PhysicalExpr>> {
582+
snapshot_physical_expr_opt(expr).data()
583+
}
584+
585+
/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
586+
///
587+
/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
588+
/// This is used to capture the current state of `PhysicalExpr`s that may contain
589+
/// dynamic references to other operators in order to serialize it over the wire
590+
/// or treat it via downcast matching.
591+
///
592+
/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
593+
///
594+
/// # Returns
595+
///
596+
/// Returns a `[`Transformed`] indicating whether a snapshot was taken,
597+
/// along with the resulting `PhysicalExpr`.
598+
pub fn snapshot_physical_expr_opt(
599+
expr: Arc<dyn PhysicalExpr>,
600+
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
582601
expr.transform_up(|e| {
583602
if let Some(snapshot) = e.snapshot()? {
584603
Ok(Transformed::yes(snapshot))
585604
} else {
586605
Ok(Transformed::no(Arc::clone(&e)))
587606
}
588607
})
589-
.data()
590608
}
591609

592610
/// Check the generation of this `PhysicalExpr`.

datafusion/physical-expr/benches/in_list.rs

Lines changed: 114 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,21 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Array, ArrayRef, Float32Array, Int32Array, StringArray};
18+
use arrow::array::{
19+
Array, ArrayRef, Float32Array, Int32Array, StringArray, StringViewArray,
20+
};
1921
use arrow::datatypes::{Field, Schema};
2022
use arrow::record_batch::RecordBatch;
2123
use criterion::{Criterion, criterion_group, criterion_main};
2224
use datafusion_common::ScalarValue;
2325
use datafusion_physical_expr::expressions::{col, in_list, lit};
2426
use rand::distr::Alphanumeric;
2527
use rand::prelude::*;
28+
use std::any::TypeId;
2629
use std::hint::black_box;
2730
use std::sync::Arc;
2831

32+
/// Measures how long `in_list(col("a"), exprs)` takes to evaluate against a single RecordBatch.
2933
fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValue]) {
3034
let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]);
3135
let exprs = exprs.iter().map(|s| lit(s.clone())).collect();
@@ -37,79 +41,128 @@ fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValu
3741
});
3842
}
3943

44+
/// Generates a random alphanumeric string of the specified length.
4045
fn random_string(rng: &mut StdRng, len: usize) -> String {
4146
let value = rng.sample_iter(&Alphanumeric).take(len).collect();
4247
String::from_utf8(value).unwrap()
4348
}
4449

45-
fn do_benches(
46-
c: &mut Criterion,
47-
array_length: usize,
48-
in_list_length: usize,
49-
null_percent: f64,
50-
) {
51-
let mut rng = StdRng::seed_from_u64(120320);
52-
let non_null_percent = 1.0 - null_percent;
53-
54-
for string_length in [5, 10, 20] {
55-
let values: StringArray = (0..array_length)
56-
.map(|_| {
57-
rng.random_bool(non_null_percent)
58-
.then(|| random_string(&mut rng, string_length))
59-
})
60-
.collect();
61-
62-
let in_list: Vec<_> = (0..in_list_length)
63-
.map(|_| ScalarValue::from(random_string(&mut rng, string_length)))
64-
.collect();
65-
66-
do_bench(
67-
c,
68-
&format!(
69-
"in_list_utf8({string_length}) ({array_length}, {null_percent}) IN ({in_list_length}, 0)"
70-
),
71-
Arc::new(values),
72-
&in_list,
73-
)
50+
const IN_LIST_LENGTHS: [usize; 3] = [3, 8, 100];
51+
const NULL_PERCENTS: [f64; 2] = [0., 0.2];
52+
const STRING_LENGTHS: [usize; 3] = [3, 12, 100];
53+
const ARRAY_LENGTH: usize = 1024;
54+
55+
/// Returns a friendly type name for the array type.
56+
fn array_type_name<A: 'static>() -> &'static str {
57+
let id = TypeId::of::<A>();
58+
if id == TypeId::of::<StringArray>() {
59+
"Utf8"
60+
} else if id == TypeId::of::<StringViewArray>() {
61+
"Utf8View"
62+
} else if id == TypeId::of::<Float32Array>() {
63+
"Float32"
64+
} else if id == TypeId::of::<Int32Array>() {
65+
"Int32"
66+
} else {
67+
"Unknown"
7468
}
69+
}
7570

76-
let values: Float32Array = (0..array_length)
77-
.map(|_| rng.random_bool(non_null_percent).then(|| rng.random()))
78-
.collect();
71+
/// Builds a benchmark name from array type, list size, and null percentage.
72+
fn bench_name<A: 'static>(in_list_length: usize, null_percent: f64) -> String {
73+
format!(
74+
"in_list/{}/list={in_list_length}/nulls={}%",
75+
array_type_name::<A>(),
76+
(null_percent * 100.0) as u32
77+
)
78+
}
7979

80-
let in_list: Vec<_> = (0..in_list_length)
81-
.map(|_| ScalarValue::Float32(Some(rng.random())))
82-
.collect();
80+
/// Runs in_list benchmarks for a string array type across all list-size × null-ratio × string-length combinations.
81+
fn bench_string_type<A>(
82+
c: &mut Criterion,
83+
rng: &mut StdRng,
84+
make_scalar: fn(String) -> ScalarValue,
85+
) where
86+
A: Array + FromIterator<Option<String>> + 'static,
87+
{
88+
for in_list_length in IN_LIST_LENGTHS {
89+
for null_percent in NULL_PERCENTS {
90+
for string_length in STRING_LENGTHS {
91+
let values: A = (0..ARRAY_LENGTH)
92+
.map(|_| {
93+
rng.random_bool(1.0 - null_percent)
94+
.then(|| random_string(rng, string_length))
95+
})
96+
.collect();
97+
98+
let in_list: Vec<_> = (0..in_list_length)
99+
.map(|_| make_scalar(random_string(rng, string_length)))
100+
.collect();
101+
102+
do_bench(
103+
c,
104+
&format!(
105+
"{}/str={string_length}",
106+
bench_name::<A>(in_list_length, null_percent)
107+
),
108+
Arc::new(values),
109+
&in_list,
110+
)
111+
}
112+
}
113+
}
114+
}
83115

84-
do_bench(
85-
c,
86-
&format!("in_list_f32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"),
87-
Arc::new(values),
88-
&in_list,
89-
);
116+
/// Runs in_list benchmarks for a numeric array type across all list-size × null-ratio combinations.
117+
fn bench_numeric_type<T, A>(
118+
c: &mut Criterion,
119+
rng: &mut StdRng,
120+
mut gen_value: impl FnMut(&mut StdRng) -> T,
121+
make_scalar: fn(T) -> ScalarValue,
122+
) where
123+
A: Array + FromIterator<Option<T>> + 'static,
124+
{
125+
for in_list_length in IN_LIST_LENGTHS {
126+
for null_percent in NULL_PERCENTS {
127+
let values: A = (0..ARRAY_LENGTH)
128+
.map(|_| rng.random_bool(1.0 - null_percent).then(|| gen_value(rng)))
129+
.collect();
130+
131+
let in_list: Vec<_> = (0..in_list_length)
132+
.map(|_| make_scalar(gen_value(rng)))
133+
.collect();
134+
135+
do_bench(
136+
c,
137+
&bench_name::<A>(in_list_length, null_percent),
138+
Arc::new(values),
139+
&in_list,
140+
);
141+
}
142+
}
143+
}
90144

91-
let values: Int32Array = (0..array_length)
92-
.map(|_| rng.random_bool(non_null_percent).then(|| rng.random()))
93-
.collect();
145+
/// Entry point: registers in_list benchmarks for Utf8, Utf8View, Float32, and Int32 arrays.
146+
fn criterion_benchmark(c: &mut Criterion) {
147+
let mut rng = StdRng::seed_from_u64(120320);
94148

95-
let in_list: Vec<_> = (0..in_list_length)
96-
.map(|_| ScalarValue::Int32(Some(rng.random())))
97-
.collect();
149+
// Benchmarks for string array types (Utf8, Utf8View)
150+
bench_string_type::<StringArray>(c, &mut rng, |s| ScalarValue::Utf8(Some(s)));
151+
bench_string_type::<StringViewArray>(c, &mut rng, |s| ScalarValue::Utf8View(Some(s)));
98152

99-
do_bench(
153+
// Benchmarks for numeric types
154+
bench_numeric_type::<f32, Float32Array>(
100155
c,
101-
&format!("in_list_i32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"),
102-
Arc::new(values),
103-
&in_list,
104-
)
105-
}
106-
107-
fn criterion_benchmark(c: &mut Criterion) {
108-
for in_list_length in [1, 3, 10, 100] {
109-
for null_percent in [0., 0.2] {
110-
do_benches(c, 1024, in_list_length, null_percent)
111-
}
112-
}
156+
&mut rng,
157+
|rng| rng.random(),
158+
|v| ScalarValue::Float32(Some(v)),
159+
);
160+
bench_numeric_type::<i32, Int32Array>(
161+
c,
162+
&mut rng,
163+
|rng| rng.random(),
164+
|v| ScalarValue::Int32(Some(v)),
165+
);
113166
}
114167

115168
criterion_group!(benches, criterion_benchmark);

0 commit comments

Comments
 (0)