diff --git a/src/driver/connection.rs b/src/driver/connection.rs index ded325a2..2210e303 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -9,17 +9,16 @@ use crate::{ exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, format_helpers::quote_ident, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - runtime::{rustdriver_future, tokio_runtime}, + runtime::tokio_runtime, }; use super::{ common_options::{LoadBalanceHosts, SslMode, TargetSessionAttrs}, - connection_pool::{connect_pool, ConnectionPool}, + connection_pool::connect_pool, cursor::Cursor, inner_connection::PsqlpyConnection, transaction::Transaction, transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit}, - utils::build_connection_config, }; /// Make new connection pool. diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index d8acc4d8..bb591de7 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -1,7 +1,7 @@ use bytes::Buf; -use deadpool_postgres::Object; +use deadpool_postgres::{Object, Transaction}; use postgres_types::{ToSql, Type}; -use pyo3::{Py, PyAny, Python}; +use pyo3::{pyclass, Py, PyAny, Python}; use std::vec; use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement}; @@ -18,6 +18,11 @@ pub enum PsqlpyConnection { SingleConn(Client), } +// #[pyclass] +// struct Portal { +// trans: Transaction<'static>, +// } + impl PsqlpyConnection { /// Prepare cached statement. /// @@ -38,6 +43,25 @@ impl PsqlpyConnection { } } + // pub async fn transaction(&mut self) -> Portal { + // match self { + // PsqlpyConnection::PoolConn(pconn, _) => { + // let b = unsafe { + // std::mem::transmute::, Transaction<'static>>(pconn.transaction().await.unwrap()) + // }; + // Portal {trans: b} + // // let c = b.bind("SELECT 1", &[]).await.unwrap(); + // // b.query_portal(&c, 1).await; + // } + // PsqlpyConnection::SingleConn(sconn) => { + // let b = unsafe { + // std::mem::transmute::, Transaction<'static>>(sconn.transaction().await.unwrap()) + // }; + // Portal {trans: b} + // }, + // } + // } + /// Delete prepared statement. /// /// # Errors diff --git a/src/driver/inner_transaction.rs b/src/driver/inner_transaction.rs new file mode 100644 index 00000000..a23f0536 --- /dev/null +++ b/src/driver/inner_transaction.rs @@ -0,0 +1,94 @@ +use deadpool_postgres::Transaction as dp_Transaction; +use postgres_types::ToSql; +use tokio_postgres::{Portal, Row, ToStatement, Transaction as tp_Transaction}; + +use crate::exceptions::rust_errors::PSQLPyResult; + +pub enum PsqlpyTransaction { + PoolTrans(dp_Transaction<'static>), + SingleConnTrans(tp_Transaction<'static>), +} + +impl PsqlpyTransaction { + async fn commit(self) -> PSQLPyResult<()> { + match self { + PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.commit().await?), + PsqlpyTransaction::SingleConnTrans(s_txid) => Ok(s_txid.commit().await?), + } + } + + async fn rollback(self) -> PSQLPyResult<()> { + match self { + PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.rollback().await?), + PsqlpyTransaction::SingleConnTrans(s_txid) => Ok(s_txid.rollback().await?), + } + } + + async fn savepoint(&mut self, sp_name: &str) -> PSQLPyResult<()> { + match self { + PsqlpyTransaction::PoolTrans(p_txid) => { + p_txid.savepoint(sp_name).await?; + Ok(()) + } + PsqlpyTransaction::SingleConnTrans(s_txid) => { + s_txid.savepoint(sp_name).await?; + Ok(()) + } + } + } + + async fn release_savepoint(&self, sp_name: &str) -> PSQLPyResult<()> { + match self { + PsqlpyTransaction::PoolTrans(p_txid) => { + p_txid + .batch_execute(format!("RELEASE SAVEPOINT {sp_name}").as_str()) + .await?; + Ok(()) + } + PsqlpyTransaction::SingleConnTrans(s_txid) => { + s_txid + .batch_execute(format!("RELEASE SAVEPOINT {sp_name}").as_str()) + .await?; + Ok(()) + } + } + } + + async fn rollback_savepoint(&self, sp_name: &str) -> PSQLPyResult<()> { + match self { + PsqlpyTransaction::PoolTrans(p_txid) => { + p_txid + .batch_execute(format!("ROLLBACK TO SAVEPOINT {sp_name}").as_str()) + .await?; + Ok(()) + } + PsqlpyTransaction::SingleConnTrans(s_txid) => { + s_txid + .batch_execute(format!("ROLLBACK TO SAVEPOINT {sp_name}").as_str()) + .await?; + Ok(()) + } + } + } + + async fn bind(&self, statement: &T, params: &[&(dyn ToSql + Sync)]) -> PSQLPyResult + where + T: ?Sized + ToStatement, + { + match self { + PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.bind(statement, params).await?), + PsqlpyTransaction::SingleConnTrans(s_txid) => { + Ok(s_txid.bind(statement, params).await?) + } + } + } + + pub async fn query_portal(&self, portal: &Portal, size: i32) -> PSQLPyResult> { + match self { + PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.query_portal(portal, size).await?), + PsqlpyTransaction::SingleConnTrans(s_txid) => { + Ok(s_txid.query_portal(portal, size).await?) + } + } + } +} diff --git a/src/driver/mod.rs b/src/driver/mod.rs index e7827cd5..416bfa97 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -4,7 +4,9 @@ pub mod connection_pool; pub mod connection_pool_builder; pub mod cursor; pub mod inner_connection; +pub mod inner_transaction; pub mod listener; +pub mod portal; pub mod transaction; pub mod transaction_options; pub mod utils; diff --git a/src/driver/portal.rs b/src/driver/portal.rs new file mode 100644 index 00000000..d90138b0 --- /dev/null +++ b/src/driver/portal.rs @@ -0,0 +1,52 @@ +use std::sync::Arc; + +use pyo3::{pyclass, pymethods}; +use tokio_postgres::Portal as tp_Portal; + +use crate::{exceptions::rust_errors::PSQLPyResult, query_result::PSQLDriverPyQueryResult}; + +use super::inner_transaction::PsqlpyTransaction; + +#[pyclass] +struct Portal { + transaction: Arc, + inner: tp_Portal, + array_size: i32, +} + +impl Portal { + async fn query_portal(&self, size: i32) -> PSQLPyResult { + let result = self.transaction.query_portal(&self.inner, size).await?; + Ok(PSQLDriverPyQueryResult::new(result)) + } +} + +#[pymethods] +impl Portal { + #[getter] + fn get_array_size(&self) -> i32 { + self.array_size + } + + #[setter] + fn set_array_size(&mut self, value: i32) { + self.array_size = value; + } + + async fn fetch_one(&self) -> PSQLPyResult { + self.query_portal(1).await + } + + #[pyo3(signature = (size=None))] + async fn fetch_many(&self, size: Option) -> PSQLPyResult { + self.query_portal(size.unwrap_or(self.array_size)).await + } + + async fn fetch_all(&self) -> PSQLPyResult { + self.query_portal(-1).await + } + + async fn close(&mut self) { + let _ = Arc::downgrade(&self.transaction); + } +}