Skip to content

Commit d97d730

Browse files
authored
refactor: extract transaction statement processing as hook (#249)
* refactor: extract transaction statements to hook * refactor: remove special cases for parser
1 parent d7726a9 commit d97d730

File tree

5 files changed

+176
-182
lines changed

5 files changed

+176
-182
lines changed

datafusion-pg-catalog/src/sql/parser.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,4 +336,17 @@ mod tests {
336336
let match_result = parser.parse_and_replace(sql).expect("failed to parse sql");
337337
assert!(matches!(match_result, MatchResult::Matches(_)));
338338
}
339+
340+
#[test]
341+
fn test_empty_query() {
342+
let parser = PostgresCompatibilityParser::new();
343+
let result = parser.parse(" ").expect("failed to parse sql");
344+
assert!(result.is_empty());
345+
346+
let result = parser.parse("").expect("failed to parse sql");
347+
assert!(result.is_empty());
348+
349+
let result = parser.parse(";").expect("failed to parse sql");
350+
assert!(result.is_empty());
351+
}
339352
}

datafusion-postgres/src/handlers.rs

Lines changed: 5 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use async_trait::async_trait;
5-
use datafusion::arrow::datatypes::{DataType, Field, Schema};
6-
use datafusion::common::{ParamValues, ToDFSchema};
7-
use datafusion::error::DataFusionError;
5+
use datafusion::arrow::datatypes::DataType;
6+
use datafusion::common::ParamValues;
87
use datafusion::logical_expr::LogicalPlan;
98
use datafusion::prelude::*;
109
use datafusion::sql::parser::Statement;
@@ -21,12 +20,12 @@ use pgwire::api::stmt::QueryParser;
2120
use pgwire::api::stmt::StoredStatement;
2221
use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
2322
use pgwire::error::{PgWireError, PgWireResult};
24-
use pgwire::messages::response::TransactionStatus;
2523
use pgwire::types::format::FormatOptions;
2624

2725
use crate::auth::AuthManager;
2826
use crate::client;
2927
use crate::hooks::set_show::SetShowHook;
28+
use crate::hooks::transactions::TransactionStatementHook;
3029
use crate::hooks::QueryHook;
3130
use arrow_pg::datatypes::df;
3231
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
@@ -107,7 +106,8 @@ impl DfSessionService {
107106
session_context: Arc<SessionContext>,
108107
auth_manager: Arc<AuthManager>,
109108
) -> DfSessionService {
110-
let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(SetShowHook)];
109+
let hooks: Vec<Arc<dyn QueryHook>> =
110+
vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
111111
Self::new_with_hooks(session_context, auth_manager, hooks)
112112
}
113113

@@ -203,57 +203,6 @@ impl DfSessionService {
203203
// If we can't determine the table, default to All
204204
ResourceType::All
205205
}
206-
207-
async fn try_respond_transaction_statements<C>(
208-
&self,
209-
client: &C,
210-
query_lower: &str,
211-
) -> PgWireResult<Option<Response>>
212-
where
213-
C: ClientInfo,
214-
{
215-
// Transaction handling based on pgwire example:
216-
// https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
217-
match query_lower.trim() {
218-
"begin" | "begin transaction" | "begin work" | "start transaction" => {
219-
match client.transaction_status() {
220-
TransactionStatus::Idle => {
221-
Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
222-
}
223-
TransactionStatus::Transaction => {
224-
// PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
225-
// This matches PostgreSQL's handling of nested transaction blocks
226-
log::warn!("BEGIN command ignored: already in transaction block");
227-
Ok(Some(Response::Execution(Tag::new("BEGIN"))))
228-
}
229-
TransactionStatus::Error => {
230-
// Can't start new transaction from failed state
231-
Err(PgWireError::UserError(Box::new(
232-
pgwire::error::ErrorInfo::new(
233-
"ERROR".to_string(),
234-
"25P01".to_string(),
235-
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
236-
),
237-
)))
238-
}
239-
}
240-
}
241-
"commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
242-
match client.transaction_status() {
243-
TransactionStatus::Idle | TransactionStatus::Transaction => {
244-
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
245-
}
246-
TransactionStatus::Error => {
247-
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
248-
}
249-
}
250-
}
251-
"rollback" | "rollback transaction" | "rollback work" | "abort" => {
252-
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
253-
}
254-
_ => Ok(None),
255-
}
256-
}
257206
}
258207

