Skip to content

Commit 1ffad82

Browse files
committed
fix rt dropping before hrana close can be sent
1 parent e75e3ee commit 1ffad82

File tree

1 file changed

+32
-40
lines changed

1 file changed

+32
-40
lines changed

src/lib.rs

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,25 @@ use pyo3::create_exception;
33
use pyo3::exceptions::PyValueError;
44
use pyo3::prelude::*;
55
use pyo3::types::{PyList, PyTuple};
6-
use std::cell::RefCell;
6+
use std::cell::{OnceCell, RefCell};
77
use std::sync::Arc;
8+
use tokio::runtime::{Handle, Runtime};
89

910
const LEGACY_TRANSACTION_CONTROL: i32 = -1;
1011

12+
fn rt() -> Handle {
13+
const RT: OnceCell<Runtime> = OnceCell::new();
14+
15+
RT.get_or_init(|| {
16+
tokio::runtime::Builder::new_multi_thread()
17+
.worker_threads(1)
18+
.build()
19+
.unwrap()
20+
})
21+
.handle()
22+
.clone()
23+
}
24+
1125
fn to_py_err(error: libsql_core::errors::Error) -> PyErr {
1226
let msg = match error {
1327
libsql::Error::SqliteFailure(_, err) => err,
@@ -99,7 +113,7 @@ fn _connect_core(
99113
) -> PyResult<Connection> {
100114
let ver = env!("CARGO_PKG_VERSION");
101115
let ver = format!("libsql-python-rpc-{ver}");
102-
let rt = tokio::runtime::Runtime::new().unwrap();
116+
let rt = rt();
103117
let encryption_config = match encryption_key {
104118
Some(key) => {
105119
let cipher = libsql::Cipher::default();
@@ -147,9 +161,8 @@ fn _connect_core(
147161
db,
148162
conn: Arc::new(ConnectionGuard {
149163
conn: Some(conn),
150-
handle: rt.handle().clone(),
164+
handle: rt.clone(),
151165
}),
152-
rt,
153166
isolation_level,
154167
autocommit,
155168
})
@@ -186,7 +199,6 @@ impl Drop for ConnectionGuard {
186199
pub struct Connection {
187200
db: libsql_core::Database,
188201
conn: Arc<ConnectionGuard>,
189-
rt: tokio::runtime::Runtime,
190202
isolation_level: Option<String>,
191203
autocommit: i32,
192204
}
@@ -199,7 +211,6 @@ impl Connection {
199211
fn cursor(&self) -> PyResult<Cursor> {
200212
Ok(Cursor {
201213
arraysize: 1,
202-
rt: self.rt.handle().clone(),
203214
conn: self.conn.clone(),
204215
stmt: RefCell::new(None),
205216
rows: RefCell::new(None),
@@ -212,24 +223,19 @@ impl Connection {
212223

213224
fn sync(self_: PyRef<'_, Self>, py: Python<'_>) -> PyResult<()> {
214225
let fut = {
215-
let _enter = self_.rt.enter();
226+
let _enter = rt().enter();
216227
self_.db.sync()
217228
};
218229
tokio::pin!(fut);
219230

220-
self_
221-
.rt
222-
.block_on(check_signals(py, fut))
223-
.map_err(to_py_err)?;
231+
rt().block_on(check_signals(py, fut)).map_err(to_py_err)?;
224232
Ok(())
225233
}
226234

227235
fn commit(self_: PyRef<'_, Self>) -> PyResult<()> {
228236
// TODO: Switch to libSQL transaction API
229237
if !self_.conn.is_autocommit() {
230-
self_
231-
.rt
232-
.block_on(async { self_.conn.execute("COMMIT", ()).await })
238+
rt().block_on(async { self_.conn.execute("COMMIT", ()).await })
233239
.map_err(to_py_err)?;
234240
}
235241
Ok(())
@@ -238,9 +244,7 @@ impl Connection {
238244
fn rollback(self_: PyRef<'_, Self>) -> PyResult<()> {
239245
// TODO: Switch to libSQL transaction API
240246
if !self_.conn.is_autocommit() {
241-
self_
242-
.rt
243-
.block_on(async { self_.conn.execute("ROLLBACK", ()).await })
247+
rt().block_on(async { self_.conn.execute("ROLLBACK", ()).await })
244248
.map_err(to_py_err)?;
245249
}
246250
Ok(())
@@ -252,8 +256,7 @@ impl Connection {
252256
parameters: Option<&PyTuple>,
253257
) -> PyResult<Cursor> {
254258
let cursor = Connection::cursor(&self_)?;
255-
let rt = self_.rt.handle();
256-
rt.block_on(async { execute(&cursor, sql, parameters).await })?;
259+
rt().block_on(async { execute(&cursor, sql, parameters).await })?;
257260
Ok(cursor)
258261
}
259262

@@ -265,17 +268,15 @@ impl Connection {
265268
let cursor = Connection::cursor(&self_)?;
266269
for parameters in parameters.unwrap().iter() {
267270
let parameters = parameters.extract::<&PyTuple>()?;
268-
self_
269-
.rt
270-
.block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?;
271+
rt().block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?;
271272
}
272273
Ok(cursor)
273274
}
274275

275276
fn executescript(self_: PyRef<'_, Self>, script: String) -> PyResult<()> {
276-
let _ = self_.rt.block_on(async {
277-
self_.conn.execute_batch(&script).await
278-
}).map_err(to_py_err);
277+
let _ = rt()
278+
.block_on(async { self_.conn.execute_batch(&script).await })
279+
.map_err(to_py_err);
279280
Ok(())
280281
}
281282

@@ -316,7 +317,6 @@ impl Connection {
316317
pub struct Cursor {
317318
#[pyo3(get, set)]
318319
arraysize: usize,
319-
rt: tokio::runtime::Handle,
320320
conn: Arc<ConnectionGuard>,
321321
stmt: RefCell<Option<libsql_core::Statement>>,
322322
rows: RefCell<Option<libsql_core::Rows>>,
@@ -336,9 +336,7 @@ impl Cursor {
336336
sql: String,
337337
parameters: Option<&PyTuple>,
338338
) -> PyResult<pyo3::PyRef<'a, Self>> {
339-
self_
340-
.rt
341-
.block_on(async { execute(&self_, sql, parameters).await })?;
339+
rt().block_on(async { execute(&self_, sql, parameters).await })?;
342340
Ok(self_)
343341
}
344342

@@ -349,9 +347,7 @@ impl Cursor {
349347
) -> PyResult<pyo3::PyRef<'a, Cursor>> {
350348
for parameters in parameters.unwrap().iter() {
351349
let parameters = parameters.extract::<&PyTuple>()?;
352-
self_
353-
.rt
354-
.block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?;
350+
rt().block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?;
355351
}
356352
Ok(self_)
357353
}
@@ -360,9 +356,7 @@ impl Cursor {
360356
self_: PyRef<'a, Self>,
361357
script: String,
362358
) -> PyResult<pyo3::PyRef<'a, Self>> {
363-
self_
364-
.rt
365-
.block_on(async { self_.conn.execute_batch(&script).await })
359+
rt().block_on(async { self_.conn.execute_batch(&script).await })
366360
.map_err(to_py_err)?;
367361
Ok(self_)
368362
}
@@ -398,7 +392,7 @@ impl Cursor {
398392
let mut rows = self_.rows.borrow_mut();
399393
match rows.as_mut() {
400394
Some(rows) => {
401-
let row = self_.rt.block_on(rows.next()).map_err(to_py_err)?;
395+
let row = rt().block_on(rows.next()).map_err(to_py_err)?;
402396
match row {
403397
Some(row) => {
404398
let row = convert_row(self_.py(), row, rows.column_count())?;
@@ -422,8 +416,7 @@ impl Cursor {
422416
// done before iterating.
423417
if !*self_.done.borrow() {
424418
for _ in 0..size {
425-
let row = self_
426-
.rt
419+
let row = rt()
427420
.block_on(async { rows.next().await })
428421
.map_err(to_py_err)?;
429422
match row {
@@ -450,8 +443,7 @@ impl Cursor {
450443
Some(rows) => {
451444
let mut elements: Vec<Py<PyAny>> = vec![];
452445
loop {
453-
let row = self_
454-
.rt
446+
let row = rt()
455447
.block_on(async { rows.next().await })
456448
.map_err(to_py_err)?;
457449
match row {

0 commit comments

Comments
 (0)