Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,38 @@ impl PoolBuilder {
/// ```
pub async fn open(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
let opens = (0..num_conns).map(|_| {

// Open the first connection with full config (including journal_mode).
// This must complete before opening remaining connections to avoid
// concurrent PRAGMA writes on a new database file.
let first = ClientBuilder {
path: self.path.clone(),
flags: self.flags,
journal_mode: self.journal_mode,
vfs: self.vfs.clone(),
}
.open()
.await?;

// Open remaining connections without journal_mode since it's a
// database-level setting already applied by the first connection.
let opens = (1..num_conns).map(|_| {
ClientBuilder {
path: self.path.clone(),
flags: self.flags,
journal_mode: self.journal_mode,
journal_mode: None,
vfs: self.vfs.clone(),
}
.open()
});
let clients = join_all(opens)
.await
.into_iter()
.collect::<Result<Vec<Client>, Error>>()?;
let mut clients = vec![first];
clients.extend(
join_all(opens)
.await
.into_iter()
.collect::<Result<Vec<Client>, Error>>()?,
);

Ok(Pool {
state: Arc::new(State {
clients,
Expand All @@ -139,17 +158,33 @@ impl PoolBuilder {
/// ```
pub fn open_blocking(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
let clients = (0..num_conns)
.map(|_| {
ClientBuilder {
path: self.path.clone(),
flags: self.flags,
journal_mode: self.journal_mode,
vfs: self.vfs.clone(),
}
.open_blocking()
})
.collect::<Result<Vec<Client>, Error>>()?;

// Open the first connection with full config (including journal_mode).
let first = ClientBuilder {
path: self.path.clone(),
flags: self.flags,
journal_mode: self.journal_mode,
vfs: self.vfs.clone(),
}
.open_blocking()?;

// Open remaining connections without journal_mode since it's a
// database-level setting already applied by the first connection.
let mut clients = vec![first];
clients.extend(
(1..num_conns)
.map(|_| {
ClientBuilder {
path: self.path.clone(),
flags: self.flags,
journal_mode: None,
vfs: self.vfs.clone(),
}
.open_blocking()
})
.collect::<Result<Vec<Client>, Error>>()?,
);

Ok(Pool {
state: Arc::new(State {
clients,
Expand Down
23 changes: 23 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ macro_rules! async_test {
async_test!(test_journal_mode);
async_test!(test_concurrency);
async_test!(test_pool);
async_test!(test_pool_journal_mode);
async_test!(test_pool_conn_for_each);
async_test!(test_pool_close_concurrent);
async_test!(test_pool_num_conns_zero_clamps);
Expand Down Expand Up @@ -170,6 +171,28 @@ async fn test_pool() {
.expect("collecting query results");
}

async fn test_pool_journal_mode() {
let tmp_dir = tempfile::tempdir().unwrap();
let pool = PoolBuilder::new()
.journal_mode(JournalMode::Wal)
.path(tmp_dir.path().join("sqlite.db"))
.num_conns(4)
.open()
.await
.expect("pool unable to be opened");

// Verify all connections see WAL journal mode.
let results = pool
.conn_for_each(|conn| conn.query_row("PRAGMA journal_mode", (), |row| row.get(0)))
.await;
for result in results {
let mode: String = result.unwrap();
assert_eq!(mode, "wal");
}

pool.close().await.expect("closing pool");
}

async fn test_pool_conn_for_each() {
// make dummy db
let tmp_dir = tempfile::tempdir().unwrap();
Expand Down