diff --git a/src/pool.rs b/src/pool.rs index 2192c8b..c985a89 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -104,19 +104,38 @@ impl PoolBuilder { /// ``` pub async fn open(self) -> Result { let num_conns = self.get_num_conns(); - let opens = (0..num_conns).map(|_| { + + // Open the first connection with full config (including journal_mode). + // This must complete before opening remaining connections to avoid + // concurrent PRAGMA writes on a new database file. + let first = ClientBuilder { + path: self.path.clone(), + flags: self.flags, + journal_mode: self.journal_mode, + vfs: self.vfs.clone(), + } + .open() + .await?; + + // Open remaining connections without journal_mode since it's a + // database-level setting already applied by the first connection. + let opens = (1..num_conns).map(|_| { ClientBuilder { path: self.path.clone(), flags: self.flags, - journal_mode: self.journal_mode, + journal_mode: None, vfs: self.vfs.clone(), } .open() }); - let clients = join_all(opens) - .await - .into_iter() - .collect::, Error>>()?; + let mut clients = vec![first]; + clients.extend( + join_all(opens) + .await + .into_iter() + .collect::, Error>>()?, + ); + Ok(Pool { state: Arc::new(State { clients, @@ -139,17 +158,33 @@ impl PoolBuilder { /// ``` pub fn open_blocking(self) -> Result { let num_conns = self.get_num_conns(); - let clients = (0..num_conns) - .map(|_| { - ClientBuilder { - path: self.path.clone(), - flags: self.flags, - journal_mode: self.journal_mode, - vfs: self.vfs.clone(), - } - .open_blocking() - }) - .collect::, Error>>()?; + + // Open the first connection with full config (including journal_mode). + let first = ClientBuilder { + path: self.path.clone(), + flags: self.flags, + journal_mode: self.journal_mode, + vfs: self.vfs.clone(), + } + .open_blocking()?; + + // Open remaining connections without journal_mode since it's a + // database-level setting already applied by the first connection. + let mut clients = vec![first]; + clients.extend( + (1..num_conns) + .map(|_| { + ClientBuilder { + path: self.path.clone(), + flags: self.flags, + journal_mode: None, + vfs: self.vfs.clone(), + } + .open_blocking() + }) + .collect::, Error>>()?, + ); + Ok(Pool { state: Arc::new(State { clients, diff --git a/tests/tests.rs b/tests/tests.rs index ad15a20..eb22604 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -83,6 +83,7 @@ macro_rules! async_test { async_test!(test_journal_mode); async_test!(test_concurrency); async_test!(test_pool); +async_test!(test_pool_journal_mode); async_test!(test_pool_conn_for_each); async_test!(test_pool_close_concurrent); async_test!(test_pool_num_conns_zero_clamps); @@ -170,6 +171,28 @@ async fn test_pool() { .expect("collecting query results"); } +async fn test_pool_journal_mode() { + let tmp_dir = tempfile::tempdir().unwrap(); + let pool = PoolBuilder::new() + .journal_mode(JournalMode::Wal) + .path(tmp_dir.path().join("sqlite.db")) + .num_conns(4) + .open() + .await + .expect("pool unable to be opened"); + + // Verify all connections see WAL journal mode. + let results = pool + .conn_for_each(|conn| conn.query_row("PRAGMA journal_mode", (), |row| row.get(0))) + .await; + for result in results { + let mode: String = result.unwrap(); + assert_eq!(mode, "wal"); + } + + pool.close().await.expect("closing pool"); +} + async fn test_pool_conn_for_each() { // make dummy db let tmp_dir = tempfile::tempdir().unwrap();