Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/driver/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@ use crate::{
exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError},
format_helpers::quote_ident,
query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult},
runtime::{rustdriver_future, tokio_runtime},
runtime::tokio_runtime,
};

use super::{
common_options::{LoadBalanceHosts, SslMode, TargetSessionAttrs},
connection_pool::{connect_pool, ConnectionPool},
connection_pool::connect_pool,
cursor::Cursor,
inner_connection::PsqlpyConnection,
transaction::Transaction,
transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit},
utils::build_connection_config,
};

/// Make new connection pool.
Expand Down
28 changes: 26 additions & 2 deletions src/driver/inner_connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bytes::Buf;
use deadpool_postgres::Object;
use deadpool_postgres::{Object, Transaction};
use postgres_types::{ToSql, Type};
use pyo3::{Py, PyAny, Python};
use pyo3::{pyclass, Py, PyAny, Python};
use std::vec;
use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement};

Expand All @@ -18,6 +18,11 @@ pub enum PsqlpyConnection {
SingleConn(Client),
}

// #[pyclass]
// struct Portal {
// trans: Transaction<'static>,
// }

impl PsqlpyConnection {
/// Prepare cached statement.
///
Expand All @@ -38,6 +43,25 @@ impl PsqlpyConnection {
}
}

// pub async fn transaction(&mut self) -> Portal {
// match self {
// PsqlpyConnection::PoolConn(pconn, _) => {
// let b = unsafe {
// std::mem::transmute::<Transaction<'_>, Transaction<'static>>(pconn.transaction().await.unwrap())
// };
// Portal {trans: b}
// // let c = b.bind("SELECT 1", &[]).await.unwrap();
// // b.query_portal(&c, 1).await;
// }
// PsqlpyConnection::SingleConn(sconn) => {
// let b = unsafe {
// std::mem::transmute::<Transaction<'_>, Transaction<'static>>(sconn.transaction().await.unwrap())
// };
// Portal {trans: b}
// },
// }
// }

/// Delete prepared statement.
///
/// # Errors
Expand Down
94 changes: 94 additions & 0 deletions src/driver/inner_transaction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use deadpool_postgres::Transaction as dp_Transaction;
use postgres_types::ToSql;
use tokio_postgres::{Portal, Row, ToStatement, Transaction as tp_Transaction};

use crate::exceptions::rust_errors::PSQLPyResult;

pub enum PsqlpyTransaction {
PoolTrans(dp_Transaction<'static>),
SingleConnTrans(tp_Transaction<'static>),
}

impl PsqlpyTransaction {
async fn commit(self) -> PSQLPyResult<()> {
match self {
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.commit().await?),
PsqlpyTransaction::SingleConnTrans(s_txid) => Ok(s_txid.commit().await?),
}
}

async fn rollback(self) -> PSQLPyResult<()> {
match self {
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.rollback().await?),
PsqlpyTransaction::SingleConnTrans(s_txid) => Ok(s_txid.rollback().await?),
}
}

async fn savepoint(&mut self, sp_name: &str) -> PSQLPyResult<()> {
match self {
PsqlpyTransaction::PoolTrans(p_txid) => {
p_txid.savepoint(sp_name).await?;
Ok(())
}
PsqlpyTransaction::SingleConnTrans(s_txid) => {
s_txid.savepoint(sp_name).await?;
Ok(())
}
}
}

async fn release_savepoint(&self, sp_name: &str) -> PSQLPyResult<()> {
match self {
PsqlpyTransaction::PoolTrans(p_txid) => {
p_txid
.batch_execute(format!("RELEASE SAVEPOINT {sp_name}").as_str())
.await?;
Ok(())
}
PsqlpyTransaction::SingleConnTrans(s_txid) => {
s_txid
.batch_execute(format!("RELEASE SAVEPOINT {sp_name}").as_str())
.await?;
Ok(())
}
}
}

async fn rollback_savepoint(&self, sp_name: &str) -> PSQLPyResult<()> {
match self {
PsqlpyTransaction::PoolTrans(p_txid) => {
p_txid
.batch_execute(format!("ROLLBACK TO SAVEPOINT {sp_name}").as_str())
.await?;
Ok(())
}
PsqlpyTransaction::SingleConnTrans(s_txid) => {
s_txid
.batch_execute(format!("ROLLBACK TO SAVEPOINT {sp_name}").as_str())
.await?;
Ok(())
}
}
}

async fn bind<T>(&self, statement: &T, params: &[&(dyn ToSql + Sync)]) -> PSQLPyResult<Portal>
where
T: ?Sized + ToStatement,
{
match self {
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.bind(statement, params).await?),
PsqlpyTransaction::SingleConnTrans(s_txid) => {
Ok(s_txid.bind(statement, params).await?)
}
}
}

pub async fn query_portal(&self, portal: &Portal, size: i32) -> PSQLPyResult<Vec<Row>> {
match self {
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.query_portal(portal, size).await?),
PsqlpyTransaction::SingleConnTrans(s_txid) => {
Ok(s_txid.query_portal(portal, size).await?)
}
}
}
}
2 changes: 2 additions & 0 deletions src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ pub mod connection_pool;
pub mod connection_pool_builder;
pub mod cursor;
pub mod inner_connection;
pub mod inner_transaction;
pub mod listener;
pub mod portal;
pub mod transaction;
pub mod transaction_options;
pub mod utils;
52 changes: 52 additions & 0 deletions src/driver/portal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::sync::Arc;

use pyo3::{pyclass, pymethods};
use tokio_postgres::Portal as tp_Portal;

use crate::{exceptions::rust_errors::PSQLPyResult, query_result::PSQLDriverPyQueryResult};

use super::inner_transaction::PsqlpyTransaction;

#[pyclass]
struct Portal {
transaction: Arc<PsqlpyTransaction>,
inner: tp_Portal,
array_size: i32,
}

impl Portal {
async fn query_portal(&self, size: i32) -> PSQLPyResult<PSQLDriverPyQueryResult> {
let result = self.transaction.query_portal(&self.inner, size).await?;
Ok(PSQLDriverPyQueryResult::new(result))
}
}

#[pymethods]
impl Portal {
#[getter]
fn get_array_size(&self) -> i32 {
self.array_size
}

#[setter]
fn set_array_size(&mut self, value: i32) {
self.array_size = value;
}

async fn fetch_one(&self) -> PSQLPyResult<PSQLDriverPyQueryResult> {
self.query_portal(1).await
}

#[pyo3(signature = (size=None))]
async fn fetch_many(&self, size: Option<i32>) -> PSQLPyResult<PSQLDriverPyQueryResult> {
self.query_portal(size.unwrap_or(self.array_size)).await
}

async fn fetch_all(&self) -> PSQLPyResult<PSQLDriverPyQueryResult> {
self.query_portal(-1).await
}

async fn close(&mut self) {
let _ = Arc::downgrade(&self.transaction);
}
}
Loading