Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 8 additions & 105 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ use pgwire::types::format::FormatOptions;

use crate::auth::AuthManager;
use crate::client;
use crate::hooks::permissions::PermissionsHook;
use crate::hooks::set_show::SetShowHook;
use crate::hooks::transactions::TransactionStatementHook;
use crate::hooks::QueryHook;
use arrow_pg::datatypes::df;
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;

/// Simple startup handler that does no authentication
Expand All @@ -52,12 +52,10 @@ impl HandlerFactory {

pub fn new_with_hooks(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
) -> Self {
let session_service = Arc::new(DfSessionService::new_with_hooks(
session_context,
auth_manager.clone(),
query_hooks,
));
HandlerFactory { session_service }
Expand Down Expand Up @@ -97,7 +95,6 @@ impl ErrorHandler for LoggingErrorHandler {
pub struct DfSessionService {
session_context: Arc<SessionContext>,
parser: Arc<Parser>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
}

Expand All @@ -106,14 +103,16 @@ impl DfSessionService {
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
) -> DfSessionService {
let hooks: Vec<Arc<dyn QueryHook>> =
vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
Self::new_with_hooks(session_context, auth_manager, hooks)
let hooks: Vec<Arc<dyn QueryHook>> = vec![
Arc::new(PermissionsHook::new(auth_manager)),
Arc::new(SetShowHook),
Arc::new(TransactionStatementHook),
];
Self::new_with_hooks(session_context, hooks)
}

pub fn new_with_hooks(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
query_hooks: Vec<Arc<dyn QueryHook>>,
) -> DfSessionService {
let parser = Arc::new(Parser {
Expand All @@ -124,85 +123,9 @@ impl DfSessionService {
DfSessionService {
session_context,
parser,
auth_manager,
query_hooks,
}
}

/// Check if the current user has permission to execute a query
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
where
C: ClientInfo,
{
// Get the username from client metadata
let username = client
.metadata()
.get("user")
.map(|s| s.as_str())
.unwrap_or("anonymous");

// Parse query to determine required permissions
let query_lower = query.to_lowercase();
let query_trimmed = query_lower.trim();

let (required_permission, resource) = if query_trimmed.starts_with("select") {
(Permission::Select, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("insert") {
(Permission::Insert, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("update") {
(Permission::Update, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("delete") {
(Permission::Delete, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("create table")
|| query_trimmed.starts_with("create view")
{
(Permission::Create, ResourceType::All)
} else if query_trimmed.starts_with("drop") {
(Permission::Drop, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("alter") {
(Permission::Alter, self.extract_table_from_query(query))
} else {
// For other queries (SHOW, EXPLAIN, etc.), allow all users
return Ok(());
};

// Check permission
let has_permission = self
.auth_manager
.check_permission(username, required_permission, resource)
.await;

if !has_permission {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42501".to_string(), // insufficient_privilege
format!("permission denied for user \"{username}\""),
),
)));
}

Ok(())
}

/// Extract table name from query (simplified parsing)
fn extract_table_from_query(&self, query: &str) -> ResourceType {
let words: Vec<&str> = query.split_whitespace().collect();

// Simple heuristic to find table names
for (i, word) in words.iter().enumerate() {
let word_lower = word.to_lowercase();
if (word_lower == "from" || word_lower == "into" || word_lower == "table")
&& i + 1 < words.len()
{
let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
return ResourceType::Table(table_name.to_string());
}
}

// If we can't determine the table, default to All
ResourceType::All
}
}

#[async_trait]
Expand Down Expand Up @@ -230,19 +153,6 @@ impl SimpleQueryHandler for DfSessionService {
let query = statement.to_string();
let query_lower = query.to_lowercase().trim().to_string();

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
&& !query_lower.starts_with("begin")
&& !query_lower.starts_with("commit")
&& !query_lower.starts_with("rollback")
&& !query_lower.starts_with("start")
&& !query_lower.starts_with("end")
&& !query_lower.starts_with("abort")
&& !query_lower.starts_with("show")
{
self.check_query_permission(client, &query).await?;
}

// Call query hooks with the parsed statement
for hook in &self.query_hooks {
if let Some(result) = hook
Expand Down Expand Up @@ -402,12 +312,6 @@ impl ExtendedQueryHandler for DfSessionService {
}
}

// Check permissions for the query (skip for SET and SHOW statements)
if !query.starts_with("set") && !query.starts_with("show") {
self.check_query_permission(client, &portal.statement.statement.0)
.await?;
}

if let (_, Some((_, plan))) = &portal.statement.statement {
let param_types = plan
.get_parameter_types()
Expand Down Expand Up @@ -632,10 +536,9 @@ mod tests {
// which would exit the entire statement loop, preventing subsequent statements
// from being processed.
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());

let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
let service = DfSessionService::new_with_hooks(session_context, auth_manager, hooks);
let service = DfSessionService::new_with_hooks(session_context, hooks);

let mut client = MockClient::new();

Expand Down
1 change: 1 addition & 0 deletions datafusion-postgres/src/hooks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod permissions;
pub mod set_show;
pub mod transactions;

Expand Down
160 changes: 160 additions & 0 deletions datafusion-postgres/src/hooks/permissions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::common::ParamValues;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::Statement;
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
use pgwire::api::results::Response;
use pgwire::api::ClientInfo;
use pgwire::error::{PgWireError, PgWireResult};

use crate::auth::AuthManager;
use crate::QueryHook;

#[derive(Debug)]
pub struct PermissionsHook {
auth_manager: Arc<AuthManager>,
}

impl PermissionsHook {
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
PermissionsHook { auth_manager }
}

/// Check if the current user has permission to execute a query
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
where
C: ClientInfo + ?Sized,
{
// Get the username from client metadata
let username = client
.metadata()
.get("user")
.map(|s| s.as_str())
.unwrap_or("anonymous");

// Parse query to determine required permissions
let query_lower = query.to_lowercase();
let query_trimmed = query_lower.trim();

let (required_permission, resource) = if query_trimmed.starts_with("select") {
(Permission::Select, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("insert") {
(Permission::Insert, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("update") {
(Permission::Update, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("delete") {
(Permission::Delete, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("create table")
|| query_trimmed.starts_with("create view")
{
(Permission::Create, ResourceType::All)
} else if query_trimmed.starts_with("drop") {
(Permission::Drop, self.extract_table_from_query(query))
} else if query_trimmed.starts_with("alter") {
(Permission::Alter, self.extract_table_from_query(query))
} else {
// For other queries (SHOW, EXPLAIN, etc.), allow all users
return Ok(());
};

// Check permission
let has_permission = self
.auth_manager
.check_permission(username, required_permission, resource)
.await;

if !has_permission {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42501".to_string(), // insufficient_privilege
format!("permission denied for user \"{username}\""),
),
)));
}

Ok(())
}

/// Extract table name from query (simplified parsing)
fn extract_table_from_query(&self, query: &str) -> ResourceType {
let words: Vec<&str> = query.split_whitespace().collect();

// Simple heuristic to find table names
for (i, word) in words.iter().enumerate() {
let word_lower = word.to_lowercase();
if (word_lower == "from" || word_lower == "into" || word_lower == "table")
&& i + 1 < words.len()
{
let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
return ResourceType::Table(table_name.to_string());
}
}

// If we can't determine the table, default to All
ResourceType::All
}
}

#[async_trait]
impl QueryHook for PermissionsHook {
/// called in simple query handler to return response directly
async fn handle_simple_query(
&self,
statement: &Statement,
_session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>> {
let query_lower = statement.to_string().to_lowercase();

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
&& !query_lower.starts_with("begin")
&& !query_lower.starts_with("commit")
&& !query_lower.starts_with("rollback")
&& !query_lower.starts_with("start")
&& !query_lower.starts_with("end")
&& !query_lower.starts_with("abort")
&& !query_lower.starts_with("show")
{
let res = self.check_query_permission(&*client, &query_lower).await;
if let Err(e) = res {
return Some(Err(e));
}
}

None
}

async fn handle_extended_parse_query(
&self,
_stmt: &Statement,
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<LogicalPlan>> {
None
}

async fn handle_extended_query(
&self,
statement: &Statement,
_logical_plan: &LogicalPlan,
_params: &ParamValues,
_session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<Response>> {
let query = statement.to_string().to_lowercase();

// Check permissions for the query (skip for SET and SHOW statements)
if !query.starts_with("set") && !query.starts_with("show") {
let res = self.check_query_permission(&*client, &query).await;
if let Err(e) = res {
return Some(Err(e));
}
}
None
}
}
7 changes: 1 addition & 6 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,10 @@ pub async fn serve(
pub async fn serve_with_hooks(
session_context: Arc<SessionContext>,
opts: &ServerOptions,
auth_manager: Arc<AuthManager>,
hooks: Vec<Arc<dyn QueryHook>>,
) -> Result<(), std::io::Error> {
// Create the handler factory with authentication
let factory = Arc::new(HandlerFactory::new_with_hooks(
session_context,
auth_manager,
hooks,
));
let factory = Arc::new(HandlerFactory::new_with_hooks(session_context, hooks));

serve_with_handlers(factory, opts).await
}
Expand Down
Loading