Skip to content

Commit 943d149

Browse files
feat(rust/sedona-functions): Add SRID argument to ST_Point() (#275)
Co-authored-by: Dewey Dunnington <dewey@dunnington.ca>
1 parent 54ec899 commit 943d149

File tree

4 files changed

+330
-23
lines changed

4 files changed

+330
-23
lines changed

python/sedonadb/tests/functions/test_functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,25 @@ def test_st_point(eng, x, y, expected):
13131313
)
13141314

13151315

1316+
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
1317+
@pytest.mark.parametrize(
1318+
("x", "y", "srid", "expected"),
1319+
[
1320+
(None, None, None, None),
1321+
(1, 1, None, None),
1322+
(1, 1, 0, 0),
1323+
(1, 1, 4326, 4326),
1324+
(1, 1, "4326", 4326),
1325+
],
1326+
)
1327+
def test_st_point_with_srid(eng, x, y, srid, expected):
1328+
eng = eng.create_or_skip()
1329+
eng.assert_query_result(
1330+
f"SELECT ST_SRID(ST_Point({val_or_null(x)}, {val_or_null(y)}, {val_or_null(srid)}))",
1331+
expected,
1332+
)
1333+
1334+
13161335
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
13171336
@pytest.mark.parametrize(
13181337
("x", "y", "z", "expected"),

rust/sedona-functions/src/st_point.rs

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,21 @@ use sedona_schema::{
3030
matchers::ArgMatcher,
3131
};
3232

33-
use crate::executor::WkbExecutor;
33+
use crate::{executor::WkbExecutor, st_setsrid::SRIDifiedKernel};
3434

3535
/// ST_Point() scalar UDF implementation
3636
///
3737
/// Native implementation to create geometries from coordinates.
3838
/// See [`st_geogpoint_udf`] for the corresponding geography constructor.
3939
pub fn st_point_udf() -> SedonaScalarUDF {
40+
let kernel = Arc::new(STGeoFromPoint {
41+
out_type: WKB_GEOMETRY,
42+
});
43+
let sridified_kernel = Arc::new(SRIDifiedKernel::new(kernel.clone()));
44+
4045
SedonaScalarUDF::new(
4146
"st_point",
42-
vec![Arc::new(STGeoFromPoint {
43-
out_type: WKB_GEOMETRY,
44-
})],
47+
vec![sridified_kernel, kernel],
4548
Volatility::Immutable,
4649
Some(doc("ST_Point", "Geometry")),
4750
)
@@ -52,11 +55,14 @@ pub fn st_point_udf() -> SedonaScalarUDF {
5255
/// Native implementation to create geometries from coordinates.
5356
/// See [`st_geogpoint_udf`] for the corresponding geography constructor.
5457
pub fn st_geogpoint_udf() -> SedonaScalarUDF {
58+
let kernel = Arc::new(STGeoFromPoint {
59+
out_type: WKB_GEOGRAPHY,
60+
});
61+
let sridified_kernel = Arc::new(SRIDifiedKernel::new(kernel.clone()));
62+
5563
SedonaScalarUDF::new(
5664
"st_geogpoint",
57-
vec![Arc::new(STGeoFromPoint {
58-
out_type: WKB_GEOGRAPHY,
59-
})],
65+
vec![sridified_kernel, kernel],
6066
Volatility::Immutable,
6167
Some(doc("st_geogpoint", "Geography")),
6268
)
@@ -73,6 +79,7 @@ fn doc(name: &str, out_type_name: &str) -> Documentation {
7379
)
7480
.with_argument("x", "double: X value")
7581
.with_argument("y", "double: Y value")
82+
.with_argument("srid", "srid: EPSG code to set (e.g., 4326)")
7683
.with_sql_example(format!("{name}(-64.36, 45.09)"))
7784
.build()
7885
}
@@ -157,8 +164,11 @@ mod tests {
157164
use arrow_array::create_array;
158165
use arrow_array::ArrayRef;
159166
use arrow_schema::DataType;
167+
use datafusion_expr::Literal;
160168
use datafusion_expr::ScalarUDF;
161169
use rstest::rstest;
170+
use sedona_schema::crs::lnglat;
171+
use sedona_schema::datatypes::Edges;
162172
use sedona_testing::compare::assert_array_equal;
163173
use sedona_testing::{create::create_array, testers::ScalarUdfTester};
164174

@@ -247,6 +257,56 @@ mod tests {
247257
);
248258
}
249259

260+
#[rstest]
261+
#[case(DataType::UInt32, 4326)]
262+
#[case(DataType::Int32, 4326)]
263+
#[case(DataType::Utf8, "4326")]
264+
#[case(DataType::Utf8, "EPSG:4326")]
265+
fn udf_invoke_with_srid(#[case] srid_type: DataType, #[case] srid_value: impl Literal + Copy) {
266+
let udf = st_point_udf();
267+
let tester = ScalarUdfTester::new(
268+
udf.into(),
269+
vec![
270+
SedonaType::Arrow(DataType::Float64),
271+
SedonaType::Arrow(DataType::Float64),
272+
SedonaType::Arrow(srid_type),
273+
],
274+
);
275+
276+
let return_type = tester
277+
.return_type_with_scalar_scalar_scalar(Some(1.0), Some(2.0), Some(srid_value))
278+
.unwrap();
279+
assert_eq!(return_type, SedonaType::Wkb(Edges::Planar, lnglat()));
280+
281+
let result = tester
282+
.invoke_scalar_scalar_scalar(1.0, 2.0, srid_value)
283+
.unwrap();
284+
tester.assert_scalar_result_equals_with_return_type(result, "POINT (1 2)", return_type);
285+
}
286+
287+
#[test]
288+
fn udf_invoke_with_invalid_srid() {
289+
let udf = st_point_udf();
290+
let tester = ScalarUdfTester::new(
291+
udf.into(),
292+
vec![
293+
SedonaType::Arrow(DataType::Float64),
294+
SedonaType::Arrow(DataType::Float64),
295+
SedonaType::Arrow(DataType::Utf8),
296+
],
297+
);
298+
299+
let return_type = tester.return_type_with_scalar_scalar_scalar(
300+
Some(1.0),
301+
Some(2.0),
302+
Some("gazornenplat"),
303+
);
304+
assert!(return_type.is_err());
305+
306+
let result = tester.invoke_scalar_scalar_scalar(1.0, 2.0, "gazornenplat");
307+
assert!(result.is_err());
308+
}
309+
250310
#[test]
251311
fn geog() {
252312
let udf = st_geogpoint_udf();

rust/sedona-functions/src/st_setsrid.rs

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
// under the License.
1717
use std::{sync::Arc, vec};
1818

19+
use arrow_array::builder::BinaryBuilder;
1920
use arrow_schema::DataType;
2021
use datafusion_common::{error::Result, DataFusionError, ScalarValue};
2122
use datafusion_expr::{
2223
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
2324
};
2425
use sedona_common::sedona_internal_err;
25-
use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
26+
use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel, SedonaScalarUDF};
2627
use sedona_geometry::transform::CrsEngine;
2728
use sedona_schema::{crs::deserialize_crs, datatypes::SedonaType, matchers::ArgMatcher};
2829

@@ -227,6 +228,119 @@ fn determine_return_type(
227228
sedona_internal_err!("Unexpected argument types: {}, {}", args[0], args[1])
228229
}
229230

231+
/// [SedonaScalarKernel] wrapper that handles the SRID argument for constructors like ST_Point
232+
#[derive(Debug)]
233+
pub(crate) struct SRIDifiedKernel {
234+
inner: ScalarKernelRef,
235+
}
236+
237+
impl SRIDifiedKernel {
238+
pub(crate) fn new(inner: ScalarKernelRef) -> Self {
239+
Self { inner }
240+
}
241+
}
242+
243+
impl SedonaScalarKernel for SRIDifiedKernel {
244+
fn return_type_from_args_and_scalars(
245+
&self,
246+
args: &[SedonaType],
247+
scalar_args: &[Option<&ScalarValue>],
248+
) -> Result<Option<SedonaType>> {
249+
// args should consist of the original args and one extra arg for
250+
// specifying CRS. So, first, validate the length and separate these.
251+
//
252+
// [arg0, arg1, ..., crs_arg];
253+
// ^^^^^^^^^^^^^^^
254+
// orig_args
255+
let orig_args_len = match (args.len(), scalar_args.len()) {
256+
(0, 0) => return Ok(None),
257+
(l1, l2) if l1 == l2 => l1 - 1,
258+
_ => return sedona_internal_err!("Arg types and arg values have different lengths"),
259+
};
260+
261+
let orig_args = &args[..orig_args_len];
262+
let orig_scalar_args = &scalar_args[..orig_args_len];
263+
264+
// Invoke the original return_type_from_args_and_scalars() first before checking the CRS argument
265+
let mut inner_result = match self
266+
.inner
267+
.return_type_from_args_and_scalars(orig_args, orig_scalar_args)?
268+
{
269+
Some(sedona_type) => sedona_type,
270+
// if no match, quit here. Since the CRS arg is also an unintended
271+
// one, validating it would be a cryptic error to the user.
272+
None => return Ok(None),
273+
};
274+
275+
let crs = match scalar_args[orig_args_len] {
276+
Some(crs) => crs,
277+
None => return Ok(None),
278+
};
279+
let new_crs = match crs.cast_to(&DataType::Utf8) {
280+
Ok(ScalarValue::Utf8(Some(crs))) => {
281+
if crs == "0" {
282+
None
283+
} else {
284+
validate_crs(&crs, None)?;
285+
deserialize_crs(&serde_json::Value::String(crs))?
286+
}
287+
}
288+
Ok(ScalarValue::Utf8(None)) => None,
289+
Ok(_) | Err(_) => return sedona_internal_err!("Can't cast Crs {crs:?} to Utf8"),
290+
};
291+
292+
match &mut inner_result {
293+
SedonaType::Wkb(_, crs) => *crs = new_crs,
294+
SedonaType::WkbView(_, crs) => *crs = new_crs,
295+
_ => {
296+
return sedona_internal_err!("Return type must be Wkb or WkbView");
297+
}
298+
}
299+
300+
Ok(Some(inner_result))
301+
}
302+
303+
fn invoke_batch(
304+
&self,
305+
arg_types: &[SedonaType],
306+
args: &[ColumnarValue],
307+
) -> Result<ColumnarValue> {
308+
let orig_args_len = arg_types.len() - 1;
309+
let orig_arg_types = &arg_types[..orig_args_len];
310+
let orig_args = &args[..orig_args_len];
311+
312+
// Invoke the inner UDF first to propagate any errors even when the CRS is NULL.
313+
// Note that, this behavior is different from PostGIS.
314+
let result = self.inner.invoke_batch(orig_arg_types, orig_args)?;
315+
316+
// If the specified SRID is NULL, the result is also NULL.
317+
if let ColumnarValue::Scalar(sc) = &args[orig_args_len] {
318+
if sc.is_null() {
319+
// Create the same length of NULLs as the original result.
320+
let len = match &result {
321+
ColumnarValue::Array(array) => array.len(),
322+
ColumnarValue::Scalar(_) => 1,
323+
};
324+
325+
let mut builder = BinaryBuilder::with_capacity(len, 0);
326+
for _ in 0..len {
327+
builder.append_null();
328+
}
329+
let new_array = builder.finish();
330+
return Ok(ColumnarValue::Array(Arc::new(new_array)));
331+
}
332+
}
333+
334+
Ok(result)
335+
}
336+
337+
fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
338+
sedona_internal_err!(
339+
"Should not be called because return_type_from_args_and_scalars() is implemented"
340+
)
341+
}
342+
}
343+
230344
#[cfg(test)]
231345
mod test {
232346
use std::rc::Rc;

0 commit comments

Comments
 (0)