diff --git a/plugins/sql/api-iife.js b/plugins/sql/api-iife.js index a30f68d909..24b50d105a 100644 --- a/plugins/sql/api-iife.js +++ b/plugins/sql/api-iife.js @@ -1 +1 @@ -if("__TAURI__"in window){var __TAURI_PLUGIN_SQL__=function(){"use strict";async function e(e,t={},s){return window.__TAURI_INTERNALS__.invoke(e,t,s)}"function"==typeof SuppressedError&&SuppressedError;class t{constructor(e){this.path=e}static async load(s){const n=await e("plugin:sql|load",{db:s});return new t(n)}static get(e){return new t(e)}async execute(t,s){const[n,r]=await e("plugin:sql|execute",{db:this.path,query:t,values:s??[]});return{lastInsertId:r,rowsAffected:n}}async select(t,s){return await e("plugin:sql|select",{db:this.path,query:t,values:s??[]})}async close(t){return await e("plugin:sql|close",{db:t})}}return t}();Object.defineProperty(window.__TAURI__,"sql",{value:__TAURI_PLUGIN_SQL__})} +if("__TAURI__"in window){var __TAURI_PLUGIN_SQL__=function(){"use strict";async function e(e,t={},n){return window.__TAURI_INTERNALS__.invoke(e,t,n)}"function"==typeof SuppressedError&&SuppressedError;class t{constructor(e){this.path=e}static async load(n){const r=await e("plugin:sql|load",{db:n});return new t(r)}static get(e){return new t(e)}async execute(t,r){const[s,a]=await e("plugin:sql|execute",{db:this.path,query:t,values:r??[],bindTypes:n(r)});return{lastInsertId:a,rowsAffected:s}}async select(t,r){return await e("plugin:sql|select",{db:this.path,query:t,values:r??[],bindTypes:n(r)})}async close(t){return await e("plugin:sql|close",{db:t})}}const n=e=>{const t=new Map;if(!e)return t;for(const[n,r]of e.entries())(r instanceof Uint8Array||r instanceof ArrayBuffer)&&t.set(n,"bytearray");return t};return t}();Object.defineProperty(window.__TAURI__,"sql",{value:__TAURI_PLUGIN_SQL__})} diff --git a/plugins/sql/guest-js/index.ts b/plugins/sql/guest-js/index.ts index 11d39e70b4..5b27035db3 100644 --- a/plugins/sql/guest-js/index.ts +++ b/plugins/sql/guest-js/index.ts @@ -111,7 +111,8 @@ export default class Database { { db: this.path, query, - values: bindValues ?? [] + values: bindValues ?? [], + bindTypes: buildBindTypes(bindValues) } ) return { @@ -142,7 +143,8 @@ export default class Database { const result = await invoke('plugin:sql|select', { db: this.path, query, - values: bindValues ?? [] + values: bindValues ?? [], + bindTypes: buildBindTypes(bindValues) }) return result @@ -166,3 +168,14 @@ export default class Database { return success } } + +const buildBindTypes = (values?: unknown[]): Map => { + const bindTypes: Map = new Map() + if (!values) return bindTypes + for (const [index, v] of values.entries()) { + if (v instanceof Uint8Array || v instanceof ArrayBuffer) { + bindTypes.set(index, 'bytearray') + } + } + return bindTypes +} diff --git a/plugins/sql/src/commands.rs b/plugins/sql/src/commands.rs index 760d00b2d2..14443a959c 100644 --- a/plugins/sql/src/commands.rs +++ b/plugins/sql/src/commands.rs @@ -5,6 +5,7 @@ use indexmap::IndexMap; use serde_json::Value as JsonValue; use sqlx::migrate::Migrator; +use std::collections::HashMap; use tauri::{command, AppHandle, Runtime, State}; use crate::{DbInstances, DbPool, Error, LastInsertId, Migrations}; @@ -59,11 +60,12 @@ pub(crate) async fn execute( db: String, query: String, values: Vec, + bind_types: HashMap, ) -> Result<(u64, LastInsertId), crate::Error> { let instances = db_instances.0.read().await; let db = instances.get(&db).ok_or(Error::DatabaseNotLoaded(db))?; - db.execute(query, values).await + db.execute(query, values, bind_types).await } #[command] @@ -72,9 +74,10 @@ pub(crate) async fn select( db: String, query: String, values: Vec, + bind_types: HashMap, ) -> Result>, crate::Error> { let instances = db_instances.0.read().await; let db = instances.get(&db).ok_or(Error::DatabaseNotLoaded(db))?; - db.select(query, values).await + db.select(query, values, bind_types).await } diff --git a/plugins/sql/src/wrapper.rs b/plugins/sql/src/wrapper.rs index d47b2d1cbe..6ca4e3a4a2 100644 --- a/plugins/sql/src/wrapper.rs +++ b/plugins/sql/src/wrapper.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: MIT +use std::collections::HashMap; #[cfg(feature = "sqlite")] use std::fs::create_dir_all; @@ -147,13 +148,22 @@ impl DbPool { &self, _query: String, _values: Vec, + _bind_types: HashMap, ) -> Result<(u64, LastInsertId), crate::Error> { Ok(match self { #[cfg(feature = "sqlite")] DbPool::Sqlite(pool) => { let mut query = sqlx::query(&_query); - for value in _values { - if value.is_null() { + for (i, value) in _values.iter().enumerate() { + if _bind_types.get(&i) == Some(&"bytearray".to_string()) { + let bytes = value + .as_array() + .unwrap() + .iter() + .map(|v| v.as_u64().unwrap() as u8) + .collect::>(); + query = query.bind(bytes); + } else if value.is_null() { query = query.bind(None::); } else if value.is_string() { query = query.bind(value.as_str().unwrap().to_owned()) @@ -215,13 +225,22 @@ impl DbPool { &self, _query: String, _values: Vec, + _bind_types: HashMap, ) -> Result>, crate::Error> { Ok(match self { #[cfg(feature = "sqlite")] DbPool::Sqlite(pool) => { let mut query = sqlx::query(&_query); - for value in _values { - if value.is_null() { + for (i, value) in _values.iter().enumerate() { + if _bind_types.get(&i) == Some(&"bytearray".to_string()) { + let bytes = value + .as_array() + .unwrap() + .iter() + .map(|v| v.as_u64().unwrap() as u8) + .collect::>(); + query = query.bind(bytes); + } else if value.is_null() { query = query.bind(None::); } else if value.is_string() { query = query.bind(value.as_str().unwrap().to_owned())