259208
#[async_trait]
@@ -264,15 +213,6 @@ impl SimpleQueryHandler for DfSessionService {
264213
{
265214
log::debug!("Received query: {query}"); // Log the query for debugging
266215

267-
// Check for transaction commands early to avoid SQL parsing issues with ABORT
268-
let query_lower = query.to_lowercase().trim().to_string();
269-
if let Some(resp) = self
270-
.try_respond_transaction_statements(client, &query_lower)
271-
.await?
272-
{
273-
return Ok(vec![resp]);
274-
}
275-
276216
let statements = self
277217
.parser
278218
.sql_parser
@@ -314,18 +254,6 @@ impl SimpleQueryHandler for DfSessionService {
314254
}
315255
}
316256

317-
// Check if we're in a failed transaction and block non-transaction
318-
// commands
319-
if client.transaction_status() == TransactionStatus::Error {
320-
return Err(PgWireError::UserError(Box::new(
321-
pgwire::error::ErrorInfo::new(
322-
"ERROR".to_string(),
323-
"25P01".to_string(),
324-
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
325-
),
326-
)));
327-
}
328-
329257
let df_result = {
330258
let timeout = client::get_statement_timeout(client);
331259
if let Some(timeout_duration) = timeout {
@@ -480,25 +408,6 @@ impl ExtendedQueryHandler for DfSessionService {
480408
.await?;
481409
}
482410

483-
if let Some(resp) = self
484-
.try_respond_transaction_statements(client, &query)
485-
.await?
486-
{
487-
return Ok(resp);
488-
}
489-
490-
// Check if we're in a failed transaction and block non-transaction
491-
// commands
492-
if client.transaction_status() == TransactionStatus::Error {
493-
return Err(PgWireError::UserError(Box::new(
494-
pgwire::error::ErrorInfo::new(
495-
"ERROR".to_string(),
496-
"25P01".to_string(),
497-
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
498-
),
499-
)));
500-
}
501-
502411
if let (_, Some((_, plan))) = &portal.statement.statement {
503412
let param_types = plan
504413
.get_parameter_types()
@@ -594,55 +503,6 @@ pub struct Parser {
594503
query_hooks: Vec<Arc<dyn QueryHook>>,
595504
}
596505

597-
impl Parser {
598-
fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
599-
// Check for transaction commands that shouldn't be parsed by DataFusion
600-
let sql_lower = sql.to_lowercase();
601-
let sql_trimmed = sql_lower.trim();
602-
603-
if matches!(
604-
sql_trimmed,
605-
"" | "begin"
606-
| "begin transaction"
607-
| "begin work"
608-
| "start transaction"
609-
| "commit"
610-
| "commit transaction"
611-
| "commit work"
612-
| "end"
613-
| "end transaction"
614-
| "rollback"
615-
| "rollback transaction"
616-
| "rollback work"
617-
| "abort"
618-
) {
619-
// Return a dummy plan for transaction commands - they'll be handled by transaction handler
620-
let dummy_schema = datafusion::common::DFSchema::empty();
621-
return Ok(Some(LogicalPlan::EmptyRelation(
622-
datafusion::logical_expr::EmptyRelation {
623-
produce_one_row: false,
624-
schema: Arc::new(dummy_schema),
625-
},
626-
)));
627-
}
628-
629-
// show statement may not be supported by datafusion
630-
if sql_trimmed.starts_with("show") {
631-
let show_schema =
632-
Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
633-
let df_schema = show_schema.to_dfschema()?;
634-
return Ok(Some(LogicalPlan::EmptyRelation(
635-
datafusion::logical_expr::EmptyRelation {
636-
produce_one_row: true,
637-
schema: Arc::new(df_schema),
638-
},
639-
)));
640-
}
641-
642-
Ok(None)
643-
}
644-
}
645-
646506
#[async_trait]
647507
impl QueryParser for Parser {
648508
type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
@@ -667,15 +527,6 @@ impl QueryParser for Parser {
667527
}
668528

669529
let statement = statements.remove(0);
670-
671-
// Check for transaction commands that shouldn't be parsed by DataFusion
672-
if let Some(plan) = self
673-
.try_shortcut_parse_plan(sql)
674-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
675-
{
676-
return Ok((sql.to_string(), Some((statement, plan))));
677-
}
678-
679530
let query = statement.to_string();
680531

681532
let context = &self.session_context;

datafusion-postgres/src/hooks/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod set_show;
2+
pub mod transactions;
23

34
use async_trait::async_trait;
45

datafusion-postgres/src/hooks/set_show.rs

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -47,36 +47,35 @@ impl QueryHook for SetShowHook {
4747
_session_context: &SessionContext,
4848
_client: &(dyn ClientInfo + Send + Sync),
4949
) -> Option<PgWireResult<LogicalPlan>> {
50-
let sql_lower = stmt.to_string().to_lowercase();
51-
let sql_trimmed = sql_lower.trim();
52-
53-
if sql_trimmed.starts_with("show") {
54-
let show_schema =
55-
Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
56-
let result = show_schema
57-
.to_dfschema()
58-
.map(|df_schema| {
59-
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
60-
produce_one_row: true,
61-
schema: Arc::new(df_schema),
50+
match stmt {
51+
Statement::Set { .. } => {
52+
let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
53+
let result = show_schema
54+
.to_dfschema()
55+
.map(|df_schema| {
56+
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
57+
produce_one_row: true,
58+
schema: Arc::new(df_schema),
59+
})
6260
})
63-
})
64-
.map_err(|e| PgWireError::ApiError(Box::new(e)));
65-
Some(result)
66-
} else if sql_trimmed.starts_with("set") {
67-
let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
68-
let result = show_schema
69-
.to_dfschema()
70-
.map(|df_schema| {
71-
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
72-
produce_one_row: true,
73-
schema: Arc::new(df_schema),
61+
.map_err(|e| PgWireError::ApiError(Box::new(e)));
62+
Some(result)
63+
}
64+
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
65+
let show_schema =
66+
Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
67+
let result = show_schema
68+
.to_dfschema()
69+
.map(|df_schema| {
70+
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
71+
produce_one_row: true,
72+
schema: Arc::new(df_schema),
73+
})
7474
})
75-
})
76-
.map_err(|e| PgWireError::ApiError(Box::new(e)));
77-
Some(result)
78-
} else {
79-
None
75+
.map_err(|e| PgWireError::ApiError(Box::new(e)));
76+
Some(result)
77+
}
78+
_ => None,
8079
}
8180
}
8281

0 commit comments

Comments
 (0)