@@ -2,9 +2,8 @@ use std::collections::HashMap;
22use std:: sync:: Arc ;
33
44use async_trait:: async_trait;
5- use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
6- use datafusion:: common:: { ParamValues , ToDFSchema } ;
7- use datafusion:: error:: DataFusionError ;
5+ use datafusion:: arrow:: datatypes:: DataType ;
6+ use datafusion:: common:: ParamValues ;
87use datafusion:: logical_expr:: LogicalPlan ;
98use datafusion:: prelude:: * ;
109use datafusion:: sql:: parser:: Statement ;
@@ -21,12 +20,12 @@ use pgwire::api::stmt::QueryParser;
2120use pgwire:: api:: stmt:: StoredStatement ;
2221use pgwire:: api:: { ClientInfo , ErrorHandler , PgWireServerHandlers , Type } ;
2322use pgwire:: error:: { PgWireError , PgWireResult } ;
24- use pgwire:: messages:: response:: TransactionStatus ;
2523use pgwire:: types:: format:: FormatOptions ;
2624
2725use crate :: auth:: AuthManager ;
2826use crate :: client;
2927use crate :: hooks:: set_show:: SetShowHook ;
28+ use crate :: hooks:: transactions:: TransactionStatementHook ;
3029use crate :: hooks:: QueryHook ;
3130use arrow_pg:: datatypes:: df;
3231use arrow_pg:: datatypes:: { arrow_schema_to_pg_fields, into_pg_type} ;
@@ -107,7 +106,8 @@ impl DfSessionService {
107106 session_context : Arc < SessionContext > ,
108107 auth_manager : Arc < AuthManager > ,
109108 ) -> DfSessionService {
110- let hooks: Vec < Arc < dyn QueryHook > > = vec ! [ Arc :: new( SetShowHook ) ] ;
109+ let hooks: Vec < Arc < dyn QueryHook > > =
110+ vec ! [ Arc :: new( SetShowHook ) , Arc :: new( TransactionStatementHook ) ] ;
111111 Self :: new_with_hooks ( session_context, auth_manager, hooks)
112112 }
113113
@@ -203,57 +203,6 @@ impl DfSessionService {
203203 // If we can't determine the table, default to All
204204 ResourceType :: All
205205 }
206-
207- async fn try_respond_transaction_statements < C > (
208- & self ,
209- client : & C ,
210- query_lower : & str ,
211- ) -> PgWireResult < Option < Response > >
212- where
213- C : ClientInfo ,
214- {
215- // Transaction handling based on pgwire example:
216- // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
217- match query_lower. trim ( ) {
218- "begin" | "begin transaction" | "begin work" | "start transaction" => {
219- match client. transaction_status ( ) {
220- TransactionStatus :: Idle => {
221- Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) )
222- }
223- TransactionStatus :: Transaction => {
224- // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
225- // This matches PostgreSQL's handling of nested transaction blocks
226- log:: warn!( "BEGIN command ignored: already in transaction block" ) ;
227- Ok ( Some ( Response :: Execution ( Tag :: new ( "BEGIN" ) ) ) )
228- }
229- TransactionStatus :: Error => {
230- // Can't start new transaction from failed state
231- Err ( PgWireError :: UserError ( Box :: new (
232- pgwire:: error:: ErrorInfo :: new (
233- "ERROR" . to_string ( ) ,
234- "25P01" . to_string ( ) ,
235- "current transaction is aborted, commands ignored until end of transaction block" . to_string ( ) ,
236- ) ,
237- ) ) )
238- }
239- }
240- }
241- "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
242- match client. transaction_status ( ) {
243- TransactionStatus :: Idle | TransactionStatus :: Transaction => {
244- Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) )
245- }
246- TransactionStatus :: Error => {
247- Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
248- }
249- }
250- }
251- "rollback" | "rollback transaction" | "rollback work" | "abort" => {
252- Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
253- }
254- _ => Ok ( None ) ,
255- }
256- }
257206}
258207
259208#[ async_trait]
@@ -264,15 +213,6 @@ impl SimpleQueryHandler for DfSessionService {
264213 {
265214 log:: debug!( "Received query: {query}" ) ; // Log the query for debugging
266215
267- // Check for transaction commands early to avoid SQL parsing issues with ABORT
268- let query_lower = query. to_lowercase ( ) . trim ( ) . to_string ( ) ;
269- if let Some ( resp) = self
270- . try_respond_transaction_statements ( client, & query_lower)
271- . await ?
272- {
273- return Ok ( vec ! [ resp] ) ;
274- }
275-
276216 let statements = self
277217 . parser
278218 . sql_parser
@@ -314,18 +254,6 @@ impl SimpleQueryHandler for DfSessionService {
314254 }
315255 }
316256
317- // Check if we're in a failed transaction and block non-transaction
318- // commands
319- if client. transaction_status ( ) == TransactionStatus :: Error {
320- return Err ( PgWireError :: UserError ( Box :: new (
321- pgwire:: error:: ErrorInfo :: new (
322- "ERROR" . to_string ( ) ,
323- "25P01" . to_string ( ) ,
324- "current transaction is aborted, commands ignored until end of transaction block" . to_string ( ) ,
325- ) ,
326- ) ) ) ;
327- }
328-
329257 let df_result = {
330258 let timeout = client:: get_statement_timeout ( client) ;
331259 if let Some ( timeout_duration) = timeout {
@@ -480,25 +408,6 @@ impl ExtendedQueryHandler for DfSessionService {
480408 . await ?;
481409 }
482410
483- if let Some ( resp) = self
484- . try_respond_transaction_statements ( client, & query)
485- . await ?
486- {
487- return Ok ( resp) ;
488- }
489-
490- // Check if we're in a failed transaction and block non-transaction
491- // commands
492- if client. transaction_status ( ) == TransactionStatus :: Error {
493- return Err ( PgWireError :: UserError ( Box :: new (
494- pgwire:: error:: ErrorInfo :: new (
495- "ERROR" . to_string ( ) ,
496- "25P01" . to_string ( ) ,
497- "current transaction is aborted, commands ignored until end of transaction block" . to_string ( ) ,
498- ) ,
499- ) ) ) ;
500- }
501-
502411 if let ( _, Some ( ( _, plan) ) ) = & portal. statement . statement {
503412 let param_types = plan
504413 . get_parameter_types ( )
@@ -594,55 +503,6 @@ pub struct Parser {
594503 query_hooks : Vec < Arc < dyn QueryHook > > ,
595504}
596505
597- impl Parser {
598- fn try_shortcut_parse_plan ( & self , sql : & str ) -> Result < Option < LogicalPlan > , DataFusionError > {
599- // Check for transaction commands that shouldn't be parsed by DataFusion
600- let sql_lower = sql. to_lowercase ( ) ;
601- let sql_trimmed = sql_lower. trim ( ) ;
602-
603- if matches ! (
604- sql_trimmed,
605- "" | "begin"
606- | "begin transaction"
607- | "begin work"
608- | "start transaction"
609- | "commit"
610- | "commit transaction"
611- | "commit work"
612- | "end"
613- | "end transaction"
614- | "rollback"
615- | "rollback transaction"
616- | "rollback work"
617- | "abort"
618- ) {
619- // Return a dummy plan for transaction commands - they'll be handled by transaction handler
620- let dummy_schema = datafusion:: common:: DFSchema :: empty ( ) ;
621- return Ok ( Some ( LogicalPlan :: EmptyRelation (
622- datafusion:: logical_expr:: EmptyRelation {
623- produce_one_row : false ,
624- schema : Arc :: new ( dummy_schema) ,
625- } ,
626- ) ) ) ;
627- }
628-
629- // show statement may not be supported by datafusion
630- if sql_trimmed. starts_with ( "show" ) {
631- let show_schema =
632- Arc :: new ( Schema :: new ( vec ! [ Field :: new( "show" , DataType :: Utf8 , false ) ] ) ) ;
633- let df_schema = show_schema. to_dfschema ( ) ?;
634- return Ok ( Some ( LogicalPlan :: EmptyRelation (
635- datafusion:: logical_expr:: EmptyRelation {
636- produce_one_row : true ,
637- schema : Arc :: new ( df_schema) ,
638- } ,
639- ) ) ) ;
640- }
641-
642- Ok ( None )
643- }
644- }
645-
646506#[ async_trait]
647507impl QueryParser for Parser {
648508 type Statement = ( String , Option < ( sqlparser:: ast:: Statement , LogicalPlan ) > ) ;
@@ -667,15 +527,6 @@ impl QueryParser for Parser {
667527 }
668528
669529 let statement = statements. remove ( 0 ) ;
670-
671- // Check for transaction commands that shouldn't be parsed by DataFusion
672- if let Some ( plan) = self
673- . try_shortcut_parse_plan ( sql)
674- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
675- {
676- return Ok ( ( sql. to_string ( ) , Some ( ( statement, plan) ) ) ) ;
677- }
678-
679530 let query = statement. to_string ( ) ;
680531
681532 let context = & self . session_context ;
0 commit comments