Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions datafusion-pg-catalog/src/sql/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,17 @@ mod tests {
let match_result = parser.parse_and_replace(sql).expect("failed to parse sql");
assert!(matches!(match_result, MatchResult::Matches(_)));
}

#[test]
fn test_empty_query() {
let parser = PostgresCompatibilityParser::new();
let result = parser.parse(" ").expect("failed to parse sql");
assert!(result.is_empty());

let result = parser.parse("").expect("failed to parse sql");
assert!(result.is_empty());

let result = parser.parse(";").expect("failed to parse sql");
assert!(result.is_empty());
}
}
159 changes: 5 additions & 154 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::{ParamValues, ToDFSchema};
use datafusion::error::DataFusionError;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::ParamValues;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
use datafusion::sql::parser::Statement;
Expand All @@ -21,12 +20,12 @@ use pgwire::api::stmt::QueryParser;
use pgwire::api::stmt::StoredStatement;
use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::response::TransactionStatus;
use pgwire::types::format::FormatOptions;

use crate::auth::AuthManager;
use crate::client;
use crate::hooks::set_show::SetShowHook;
use crate::hooks::transactions::TransactionStatementHook;
use crate::hooks::QueryHook;
use arrow_pg::datatypes::df;
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
Expand Down Expand Up @@ -107,7 +106,8 @@ impl DfSessionService {
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
) -> DfSessionService {
let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(SetShowHook)];
let hooks: Vec<Arc<dyn QueryHook>> =
vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
Self::new_with_hooks(session_context, auth_manager, hooks)
}

Expand Down Expand Up @@ -203,57 +203,6 @@ impl DfSessionService {
// If we can't determine the table, default to All
ResourceType::All
}

async fn try_respond_transaction_statements<C>(
&self,
client: &C,
query_lower: &str,
) -> PgWireResult<Option<Response>>
where
C: ClientInfo,
{
// Transaction handling based on pgwire example:
// https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
match query_lower.trim() {
"begin" | "begin transaction" | "begin work" | "start transaction" => {
match client.transaction_status() {
TransactionStatus::Idle => {
Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
}
TransactionStatus::Transaction => {
// PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
// This matches PostgreSQL's handling of nested transaction blocks
log::warn!("BEGIN command ignored: already in transaction block");
Ok(Some(Response::Execution(Tag::new("BEGIN"))))
}
TransactionStatus::Error => {
// Can't start new transaction from failed state
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"25P01".to_string(),
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
),
)))
}
}
}
"commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
match client.transaction_status() {
TransactionStatus::Idle | TransactionStatus::Transaction => {
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
}
TransactionStatus::Error => {
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
}
}
}
"rollback" | "rollback transaction" | "rollback work" | "abort" => {
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
}
_ => Ok(None),
}
}
}

#[async_trait]
Expand All @@ -264,15 +213,6 @@ impl SimpleQueryHandler for DfSessionService {
{
log::debug!("Received query: {query}"); // Log the query for debugging

// Check for transaction commands early to avoid SQL parsing issues with ABORT
let query_lower = query.to_lowercase().trim().to_string();
if let Some(resp) = self
.try_respond_transaction_statements(client, &query_lower)
.await?
{
return Ok(vec![resp]);
}

let statements = self
.parser
.sql_parser
Expand Down Expand Up @@ -314,18 +254,6 @@ impl SimpleQueryHandler for DfSessionService {
}
}

// Check if we're in a failed transaction and block non-transaction
// commands
if client.transaction_status() == TransactionStatus::Error {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"25P01".to_string(),
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
),
)));
}

let df_result = {
let timeout = client::get_statement_timeout(client);
if let Some(timeout_duration) = timeout {
Expand Down Expand Up @@ -480,25 +408,6 @@ impl ExtendedQueryHandler for DfSessionService {
.await?;
}

if let Some(resp) = self
.try_respond_transaction_statements(client, &query)
.await?
{
return Ok(resp);
}

// Check if we're in a failed transaction and block non-transaction
// commands
if client.transaction_status() == TransactionStatus::Error {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"25P01".to_string(),
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
),
)));
}

if let (_, Some((_, plan))) = &portal.statement.statement {
let param_types = plan
.get_parameter_types()
Expand Down Expand Up @@ -594,55 +503,6 @@ pub struct Parser {
query_hooks: Vec<Arc<dyn QueryHook>>,
}

impl Parser {
fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
// Check for transaction commands that shouldn't be parsed by DataFusion
let sql_lower = sql.to_lowercase();
let sql_trimmed = sql_lower.trim();

if matches!(
sql_trimmed,
"" | "begin"
| "begin transaction"
| "begin work"
| "start transaction"
| "commit"
| "commit transaction"
| "commit work"
| "end"
| "end transaction"
| "rollback"
| "rollback transaction"
| "rollback work"
| "abort"
) {
// Return a dummy plan for transaction commands - they'll be handled by transaction handler
let dummy_schema = datafusion::common::DFSchema::empty();
return Ok(Some(LogicalPlan::EmptyRelation(
datafusion::logical_expr::EmptyRelation {
produce_one_row: false,
schema: Arc::new(dummy_schema),
},
)));
}

// show statement may not be supported by datafusion
if sql_trimmed.starts_with("show") {
let show_schema =
Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
let df_schema = show_schema.to_dfschema()?;
return Ok(Some(LogicalPlan::EmptyRelation(
datafusion::logical_expr::EmptyRelation {
produce_one_row: true,
schema: Arc::new(df_schema),
},
)));
}

Ok(None)
}
}

#[async_trait]
impl QueryParser for Parser {
type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
Expand All @@ -667,15 +527,6 @@ impl QueryParser for Parser {
}

let statement = statements.remove(0);

// Check for transaction commands that shouldn't be parsed by DataFusion
if let Some(plan) = self
.try_shortcut_parse_plan(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
return Ok((sql.to_string(), Some((statement, plan))));
}

let query = statement.to_string();

let context = &self.session_context;
Expand Down
1 change: 1 addition & 0 deletions datafusion-postgres/src/hooks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod set_show;
pub mod transactions;

use async_trait::async_trait;

Expand Down
55 changes: 27 additions & 28 deletions datafusion-postgres/src/hooks/set_show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,36 +47,35 @@ impl QueryHook for SetShowHook {
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<LogicalPlan>> {
let sql_lower = stmt.to_string().to_lowercase();
let sql_trimmed = sql_lower.trim();

if sql_trimmed.starts_with("show") {
let show_schema =
Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
let result = show_schema
.to_dfschema()
.map(|df_schema| {
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
produce_one_row: true,
schema: Arc::new(df_schema),
match stmt {
Statement::Set { .. } => {
let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
let result = show_schema
.to_dfschema()
.map(|df_schema| {
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
produce_one_row: true,
schema: Arc::new(df_schema),
})
})
})
.map_err(|e| PgWireError::ApiError(Box::new(e)));
Some(result)
} else if sql_trimmed.starts_with("set") {
let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
let result = show_schema
.to_dfschema()
.map(|df_schema| {
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
produce_one_row: true,
schema: Arc::new(df_schema),
.map_err(|e| PgWireError::ApiError(Box::new(e)));
Some(result)
}
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
let show_schema =
Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
let result = show_schema
.to_dfschema()
.map(|df_schema| {
LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
produce_one_row: true,
schema: Arc::new(df_schema),
})
})
})
.map_err(|e| PgWireError::ApiError(Box::new(e)));
Some(result)
} else {
None
.map_err(|e| PgWireError::ApiError(Box::new(e)));
Some(result)
}
_ => None,
}
}

Expand Down
Loading
Loading