From 11d7bcc4f5e95de6a30911d000262d0714cfe2a6 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 1 Dec 2025 14:56:54 +0800 Subject: [PATCH 1/2] refactor: extract transaction statements to hook --- datafusion-postgres/src/handlers.rs | 113 +-------------- datafusion-postgres/src/hooks/mod.rs | 1 + datafusion-postgres/src/hooks/transactions.rs | 130 ++++++++++++++++++ 3 files changed, 135 insertions(+), 109 deletions(-) create mode 100644 datafusion-postgres/src/hooks/transactions.rs diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index c12b801..e94a215 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -21,12 +21,12 @@ use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; -use pgwire::messages::response::TransactionStatus; use pgwire::types::format::FormatOptions; use crate::auth::AuthManager; use crate::client; use crate::hooks::set_show::SetShowHook; +use crate::hooks::transactions::TransactionStatementHook; use crate::hooks::QueryHook; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; @@ -107,7 +107,8 @@ impl DfSessionService { session_context: Arc, auth_manager: Arc, ) -> DfSessionService { - let hooks: Vec> = vec![Arc::new(SetShowHook)]; + let hooks: Vec> = + vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)]; Self::new_with_hooks(session_context, auth_manager, hooks) } @@ -203,57 +204,6 @@ impl DfSessionService { // If we can't determine the table, default to All ResourceType::All } - - async fn try_respond_transaction_statements( - &self, - client: &C, - query_lower: &str, - ) -> PgWireResult> - where - C: ClientInfo, - { - // Transaction handling based on pgwire example: - // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57 - match query_lower.trim() { - "begin" | "begin transaction" | "begin work" | "start transaction" => { - match client.transaction_status() { - TransactionStatus::Idle => { - Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))) - } - TransactionStatus::Transaction => { - // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS - // This matches PostgreSQL's handling of nested transaction blocks - log::warn!("BEGIN command ignored: already in transaction block"); - Ok(Some(Response::Execution(Tag::new("BEGIN")))) - } - TransactionStatus::Error => { - // Can't start new transaction from failed state - Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "25P01".to_string(), - "current transaction is aborted, commands ignored until end of transaction block".to_string(), - ), - ))) - } - } - } - "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => { - match client.transaction_status() { - TransactionStatus::Idle | TransactionStatus::Transaction => { - Ok(Some(Response::TransactionEnd(Tag::new("COMMIT")))) - } - TransactionStatus::Error => { - Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))) - } - } - } - "rollback" | "rollback transaction" | "rollback work" | "abort" => { - Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))) - } - _ => Ok(None), - } - } } #[async_trait] @@ -264,15 +214,6 @@ impl SimpleQueryHandler for DfSessionService { { log::debug!("Received query: {query}"); // Log the query for debugging - // Check for transaction commands early to avoid SQL parsing issues with ABORT - let query_lower = query.to_lowercase().trim().to_string(); - if let Some(resp) = self - .try_respond_transaction_statements(client, &query_lower) - .await? - { - return Ok(vec![resp]); - } - let statements = self .parser .sql_parser @@ -314,18 +255,6 @@ impl SimpleQueryHandler for DfSessionService { } } - // Check if we're in a failed transaction and block non-transaction - // commands - if client.transaction_status() == TransactionStatus::Error { - return Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "25P01".to_string(), - "current transaction is aborted, commands ignored until end of transaction block".to_string(), - ), - ))); - } - let df_result = { let timeout = client::get_statement_timeout(client); if let Some(timeout_duration) = timeout { @@ -480,25 +409,6 @@ impl ExtendedQueryHandler for DfSessionService { .await?; } - if let Some(resp) = self - .try_respond_transaction_statements(client, &query) - .await? - { - return Ok(resp); - } - - // Check if we're in a failed transaction and block non-transaction - // commands - if client.transaction_status() == TransactionStatus::Error { - return Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "25P01".to_string(), - "current transaction is aborted, commands ignored until end of transaction block".to_string(), - ), - ))); - } - if let (_, Some((_, plan))) = &portal.statement.statement { let param_types = plan .get_parameter_types() @@ -600,22 +510,7 @@ impl Parser { let sql_lower = sql.to_lowercase(); let sql_trimmed = sql_lower.trim(); - if matches!( - sql_trimmed, - "" | "begin" - | "begin transaction" - | "begin work" - | "start transaction" - | "commit" - | "commit transaction" - | "commit work" - | "end" - | "end transaction" - | "rollback" - | "rollback transaction" - | "rollback work" - | "abort" - ) { + if sql_trimmed.is_empty() { // Return a dummy plan for transaction commands - they'll be handled by transaction handler let dummy_schema = datafusion::common::DFSchema::empty(); return Ok(Some(LogicalPlan::EmptyRelation( diff --git a/datafusion-postgres/src/hooks/mod.rs b/datafusion-postgres/src/hooks/mod.rs index 6df8d6e..d6cd7bb 100644 --- a/datafusion-postgres/src/hooks/mod.rs +++ b/datafusion-postgres/src/hooks/mod.rs @@ -1,4 +1,5 @@ pub mod set_show; +pub mod transactions; use async_trait::async_trait; diff --git a/datafusion-postgres/src/hooks/transactions.rs b/datafusion-postgres/src/hooks/transactions.rs new file mode 100644 index 0000000..a7bf2fe --- /dev/null +++ b/datafusion-postgres/src/hooks/transactions.rs @@ -0,0 +1,130 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::common::ParamValues; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::SessionContext; +use datafusion::sql::sqlparser::ast::Statement; +use pgwire::api::results::{Response, Tag}; +use pgwire::api::ClientInfo; +use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::messages::response::TransactionStatus; + +use crate::QueryHook; + +/// Hook for processing transaction related statements +/// +/// Note that this hook doesn't create actual transactions. It just responds +/// with reasonable return values. +#[derive(Debug)] +pub struct TransactionStatementHook; + +#[async_trait] +impl QueryHook for TransactionStatementHook { + /// called in simple query handler to return response directly + async fn handle_simple_query( + &self, + statement: &Statement, + _session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option> { + let resp = try_respond_transaction_statements(client, statement) + .await + .transpose(); + + if resp.is_some() { + return resp; + } + + // Check if we're in a failed transaction and block non-transaction + // commands + if client.transaction_status() == TransactionStatus::Error { + return Some(Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "25P01".to_string(), + "current transaction is aborted, commands ignored until end of transaction block".to_string(), + ), + )))); + } + + None + } + + async fn handle_extended_parse_query( + &self, + stmt: &Statement, + _session_context: &SessionContext, + _client: &(dyn ClientInfo + Send + Sync), + ) -> Option> { + // We don't generate logical plan for these statements + if matches!( + stmt, + Statement::StartTransaction { .. } + | Statement::Commit { .. } + | Statement::Rollback { .. } + ) { + // Return a dummy plan for transaction commands - they'll be handled by transaction handler + let dummy_schema = datafusion::common::DFSchema::empty(); + return Some(Ok(LogicalPlan::EmptyRelation( + datafusion::logical_expr::EmptyRelation { + produce_one_row: false, + schema: Arc::new(dummy_schema), + }, + ))); + } + None + } + + async fn handle_extended_query( + &self, + statement: &Statement, + _logical_plan: &LogicalPlan, + _params: &ParamValues, + session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option> { + self.handle_simple_query(statement, session_context, client) + .await + } +} + +async fn try_respond_transaction_statements( + client: &C, + stmt: &Statement, +) -> PgWireResult> +where + C: ClientInfo + Send + Sync + ?Sized, +{ + match stmt { + Statement::StartTransaction { .. } => { + match client.transaction_status() { + TransactionStatus::Idle => Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))), + TransactionStatus::Transaction => { + // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS + // This matches PostgreSQL's handling of nested transaction blocks + log::warn!("BEGIN command ignored: already in transaction block"); + Ok(Some(Response::Execution(Tag::new("BEGIN")))) + } + TransactionStatus::Error => { + // Can't start new transaction from failed state + Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "25P01".to_string(), + "current transaction is aborted, commands ignored until end of transaction block".to_string(), + ), + ))) + } + } + } + Statement::Commit { .. } => match client.transaction_status() { + TransactionStatus::Idle | TransactionStatus::Transaction => { + Ok(Some(Response::TransactionEnd(Tag::new("COMMIT")))) + } + TransactionStatus::Error => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))), + }, + Statement::Rollback { .. } => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))), + _ => Ok(None), + } +} From 59674a764f91650e7c35758b84469904b85a1250 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 1 Dec 2025 15:41:44 +0800 Subject: [PATCH 2/2] refactor: remove special cases for parser --- datafusion-pg-catalog/src/sql/parser.rs | 13 ++++++ datafusion-postgres/src/handlers.rs | 48 +------------------- datafusion-postgres/src/hooks/set_show.rs | 55 +++++++++++------------ 3 files changed, 42 insertions(+), 74 deletions(-) diff --git a/datafusion-pg-catalog/src/sql/parser.rs b/datafusion-pg-catalog/src/sql/parser.rs index 7c09f3d..0420c9c 100644 --- a/datafusion-pg-catalog/src/sql/parser.rs +++ b/datafusion-pg-catalog/src/sql/parser.rs @@ -336,4 +336,17 @@ mod tests { let match_result = parser.parse_and_replace(sql).expect("failed to parse sql"); assert!(matches!(match_result, MatchResult::Matches(_))); } + + #[test] + fn test_empty_query() { + let parser = PostgresCompatibilityParser::new(); + let result = parser.parse(" ").expect("failed to parse sql"); + assert!(result.is_empty()); + + let result = parser.parse("").expect("failed to parse sql"); + assert!(result.is_empty()); + + let result = parser.parse(";").expect("failed to parse sql"); + assert!(result.is_empty()); + } } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index e94a215..0f3b68b 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -2,9 +2,8 @@ use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; -use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::common::{ParamValues, ToDFSchema}; -use datafusion::error::DataFusionError; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::ParamValues; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; use datafusion::sql::parser::Statement; @@ -504,40 +503,6 @@ pub struct Parser { query_hooks: Vec>, } -impl Parser { - fn try_shortcut_parse_plan(&self, sql: &str) -> Result, DataFusionError> { - // Check for transaction commands that shouldn't be parsed by DataFusion - let sql_lower = sql.to_lowercase(); - let sql_trimmed = sql_lower.trim(); - - if sql_trimmed.is_empty() { - // Return a dummy plan for transaction commands - they'll be handled by transaction handler - let dummy_schema = datafusion::common::DFSchema::empty(); - return Ok(Some(LogicalPlan::EmptyRelation( - datafusion::logical_expr::EmptyRelation { - produce_one_row: false, - schema: Arc::new(dummy_schema), - }, - ))); - } - - // show statement may not be supported by datafusion - if sql_trimmed.starts_with("show") { - let show_schema = - Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); - let df_schema = show_schema.to_dfschema()?; - return Ok(Some(LogicalPlan::EmptyRelation( - datafusion::logical_expr::EmptyRelation { - produce_one_row: true, - schema: Arc::new(df_schema), - }, - ))); - } - - Ok(None) - } -} - #[async_trait] impl QueryParser for Parser { type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>); @@ -562,15 +527,6 @@ impl QueryParser for Parser { } let statement = statements.remove(0); - - // Check for transaction commands that shouldn't be parsed by DataFusion - if let Some(plan) = self - .try_shortcut_parse_plan(sql) - .map_err(|e| PgWireError::ApiError(Box::new(e)))? - { - return Ok((sql.to_string(), Some((statement, plan)))); - } - let query = statement.to_string(); let context = &self.session_context; diff --git a/datafusion-postgres/src/hooks/set_show.rs b/datafusion-postgres/src/hooks/set_show.rs index e1f7747..cb5332f 100644 --- a/datafusion-postgres/src/hooks/set_show.rs +++ b/datafusion-postgres/src/hooks/set_show.rs @@ -47,36 +47,35 @@ impl QueryHook for SetShowHook { _session_context: &SessionContext, _client: &(dyn ClientInfo + Send + Sync), ) -> Option> { - let sql_lower = stmt.to_string().to_lowercase(); - let sql_trimmed = sql_lower.trim(); - - if sql_trimmed.starts_with("show") { - let show_schema = - Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); - let result = show_schema - .to_dfschema() - .map(|df_schema| { - LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation { - produce_one_row: true, - schema: Arc::new(df_schema), + match stmt { + Statement::Set { .. } => { + let show_schema = Arc::new(Schema::new(Vec::::new())); + let result = show_schema + .to_dfschema() + .map(|df_schema| { + LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(df_schema), + }) }) - }) - .map_err(|e| PgWireError::ApiError(Box::new(e))); - Some(result) - } else if sql_trimmed.starts_with("set") { - let show_schema = Arc::new(Schema::new(Vec::::new())); - let result = show_schema - .to_dfschema() - .map(|df_schema| { - LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation { - produce_one_row: true, - schema: Arc::new(df_schema), + .map_err(|e| PgWireError::ApiError(Box::new(e))); + Some(result) + } + Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => { + let show_schema = + Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); + let result = show_schema + .to_dfschema() + .map(|df_schema| { + LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(df_schema), + }) }) - }) - .map_err(|e| PgWireError::ApiError(Box::new(e))); - Some(result) - } else { - None + .map_err(|e| PgWireError::ApiError(Box::new(e))); + Some(result) + } + _ => None, } }