diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 0f3b68b..8fdd5e3 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -24,12 +24,12 @@ use pgwire::types::format::FormatOptions; use crate::auth::AuthManager; use crate::client; +use crate::hooks::permissions::PermissionsHook; use crate::hooks::set_show::SetShowHook; use crate::hooks::transactions::TransactionStatementHook; use crate::hooks::QueryHook; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; -use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; /// Simple startup handler that does no authentication @@ -52,12 +52,10 @@ impl HandlerFactory { pub fn new_with_hooks( session_context: Arc, - auth_manager: Arc, query_hooks: Vec>, ) -> Self { let session_service = Arc::new(DfSessionService::new_with_hooks( session_context, - auth_manager.clone(), query_hooks, )); HandlerFactory { session_service } @@ -97,7 +95,6 @@ impl ErrorHandler for LoggingErrorHandler { pub struct DfSessionService { session_context: Arc, parser: Arc, - auth_manager: Arc, query_hooks: Vec>, } @@ -106,14 +103,16 @@ impl DfSessionService { session_context: Arc, auth_manager: Arc, ) -> DfSessionService { - let hooks: Vec> = - vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)]; - Self::new_with_hooks(session_context, auth_manager, hooks) + let hooks: Vec> = vec![ + Arc::new(PermissionsHook::new(auth_manager)), + Arc::new(SetShowHook), + Arc::new(TransactionStatementHook), + ]; + Self::new_with_hooks(session_context, hooks) } pub fn new_with_hooks( session_context: Arc, - auth_manager: Arc, query_hooks: Vec>, ) -> DfSessionService { let parser = Arc::new(Parser { @@ -124,85 +123,9 @@ impl DfSessionService { DfSessionService { session_context, parser, - auth_manager, query_hooks, } } - - /// Check if the current user has permission to execute a query - async fn check_query_permission(&self, client: &C, query: &str) -> PgWireResult<()> - where - C: ClientInfo, - { - // Get the username from client metadata - let username = client - .metadata() - .get("user") - .map(|s| s.as_str()) - .unwrap_or("anonymous"); - - // Parse query to determine required permissions - let query_lower = query.to_lowercase(); - let query_trimmed = query_lower.trim(); - - let (required_permission, resource) = if query_trimmed.starts_with("select") { - (Permission::Select, self.extract_table_from_query(query)) - } else if query_trimmed.starts_with("insert") { - (Permission::Insert, self.extract_table_from_query(query)) - } else if query_trimmed.starts_with("update") { - (Permission::Update, self.extract_table_from_query(query)) - } else if query_trimmed.starts_with("delete") { - (Permission::Delete, self.extract_table_from_query(query)) - } else if query_trimmed.starts_with("create table") - || query_trimmed.starts_with("create view") - { - (Permission::Create, ResourceType::All) - } else if query_trimmed.starts_with("drop") { - (Permission::Drop, self.extract_table_from_query(query)) - } else if query_trimmed.starts_with("alter") { - (Permission::Alter, self.extract_table_from_query(query)) - } else { - // For other queries (SHOW, EXPLAIN, etc.), allow all users - return Ok(()); - }; - - // Check permission - let has_permission = self - .auth_manager - .check_permission(username, required_permission, resource) - .await; - - if !has_permission { - return Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "42501".to_string(), // insufficient_privilege - format!("permission denied for user \"{username}\""), - ), - ))); - } - - Ok(()) - } - - /// Extract table name from query (simplified parsing) - fn extract_table_from_query(&self, query: &str) -> ResourceType { - let words: Vec<&str> = query.split_whitespace().collect(); - - // Simple heuristic to find table names - for (i, word) in words.iter().enumerate() { - let word_lower = word.to_lowercase(); - if (word_lower == "from" || word_lower == "into" || word_lower == "table") - && i + 1 < words.len() - { - let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';'); - return ResourceType::Table(table_name.to_string()); - } - } - - // If we can't determine the table, default to All - ResourceType::All - } } #[async_trait] @@ -230,19 +153,6 @@ impl SimpleQueryHandler for DfSessionService { let query = statement.to_string(); let query_lower = query.to_lowercase().trim().to_string(); - // Check permissions for the query (skip for SET, transaction, and SHOW statements) - if !query_lower.starts_with("set") - && !query_lower.starts_with("begin") - && !query_lower.starts_with("commit") - && !query_lower.starts_with("rollback") - && !query_lower.starts_with("start") - && !query_lower.starts_with("end") - && !query_lower.starts_with("abort") - && !query_lower.starts_with("show") - { - self.check_query_permission(client, &query).await?; - } - // Call query hooks with the parsed statement for hook in &self.query_hooks { if let Some(result) = hook @@ -402,12 +312,6 @@ impl ExtendedQueryHandler for DfSessionService { } } - // Check permissions for the query (skip for SET and SHOW statements) - if !query.starts_with("set") && !query.starts_with("show") { - self.check_query_permission(client, &portal.statement.statement.0) - .await?; - } - if let (_, Some((_, plan))) = &portal.statement.statement { let param_types = plan .get_parameter_types() @@ -632,10 +536,9 @@ mod tests { // which would exit the entire statement loop, preventing subsequent statements // from being processed. let session_context = Arc::new(SessionContext::new()); - let auth_manager = Arc::new(AuthManager::new()); let hooks: Vec> = vec![Arc::new(TestHook)]; - let service = DfSessionService::new_with_hooks(session_context, auth_manager, hooks); + let service = DfSessionService::new_with_hooks(session_context, hooks); let mut client = MockClient::new(); diff --git a/datafusion-postgres/src/hooks/mod.rs b/datafusion-postgres/src/hooks/mod.rs index d6cd7bb..2f12ef9 100644 --- a/datafusion-postgres/src/hooks/mod.rs +++ b/datafusion-postgres/src/hooks/mod.rs @@ -1,3 +1,4 @@ +pub mod permissions; pub mod set_show; pub mod transactions; diff --git a/datafusion-postgres/src/hooks/permissions.rs b/datafusion-postgres/src/hooks/permissions.rs new file mode 100644 index 0000000..4d42527 --- /dev/null +++ b/datafusion-postgres/src/hooks/permissions.rs @@ -0,0 +1,160 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::common::ParamValues; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::SessionContext; +use datafusion::sql::sqlparser::ast::Statement; +use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; +use pgwire::api::results::Response; +use pgwire::api::ClientInfo; +use pgwire::error::{PgWireError, PgWireResult}; + +use crate::auth::AuthManager; +use crate::QueryHook; + +#[derive(Debug)] +pub struct PermissionsHook { + auth_manager: Arc, +} + +impl PermissionsHook { + pub fn new(auth_manager: Arc) -> Self { + PermissionsHook { auth_manager } + } + + /// Check if the current user has permission to execute a query + async fn check_query_permission(&self, client: &C, query: &str) -> PgWireResult<()> + where + C: ClientInfo + ?Sized, + { + // Get the username from client metadata + let username = client + .metadata() + .get("user") + .map(|s| s.as_str()) + .unwrap_or("anonymous"); + + // Parse query to determine required permissions + let query_lower = query.to_lowercase(); + let query_trimmed = query_lower.trim(); + + let (required_permission, resource) = if query_trimmed.starts_with("select") { + (Permission::Select, self.extract_table_from_query(query)) + } else if query_trimmed.starts_with("insert") { + (Permission::Insert, self.extract_table_from_query(query)) + } else if query_trimmed.starts_with("update") { + (Permission::Update, self.extract_table_from_query(query)) + } else if query_trimmed.starts_with("delete") { + (Permission::Delete, self.extract_table_from_query(query)) + } else if query_trimmed.starts_with("create table") + || query_trimmed.starts_with("create view") + { + (Permission::Create, ResourceType::All) + } else if query_trimmed.starts_with("drop") { + (Permission::Drop, self.extract_table_from_query(query)) + } else if query_trimmed.starts_with("alter") { + (Permission::Alter, self.extract_table_from_query(query)) + } else { + // For other queries (SHOW, EXPLAIN, etc.), allow all users + return Ok(()); + }; + + // Check permission + let has_permission = self + .auth_manager + .check_permission(username, required_permission, resource) + .await; + + if !has_permission { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42501".to_string(), // insufficient_privilege + format!("permission denied for user \"{username}\""), + ), + ))); + } + + Ok(()) + } + + /// Extract table name from query (simplified parsing) + fn extract_table_from_query(&self, query: &str) -> ResourceType { + let words: Vec<&str> = query.split_whitespace().collect(); + + // Simple heuristic to find table names + for (i, word) in words.iter().enumerate() { + let word_lower = word.to_lowercase(); + if (word_lower == "from" || word_lower == "into" || word_lower == "table") + && i + 1 < words.len() + { + let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';'); + return ResourceType::Table(table_name.to_string()); + } + } + + // If we can't determine the table, default to All + ResourceType::All + } +} + +#[async_trait] +impl QueryHook for PermissionsHook { + /// called in simple query handler to return response directly + async fn handle_simple_query( + &self, + statement: &Statement, + _session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option> { + let query_lower = statement.to_string().to_lowercase(); + + // Check permissions for the query (skip for SET, transaction, and SHOW statements) + if !query_lower.starts_with("set") + && !query_lower.starts_with("begin") + && !query_lower.starts_with("commit") + && !query_lower.starts_with("rollback") + && !query_lower.starts_with("start") + && !query_lower.starts_with("end") + && !query_lower.starts_with("abort") + && !query_lower.starts_with("show") + { + let res = self.check_query_permission(&*client, &query_lower).await; + if let Err(e) = res { + return Some(Err(e)); + } + } + + None + } + + async fn handle_extended_parse_query( + &self, + _stmt: &Statement, + _session_context: &SessionContext, + _client: &(dyn ClientInfo + Send + Sync), + ) -> Option> { + None + } + + async fn handle_extended_query( + &self, + statement: &Statement, + _logical_plan: &LogicalPlan, + _params: &ParamValues, + _session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option> { + let query = statement.to_string().to_lowercase(); + + // Check permissions for the query (skip for SET and SHOW statements) + if !query.starts_with("set") && !query.starts_with("show") { + let res = self.check_query_permission(&*client, &query).await; + if let Err(e) = res { + return Some(Err(e)); + } + } + None + } +} diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index 4ced1fc..6996cd6 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -99,15 +99,10 @@ pub async fn serve( pub async fn serve_with_hooks( session_context: Arc, opts: &ServerOptions, - auth_manager: Arc, hooks: Vec>, ) -> Result<(), std::io::Error> { // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new_with_hooks( - session_context, - auth_manager, - hooks, - )); + let factory = Arc::new(HandlerFactory::new_with_hooks(session_context, hooks)); serve_with_handlers(factory, opts).await }