Skip to content

Commit 37b2f72

Browse files
committed
refactor: use AST pattern matching instead of string matching for statement detection
Fixes #159 The codebase was using starts_with() string checks to detect statement types (SET, SHOW, INSERT, etc.) which is fragile - it can break with leading whitespace, comments, or case variations. Since we already have the parsed Statement AST from sqlparser, this switches to using pattern matching directly on the Statement enum. Changes in permissions.rs: - Renamed check_query_permission to check_statement_permission, now takes &Statement instead of &str - Permission detection (SELECT/INSERT/UPDATE/DELETE/CREATE/DROP/ALTER) now uses match on Statement variants - Added should_skip_permission_check() helper using matches!() macro for SET, SHOW, and transaction statements - Removed all the to_lowercase().starts_with() chains Changes in handlers.rs: - INSERT detection now uses matches!(statement, Statement::Insert(_)) - Removed unnecessary query_lower variable construction - Fixed extended query handler to properly destructure statement from the portal tuple All 12 unit tests pass. The existing integration test failures (dbeaver, metabase, psql) are unrelated - they fail due to missing DataFusion functions like array_length and array_contains.
1 parent 2b51166 commit 37b2f72

File tree

2 files changed

+54
-60
lines changed

2 files changed

+54
-60
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ impl SimpleQueryHandler for DfSessionService {
139139

140140
let mut results = vec![];
141141
'stmt: for statement in statements {
142-
// TODO: improve statement check by using statement directly
143142
let query = statement.to_string();
144-
let query_lower = query.to_lowercase().trim().to_string();
145143

146144
// Call query hooks with the parsed statement
147145
for hook in &self.query_hooks {
@@ -179,7 +177,7 @@ impl SimpleQueryHandler for DfSessionService {
179177
}
180178
};
181179

182-
if query_lower.starts_with("insert into") {
180+
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
183181
let resp = map_rows_affected_for_insert(&df).await?;
184182
results.push(resp);
185183
} else {
@@ -265,13 +263,7 @@ impl ExtendedQueryHandler for DfSessionService {
265263
where
266264
C: ClientInfo + Unpin + Send + Sync,
267265
{
268-
let query = portal
269-
.statement
270-
.statement
271-
.0
272-
.to_lowercase()
273-
.trim()
274-
.to_string();
266+
let query = portal.statement.statement.0.to_string();
275267
log::debug!("Received execute extended query: {query}"); // Log for debugging
276268

277269
// Check query hooks first
@@ -302,7 +294,7 @@ impl ExtendedQueryHandler for DfSessionService {
302294
}
303295
}
304296

305-
if let (_, Some((_, plan))) = &portal.statement.statement {
297+
if let (_, Some((statement, plan))) = &portal.statement.statement {
306298
let param_types = plan
307299
.get_parameter_types()
308300
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -345,7 +337,7 @@ impl ExtendedQueryHandler for DfSessionService {
345337
}
346338
};
347339

348-
if query.starts_with("insert into") {
340+
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
349341
let resp = map_rows_affected_for_insert(&dataframe).await?;
350342

351343
Ok(resp)

datafusion-postgres/src/hooks/permissions.rs

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ impl PermissionsHook {
2323
PermissionsHook { auth_manager }
2424
}
2525

26-
/// Check if the current user has permission to execute a query
27-
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
26+
/// Check if the current user has permission to execute a statement
27+
async fn check_statement_permission<C>(
28+
&self,
29+
client: &C,
30+
statement: &Statement,
31+
) -> PgWireResult<()>
2832
where
2933
C: ClientInfo + ?Sized,
3034
{
@@ -35,29 +39,19 @@ impl PermissionsHook {
3539
.map(|s| s.as_str())
3640
.unwrap_or("anonymous");
3741

38-
// Parse query to determine required permissions
39-
let query_lower = query.to_lowercase();
40-
let query_trimmed = query_lower.trim();
41-
42-
let (required_permission, resource) = if query_trimmed.starts_with("select") {
43-
(Permission::Select, ResourceType::All)
44-
} else if query_trimmed.starts_with("insert") {
45-
(Permission::Insert, ResourceType::All)
46-
} else if query_trimmed.starts_with("update") {
47-
(Permission::Update, ResourceType::All)
48-
} else if query_trimmed.starts_with("delete") {
49-
(Permission::Delete, ResourceType::All)
50-
} else if query_trimmed.starts_with("create table")
51-
|| query_trimmed.starts_with("create view")
52-
{
53-
(Permission::Create, ResourceType::All)
54-
} else if query_trimmed.starts_with("drop") {
55-
(Permission::Drop, ResourceType::All)
56-
} else if query_trimmed.starts_with("alter") {
57-
(Permission::Alter, ResourceType::All)
58-
} else {
59-
// For other queries (SHOW, EXPLAIN, etc.), allow all users
60-
return Ok(());
42+
// Determine required permissions based on Statement type
43+
let (required_permission, resource) = match statement {
44+
Statement::Query(_) => (Permission::Select, ResourceType::All),
45+
Statement::Insert(_) => (Permission::Insert, ResourceType::All),
46+
Statement::Update { .. } => (Permission::Update, ResourceType::All),
47+
Statement::Delete(_) => (Permission::Delete, ResourceType::All),
48+
Statement::CreateTable { .. } | Statement::CreateView { .. } => {
49+
(Permission::Create, ResourceType::All)
50+
}
51+
Statement::Drop { .. } => (Permission::Drop, ResourceType::All),
52+
Statement::AlterTable { .. } => (Permission::Alter, ResourceType::All),
53+
// For other statements (SET, SHOW, EXPLAIN, transactions, etc.), allow all users
54+
_ => return Ok(()),
6155
};
6256

6357
// Check permission
@@ -78,6 +72,21 @@ impl PermissionsHook {
7872

7973
Ok(())
8074
}
75+
76+
/// Check if a statement should skip permission checks
77+
fn should_skip_permission_check(statement: &Statement) -> bool {
78+
matches!(
79+
statement,
80+
Statement::Set { .. }
81+
| Statement::ShowVariable { .. }
82+
| Statement::ShowStatus { .. }
83+
| Statement::StartTransaction { .. }
84+
| Statement::Commit { .. }
85+
| Statement::Rollback { .. }
86+
| Statement::Savepoint { .. }
87+
| Statement::ReleaseSavepoint { .. }
88+
)
89+
}
8190
}
8291

8392
#[async_trait]
@@ -89,22 +98,14 @@ impl QueryHook for PermissionsHook {
8998
_session_context: &SessionContext,
9099
client: &mut (dyn ClientInfo + Send + Sync),
91100
) -> Option<PgWireResult<Response>> {
92-
let query_lower = statement.to_string().to_lowercase();
93-
94-
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
95-
if !query_lower.starts_with("set")
96-
&& !query_lower.starts_with("begin")
97-
&& !query_lower.starts_with("commit")
98-
&& !query_lower.starts_with("rollback")
99-
&& !query_lower.starts_with("start")
100-
&& !query_lower.starts_with("end")
101-
&& !query_lower.starts_with("abort")
102-
&& !query_lower.starts_with("show")
103-
{
104-
let res = self.check_query_permission(&*client, &query_lower).await;
105-
if let Err(e) = res {
106-
return Some(Err(e));
107-
}
101+
// Skip permission checks for SET, SHOW, and transaction statements
102+
if Self::should_skip_permission_check(statement) {
103+
return None;
104+
}
105+
106+
// Check permissions for other statements
107+
if let Err(e) = self.check_statement_permission(&*client, statement).await {
108+
return Some(Err(e));
108109
}
109110

110111
None
@@ -127,15 +128,16 @@ impl QueryHook for PermissionsHook {
127128
_session_context: &SessionContext,
128129
client: &mut (dyn ClientInfo + Send + Sync),
129130
) -> Option<PgWireResult<Response>> {
130-
let query = statement.to_string().to_lowercase();
131+
// Skip permission checks for SET and SHOW statements
132+
if Self::should_skip_permission_check(statement) {
133+
return None;
134+
}
131135

132-
// Check permissions for the query (skip for SET and SHOW statements)
133-
if !query.starts_with("set") && !query.starts_with("show") {
134-
let res = self.check_query_permission(&*client, &query).await;
135-
if let Err(e) = res {
136-
return Some(Err(e));
137-
}
136+
// Check permissions for other statements
137+
if let Err(e) = self.check_statement_permission(&*client, statement).await {
138+
return Some(Err(e));
138139
}
140+
139141
None
140142
}
141143
}

0 commit comments

Comments
 (0)