From 5324ad2227e1c830231fdf7d7026fa356e429750 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 29 Oct 2025 04:45:52 +0800 Subject: [PATCH 1/2] feat: add support for postgres variables of formatting --- datafusion-postgres/src/hooks/set_show.rs | 196 +++++++++++++++++----- 1 file changed, 154 insertions(+), 42 deletions(-) diff --git a/datafusion-postgres/src/hooks/set_show.rs b/datafusion-postgres/src/hooks/set_show.rs index 92fb3f2..0568f9c 100644 --- a/datafusion-postgres/src/hooks/set_show.rs +++ b/datafusion-postgres/src/hooks/set_show.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::{ParamValues, ToDFSchema}; +use datafusion::error::DataFusionError; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::SessionContext; -use datafusion::sql::sqlparser::ast::{Set, Statement}; +use datafusion::sql::sqlparser::ast::{Expr, Set, Statement}; use log::{info, warn}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::ClientInfo; @@ -134,39 +135,53 @@ where hivevar: false, variable, values, - } if &variable.to_string() == "statement_timeout" => { - let value = values[0].to_string(); - let timeout_str = value.trim_matches('"').trim_matches('\''); - - let timeout = if timeout_str == "0" || timeout_str.is_empty() { - None - } else { - // Parse timeout value (supports ms, s, min formats) - let timeout_ms = if timeout_str.ends_with("ms") { - timeout_str.trim_end_matches("ms").parse::() - } else if timeout_str.ends_with("s") { - timeout_str - .trim_end_matches("s") - .parse::() - .map(|s| s * 1000) - } else if timeout_str.ends_with("min") { - timeout_str - .trim_end_matches("min") - .parse::() - .map(|m| m * 60 * 1000) + } => { + let var = variable.to_string().to_lowercase(); + if var == "statement_timeout" { + let value = values[0].to_string(); + let timeout_str = value.trim_matches('"').trim_matches('\''); + + let timeout = if timeout_str == "0" || timeout_str.is_empty() { + None } else { - // Default to milliseconds - timeout_str.parse::() + // Parse timeout value (supports ms, s, min formats) + let timeout_ms = if timeout_str.ends_with("ms") { + timeout_str.trim_end_matches("ms").parse::() + } else if timeout_str.ends_with("s") { + timeout_str + .trim_end_matches("s") + .parse::() + .map(|s| s * 1000) + } else if timeout_str.ends_with("min") { + timeout_str + .trim_end_matches("min") + .parse::() + .map(|m| m * 60 * 1000) + } else { + // Default to milliseconds + timeout_str.parse::() + }; + + match timeout_ms { + Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)), + _ => None, + } }; - match timeout_ms { - Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)), - _ => None, + client::set_statement_timeout(client, timeout); + return Some(Ok(Response::Execution(Tag::new("SET")))); + } else if matches!(var.as_str(), "datestyle" | "bytea_output" | "intervalstyle") { + if values.len() > 0 { + // postgres configuration variables + let value = values[0].clone(); + if let Expr::Value(value) = value { + client + .metadata_mut() + .insert(var, value.into_string().unwrap_or_else(|| "".to_string())); + return Some(Ok(Response::Execution(Tag::new("SET")))); + } } - }; - - client::set_statement_timeout(client, timeout); - Some(Ok(Response::Execution(Tag::new("SET")))) + } } Set::SetTimeZone { local: false, @@ -175,19 +190,39 @@ where let tz = value.to_string(); let tz = tz.trim_matches('"').trim_matches('\''); client::set_timezone(client, Some(tz)); - Some(Ok(Response::Execution(Tag::new("SET")))) + return Some(Ok(Response::Execution(Tag::new("SET")))); } - _ => { - // pass SET query to datafusion - let query = statement.to_string(); - if let Err(e) = session_context.sql(&query).await { - warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored"); - } + _ => {} + } - // Always return SET success - Some(Ok(Response::Execution(Tag::new("SET")))) - } + // fallback to datafusion and ignore all errors + if let Err(e) = execute_set_statement(session_context, statement.clone()).await { + warn!( + "SET statement {} is not supported by datafusion, error {e}, statement ignored", + statement.to_string() + ); } + + // Always return SET success + Some(Ok(Response::Execution(Tag::new("SET")))) +} + +async fn execute_set_statement( + session_context: &SessionContext, + statement: Statement, +) -> Result<(), DataFusionError> { + let state = session_context.state(); + let logical_plan = state + .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new( + statement, + ))) + .await + .and_then(|logical_plan| state.optimize(&logical_plan))?; + + session_context + .execute_logical_plan(logical_plan) + .await + .map(|_| ()) } async fn try_respond_show_statements( @@ -204,10 +239,11 @@ where let variables = variable .iter() - .map(|v| &v.value as &str) + .map(|v| v.value.to_lowercase()) .collect::>(); + let variables_ref = variables.iter().map(|s| s.as_str()).collect::>(); - match &variables as &[&str] { + match variables_ref.as_slice() { ["time", "zone"] => { let timezone = client::get_timezone(client).unwrap_or("UTC"); Some(mock_show_response("TimeZone", timezone).map(Response::Query)) @@ -238,6 +274,14 @@ where ["transaction", "isolation", "level"] => { Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query)) } + ["bytea_output"] | ["datestyle"] | ["intervalstyle"] => { + let val = client + .metadata() + .get(&variables[0]) + .map(|v| v.as_str()) + .unwrap_or(""); + Some(mock_show_response(&variables[0], val).map(Response::Query)) + } _ => { info!("Unsupported show statement: {}", statement); Some(mock_show_response("unsupported_show_statement", "").map(Response::Query)) @@ -288,6 +332,74 @@ mod tests { assert!(show_response.unwrap().is_ok()); } + #[tokio::test] + async fn test_bytea_output_set_and_show() { + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + + // Test setting timeout to 5000ms + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("set bytea_output = 'hex'") + .unwrap() + .parse_statement() + .unwrap(); + let set_response = + try_respond_set_statements(&mut client, &statement, &session_context).await; + + assert!(set_response.is_some()); + assert!(set_response.unwrap().is_ok()); + + // Verify the timeout was set in client metadata + let bytea_output = client.metadata().get("bytea_output").unwrap(); + assert_eq!(bytea_output, "hex"); + + // Test SHOW statement_timeout + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("show bytea_output") + .unwrap() + .parse_statement() + .unwrap(); + let show_response = + try_respond_show_statements(&client, &statement, &session_context).await; + + assert!(show_response.is_some()); + assert!(show_response.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_date_style_set_and_show() { + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + + // Test setting timeout to 5000ms + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("set dateStyle = 'ISO, DMY'") + .unwrap() + .parse_statement() + .unwrap(); + let set_response = + try_respond_set_statements(&mut client, &statement, &session_context).await; + + assert!(set_response.is_some()); + assert!(set_response.unwrap().is_ok()); + + // Verify the timeout was set in client metadata + let bytea_output = client.metadata().get("datestyle").unwrap(); + assert_eq!(bytea_output, "ISO, DMY"); + + // Test SHOW statement_timeout + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("show dateStyle") + .unwrap() + .parse_statement() + .unwrap(); + let show_response = + try_respond_show_statements(&client, &statement, &session_context).await; + + assert!(show_response.is_some()); + assert!(show_response.unwrap().is_ok()); + } + #[tokio::test] async fn test_statement_timeout_disable() { let session_context = SessionContext::new(); From a62f3b5611318057d896f3d454da1e7fbe7bb9a2 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 29 Oct 2025 04:51:23 +0800 Subject: [PATCH 2/2] fix: resolve lint issue --- datafusion-postgres/src/hooks/set_show.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/datafusion-postgres/src/hooks/set_show.rs b/datafusion-postgres/src/hooks/set_show.rs index 0568f9c..c1827bc 100644 --- a/datafusion-postgres/src/hooks/set_show.rs +++ b/datafusion-postgres/src/hooks/set_show.rs @@ -170,16 +170,16 @@ where client::set_statement_timeout(client, timeout); return Some(Ok(Response::Execution(Tag::new("SET")))); - } else if matches!(var.as_str(), "datestyle" | "bytea_output" | "intervalstyle") { - if values.len() > 0 { - // postgres configuration variables - let value = values[0].clone(); - if let Expr::Value(value) = value { - client - .metadata_mut() - .insert(var, value.into_string().unwrap_or_else(|| "".to_string())); - return Some(Ok(Response::Execution(Tag::new("SET")))); - } + } else if matches!(var.as_str(), "datestyle" | "bytea_output" | "intervalstyle") + && !values.is_empty() + { + // postgres configuration variables + let value = values[0].clone(); + if let Expr::Value(value) = value { + client + .metadata_mut() + .insert(var, value.into_string().unwrap_or_else(|| "".to_string())); + return Some(Ok(Response::Execution(Tag::new("SET")))); } } } @@ -199,7 +199,7 @@ where if let Err(e) = execute_set_statement(session_context, statement.clone()).await { warn!( "SET statement {} is not supported by datafusion, error {e}, statement ignored", - statement.to_string() + statement ); }