Skip to content

Commit c2d02d0

Browse files
authored
Move permissions handling into a QueryHook (#254)
This moves permission handling into a new QueryHook implementation called `PermissionsHook` The `AuthManager` is now moved from `HandlerFactory` and `DfSessionService` and is now held only inside the `PermissionsHook` struct. This does not yet move from string processing to query matching, nor do we yet try to completely decouple the permissions logic completely from everything else.
1 parent c60f38f commit c2d02d0

File tree

4 files changed

+170
-111
lines changed

4 files changed

+170
-111
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 8 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ use pgwire::types::format::FormatOptions;
2424

2525
use crate::auth::AuthManager;
2626
use crate::client;
27+
use crate::hooks::permissions::PermissionsHook;
2728
use crate::hooks::set_show::SetShowHook;
2829
use crate::hooks::transactions::TransactionStatementHook;
2930
use crate::hooks::QueryHook;
3031
use arrow_pg::datatypes::df;
3132
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
32-
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
3333
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
3434

3535
/// Simple startup handler that does no authentication
@@ -52,12 +52,10 @@ impl HandlerFactory {
5252

5353
pub fn new_with_hooks(
5454
session_context: Arc<SessionContext>,
55-
auth_manager: Arc<AuthManager>,
5655
query_hooks: Vec<Arc<dyn QueryHook>>,
5756
) -> Self {
5857
let session_service = Arc::new(DfSessionService::new_with_hooks(
5958
session_context,
60-
auth_manager.clone(),
6159
query_hooks,
6260
));
6361
HandlerFactory { session_service }
@@ -97,7 +95,6 @@ impl ErrorHandler for LoggingErrorHandler {
9795
pub struct DfSessionService {
9896
session_context: Arc<SessionContext>,
9997
parser: Arc<Parser>,
100-
auth_manager: Arc<AuthManager>,
10198
query_hooks: Vec<Arc<dyn QueryHook>>,
10299
}
103100

@@ -106,14 +103,16 @@ impl DfSessionService {
106103
session_context: Arc<SessionContext>,
107104
auth_manager: Arc<AuthManager>,
108105
) -> DfSessionService {
109-
let hooks: Vec<Arc<dyn QueryHook>> =
110-
vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
111-
Self::new_with_hooks(session_context, auth_manager, hooks)
106+
let hooks: Vec<Arc<dyn QueryHook>> = vec![
107+
Arc::new(PermissionsHook::new(auth_manager)),
108+
Arc::new(SetShowHook),
109+
Arc::new(TransactionStatementHook),
110+
];
111+
Self::new_with_hooks(session_context, hooks)
112112
}
113113

114114
pub fn new_with_hooks(
115115
session_context: Arc<SessionContext>,
116-
auth_manager: Arc<AuthManager>,
117116
query_hooks: Vec<Arc<dyn QueryHook>>,
118117
) -> DfSessionService {
119118
let parser = Arc::new(Parser {
@@ -124,85 +123,9 @@ impl DfSessionService {
124123
DfSessionService {
125124
session_context,
126125
parser,
127-
auth_manager,
128126
query_hooks,
129127
}
130128
}
131-
132-
/// Check if the current user has permission to execute a query
133-
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
134-
where
135-
C: ClientInfo,
136-
{
137-
// Get the username from client metadata
138-
let username = client
139-
.metadata()
140-
.get("user")
141-
.map(|s| s.as_str())
142-
.unwrap_or("anonymous");
143-
144-
// Parse query to determine required permissions
145-
let query_lower = query.to_lowercase();
146-
let query_trimmed = query_lower.trim();
147-
148-
let (required_permission, resource) = if query_trimmed.starts_with("select") {
149-
(Permission::Select, self.extract_table_from_query(query))
150-
} else if query_trimmed.starts_with("insert") {
151-
(Permission::Insert, self.extract_table_from_query(query))
152-
} else if query_trimmed.starts_with("update") {
153-
(Permission::Update, self.extract_table_from_query(query))
154-
} else if query_trimmed.starts_with("delete") {
155-
(Permission::Delete, self.extract_table_from_query(query))
156-
} else if query_trimmed.starts_with("create table")
157-
|| query_trimmed.starts_with("create view")
158-
{
159-
(Permission::Create, ResourceType::All)
160-
} else if query_trimmed.starts_with("drop") {
161-
(Permission::Drop, self.extract_table_from_query(query))
162-
} else if query_trimmed.starts_with("alter") {
163-
(Permission::Alter, self.extract_table_from_query(query))
164-
} else {
165-
// For other queries (SHOW, EXPLAIN, etc.), allow all users
166-
return Ok(());
167-
};
168-
169-
// Check permission
170-
let has_permission = self
171-
.auth_manager
172-
.check_permission(username, required_permission, resource)
173-
.await;
174-
175-
if !has_permission {
176-
return Err(PgWireError::UserError(Box::new(
177-
pgwire::error::ErrorInfo::new(
178-
"ERROR".to_string(),
179-
"42501".to_string(), // insufficient_privilege
180-
format!("permission denied for user \"{username}\""),
181-
),
182-
)));
183-
}
184-
185-
Ok(())
186-
}
187-
188-
/// Extract table name from query (simplified parsing)
189-
fn extract_table_from_query(&self, query: &str) -> ResourceType {
190-
let words: Vec<&str> = query.split_whitespace().collect();
191-
192-
// Simple heuristic to find table names
193-
for (i, word) in words.iter().enumerate() {
194-
let word_lower = word.to_lowercase();
195-
if (word_lower == "from" || word_lower == "into" || word_lower == "table")
196-
&& i + 1 < words.len()
197-
{
198-
let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
199-
return ResourceType::Table(table_name.to_string());
200-
}
201-
}
202-
203-
// If we can't determine the table, default to All
204-
ResourceType::All
205-
}
206129
}
207130

208131
#[async_trait]
@@ -230,19 +153,6 @@ impl SimpleQueryHandler for DfSessionService {
230153
let query = statement.to_string();
231154
let query_lower = query.to_lowercase().trim().to_string();
232155

233-
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
234-
if !query_lower.starts_with("set")
235-
&& !query_lower.starts_with("begin")
236-
&& !query_lower.starts_with("commit")
237-
&& !query_lower.starts_with("rollback")
238-
&& !query_lower.starts_with("start")
239-
&& !query_lower.starts_with("end")
240-
&& !query_lower.starts_with("abort")
241-
&& !query_lower.starts_with("show")
242-
{
243-
self.check_query_permission(client, &query).await?;
244-
}
245-
246156
// Call query hooks with the parsed statement
247157
for hook in &self.query_hooks {
248158
if let Some(result) = hook
@@ -402,12 +312,6 @@ impl ExtendedQueryHandler for DfSessionService {
402312
}
403313
}
404314

405-
// Check permissions for the query (skip for SET and SHOW statements)
406-
if !query.starts_with("set") && !query.starts_with("show") {
407-
self.check_query_permission(client, &portal.statement.statement.0)
408-
.await?;
409-
}
410-
411315
if let (_, Some((_, plan))) = &portal.statement.statement {
412316
let param_types = plan
413317
.get_parameter_types()
@@ -632,10 +536,9 @@ mod tests {
632536
// which would exit the entire statement loop, preventing subsequent statements
633537
// from being processed.
634538
let session_context = Arc::new(SessionContext::new());
635-
let auth_manager = Arc::new(AuthManager::new());
636539

637540
let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
638-
let service = DfSessionService::new_with_hooks(session_context, auth_manager, hooks);
541+
let service = DfSessionService::new_with_hooks(session_context, hooks);
639542

640543
let mut client = MockClient::new();
641544

datafusion-postgres/src/hooks/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod permissions;
12
pub mod set_show;
23
pub mod transactions;
34

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
use std::sync::Arc;
2+
3+
use async_trait::async_trait;
4+
use datafusion::common::ParamValues;
5+
use datafusion::logical_expr::LogicalPlan;
6+
use datafusion::prelude::SessionContext;
7+
use datafusion::sql::sqlparser::ast::Statement;
8+
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
9+
use pgwire::api::results::Response;
10+
use pgwire::api::ClientInfo;
11+
use pgwire::error::{PgWireError, PgWireResult};
12+
13+
use crate::auth::AuthManager;
14+
use crate::QueryHook;
15+
16+
#[derive(Debug)]
17+
pub struct PermissionsHook {
18+
auth_manager: Arc<AuthManager>,
19+
}
20+
21+
impl PermissionsHook {
22+
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
23+
PermissionsHook { auth_manager }
24+
}
25+
26+
/// Check if the current user has permission to execute a query
27+
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
28+
where
29+
C: ClientInfo + ?Sized,
30+
{
31+
// Get the username from client metadata
32+
let username = client
33+
.metadata()
34+
.get("user")
35+
.map(|s| s.as_str())
36+
.unwrap_or("anonymous");
37+
38+
// Parse query to determine required permissions
39+
let query_lower = query.to_lowercase();
40+
let query_trimmed = query_lower.trim();
41+
42+
let (required_permission, resource) = if query_trimmed.starts_with("select") {
43+
(Permission::Select, self.extract_table_from_query(query))
44+
} else if query_trimmed.starts_with("insert") {
45+
(Permission::Insert, self.extract_table_from_query(query))
46+
} else if query_trimmed.starts_with("update") {
47+
(Permission::Update, self.extract_table_from_query(query))
48+
} else if query_trimmed.starts_with("delete") {
49+
(Permission::Delete, self.extract_table_from_query(query))
50+
} else if query_trimmed.starts_with("create table")
51+
|| query_trimmed.starts_with("create view")
52+
{
53+
(Permission::Create, ResourceType::All)
54+
} else if query_trimmed.starts_with("drop") {
55+
(Permission::Drop, self.extract_table_from_query(query))
56+
} else if query_trimmed.starts_with("alter") {
57+
(Permission::Alter, self.extract_table_from_query(query))
58+
} else {
59+
// For other queries (SHOW, EXPLAIN, etc.), allow all users
60+
return Ok(());
61+
};
62+
63+
// Check permission
64+
let has_permission = self
65+
.auth_manager
66+
.check_permission(username, required_permission, resource)
67+
.await;
68+
69+
if !has_permission {
70+
return Err(PgWireError::UserError(Box::new(
71+
pgwire::error::ErrorInfo::new(
72+
"ERROR".to_string(),
73+
"42501".to_string(), // insufficient_privilege
74+
format!("permission denied for user \"{username}\""),
75+
),
76+
)));
77+
}
78+
79+
Ok(())
80+
}
81+
82+
/// Extract table name from query (simplified parsing)
83+
fn extract_table_from_query(&self, query: &str) -> ResourceType {
84+
let words: Vec<&str> = query.split_whitespace().collect();
85+
86+
// Simple heuristic to find table names
87+
for (i, word) in words.iter().enumerate() {
88+
let word_lower = word.to_lowercase();
89+
if (word_lower == "from" || word_lower == "into" || word_lower == "table")
90+
&& i + 1 < words.len()
91+
{
92+
let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
93+
return ResourceType::Table(table_name.to_string());
94+
}
95+
}
96+
97+
// If we can't determine the table, default to All
98+
ResourceType::All
99+
}
100+
}
101+
102+
#[async_trait]
103+
impl QueryHook for PermissionsHook {
104+
/// called in simple query handler to return response directly
105+
async fn handle_simple_query(
106+
&self,
107+
statement: &Statement,
108+
_session_context: &SessionContext,
109+
client: &mut (dyn ClientInfo + Send + Sync),
110+
) -> Option<PgWireResult<Response>> {
111+
let query_lower = statement.to_string().to_lowercase();
112+
113+
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
114+
if !query_lower.starts_with("set")
115+
&& !query_lower.starts_with("begin")
116+
&& !query_lower.starts_with("commit")
117+
&& !query_lower.starts_with("rollback")
118+
&& !query_lower.starts_with("start")
119+
&& !query_lower.starts_with("end")
120+
&& !query_lower.starts_with("abort")
121+
&& !query_lower.starts_with("show")
122+
{
123+
let res = self.check_query_permission(&*client, &query_lower).await;
124+
if let Err(e) = res {
125+
return Some(Err(e));
126+
}
127+
}
128+
129+
None
130+
}
131+
132+
async fn handle_extended_parse_query(
133+
&self,
134+
_stmt: &Statement,
135+
_session_context: &SessionContext,
136+
_client: &(dyn ClientInfo + Send + Sync),
137+
) -> Option<PgWireResult<LogicalPlan>> {
138+
None
139+
}
140+
141+
async fn handle_extended_query(
142+
&self,
143+
statement: &Statement,
144+
_logical_plan: &LogicalPlan,
145+
_params: &ParamValues,
146+
_session_context: &SessionContext,
147+
client: &mut (dyn ClientInfo + Send + Sync),
148+
) -> Option<PgWireResult<Response>> {
149+
let query = statement.to_string().to_lowercase();
150+
151+
// Check permissions for the query (skip for SET and SHOW statements)
152+
if !query.starts_with("set") && !query.starts_with("show") {
153+
let res = self.check_query_permission(&*client, &query).await;
154+
if let Err(e) = res {
155+
return Some(Err(e));
156+
}
157+
}
158+
None
159+
}
160+
}

datafusion-postgres/src/lib.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,10 @@ pub async fn serve(
9999
pub async fn serve_with_hooks(
100100
session_context: Arc<SessionContext>,
101101
opts: &ServerOptions,
102-
auth_manager: Arc<AuthManager>,
103102
hooks: Vec<Arc<dyn QueryHook>>,
104103
) -> Result<(), std::io::Error> {
105104
// Create the handler factory with authentication
106-
let factory = Arc::new(HandlerFactory::new_with_hooks(
107-
session_context,
108-
auth_manager,
109-
hooks,
110-
));
105+
let factory = Arc::new(HandlerFactory::new_with_hooks(session_context, hooks));
111106

112107
serve_with_handlers(factory, opts).await
113108
}

0 commit comments

Comments
 (0)