Skip to content

Commit 6081853

Browse files
authored
Merge pull request #217 from influxdata/crepererum/check-array-size
feat: check return row count of UDF
2 parents 9caf8b4 + 20f9eda commit 6081853

File tree

5 files changed

+105
-1
lines changed

5 files changed

+105
-1
lines changed

guests/evil/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod common;
1111
mod env;
1212
mod fs;
1313
mod net;
14+
mod return_data;
1415
mod root;
1516
mod runtime;
1617
mod spin;
@@ -46,6 +47,10 @@ impl Evil {
4647
root: Box::new(common::root_empty),
4748
udfs: Box::new(net::udfs),
4849
},
50+
"return_data" => Self {
51+
root: Box::new(common::root_empty),
52+
udfs: Box::new(return_data::udfs),
53+
},
4954
"root::invalid_entry" => Self {
5055
root: Box::new(root::invalid_entry::root),
5156
udfs: Box::new(common::udfs_empty),

guests/evil/src/return_data.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//! Payload that returns invalid data.
2+
3+
use std::sync::Arc;
4+
5+
use arrow::{array::StringArray, datatypes::DataType};
6+
use datafusion_common::error::Result as DataFusionResult;
7+
use datafusion_expr::{
8+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
9+
};
10+
11+
/// UDF that return the wrong number of rows.
12+
#[derive(Debug, PartialEq, Eq, Hash)]
13+
struct WrongNumberOfRows;
14+
15+
impl ScalarUDFImpl for WrongNumberOfRows {
16+
fn as_any(&self) -> &dyn std::any::Any {
17+
self
18+
}
19+
20+
fn name(&self) -> &str {
21+
"wrong-number-of-rows"
22+
}
23+
24+
fn signature(&self) -> &Signature {
25+
static S: Signature = Signature {
26+
type_signature: TypeSignature::Uniform(0, vec![]),
27+
volatility: Volatility::Immutable,
28+
};
29+
30+
&S
31+
}
32+
33+
fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
34+
Ok(DataType::Utf8)
35+
}
36+
37+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
38+
Ok(ColumnarValue::Array(Arc::new(
39+
(0..=args.number_rows)
40+
.map(|idx| Some(idx.to_string()))
41+
.collect::<StringArray>(),
42+
)))
43+
}
44+
}
45+
46+
/// Returns our evil UDFs.
47+
///
48+
/// The passed `source` is ignored.
49+
#[expect(clippy::unnecessary_wraps, reason = "public API through export! macro")]
50+
pub(crate) fn udfs(_source: String) -> DataFusionResult<Vec<Arc<dyn ScalarUDFImpl>>> {
51+
Ok(vec![Arc::new(WrongNumberOfRows)])
52+
}

host/src/lib.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,20 @@ impl AsyncScalarUDFImpl for WasmScalarUdf {
797797

798798
drop(store_guard);
799799

800-
return_type.checked_into_root(&self.trusted_data_limits)
800+
match return_type.checked_into_root(&self.trusted_data_limits) {
801+
Ok(ColumnarValue::Scalar(scalar)) => Ok(ColumnarValue::Scalar(scalar)),
802+
Ok(ColumnarValue::Array(array)) if array.len() as u64 != args.number_rows => {
803+
Err(DataFusionError::External(
804+
format!(
805+
"UDF returned array of length {} but should produce {} rows",
806+
array.len(),
807+
args.number_rows
808+
)
809+
.into(),
810+
))
811+
}
812+
Ok(ColumnarValue::Array(array)) => Ok(ColumnarValue::Array(array)),
813+
Err(e) => Err(e),
814+
}
801815
}
802816
}

host/tests/integration_tests/evil/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod env;
22
mod fs;
33
mod net;
4+
mod return_data;
45
mod root;
56
mod runtime;
67
mod spin;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use std::sync::Arc;
2+
3+
use arrow::datatypes::{DataType, Field};
4+
use datafusion_common::config::ConfigOptions;
5+
use datafusion_expr::{ScalarFunctionArgs, async_udf::AsyncScalarUDFImpl};
6+
7+
use crate::integration_tests::evil::test_utils::try_scalar_udfs;
8+
9+
#[tokio::test]
10+
async fn test_wrong_number_of_rows() {
11+
let [udf] = try_scalar_udfs("return_data")
12+
.await
13+
.unwrap()
14+
.try_into()
15+
.unwrap();
16+
17+
let err = udf
18+
.invoke_async_with_args(ScalarFunctionArgs {
19+
args: vec![],
20+
arg_fields: vec![],
21+
number_rows: 42,
22+
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
23+
config_options: Arc::new(ConfigOptions::default()),
24+
})
25+
.await
26+
.unwrap_err();
27+
28+
insta::assert_snapshot!(
29+
err,
30+
@"External error: UDF returned array of length 43 but should produce 42 rows",
31+
);
32+
}

0 commit comments

Comments
 (0)