Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions src-tauri/src/drivers/mysql/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,22 @@ pub async fn explain_query(
get_mysql_pool(params).await?
};

// Behind a bastion that rejects prepared statements, EXPLAIN variants must
// run over the text protocol (COM_QUERY) — see `super::force_text_protocol`.
let text = super::force_text_protocol(params);

// Detect server version to skip unsupported EXPLAIN variants
let caps = {
let mut vc = pool.acquire().await.map_err(|e| e.to_string())?;
let ver_row = sqlx::query("SELECT VERSION()")
.fetch_one(&mut *vc)
.await
.ok();
let ver_row = if text {
use sqlx::Executor;
(&mut *vc)
.fetch_one(sqlx::raw_sql("SELECT VERSION()"))
.await
} else {
sqlx::query("SELECT VERSION()").fetch_one(&mut *vc).await
}
.ok();
let ver_str: String = ver_row.and_then(|r| r.try_get(0).ok()).unwrap_or_default();
log::debug!("MySQL/MariaDB version: {}", ver_str);
parse_mysql_version(&ver_str)
Expand All @@ -77,7 +86,13 @@ pub async fn explain_query(
if analyze && caps.supports_explain_analyze {
let mut conn = pool.acquire().await.map_err(|e| e.to_string())?;
let analyze_sql = format!("EXPLAIN ANALYZE {}", query);
if let Ok(rows) = sqlx::query(&analyze_sql).fetch_all(&mut *conn).await {
let analyze_res = if text {
use sqlx::Executor;
(&mut *conn).fetch_all(sqlx::raw_sql(&analyze_sql)).await
} else {
sqlx::query(&analyze_sql).fetch_all(&mut *conn).await
};
if let Ok(rows) = analyze_res {
let mut lines = Vec::new();
for row in &rows {
if let Ok(line) = row.try_get::<String, _>(0) {
Expand Down Expand Up @@ -108,7 +123,13 @@ pub async fn explain_query(
if analyze && caps.supports_analyze_format {
let mut conn = pool.acquire().await.map_err(|e| e.to_string())?;
let maria_sql = format!("ANALYZE FORMAT=JSON {}", query);
if let Ok(row) = sqlx::query(&maria_sql).fetch_one(&mut *conn).await {
let maria_res = if text {
use sqlx::Executor;
(&mut *conn).fetch_one(sqlx::raw_sql(&maria_sql)).await
} else {
sqlx::query(&maria_sql).fetch_one(&mut *conn).await
};
if let Ok(row) = maria_res {
if let Ok(raw_json) = row.try_get::<String, _>(0) {
if let Ok(json_val) = serde_json::from_str::<serde_json::Value>(&raw_json) {
if let Some(query_block) = json_val.get("query_block") {
Expand Down Expand Up @@ -142,10 +163,13 @@ pub async fn explain_query(
let mut conn = pool.acquire().await.map_err(|e| e.to_string())?;
let json_sql = format!("EXPLAIN FORMAT=JSON {}", query);
let json_result: Result<String, String> = async {
let row = sqlx::query(&json_sql)
.fetch_one(&mut *conn)
.await
.map_err(|e| e.to_string())?;
let row = if text {
use sqlx::Executor;
(&mut *conn).fetch_one(sqlx::raw_sql(&json_sql)).await
} else {
sqlx::query(&json_sql).fetch_one(&mut *conn).await
}
.map_err(|e| e.to_string())?;
row.try_get::<String, _>(0).map_err(|e| e.to_string())
}
.await;
Expand Down Expand Up @@ -174,10 +198,13 @@ pub async fn explain_query(
// Tabular fallback — works on all MySQL/MariaDB versions
let mut conn = pool.acquire().await.map_err(|e| e.to_string())?;
let explain_sql = format!("EXPLAIN {}", query);
let rows = sqlx::query(&explain_sql)
.fetch_all(&mut *conn)
.await
.map_err(|e| e.to_string())?;
let rows = if text {
use sqlx::Executor;
(&mut *conn).fetch_all(sqlx::raw_sql(&explain_sql)).await
} else {
sqlx::query(&explain_sql).fetch_all(&mut *conn).await
}
.map_err(|e| e.to_string())?;

let (root, raw) = parse_mysql_tabular_explain(&rows);
Ok(ExplainPlan {
Expand Down
8 changes: 7 additions & 1 deletion src-tauri/src/drivers/mysql/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ where
F: FnMut(&[String], &[Value]) -> Result<(), String> + Send,
{
let pool = get_mysql_pool(params).await?;
let mut rows = sqlx::query(query).fetch(&pool);
// Behind a bastion that rejects prepared statements, stream over the text
// protocol (COM_QUERY) instead — see `super::force_text_protocol`.
let mut rows = if super::force_text_protocol(params) {
sqlx::raw_sql(query).fetch(&pool)
} else {
sqlx::query(query).fetch(&pool)
};
let mut headers: Option<Vec<String>> = None;

while let Some(row_res) = rows.next().await {
Expand Down
92 changes: 92 additions & 0 deletions src-tauri/src/drivers/mysql/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,98 @@ pub(super) fn escape_identifier(name: &str) -> String {
name.replace('`', "``")
}

/// Renders a `&str` as a quoted MySQL string literal for the text protocol.
///
/// Used when a query has to bypass the prepared-statement protocol (e.g.
/// behind a Warpgate-style bastion that rejects `COM_STMT_PREPARE`): the
/// value can no longer travel as a bind parameter, so it is inlined as an
/// escaped literal instead.
///
/// The escaping depends on the server's `sql_mode`: when `NO_BACKSLASH_ESCAPES`
/// is set (ANSI mode, some bastion targets) the backslash is an ordinary
/// character, so a value like `o\'brien` must close the quote by doubling it
/// (`''`) rather than `\'` — otherwise the literal is mis-parsed and user cell
/// values become an injection vector. Quote doubling is also valid in the
/// default mode, but backslash escaping is not portable, so callers must pass
/// the actual server setting via `no_backslash_escapes`.
pub(super) fn mysql_string_literal(s: &str, no_backslash_escapes: bool) -> String {
let mut out = String::with_capacity(s.len() + 2);
out.push('\'');
if no_backslash_escapes {
// Backslash is literal here; the single quote is the only metacharacter
// inside the literal and is escaped by doubling it. Everything else
// (including control bytes and backslashes) is emitted verbatim.
for ch in s.chars() {
if ch == '\'' {
out.push_str("''");
} else {
out.push(ch);
}
}
} else {
// Default mode: mirror `mysql_real_escape_string` (backslash escapes on).
for ch in s.chars() {
match ch {
'\0' => out.push_str("\\0"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\\' => out.push_str("\\\\"),
'\'' => out.push_str("\\'"),
'"' => out.push_str("\\\""),
'\u{1a}' => out.push_str("\\Z"),
c => out.push(c),
}
}
}
out.push('\'');
out
}

/// Renders raw bytes as a MySQL hexadecimal literal (`x'..'`) for the text
/// protocol — the inlined equivalent of binding a `Vec<u8>` blob parameter.
pub(super) fn mysql_bytes_literal(bytes: &[u8]) -> String {
use std::fmt::Write;
let mut out = String::with_capacity(bytes.len() * 2 + 3);
out.push_str("x'");
for b in bytes {
let _ = write!(out, "{:02x}", b);
}
out.push('\'');
out
}

/// Substitutes each `?` placeholder in `sql` with the next quoted string
/// literal from `binds`, in order. Used to turn a parameterised
/// introspection query into a text-protocol statement. Placeholders past
/// the end of `binds` (and `?` chars when `binds` is empty) are left as-is.
/// `no_backslash_escapes` is forwarded to [`mysql_string_literal`] so the
/// literals match the server's `sql_mode`.
///
/// # Safety
///
/// This treats every `?` as a bind placeholder, so it is only sound for the
/// driver's own hand-written introspection queries (whose `?` chars are
/// exclusively placeholders). It must never be used to render arbitrary user
/// SQL, where a `?` could appear inside a string literal.
pub(super) fn inline_str_placeholders(
sql: &str,
binds: &[&str],
no_backslash_escapes: bool,
) -> String {
let mut out = String::with_capacity(sql.len());
let mut iter = binds.iter();
for ch in sql.chars() {
if ch == '?' {
if let Some(b) = iter.next() {
out.push_str(&mysql_string_literal(b, no_backslash_escapes));
continue;
}
}
out.push(ch);
}
out
}

/// Read a string from a MySQL row by index.
/// MySQL 8 information_schema returns VARBINARY/BLOB instead of VARCHAR,
/// so try_get::<String> fails silently. This falls back to reading raw bytes.
Expand Down
Loading