Skip to content

Commit 819bdae

Browse files
authored
Fix TIME processing (#236)
* Fix TIME processing Instead of assuming we want `Time64Microsecond` when converting from NaiveType, convert based on the interenced type. Update the `get_time*_value` functions to return `NaiveTime` instead of `NaiveDateTime`. Previously they were returning None, because that's what `as_datetime` always returns for Time32 & Time64 values. * Add tests for various get time functions Add tests for the `get_time*_value` functions. * Add test for deserialising time parameters Add test for deserialising time parameters, and fix an implementation bug found by the test.
1 parent feac76a commit 819bdae

File tree

2 files changed

+144
-16
lines changed

2 files changed

+144
-16
lines changed

arrow-pg/src/datatypes/df.rs

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::iter;
22
use std::sync::Arc;
33

44
use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
5-
use datafusion::arrow::datatypes::{DataType, Date32Type};
5+
use datafusion::arrow::datatypes::{DataType, Date32Type, TimeUnit};
66
use datafusion::arrow::record_batch::RecordBatch;
77
use datafusion::common::ParamValues;
88
use datafusion::prelude::*;
@@ -79,10 +79,8 @@ where
7979
let param_len = portal.parameter_len();
8080
let mut deserialized_params = Vec::with_capacity(param_len);
8181
for i in 0..param_len {
82-
let pg_type = get_pg_type(
83-
portal.statement.parameter_types.get(i),
84-
inferenced_types.get(i).and_then(|v| v.to_owned()),
85-
)?;
82+
let inferenced_type = inferenced_types.get(i).and_then(|v| v.to_owned());
83+
let pg_type = get_pg_type(portal.statement.parameter_types.get(i), inferenced_type)?;
8684
match pg_type {
8785
// enumerate all supported parameter types and deserialize the
8886
// type to ScalarValue
@@ -158,9 +156,36 @@ where
158156
}
159157
Type::TIME => {
160158
let value = portal.parameter::<NaiveTime>(i, &pg_type)?;
161-
deserialized_params.push(ScalarValue::Time64Microsecond(value.map(|t| {
162-
t.num_seconds_from_midnight() as i64 * 1_000_000 + t.nanosecond() as i64 / 1_000
163-
})));
159+
160+
let ns = value.map(|t| {
161+
t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64
162+
});
163+
164+
let scalar_value = match inferenced_type {
165+
Some(DataType::Time64(TimeUnit::Nanosecond)) => {
166+
ScalarValue::Time64Nanosecond(ns)
167+
}
168+
Some(DataType::Time64(TimeUnit::Microsecond)) => {
169+
ScalarValue::Time64Microsecond(ns.map(|ns| (ns / 1_000) as _))
170+
}
171+
Some(DataType::Time32(TimeUnit::Millisecond)) => {
172+
ScalarValue::Time32Millisecond(ns.map(|ns| (ns / 1_000_000) as _))
173+
}
174+
Some(DataType::Time32(TimeUnit::Second)) => {
175+
ScalarValue::Time32Second(ns.map(|ns| (ns / 1_000_000_000) as _))
176+
}
177+
_ => {
178+
return Err(PgWireError::ApiError(
179+
format!(
180+
"Unable to deserialise time parameter type {:?} to type {:?}",
181+
value, inferenced_type
182+
)
183+
.into(),
184+
))
185+
}
186+
};
187+
188+
deserialized_params.push(scalar_value);
164189
}
165190
Type::UUID => {
166191
let value = portal.parameter::<String>(i, &pg_type)?;
@@ -294,3 +319,64 @@ where
294319

295320
Ok(ParamValues::List(deserialized_params))
296321
}
322+
323+
#[cfg(test)]
324+
mod tests {
325+
use std::sync::Arc;
326+
327+
use arrow::datatypes::DataType;
328+
use bytes::Bytes;
329+
use datafusion::{common::ParamValues, scalar::ScalarValue};
330+
use pgwire::{
331+
api::{portal::Portal, stmt::StoredStatement},
332+
messages::extendedquery::Bind,
333+
};
334+
use postgres_types::Type;
335+
336+
use crate::datatypes::df::deserialize_parameters;
337+
338+
#[test]
339+
fn test_deserialise_time_params() {
340+
let postgres_types = vec![Type::TIME];
341+
342+
let us: i64 = 1_000_000; // 1 second
343+
344+
let bind = Bind::new(
345+
None,
346+
None,
347+
vec![],
348+
vec![Some(Bytes::from(i64::to_be_bytes(us).to_vec()))],
349+
vec![],
350+
);
351+
352+
let stmt = StoredStatement::new("statement_id".into(), "statement", postgres_types);
353+
let portal = Portal::try_new(&bind, Arc::new(stmt)).unwrap();
354+
355+
for (arrow_type, expected) in [
356+
(
357+
DataType::Time32(arrow::datatypes::TimeUnit::Second),
358+
ScalarValue::Time32Second(Some(1)),
359+
),
360+
(
361+
DataType::Time32(arrow::datatypes::TimeUnit::Millisecond),
362+
ScalarValue::Time32Millisecond(Some(1000)),
363+
),
364+
(
365+
DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
366+
ScalarValue::Time64Microsecond(Some(1000000)),
367+
),
368+
(
369+
DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond),
370+
ScalarValue::Time64Nanosecond(Some(1000000000)),
371+
),
372+
] {
373+
let result = deserialize_parameters(&portal, &[Some(&arrow_type)]).unwrap();
374+
let ParamValues::List(list) = result else {
375+
panic!("expected list");
376+
};
377+
378+
assert_eq!(list.len(), 1);
379+
assert_eq!(list[0], expected)
380+
}
381+
}
382+
}

arrow-pg/src/encoder.rs

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::Arc;
77
use arrow::{array::*, datatypes::*};
88
use bytes::BufMut;
99
use bytes::BytesMut;
10+
use chrono::NaiveTime;
1011
use chrono::{NaiveDate, NaiveDateTime};
1112
#[cfg(feature = "datafusion")]
1213
use datafusion::arrow::{array::*, datatypes::*};
@@ -203,43 +204,43 @@ fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
203204
.value_as_date(idx)
204205
}
205206

