|
16 | 16 | // under the License. |
17 | 17 | use std::{sync::Arc, vec}; |
18 | 18 |
|
| 19 | +use arrow_array::builder::BinaryBuilder; |
19 | 20 | use arrow_schema::DataType; |
20 | 21 | use datafusion_common::{error::Result, DataFusionError, ScalarValue}; |
21 | 22 | use datafusion_expr::{ |
22 | 23 | scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility, |
23 | 24 | }; |
24 | 25 | use sedona_common::sedona_internal_err; |
25 | | -use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF}; |
| 26 | +use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel, SedonaScalarUDF}; |
26 | 27 | use sedona_geometry::transform::CrsEngine; |
27 | 28 | use sedona_schema::{crs::deserialize_crs, datatypes::SedonaType, matchers::ArgMatcher}; |
28 | 29 |
|
@@ -227,6 +228,119 @@ fn determine_return_type( |
227 | 228 | sedona_internal_err!("Unexpected argument types: {}, {}", args[0], args[1]) |
228 | 229 | } |
229 | 230 |
|
| 231 | +/// [SedonaScalarKernel] wrapper that handles the SRID argument for constructors like ST_Point |
| 232 | +#[derive(Debug)] |
| 233 | +pub(crate) struct SRIDifiedKernel { |
| 234 | + inner: ScalarKernelRef, |
| 235 | +} |
| 236 | + |
| 237 | +impl SRIDifiedKernel { |
| 238 | + pub(crate) fn new(inner: ScalarKernelRef) -> Self { |
| 239 | + Self { inner } |
| 240 | + } |
| 241 | +} |
| 242 | + |
| 243 | +impl SedonaScalarKernel for SRIDifiedKernel { |
| 244 | + fn return_type_from_args_and_scalars( |
| 245 | + &self, |
| 246 | + args: &[SedonaType], |
| 247 | + scalar_args: &[Option<&ScalarValue>], |
| 248 | + ) -> Result<Option<SedonaType>> { |
| 249 | + // args should consist of the original args and one extra arg for |
| 250 | + // specifying CRS. So, first, validate the length and separate these. |
| 251 | + // |
| 252 | + // [arg0, arg1, ..., crs_arg]; |
| 253 | + // ^^^^^^^^^^^^^^^ |
| 254 | + // orig_args |
| 255 | + let orig_args_len = match (args.len(), scalar_args.len()) { |
| 256 | + (0, 0) => return Ok(None), |
| 257 | + (l1, l2) if l1 == l2 => l1 - 1, |
| 258 | + _ => return sedona_internal_err!("Arg types and arg values have different lengths"), |
| 259 | + }; |
| 260 | + |
| 261 | + let orig_args = &args[..orig_args_len]; |
| 262 | + let orig_scalar_args = &scalar_args[..orig_args_len]; |
| 263 | + |
| 264 | + // Invoke the original return_type_from_args_and_scalars() first before checking the CRS argument |
| 265 | + let mut inner_result = match self |
| 266 | + .inner |
| 267 | + .return_type_from_args_and_scalars(orig_args, orig_scalar_args)? |
| 268 | + { |
| 269 | + Some(sedona_type) => sedona_type, |
| 270 | + // if no match, quit here. Since the CRS arg is also an unintended |
| 271 | + // one, validating it would be a cryptic error to the user. |
| 272 | + None => return Ok(None), |
| 273 | + }; |
| 274 | + |
| 275 | + let crs = match scalar_args[orig_args_len] { |
| 276 | + Some(crs) => crs, |
| 277 | + None => return Ok(None), |
| 278 | + }; |
| 279 | + let new_crs = match crs.cast_to(&DataType::Utf8) { |
| 280 | + Ok(ScalarValue::Utf8(Some(crs))) => { |
| 281 | + if crs == "0" { |
| 282 | + None |
| 283 | + } else { |
| 284 | + validate_crs(&crs, None)?; |
| 285 | + deserialize_crs(&serde_json::Value::String(crs))? |
| 286 | + } |
| 287 | + } |
| 288 | + Ok(ScalarValue::Utf8(None)) => None, |
| 289 | + Ok(_) | Err(_) => return sedona_internal_err!("Can't cast Crs {crs:?} to Utf8"), |
| 290 | + }; |
| 291 | + |
| 292 | + match &mut inner_result { |
| 293 | + SedonaType::Wkb(_, crs) => *crs = new_crs, |
| 294 | + SedonaType::WkbView(_, crs) => *crs = new_crs, |
| 295 | + _ => { |
| 296 | + return sedona_internal_err!("Return type must be Wkb or WkbView"); |
| 297 | + } |
| 298 | + } |
| 299 | + |
| 300 | + Ok(Some(inner_result)) |
| 301 | + } |
| 302 | + |
| 303 | + fn invoke_batch( |
| 304 | + &self, |
| 305 | + arg_types: &[SedonaType], |
| 306 | + args: &[ColumnarValue], |
| 307 | + ) -> Result<ColumnarValue> { |
| 308 | + let orig_args_len = arg_types.len() - 1; |
| 309 | + let orig_arg_types = &arg_types[..orig_args_len]; |
| 310 | + let orig_args = &args[..orig_args_len]; |
| 311 | + |
| 312 | + // Invoke the inner UDF first to propagate any errors even when the CRS is NULL. |
| 313 | + // Note that, this behavior is different from PostGIS. |
| 314 | + let result = self.inner.invoke_batch(orig_arg_types, orig_args)?; |
| 315 | + |
| 316 | + // If the specified SRID is NULL, the result is also NULL. |
| 317 | + if let ColumnarValue::Scalar(sc) = &args[orig_args_len] { |
| 318 | + if sc.is_null() { |
| 319 | + // Create the same length of NULLs as the original result. |
| 320 | + let len = match &result { |
| 321 | + ColumnarValue::Array(array) => array.len(), |
| 322 | + ColumnarValue::Scalar(_) => 1, |
| 323 | + }; |
| 324 | + |
| 325 | + let mut builder = BinaryBuilder::with_capacity(len, 0); |
| 326 | + for _ in 0..len { |
| 327 | + builder.append_null(); |
| 328 | + } |
| 329 | + let new_array = builder.finish(); |
| 330 | + return Ok(ColumnarValue::Array(Arc::new(new_array))); |
| 331 | + } |
| 332 | + } |
| 333 | + |
| 334 | + Ok(result) |
| 335 | + } |
| 336 | + |
| 337 | + fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> { |
| 338 | + sedona_internal_err!( |
| 339 | + "Should not be called because return_type_from_args_and_scalars() is implemented" |
| 340 | + ) |
| 341 | + } |
| 342 | +} |
| 343 | + |
230 | 344 | #[cfg(test)] |
231 | 345 | mod test { |
232 | 346 | use std::rc::Rc; |
|
0 commit comments