Skip to content

Commit 8b40f28

Browse files
committed
Optimize executescript() to use batching
Refs #70
1 parent b15302e commit 8b40f28

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

src/lib.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,9 @@ impl Connection {
205205
}
206206

207207
fn executescript(self_: PyRef<'_, Self>, script: String) -> PyResult<()> {
208-
let statements = script.split(';');
209-
for statement in statements {
210-
let statement = statement.trim();
211-
if !statement.is_empty() {
212-
let cursor = Connection::cursor(&self_)?;
213-
self_
214-
.rt
215-
.block_on(async { execute(&cursor, statement.to_string(), None).await })?;
216-
}
217-
}
208+
let _ = self_.rt.block_on(async {
209+
self_.conn.execute_batch(&script).await
210+
}).map_err(to_py_err);
218211
Ok(())
219212
}
220213

@@ -272,6 +265,16 @@ impl Cursor {
272265
Ok(self_)
273266
}
274267

268+
fn executescript<'a>(self_: PyRef<'a, Self>, script: String) -> PyResult<pyo3::PyRef<'a, Self>> {
269+
self_
270+
.rt
271+
.block_on(async {
272+
self_.conn.execute_batch(&script).await
273+
})
274+
.map_err(to_py_err)?;
275+
Ok(self_)
276+
}
277+
275278
#[getter]
276279
fn description(self_: PyRef<'_, Self>) -> PyResult<Option<&PyTuple>> {
277280
let stmt = self_.stmt.borrow();

tests/test_suite.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def test_cursor_executemany(provider):
8888
res = cur.execute("SELECT * FROM users")
8989
assert [(1, 'alice@example.com'), (2, 'bob@example.com')] == res.fetchall()
9090

91+
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
92+
def test_cursor_executescript(provider):
93+
conn = connect(provider, ":memory:")
94+
cur = conn.cursor()
95+
cur.executescript("""
96+
CREATE TABLE users (id INTEGER, email TEXT);
97+
INSERT INTO users VALUES (1, 'alice@example.org');
98+
INSERT INTO users VALUES (2, 'bob@example.org');
99+
""")
100+
res = cur.execute("SELECT * FROM users")
101+
assert (1, 'alice@example.org') == res.fetchone()
102+
assert (2, 'bob@example.org') == res.fetchone()
103+
91104
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
92105
def test_lastrowid(provider):
93106
conn = connect(provider, ":memory:")

0 commit comments

Comments
 (0)