206-
fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
207+
fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
207208
if arr.is_null(idx) {
208209
return None;
209210
}
210211
arr.as_any()
211212
.downcast_ref::<Time32SecondArray>()
212213
.unwrap()
213-
.value_as_datetime(idx)
214+
.value_as_time(idx)
214215
}
215216

216-
fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
217+
fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
217218
if arr.is_null(idx) {
218219
return None;
219220
}
220221
arr.as_any()
221222
.downcast_ref::<Time32MillisecondArray>()
222223
.unwrap()
223-
.value_as_datetime(idx)
224+
.value_as_time(idx)
224225
}
225226

226-
fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
227+
fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
227228
if arr.is_null(idx) {
228229
return None;
229230
}
230231
arr.as_any()
231232
.downcast_ref::<Time64MicrosecondArray>()
232233
.unwrap()
233-
.value_as_datetime(idx)
234+
.value_as_time(idx)
234235
}
235-
fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
236+
fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
236237
if arr.is_null(idx) {
237238
return None;
238239
}
239240
arr.as_any()
240241
.downcast_ref::<Time64NanosecondArray>()
241242
.unwrap()
242-
.value_as_datetime(idx)
243+
.value_as_time(idx)
243244
}
244245

245246
fn get_numeric_128_value(
@@ -518,4 +519,45 @@ mod tests {
518519

519520
assert!(encoder.encoded_value == val);
520521
}
522+
523+
#[test]
524+
fn test_get_time32_second_value() {
525+
let array = Time32SecondArray::from_iter_values([3723_i32]);
526+
let array: Arc<dyn Array> = Arc::new(array);
527+
let value = get_time32_second_value(&array, 0);
528+
assert_eq!(value, Some(NaiveTime::from_hms_opt(1, 2, 3)).unwrap());
529+
}
530+
531+
#[test]
532+
fn test_get_time32_millisecond_value() {
533+
let array = Time32MillisecondArray::from_iter_values([3723001_i32]);
534+
let array: Arc<dyn Array> = Arc::new(array);
535+
let value = get_time32_millisecond_value(&array, 0);
536+
assert_eq!(
537+
value,
538+
Some(NaiveTime::from_hms_milli_opt(1, 2, 3, 1)).unwrap()
539+
);
540+
}
541+
542+
#[test]
543+
fn test_get_time64_microsecond_value() {
544+
let array = Time64MicrosecondArray::from_iter_values([3723001001_i64]);
545+
let array: Arc<dyn Array> = Arc::new(array);
546+
let value = get_time64_microsecond_value(&array, 0);
547+
assert_eq!(
548+
value,
549+
Some(NaiveTime::from_hms_micro_opt(1, 2, 3, 1001)).unwrap()
550+
);
551+
}
552+
553+
#[test]
554+
fn test_get_time64_nanosecond_value() {
555+
let array = Time64NanosecondArray::from_iter_values([3723001001001_i64]);
556+
let array: Arc<dyn Array> = Arc::new(array);
557+
let value = get_time64_nanosecond_value(&array, 0);
558+
assert_eq!(
559+
value,
560+
Some(NaiveTime::from_hms_nano_opt(1, 2, 3, 1001001)).unwrap()
561+
);
562+
}
521563
}

0 commit comments

Comments
 (0)