Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,7 @@ impl SimpleQueryHandler for DfSessionService {

let mut results = vec![];
'stmt: for statement in statements {
// TODO: improve statement check by using statement directly
let query = statement.to_string();
let query_lower = query.to_lowercase().trim().to_string();

// Call query hooks with the parsed statement
for hook in &self.query_hooks {
Expand Down Expand Up @@ -179,7 +177,7 @@ impl SimpleQueryHandler for DfSessionService {
}
};

if query_lower.starts_with("insert into") {
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
let resp = map_rows_affected_for_insert(&df).await?;
results.push(resp);
} else {
Expand Down Expand Up @@ -265,13 +263,7 @@ impl ExtendedQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
let query = portal
.statement
.statement
.0
.to_lowercase()
.trim()
.to_string();
let query = &portal.statement.statement.0;
log::debug!("Received execute extended query: {query}"); // Log for debugging

// Check query hooks first
Expand Down Expand Up @@ -302,7 +294,7 @@ impl ExtendedQueryHandler for DfSessionService {
}
}

if let (_, Some((_, plan))) = &portal.statement.statement {
if let (_, Some((statement, plan))) = &portal.statement.statement {
let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Expand Down Expand Up @@ -345,7 +337,7 @@ impl ExtendedQueryHandler for DfSessionService {
}
};

if query.starts_with("insert into") {
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
let resp = map_rows_affected_for_insert(&dataframe).await?;

Ok(resp)
Expand Down
98 changes: 50 additions & 48 deletions datafusion-postgres/src/hooks/permissions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ impl PermissionsHook {
PermissionsHook { auth_manager }
}

/// Check if the current user has permission to execute a query
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
/// Check if the current user has permission to execute a statement
async fn check_statement_permission<C>(
&self,
client: &C,
statement: &Statement,
) -> PgWireResult<()>
where
C: ClientInfo + ?Sized,
{
Expand All @@ -35,29 +39,19 @@ impl PermissionsHook {
.map(|s| s.as_str())
.unwrap_or("anonymous");

// Parse query to determine required permissions
let query_lower = query.to_lowercase();
let query_trimmed = query_lower.trim();

let (required_permission, resource) = if query_trimmed.starts_with("select") {
(Permission::Select, ResourceType::All)
} else if query_trimmed.starts_with("insert") {
(Permission::Insert, ResourceType::All)
} else if query_trimmed.starts_with("update") {
(Permission::Update, ResourceType::All)
} else if query_trimmed.starts_with("delete") {
(Permission::Delete, ResourceType::All)
} else if query_trimmed.starts_with("create table")
|| query_trimmed.starts_with("create view")
{
(Permission::Create, ResourceType::All)
} else if query_trimmed.starts_with("drop") {
(Permission::Drop, ResourceType::All)
} else if query_trimmed.starts_with("alter") {
(Permission::Alter, ResourceType::All)
} else {
// For other queries (SHOW, EXPLAIN, etc.), allow all users
return Ok(());
// Determine required permissions based on Statement type
let (required_permission, resource) = match statement {
Statement::Query(_) => (Permission::Select, ResourceType::All),
Statement::Insert(_) => (Permission::Insert, ResourceType::All),
Statement::Update { .. } => (Permission::Update, ResourceType::All),
Statement::Delete(_) => (Permission::Delete, ResourceType::All),
Statement::CreateTable { .. } | Statement::CreateView { .. } => {
(Permission::Create, ResourceType::All)
}
Statement::Drop { .. } => (Permission::Drop, ResourceType::All),
Statement::AlterTable { .. } => (Permission::Alter, ResourceType::All),
// For other statements (SET, SHOW, EXPLAIN, transactions, etc.), allow all users
_ => return Ok(()),
};

// Check permission
Expand All @@ -78,6 +72,21 @@ impl PermissionsHook {

Ok(())
}

/// Check if a statement should skip permission checks
fn should_skip_permission_check(statement: &Statement) -> bool {
matches!(
statement,
Statement::Set { .. }
| Statement::ShowVariable { .. }
| Statement::ShowStatus { .. }
| Statement::StartTransaction { .. }
| Statement::Commit { .. }
| Statement::Rollback { .. }
| Statement::Savepoint { .. }
| Statement::ReleaseSavepoint { .. }
)
}
}

#[async_trait]
Expand All @@ -89,22 +98,14 @@ impl QueryHook for PermissionsHook {
_session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>> {
let query_lower = statement.to_string().to_lowercase();

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
&& !query_lower.starts_with("begin")
&& !query_lower.starts_with("commit")
&& !query_lower.starts_with("rollback")
&& !query_lower.starts_with("start")
&& !query_lower.starts_with("end")
&& !query_lower.starts_with("abort")
&& !query_lower.starts_with("show")
{
let res = self.check_query_permission(&*client, &query_lower).await;
if let Err(e) = res {
return Some(Err(e));
}
// Skip permission checks for SET, SHOW, and transaction statements
if Self::should_skip_permission_check(statement) {
return None;
}

// Check permissions for other statements
if let Err(e) = self.check_statement_permission(&*client, statement).await {
return Some(Err(e));
}

None
Expand All @@ -127,15 +128,16 @@ impl QueryHook for PermissionsHook {
_session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>> {
let query = statement.to_string().to_lowercase();
// Skip permission checks for SET and SHOW statements
if Self::should_skip_permission_check(statement) {
return None;
}

// Check permissions for the query (skip for SET and SHOW statements)
if !query.starts_with("set") && !query.starts_with("show") {
let res = self.check_query_permission(&*client, &query).await;
if let Err(e) = res {
return Some(Err(e));
}
// Check permissions for other statements
if let Err(e) = self.check_statement_permission(&*client, statement).await {
return Some(Err(e));
}

None
}
}
32 changes: 32 additions & 0 deletions datafusion-postgres/tests/pgadmin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use pgwire::api::query::SimpleQueryHandler;

use datafusion_postgres::testing::*;

// pgAdmin startup queries from issue #178
// https://github.com/datafusion-contrib/datafusion-postgres/issues/178
const PGADMIN_QUERIES: &[&str] = &[
// Basic version query (fixed by #179)
"SELECT version()",
// Query to check for BDR extension and replication slots
r#"SELECT CASE
WHEN (SELECT count(extname) FROM pg_catalog.pg_extension WHERE extname='bdr') > 0
THEN 'pgd'
WHEN (SELECT COUNT(*) FROM pg_replication_slots) > 0
THEN 'log'
ELSE NULL
END as type"#,
];

#[tokio::test]
pub async fn test_pgadmin_startup_sql() {
let service = setup_handlers();
let mut client = MockClient::new();

for query in PGADMIN_QUERIES {
SimpleQueryHandler::do_query(&service, &mut client, query)
.await
.unwrap_or_else(|e| {
panic!("failed to run sql:\n--------------\n{query}\n--------------\n{e}")
});
}
}
Loading