diff --git a/datafusion-postgres/src/hooks/set_show.rs b/datafusion-postgres/src/hooks/set_show.rs index 92fb3f2..c1827bc 100644 --- a/datafusion-postgres/src/hooks/set_show.rs +++ b/datafusion-postgres/src/hooks/set_show.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::{ParamValues, ToDFSchema}; +use datafusion::error::DataFusionError; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::SessionContext; -use datafusion::sql::sqlparser::ast::{Set, Statement}; +use datafusion::sql::sqlparser::ast::{Expr, Set, Statement}; use log::{info, warn}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::ClientInfo; @@ -134,39 +135,53 @@ where hivevar: false, variable, values, - } if &variable.to_string() == "statement_timeout" => { - let value = values[0].to_string(); - let timeout_str = value.trim_matches('"').trim_matches('\''); - - let timeout = if timeout_str == "0" || timeout_str.is_empty() { - None - } else { - // Parse timeout value (supports ms, s, min formats) - let timeout_ms = if timeout_str.ends_with("ms") { - timeout_str.trim_end_matches("ms").parse::() - } else if timeout_str.ends_with("s") { - timeout_str - .trim_end_matches("s") - .parse::() - .map(|s| s * 1000) - } else if timeout_str.ends_with("min") { - timeout_str - .trim_end_matches("min") - .parse::() - .map(|m| m * 60 * 1000) + } => { + let var = variable.to_string().to_lowercase(); + if var == "statement_timeout" { + let value = values[0].to_string(); + let timeout_str = value.trim_matches('"').trim_matches('\''); + + let timeout = if timeout_str == "0" || timeout_str.is_empty() { + None } else { - // Default to milliseconds - timeout_str.parse::() + // Parse timeout value (supports ms, s, min formats) + let timeout_ms = if timeout_str.ends_with("ms") { + timeout_str.trim_end_matches("ms").parse::() + } else if timeout_str.ends_with("s") { + timeout_str + .trim_end_matches("s") + .parse::() + .map(|s| s * 1000) + } else if timeout_str.ends_with("min") { + timeout_str + .trim_end_matches("min") + .parse::() + .map(|m| m * 60 * 1000) + } else { + // Default to milliseconds + timeout_str.parse::() + }; + + match timeout_ms { + Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)), + _ => None, + } }; - match timeout_ms { - Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)), - _ => None, + client::set_statement_timeout(client, timeout); + return Some(Ok(Response::Execution(Tag::new("SET")))); + } else if matches!(var.as_str(), "datestyle" | "bytea_output" | "intervalstyle") + && !values.is_empty() + { + // postgres configuration variables + let value = values[0].clone(); + if let Expr::Value(value) = value { + client + .metadata_mut() + .insert(var, value.into_string().unwrap_or_else(|| "".to_string())); + return Some(Ok(Response::Execution(Tag::new("SET")))); } - }; - - client::set_statement_timeout(client, timeout); - Some(Ok(Response::Execution(Tag::new("SET")))) + } } Set::SetTimeZone { local: false, @@ -175,19 +190,39 @@ where let tz = value.to_string(); let tz = tz.trim_matches('"').trim_matches('\''); client::set_timezone(client, Some(tz)); - Some(Ok(Response::Execution(Tag::new("SET")))) + return Some(Ok(Response::Execution(Tag::new("SET")))); } - _ => { - // pass SET query to datafusion - let query = statement.to_string(); - if let Err(e) = session_context.sql(&query).await { - warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored"); - } + _ => {} + } - // Always return SET success - Some(Ok(Response::Execution(Tag::new("SET")))) - } + // fallback to datafusion and ignore all errors + if let Err(e) = execute_set_statement(session_context, statement.clone()).await { + warn!( + "SET statement {} is not supported by datafusion, error {e}, statement ignored", + statement + ); } + + // Always return SET success + Some(Ok(Response::Execution(Tag::new("SET")))) +} + +async fn execute_set_statement( + session_context: &SessionContext, + statement: Statement, +) -> Result<(), DataFusionError> { + let state = session_context.state(); + let logical_plan = state + .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new( + statement, + ))) + .await + .and_then(|logical_plan| state.optimize(&logical_plan))?; + + session_context + .execute_logical_plan(logical_plan) + .await + .map(|_| ()) } async fn try_respond_show_statements( @@ -204,10 +239,11 @@ where let variables = variable .iter() - .map(|v| &v.value as &str) + .map(|v| v.value.to_lowercase()) .collect::>(); + let variables_ref = variables.iter().map(|s| s.as_str()).collect::>(); - match &variables as &[&str] { + match variables_ref.as_slice() { ["time", "zone"] => { let timezone = client::get_timezone(client).unwrap_or("UTC"); Some(mock_show_response("TimeZone", timezone).map(Response::Query)) @@ -238,6 +274,14 @@ where ["transaction", "isolation", "level"] => { Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query)) } + ["bytea_output"] | ["datestyle"] | ["intervalstyle"] => { + let val = client + .metadata() + .get(&variables[0]) + .map(|v| v.as_str()) + .unwrap_or(""); + Some(mock_show_response(&variables[0], val).map(Response::Query)) + } _ => { info!("Unsupported show statement: {}", statement); Some(mock_show_response("unsupported_show_statement", "").map(Response::Query)) @@ -288,6 +332,74 @@ mod tests { assert!(show_response.unwrap().is_ok()); } + #[tokio::test] + async fn test_bytea_output_set_and_show() { + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + + // Test setting timeout to 5000ms + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("set bytea_output = 'hex'") + .unwrap() + .parse_statement() + .unwrap(); + let set_response = + try_respond_set_statements(&mut client, &statement, &session_context).await; + + assert!(set_response.is_some()); + assert!(set_response.unwrap().is_ok()); + + // Verify the timeout was set in client metadata + let bytea_output = client.metadata().get("bytea_output").unwrap(); + assert_eq!(bytea_output, "hex"); + + // Test SHOW statement_timeout + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("show bytea_output") + .unwrap() + .parse_statement() + .unwrap(); + let show_response = + try_respond_show_statements(&client, &statement, &session_context).await; + + assert!(show_response.is_some()); + assert!(show_response.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_date_style_set_and_show() { + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + + // Test setting timeout to 5000ms + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("set dateStyle = 'ISO, DMY'") + .unwrap() + .parse_statement() + .unwrap(); + let set_response = + try_respond_set_statements(&mut client, &statement, &session_context).await; + + assert!(set_response.is_some()); + assert!(set_response.unwrap().is_ok()); + + // Verify the timeout was set in client metadata + let bytea_output = client.metadata().get("datestyle").unwrap(); + assert_eq!(bytea_output, "ISO, DMY"); + + // Test SHOW statement_timeout + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("show dateStyle") + .unwrap() + .parse_statement() + .unwrap(); + let show_response = + try_respond_show_statements(&client, &statement, &session_context).await; + + assert!(show_response.is_some()); + assert!(show_response.unwrap().is_ok()); + } + #[tokio::test] async fn test_statement_timeout_disable() { let session_context = SessionContext::new();