From d29bd6bd3f242b6cffb3c3a2d3b49f33228039e8 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Tue, 29 Apr 2025 22:48:56 +0200 Subject: [PATCH 01/15] Started implementing new traits and realizations for convertions from Python to Rust types --- src/driver/inner_connection.rs | 2 +- src/extra_types.rs | 23 +- src/value_converter/additional_types.rs | 2 + src/value_converter/dto/converter_impls.rs | 199 ++++++++++++++++++ src/value_converter/dto/enums.rs | 82 ++++++++ .../{models/dto.rs => dto/impls.rs} | 165 ++------------- src/value_converter/dto/mod.rs | 3 + src/value_converter/funcs/from_python.rs | 112 +++------- src/value_converter/mod.rs | 2 + src/value_converter/models/mod.rs | 1 - src/value_converter/models/serde_value.rs | 149 ++++++++++--- src/value_converter/traits.rs | 9 + 12 files changed, 488 insertions(+), 261 deletions(-) create mode 100644 src/value_converter/dto/converter_impls.rs create mode 100644 src/value_converter/dto/enums.rs rename src/value_converter/{models/dto.rs => dto/impls.rs} (74%) create mode 100644 src/value_converter/dto/mod.rs create mode 100644 src/value_converter/traits.rs diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index ae060baa..2dfbcbb7 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -10,8 +10,8 @@ use crate::{ query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, value_converter::{ consts::QueryParameter, + dto::enums::PythonDTO, funcs::{from_python::convert_parameters_and_qs, to_python::postgres_to_py}, - models::dto::PythonDTO, }, }; diff --git a/src/extra_types.rs b/src/extra_types.rs index 1e8d22b4..48058ff4 100644 --- a/src/extra_types.rs +++ b/src/extra_types.rs @@ -13,13 +13,19 @@ use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, value_converter::{ additional_types::{Circle as RustCircle, Line as RustLine}, + dto::enums::PythonDTO, funcs::from_python::{ build_flat_geo_coords, build_geo_coords, py_sequence_into_postgres_array, }, - models::{dto::PythonDTO, serde_value::build_serde_value}, + models::serde_value::build_serde_value, }, }; +pub struct PythonArray; +pub struct PythonDecimal; +pub struct PythonUUID; +pub struct PythonEnum; + #[pyclass] #[derive(Clone)] pub struct PgVector(Vec); @@ -34,7 +40,7 @@ impl PgVector { impl PgVector { #[must_use] - pub fn inner_value(self) -> Vec { + pub fn inner(self) -> Vec { self.0 } } @@ -49,7 +55,7 @@ macro_rules! build_python_type { impl $st_name { #[must_use] - pub fn retrieve_value(&self) -> $rust_type { + pub fn inner(&self) -> $rust_type { self.inner_value } } @@ -135,7 +141,12 @@ macro_rules! build_json_py_type { impl $st_name { #[must_use] - pub fn inner(&self) -> &$rust_type { + pub fn inner(&self) -> $rust_type { + self.inner.clone() + } + + #[must_use] + pub fn inner_ref(&self) -> &$rust_type { &self.inner } } @@ -144,7 +155,7 @@ macro_rules! build_json_py_type { impl $st_name { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_class(value: Py) -> RustPSQLDriverPyResult { + pub fn new_class(value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { Ok(Self { inner: build_serde_value(value)?, }) @@ -223,7 +234,7 @@ macro_rules! build_geo_type { impl $st_name { #[must_use] - pub fn retrieve_value(&self) -> $rust_type { + pub fn inner(&self) -> $rust_type { self.inner.clone() } } diff --git a/src/value_converter/additional_types.rs b/src/value_converter/additional_types.rs index 5dd435a0..1159939a 100644 --- a/src/value_converter/additional_types.rs +++ b/src/value_converter/additional_types.rs @@ -13,6 +13,8 @@ use pyo3::{ use serde::{Deserialize, Serialize}; use tokio_postgres::types::{FromSql, Type}; +pub struct NonePyType; + macro_rules! build_additional_rust_type { ($st_name:ident, $rust_type:ty) => { #[derive(Debug)] diff --git a/src/value_converter/dto/converter_impls.rs b/src/value_converter/dto/converter_impls.rs new file mode 100644 index 00000000..97675af3 --- /dev/null +++ b/src/value_converter/dto/converter_impls.rs @@ -0,0 +1,199 @@ +use std::net::IpAddr; + +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; +use pg_interval::Interval; +use pyo3::{ + types::{PyAnyMethods, PyDateTime, PyDelta, PyDict}, + Bound, PyAny, +}; +use rust_decimal::Decimal; +use uuid::Uuid; + +use crate::{ + exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + extra_types::{self, PythonDecimal, PythonUUID}, + value_converter::{ + additional_types::NonePyType, + funcs::from_python::{ + extract_datetime_from_python_object_attrs, py_sequence_into_postgres_array, + }, + models::serde_value::build_serde_value, + traits::PythonToDTO, + }, +}; + +use super::enums::PythonDTO; + +impl PythonToDTO for NonePyType { + fn to_python_dto(_python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + Ok(PythonDTO::PyNone) + } +} + +macro_rules! construct_simple_type_matcher { + ($match_type:ty, $kind:path) => { + impl PythonToDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + Ok($kind(python_param.extract::<$match_type>()?)) + } + } + }; +} + +construct_simple_type_matcher!(bool, PythonDTO::PyBool); +construct_simple_type_matcher!(Vec, PythonDTO::PyBytes); +construct_simple_type_matcher!(String, PythonDTO::PyString); +construct_simple_type_matcher!(f32, PythonDTO::PyFloat32); +construct_simple_type_matcher!(f64, PythonDTO::PyFloat64); +construct_simple_type_matcher!(i16, PythonDTO::PyIntI16); +construct_simple_type_matcher!(i32, PythonDTO::PyIntI32); +construct_simple_type_matcher!(i64, PythonDTO::PyIntI64); +construct_simple_type_matcher!(NaiveDate, PythonDTO::PyDate); +construct_simple_type_matcher!(NaiveTime, PythonDTO::PyTime); + +impl PythonToDTO for PyDateTime { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + let timestamp_tz = python_param.extract::>(); + if let Ok(pydatetime_tz) = timestamp_tz { + return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); + } + + let timestamp_no_tz = python_param.extract::(); + if let Ok(pydatetime_no_tz) = timestamp_no_tz { + return Ok(PythonDTO::PyDateTime(pydatetime_no_tz)); + } + + let timestamp_tz = extract_datetime_from_python_object_attrs(python_param); + if let Ok(pydatetime_tz) = timestamp_tz { + return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); + } + + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Can not convert you datetime to rust type".into(), + )); + } +} + +impl PythonToDTO for PyDelta { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + let duration = python_param.extract::()?; + if let Some(interval) = Interval::from_duration(duration) { + return Ok(PythonDTO::PyInterval(interval)); + } + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Cannot convert timedelta from Python to inner Rust type.".to_string(), + )); + } +} + +impl PythonToDTO for PyDict { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + let serde_value = build_serde_value(python_param)?; + + return Ok(PythonDTO::PyJsonb(serde_value)); + } +} + +macro_rules! construct_extra_type_matcher { + ($match_type:ty, $kind:path) => { + impl PythonToDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + Ok($kind(python_param.extract::<$match_type>()?.inner())) + } + } + }; +} + +construct_extra_type_matcher!(extra_types::JSONB, PythonDTO::PyJsonb); +construct_extra_type_matcher!(extra_types::JSON, PythonDTO::PyJson); +construct_extra_type_matcher!(extra_types::MacAddr6, PythonDTO::PyMacAddr6); +construct_extra_type_matcher!(extra_types::MacAddr8, PythonDTO::PyMacAddr8); +construct_extra_type_matcher!(extra_types::Point, PythonDTO::PyPoint); +construct_extra_type_matcher!(extra_types::Box, PythonDTO::PyBox); +construct_extra_type_matcher!(extra_types::Path, PythonDTO::PyPath); +construct_extra_type_matcher!(extra_types::Line, PythonDTO::PyLine); +construct_extra_type_matcher!(extra_types::LineSegment, PythonDTO::PyLineSegment); +construct_extra_type_matcher!(extra_types::Circle, PythonDTO::PyCircle); + +impl PythonToDTO for PythonDecimal { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + Ok(PythonDTO::PyDecimal(Decimal::from_str_exact( + python_param.str()?.extract::<&str>()?, + )?)) + } +} + +impl PythonToDTO for PythonUUID { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + Ok(PythonDTO::PyUUID(Uuid::parse_str( + python_param.str()?.extract::<&str>()?, + )?)) + } +} + +impl PythonToDTO for extra_types::PythonArray { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + Ok(PythonDTO::PyArray(py_sequence_into_postgres_array( + python_param, + )?)) + } +} + +impl PythonToDTO for IpAddr { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + if let Ok(id_address) = python_param.extract::() { + return Ok(PythonDTO::PyIpAddress(id_address)); + } + + Err(RustPSQLDriverError::PyToRustValueConversionError( + "Parameter passed to IpAddr is incorrect.".to_string(), + )) + } +} + +impl PythonToDTO for extra_types::PythonEnum { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + let string = python_param.extract::()?; + return Ok(PythonDTO::PyString(string)); + } +} + +macro_rules! construct_array_type_matcher { + ($match_type:ty) => { + impl PythonToDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + python_param + .extract::<$match_type>()? + ._convert_to_python_dto() + } + } + }; +} + +construct_array_type_matcher!(extra_types::BoolArray); +construct_array_type_matcher!(extra_types::UUIDArray); +construct_array_type_matcher!(extra_types::VarCharArray); +construct_array_type_matcher!(extra_types::TextArray); +construct_array_type_matcher!(extra_types::Int16Array); +construct_array_type_matcher!(extra_types::Int32Array); +construct_array_type_matcher!(extra_types::Int64Array); +construct_array_type_matcher!(extra_types::Float32Array); +construct_array_type_matcher!(extra_types::Float64Array); +construct_array_type_matcher!(extra_types::MoneyArray); +construct_array_type_matcher!(extra_types::IpAddressArray); +construct_array_type_matcher!(extra_types::JSONBArray); +construct_array_type_matcher!(extra_types::JSONArray); +construct_array_type_matcher!(extra_types::DateArray); +construct_array_type_matcher!(extra_types::TimeArray); +construct_array_type_matcher!(extra_types::DateTimeArray); +construct_array_type_matcher!(extra_types::DateTimeTZArray); +construct_array_type_matcher!(extra_types::MacAddr6Array); +construct_array_type_matcher!(extra_types::MacAddr8Array); +construct_array_type_matcher!(extra_types::NumericArray); +construct_array_type_matcher!(extra_types::PointArray); +construct_array_type_matcher!(extra_types::BoxArray); +construct_array_type_matcher!(extra_types::PathArray); +construct_array_type_matcher!(extra_types::LineArray); +construct_array_type_matcher!(extra_types::LsegArray); +construct_array_type_matcher!(extra_types::CircleArray); +construct_array_type_matcher!(extra_types::IntervalArray); diff --git a/src/value_converter/dto/enums.rs b/src/value_converter/dto/enums.rs new file mode 100644 index 00000000..00e88a10 --- /dev/null +++ b/src/value_converter/dto/enums.rs @@ -0,0 +1,82 @@ +use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; +use geo_types::{Line as LineSegment, LineString, Point, Rect}; +use macaddr::{MacAddr6, MacAddr8}; +use pg_interval::Interval; +use rust_decimal::Decimal; +use serde_json::Value; +use std::{fmt::Debug, net::IpAddr}; +use uuid::Uuid; + +use crate::value_converter::additional_types::{Circle, Line}; +use postgres_array::array::Array; + +#[derive(Debug, Clone, PartialEq)] +pub enum PythonDTO { + // Primitive + PyNone, + PyBytes(Vec), + PyBool(bool), + PyUUID(Uuid), + PyVarChar(String), + PyText(String), + PyString(String), + PyIntI16(i16), + PyIntI32(i32), + PyIntI64(i64), + PyIntU32(u32), + PyIntU64(u64), + PyFloat32(f32), + PyFloat64(f64), + PyMoney(i64), + PyDate(NaiveDate), + PyTime(NaiveTime), + PyDateTime(NaiveDateTime), + PyDateTimeTz(DateTime), + PyInterval(Interval), + PyIpAddress(IpAddr), + PyList(Vec), + PyArray(Array), + PyTuple(Vec), + PyJsonb(Value), + PyJson(Value), + PyMacAddr6(MacAddr6), + PyMacAddr8(MacAddr8), + PyDecimal(Decimal), + PyCustomType(Vec), + PyPoint(Point), + PyBox(Rect), + PyPath(LineString), + PyLine(Line), + PyLineSegment(LineSegment), + PyCircle(Circle), + // Arrays + PyBoolArray(Array), + PyUuidArray(Array), + PyVarCharArray(Array), + PyTextArray(Array), + PyInt16Array(Array), + PyInt32Array(Array), + PyInt64Array(Array), + PyFloat32Array(Array), + PyFloat64Array(Array), + PyMoneyArray(Array), + PyIpAddressArray(Array), + PyJSONBArray(Array), + PyJSONArray(Array), + PyDateArray(Array), + PyTimeArray(Array), + PyDateTimeArray(Array), + PyDateTimeTZArray(Array), + PyMacAddr6Array(Array), + PyMacAddr8Array(Array), + PyNumericArray(Array), + PyPointArray(Array), + PyBoxArray(Array), + PyPathArray(Array), + PyLineArray(Array), + PyLsegArray(Array), + PyCircleArray(Array), + PyIntervalArray(Array), + // PgVector + PyPgVector(Vec), +} diff --git a/src/value_converter/models/dto.rs b/src/value_converter/dto/impls.rs similarity index 74% rename from src/value_converter/models/dto.rs rename to src/value_converter/dto/impls.rs index 8609a600..b634d8b8 100644 --- a/src/value_converter/models/dto.rs +++ b/src/value_converter/dto/impls.rs @@ -1,111 +1,44 @@ use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; -use geo_types::{Line as LineSegment, LineString, Point, Rect}; -use macaddr::{MacAddr6, MacAddr8}; use pg_interval::Interval; use postgres_types::ToSql; use rust_decimal::Decimal; use serde_json::{json, Value}; -use std::{fmt::Debug, net::IpAddr}; +use std::net::IpAddr; use uuid::Uuid; use bytes::{BufMut, BytesMut}; use postgres_protocol::types; -use pyo3::{PyObject, Python, ToPyObject}; +use pyo3::{Bound, IntoPyObject, PyAny, Python}; use tokio_postgres::types::{to_sql_checked, Type}; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, - value_converter::additional_types::{ - Circle, Line, RustLineSegment, RustLineString, RustPoint, RustRect, + value_converter::{ + additional_types::{Circle, Line, RustLineSegment, RustLineString, RustPoint, RustRect}, + models::serde_value::pythondto_array_to_serde, }, }; use pgvector::Vector as PgVector; -use postgres_array::{array::Array, Dimension}; -#[derive(Debug, Clone, PartialEq)] -pub enum PythonDTO { - // Primitive - PyNone, - PyBytes(Vec), - PyBool(bool), - PyUUID(Uuid), - PyVarChar(String), - PyText(String), - PyString(String), - PyIntI16(i16), - PyIntI32(i32), - PyIntI64(i64), - PyIntU32(u32), - PyIntU64(u64), - PyFloat32(f32), - PyFloat64(f64), - PyMoney(i64), - PyDate(NaiveDate), - PyTime(NaiveTime), - PyDateTime(NaiveDateTime), - PyDateTimeTz(DateTime), - PyInterval(Interval), - PyIpAddress(IpAddr), - PyList(Vec), - PyArray(Array), - PyTuple(Vec), - PyJsonb(Value), - PyJson(Value), - PyMacAddr6(MacAddr6), - PyMacAddr8(MacAddr8), - PyDecimal(Decimal), - PyCustomType(Vec), - PyPoint(Point), - PyBox(Rect), - PyPath(LineString), - PyLine(Line), - PyLineSegment(LineSegment), - PyCircle(Circle), - // Arrays - PyBoolArray(Array), - PyUuidArray(Array), - PyVarCharArray(Array), - PyTextArray(Array), - PyInt16Array(Array), - PyInt32Array(Array), - PyInt64Array(Array), - PyFloat32Array(Array), - PyFloat64Array(Array), - PyMoneyArray(Array), - PyIpAddressArray(Array), - PyJSONBArray(Array), - PyJSONArray(Array), - PyDateArray(Array), - PyTimeArray(Array), - PyDateTimeArray(Array), - PyDateTimeTZArray(Array), - PyMacAddr6Array(Array), - PyMacAddr8Array(Array), - PyNumericArray(Array), - PyPointArray(Array), - PyBoxArray(Array), - PyPathArray(Array), - PyLineArray(Array), - PyLsegArray(Array), - PyCircleArray(Array), - PyIntervalArray(Array), - // PgVector - PyPgVector(Vec), -} +use super::enums::PythonDTO; + +impl<'py> IntoPyObject<'py> for PythonDTO { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = std::convert::Infallible; -impl ToPyObject for PythonDTO { - fn to_object(&self, py: Python<'_>) -> PyObject { + fn into_pyobject(self, py: Python<'py>) -> Result { match self { - PythonDTO::PyNone => py.None(), - PythonDTO::PyBool(pybool) => pybool.to_object(py), + PythonDTO::PyNone => Ok(py.None().into_bound(py)), + PythonDTO::PyBool(pybool) => Ok(pybool.into_pyobject(py)?.to_owned().into_any()), PythonDTO::PyString(py_string) | PythonDTO::PyText(py_string) - | PythonDTO::PyVarChar(py_string) => py_string.to_object(py), - PythonDTO::PyIntI32(pyint) => pyint.to_object(py), - PythonDTO::PyIntI64(pyint) => pyint.to_object(py), - PythonDTO::PyIntU64(pyint) => pyint.to_object(py), - PythonDTO::PyFloat32(pyfloat) => pyfloat.to_object(py), - PythonDTO::PyFloat64(pyfloat) => pyfloat.to_object(py), + | PythonDTO::PyVarChar(py_string) => Ok(py_string.into_pyobject(py)?.into_any()), + PythonDTO::PyIntI32(pyint) => Ok(pyint.into_pyobject(py)?.into_any()), + PythonDTO::PyIntI64(pyint) => Ok(pyint.into_pyobject(py)?.into_any()), + PythonDTO::PyIntU64(pyint) => Ok(pyint.into_pyobject(py)?.into_any()), + PythonDTO::PyFloat32(pyfloat) => Ok(pyfloat.into_pyobject(py)?.into_any()), + PythonDTO::PyFloat64(pyfloat) => Ok(pyfloat.into_pyobject(py)?.into_any()), _ => unreachable!(), } } @@ -431,61 +364,3 @@ impl ToSql for PythonDTO { to_sql_checked!(); } - -/// Convert Array of `PythonDTO`s to serde `Value`. -/// -/// It can convert multidimensional arrays. -fn pythondto_array_to_serde(array: Option>) -> RustPSQLDriverPyResult { - match array { - Some(array) => inner_pythondto_array_to_serde( - array.dimensions(), - array.iter().collect::>().as_slice(), - 0, - 0, - ), - None => Ok(Value::Null), - } -} - -/// Inner conversion array of `PythonDTO`s to serde `Value`. -#[allow(clippy::cast_sign_loss)] -fn inner_pythondto_array_to_serde( - dimensions: &[Dimension], - data: &[&PythonDTO], - dimension_index: usize, - mut lower_bound: usize, -) -> RustPSQLDriverPyResult { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let mut final_list: Value = Value::Array(vec![]); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_pythondto_array_to_serde( - dimensions, - &data[lower_bound..next_dimension.len as usize + lower_bound], - dimension_index + 1, - 0, - )?; - match final_list { - Value::Array(ref mut array) => array.push(inner_pylist), - _ => unreachable!(), - } - lower_bound += next_dimension.len as usize; - }; - } - - return Ok(final_list); - } - None => { - return data.iter().map(|x| x.to_serde_value()).collect(); - } - } - } - - Ok(Value::Array(vec![])) -} diff --git a/src/value_converter/dto/mod.rs b/src/value_converter/dto/mod.rs new file mode 100644 index 00000000..5be9ae5b --- /dev/null +++ b/src/value_converter/dto/mod.rs @@ -0,0 +1,3 @@ +pub mod converter_impls; +pub mod enums; +pub mod impls; diff --git a/src/value_converter/funcs/from_python.rs b/src/value_converter/funcs/from_python.rs index 4fe73290..adad8879 100644 --- a/src/value_converter/funcs/from_python.rs +++ b/src/value_converter/funcs/from_python.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use pg_interval::Interval; use postgres_array::{Array, Dimension}; use rust_decimal::Decimal; -use serde_json::{json, Map, Value}; +use serde_json::{Map, Value}; use std::net::IpAddr; use uuid::Uuid; @@ -21,7 +21,7 @@ use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, extra_types::{self}, value_converter::{ - consts::KWARGS_QUERYSTRINGS, models::dto::PythonDTO, + consts::KWARGS_QUERYSTRINGS, dto::enums::PythonDTO, utils::extract_value_from_python_object_or_raise, }, }; @@ -74,45 +74,37 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< if parameter.is_instance_of::() { return Ok(PythonDTO::PyFloat32( - parameter - .extract::()? - .retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyFloat64( - parameter - .extract::()? - .retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyIntI16( - parameter - .extract::()? - .retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyIntI32( - parameter - .extract::()? - .retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyIntI64( - parameter.extract::()?.retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyMoney( - parameter.extract::()?.retrieve_value(), + parameter.extract::()?.inner(), )); } @@ -192,13 +184,13 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< if parameter.is_instance_of::() { return Ok(PythonDTO::PyJsonb( - parameter.extract::()?.inner().clone(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyJson( - parameter.extract::()?.inner().clone(), + parameter.extract::()?.inner(), )); } @@ -214,58 +206,56 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< )); } - if parameter.get_type().name()? == "UUID" { - return Ok(PythonDTO::PyUUID(Uuid::parse_str( - parameter.str()?.extract::<&str>()?, - )?)); - } - - if parameter.get_type().name()? == "decimal.Decimal" - || parameter.get_type().name()? == "Decimal" - { - return Ok(PythonDTO::PyDecimal(Decimal::from_str_exact( - parameter.str()?.extract::<&str>()?, - )?)); - } - if parameter.is_instance_of::() { return Ok(PythonDTO::PyPoint( - parameter.extract::()?.retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyBox( - parameter.extract::()?.retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyPath( - parameter.extract::()?.retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyLine( - parameter.extract::()?.retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyLineSegment( - parameter - .extract::()? - .retrieve_value(), + parameter.extract::()?.inner(), )); } if parameter.is_instance_of::() { return Ok(PythonDTO::PyCircle( - parameter.extract::()?.retrieve_value(), + parameter.extract::()?.inner(), )); } + if parameter.get_type().name()? == "UUID" { + return Ok(PythonDTO::PyUUID(Uuid::parse_str( + parameter.str()?.extract::<&str>()?, + )?)); + } + + if parameter.get_type().name()? == "decimal.Decimal" + || parameter.get_type().name()? == "Decimal" + { + return Ok(PythonDTO::PyDecimal(Decimal::from_str_exact( + parameter.str()?.extract::<&str>()?, + )?)); + } + if parameter.is_instance_of::() { return parameter .extract::()? @@ -430,7 +420,7 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< if parameter.is_instance_of::() { return Ok(PythonDTO::PyPgVector( - parameter.extract::()?.inner_value(), + parameter.extract::()?.inner(), )); } @@ -462,7 +452,7 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< /// - The retrieved values are invalid for constructing a date, time, or datetime (e.g., invalid month or day) /// - The timezone information (`tzinfo`) is not available or cannot be parsed /// - The resulting datetime is ambiguous or invalid (e.g., due to DST transitions) -fn extract_datetime_from_python_object_attrs( +pub fn extract_datetime_from_python_object_attrs( parameter: &pyo3::Bound<'_, PyAny>, ) -> Result, RustPSQLDriverError> { let year = extract_value_from_python_object_or_raise::(parameter, "year")?; @@ -686,44 +676,6 @@ pub fn convert_seq_parameters( Ok(result_vec) } -/// Convert python List of Dict type or just Dict into serde `Value`. -/// -/// # Errors -/// May return error if cannot convert Python type into Rust one. -#[allow(clippy::needless_pass_by_value)] -pub fn build_serde_value(value: Py) -> RustPSQLDriverPyResult { - Python::with_gil(|gil| { - let bind_value = value.bind(gil); - if bind_value.is_instance_of::() { - let mut result_vec: Vec = vec![]; - - let params = bind_value.extract::>>()?; - - for inner in params { - let inner_bind = inner.bind(gil); - if inner_bind.is_instance_of::() { - let python_dto = py_to_rust(inner_bind)?; - result_vec.push(python_dto.to_serde_value()?); - } else if inner_bind.is_instance_of::() { - let serde_value = build_serde_value(inner)?; - result_vec.push(serde_value); - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must have dicts.".to_string(), - )); - } - } - Ok(json!(result_vec)) - } else if bind_value.is_instance_of::() { - return py_to_rust(bind_value)?.to_serde_value(); - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must be dict value.".to_string(), - )); - } - }) -} - /// Convert two python parameters(x and y) to Coord from `geo_type`. /// Also it checks that passed values is int or float. /// diff --git a/src/value_converter/mod.rs b/src/value_converter/mod.rs index e8cbc82b..7d08bf3f 100644 --- a/src/value_converter/mod.rs +++ b/src/value_converter/mod.rs @@ -1,5 +1,7 @@ pub mod additional_types; pub mod consts; +pub mod dto; pub mod funcs; pub mod models; +pub mod traits; pub mod utils; diff --git a/src/value_converter/models/mod.rs b/src/value_converter/models/mod.rs index 92d26e49..b36f3bff 100644 --- a/src/value_converter/models/mod.rs +++ b/src/value_converter/models/mod.rs @@ -1,5 +1,4 @@ pub mod decimal; -pub mod dto; pub mod interval; pub mod serde_value; pub mod uuid; diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index b39f7737..0bf6652f 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -1,16 +1,20 @@ +use postgres_array::{Array, Dimension}; use postgres_types::FromSql; -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; use uuid::Uuid; use pyo3::{ - types::{PyAnyMethods, PyDict, PyList}, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyTuple}, Bound, FromPyObject, Py, PyAny, PyObject, PyResult, Python, ToPyObject, }; use tokio_postgres::types::Type; use crate::{ exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, - value_converter::funcs::{from_python::py_to_rust, to_python::build_python_from_serde_value}, + value_converter::{ + dto::enums::PythonDTO, + funcs::{from_python::py_to_rust, to_python::build_python_from_serde_value}, + }, }; /// Struct for Value. @@ -22,7 +26,7 @@ pub struct InternalSerdeValue(Value); impl<'a> FromPyObject<'a> for InternalSerdeValue { fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { - let serde_value = build_serde_value(ob.clone().unbind())?; + let serde_value = build_serde_value(ob)?; Ok(InternalSerdeValue(serde_value)) } @@ -50,36 +54,67 @@ impl<'a> FromSql<'a> for InternalSerdeValue { } } +fn serde_value_from_list( + gil: Python<'_>, + bind_value: &Bound<'_, PyAny>, +) -> RustPSQLDriverPyResult { + let mut result_vec: Vec = vec![]; + + let params = bind_value.extract::>>()?; + + for inner in params { + let inner_bind = inner.bind(gil); + if inner_bind.is_instance_of::() { + let python_dto = py_to_rust(inner_bind)?; + result_vec.push(python_dto.to_serde_value()?); + } else if inner_bind.is_instance_of::() { + let serde_value = build_serde_value(inner.bind(gil))?; + result_vec.push(serde_value); + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "PyJSON must have dicts.".to_string(), + )); + } + } + Ok(json!(result_vec)) +} + +fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + let dict = bind_value.downcast::().map_err(|error| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "Can't cast to inner dict: {error}" + )) + })?; + + let mut serde_map: Map = Map::new(); + + for dict_item in dict.items() { + let py_list = dict_item.downcast::().map_err(|error| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "Cannot cast to list: {error}" + )) + })?; + + let key = py_list.get_item(0)?.extract::()?; + let value = py_to_rust(&py_list.get_item(1)?)?; + + serde_map.insert(key, value.to_serde_value()?); + } + + return Ok(Value::Object(serde_map)); +} + /// Convert python List of Dict type or just Dict into serde `Value`. /// /// # Errors /// May return error if cannot convert Python type into Rust one. #[allow(clippy::needless_pass_by_value)] -pub fn build_serde_value(value: Py) -> RustPSQLDriverPyResult { +pub fn build_serde_value(value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { Python::with_gil(|gil| { - let bind_value = value.bind(gil); - if bind_value.is_instance_of::() { - let mut result_vec: Vec = vec![]; - - let params = bind_value.extract::>>()?; - - for inner in params { - let inner_bind = inner.bind(gil); - if inner_bind.is_instance_of::() { - let python_dto = py_to_rust(inner_bind)?; - result_vec.push(python_dto.to_serde_value()?); - } else if inner_bind.is_instance_of::() { - let serde_value = build_serde_value(inner)?; - result_vec.push(serde_value); - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must have dicts.".to_string(), - )); - } - } - Ok(json!(result_vec)) - } else if bind_value.is_instance_of::() { - return py_to_rust(bind_value)?.to_serde_value(); + if value.is_instance_of::() { + return serde_value_from_list(gil, value); + } else if value.is_instance_of::() { + return serde_value_from_dict(value); } else { return Err(RustPSQLDriverError::PyToRustValueConversionError( "PyJSON must be dict value.".to_string(), @@ -87,3 +122,61 @@ pub fn build_serde_value(value: Py) -> RustPSQLDriverPyResult { } }) } + +/// Convert Array of `PythonDTO`s to serde `Value`. +/// +/// It can convert multidimensional arrays. +pub fn pythondto_array_to_serde(array: Option>) -> RustPSQLDriverPyResult { + match array { + Some(array) => inner_pythondto_array_to_serde( + array.dimensions(), + array.iter().collect::>().as_slice(), + 0, + 0, + ), + None => Ok(Value::Null), + } +} + +/// Inner conversion array of `PythonDTO`s to serde `Value`. +#[allow(clippy::cast_sign_loss)] +fn inner_pythondto_array_to_serde( + dimensions: &[Dimension], + data: &[&PythonDTO], + dimension_index: usize, + mut lower_bound: usize, +) -> RustPSQLDriverPyResult { + let current_dimension = dimensions.get(dimension_index); + + if let Some(current_dimension) = current_dimension { + let possible_next_dimension = dimensions.get(dimension_index + 1); + match possible_next_dimension { + Some(next_dimension) => { + let mut final_list: Value = Value::Array(vec![]); + + for _ in 0..current_dimension.len as usize { + if dimensions.get(dimension_index + 1).is_some() { + let inner_pylist = inner_pythondto_array_to_serde( + dimensions, + &data[lower_bound..next_dimension.len as usize + lower_bound], + dimension_index + 1, + 0, + )?; + match final_list { + Value::Array(ref mut array) => array.push(inner_pylist), + _ => unreachable!(), + } + lower_bound += next_dimension.len as usize; + }; + } + + return Ok(final_list); + } + None => { + return data.iter().map(|x| x.to_serde_value()).collect(); + } + } + } + + Ok(Value::Array(vec![])) +} diff --git a/src/value_converter/traits.rs b/src/value_converter/traits.rs new file mode 100644 index 00000000..ca44a7d0 --- /dev/null +++ b/src/value_converter/traits.rs @@ -0,0 +1,9 @@ +use pyo3::PyAny; + +use crate::exceptions::rust_errors::RustPSQLDriverPyResult; + +use super::dto::enums::PythonDTO; + +pub trait PythonToDTO { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult; +} From 50855b5a5c227dea79f671d306bcae115b85a674 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Fri, 2 May 2025 22:35:06 +0200 Subject: [PATCH 02/15] Full value converter refactor --- Cargo.lock | 10 +- src/driver/connection.rs | 27 +- src/driver/connection_pool.rs | 8 +- src/driver/connection_pool_builder.rs | 6 +- src/driver/cursor.rs | 44 +- src/driver/inner_connection.rs | 184 ++++----- src/driver/listener/core.rs | 26 +- src/driver/listener/structs.rs | 4 +- src/driver/transaction.rs | 59 ++- src/driver/utils.rs | 8 +- src/exceptions/rust_errors.rs | 8 +- src/extra_types.rs | 25 +- src/lib.rs | 1 + src/query_result.rs | 29 +- src/row_factories.rs | 6 +- src/runtime.rs | 6 +- src/statement/cache.rs | 50 +++ src/statement/mod.rs | 7 + src/statement/parameters.rs | 255 ++++++++++++ src/statement/query.rs | 92 +++++ src/statement/statement.rs | 30 ++ src/statement/statement_builder.rs | 100 +++++ src/statement/traits.rs | 8 + src/statement/utils.rs | 1 + src/value_converter/consts.rs | 2 + src/value_converter/dto/converter_impls.rs | 65 +-- src/value_converter/dto/impls.rs | 6 +- .../{funcs => }/from_python.rs | 375 ++++-------------- src/value_converter/funcs/mod.rs | 2 - src/value_converter/mod.rs | 3 +- src/value_converter/models/serde_value.rs | 19 +- src/value_converter/params_converters.rs | 0 src/value_converter/{funcs => }/to_python.rs | 21 +- src/value_converter/traits.rs | 6 +- 34 files changed, 890 insertions(+), 603 deletions(-) create mode 100644 src/statement/cache.rs create mode 100644 src/statement/mod.rs create mode 100644 src/statement/parameters.rs create mode 100644 src/statement/query.rs create mode 100644 src/statement/statement.rs create mode 100644 src/statement/statement_builder.rs create mode 100644 src/statement/traits.rs create mode 100644 src/statement/utils.rs rename src/value_converter/{funcs => }/from_python.rs (64%) delete mode 100644 src/value_converter/funcs/mod.rs create mode 100644 src/value_converter/params_converters.rs rename src/value_converter/{funcs => }/to_python.rs (97%) diff --git a/Cargo.lock b/Cargo.lock index fee82b45..df4dc951 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -881,7 +881,7 @@ checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" [[package]] name = "postgres-derive" version = "0.4.5" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "heck", "proc-macro2", @@ -892,7 +892,7 @@ dependencies = [ [[package]] name = "postgres-openssl" version = "0.5.0" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "openssl", "tokio", @@ -903,7 +903,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.7" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "base64", "byteorder", @@ -920,7 +920,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.7" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "array-init", "bytes", @@ -1540,7 +1540,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.11" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "async-trait", "byteorder", diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 3c0595bb..8f2a4b40 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -6,7 +6,7 @@ use std::{collections::HashSet, net::IpAddr, sync::Arc}; use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, format_helpers::quote_ident, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, runtime::tokio_runtime, @@ -137,7 +137,7 @@ impl Connection { return self.pg_config.get_options(); } - async fn __aenter__<'a>(self_: Py) -> RustPSQLDriverPyResult> { + async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { let (db_client, db_pool) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.db_client.clone(), self_.db_pool.clone()) @@ -169,7 +169,7 @@ impl Connection { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { ( exception.is_none(gil), @@ -205,7 +205,7 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -227,10 +227,7 @@ impl Connection { /// May return Err Result if: /// 1) Connection is closed. /// 2) Cannot execute querystring. - pub async fn execute_batch( - self_: pyo3::Py, - querystring: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn execute_batch(self_: pyo3::Py, querystring: String) -> PSQLPyResult<()> { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -256,7 +253,7 @@ impl Connection { querystring: String, parameters: Option>>, prepared: Option, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -282,7 +279,7 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -312,7 +309,7 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -339,7 +336,7 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -365,7 +362,7 @@ impl Connection { read_variant: Option, deferrable: Option, synchronous_commit: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { if let Some(db_client) = &self.db_client { return Ok(Transaction::new( db_client.clone(), @@ -401,7 +398,7 @@ impl Connection { fetch_number: Option, scroll: Option, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { if let Some(db_client) = &self.db_client { return Ok(Cursor::new( db_client.clone(), @@ -446,7 +443,7 @@ impl Connection { table_name: String, columns: Option>, schema_name: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); let mut table_name = quote_ident(&table_name); if let Some(schema_name) = schema_name { diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 24780a6a..0c52c256 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -4,7 +4,7 @@ use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; use std::sync::Arc; use tokio_postgres::Config; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::{ common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, @@ -75,7 +75,7 @@ pub fn connect( ca_file: Option, max_db_pool_size: Option, conn_recycling_method: Option, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if let Some(max_db_pool_size) = max_db_pool_size { if max_db_pool_size < 2 { return Err(RustPSQLDriverError::ConnectionPoolConfigurationError( @@ -289,7 +289,7 @@ impl ConnectionPool { conn_recycling_method: Option, ssl_mode: Option, ca_file: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { connect( dsn, username, @@ -382,7 +382,7 @@ impl ConnectionPool { /// /// # Errors /// May return Err Result if cannot get new connection from the pool. - pub async fn connection(self_: pyo3::Py) -> RustPSQLDriverPyResult { + pub async fn connection(self_: pyo3::Py) -> PSQLPyResult { let (db_pool, pg_config) = pyo3::Python::with_gil(|gil| { let slf = self_.borrow(gil); (slf.pool.clone(), slf.pg_config.clone()) diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index e0610942..42cdd641 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -3,7 +3,7 @@ use std::{net::IpAddr, time::Duration}; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; use pyo3::{pyclass, pymethods, Py, Python}; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::{ common_options, @@ -38,7 +38,7 @@ impl ConnectionPoolBuilder { /// /// # Errors /// May return error if cannot build new connection pool. - fn build(&self) -> RustPSQLDriverPyResult { + fn build(&self) -> PSQLPyResult { let mgr_config: ManagerConfig; if let Some(conn_recycling_method) = self.conn_recycling_method.as_ref() { mgr_config = ManagerConfig { @@ -84,7 +84,7 @@ impl ConnectionPoolBuilder { /// /// # Error /// If size more than 2. - fn max_pool_size(self_: Py, pool_size: usize) -> RustPSQLDriverPyResult> { + fn max_pool_size(self_: Py, pool_size: usize) -> PSQLPyResult> { if pool_size < 2 { return Err(RustPSQLDriverError::ConnectionPoolConfigurationError( "Maximum database pool size must be more than 1".into(), diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index f391d1c1..1f435ef5 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -6,7 +6,7 @@ use pyo3::{ use tokio_postgres::{config::Host, Config}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, query_result::PSQLDriverPyQueryResult, runtime::rustdriver_future, }; @@ -23,9 +23,9 @@ trait CursorObjectTrait { querystring: &str, prepared: &Option, parameters: &Option>, - ) -> RustPSQLDriverPyResult<()>; + ) -> PSQLPyResult<()>; - async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>; + async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> PSQLPyResult<()>; } impl CursorObjectTrait for PsqlpyConnection { @@ -43,7 +43,7 @@ impl CursorObjectTrait for PsqlpyConnection { querystring: &str, prepared: &Option, parameters: &Option>, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let mut cursor_init_query = format!("DECLARE {cursor_name}"); if let Some(scroll) = scroll { if *scroll { @@ -70,7 +70,7 @@ impl CursorObjectTrait for PsqlpyConnection { /// /// # Errors /// May return Err Result if cannot execute querystring. - async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()> { + async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> PSQLPyResult<()> { if *closed { return Err(RustPSQLDriverError::CursorCloseError( "Cursor is already closed".into(), @@ -232,7 +232,7 @@ impl Cursor { slf } - async fn __aenter__<'a>(slf: Py) -> RustPSQLDriverPyResult> { + async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { let (db_transaction, cursor_name, scroll, querystring, prepared, parameters) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); @@ -265,7 +265,7 @@ impl Cursor { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (db_transaction, closed, cursor_name, is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { let self_ = slf.borrow(gil); @@ -307,7 +307,7 @@ impl Cursor { /// we didn't find any solution how to implement it without /// # Errors /// May return Err Result if can't execute querystring. - fn __anext__(&self) -> RustPSQLDriverPyResult> { + fn __anext__(&self) -> PSQLPyResult> { let db_transaction = self.db_transaction.clone(); let fetch_number = self.fetch_number; let cursor_name = self.cursor_name.clone(); @@ -343,7 +343,7 @@ impl Cursor { /// # Errors /// May return Err Result /// if cannot execute querystring for cursor declaration. - pub async fn start(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn start(&mut self) -> PSQLPyResult<()> { let db_transaction_arc = self.db_transaction.clone(); if let Some(db_transaction) = db_transaction_arc { @@ -370,7 +370,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn close(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn close(&mut self) -> PSQLPyResult<()> { let db_transaction_arc = self.db_transaction.clone(); if let Some(db_transaction) = db_transaction_arc { @@ -396,7 +396,7 @@ impl Cursor { pub async fn fetch<'a>( slf: Py, fetch_number: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, inner_fetch_number, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); ( @@ -437,7 +437,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_next<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_next<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -464,7 +464,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_prior<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_prior<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -491,7 +491,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_first<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_first<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -518,7 +518,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_last<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_last<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -548,7 +548,7 @@ impl Cursor { pub async fn fetch_absolute<'a>( slf: Py, absolute_number: i64, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -582,7 +582,7 @@ impl Cursor { pub async fn fetch_relative<'a>( slf: Py, relative_number: i64, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -613,9 +613,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_forward_all<'a>( - slf: Py, - ) -> RustPSQLDriverPyResult { + pub async fn fetch_forward_all<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -649,7 +647,7 @@ impl Cursor { pub async fn fetch_backward<'a>( slf: Py, backward_count: i64, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -680,9 +678,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_backward_all<'a>( - slf: Py, - ) -> RustPSQLDriverPyResult { + pub async fn fetch_backward_all<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index 2dfbcbb7..a7e9d233 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -1,18 +1,15 @@ use bytes::Buf; use deadpool_postgres::Object; -use postgres_types::ToSql; +use postgres_types::{ToSql, Type}; use pyo3::{Py, PyAny, Python}; use std::vec; use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - value_converter::{ - consts::QueryParameter, - dto::enums::PythonDTO, - funcs::{from_python::convert_parameters_and_qs, to_python::postgres_to_py}, - }, + statement::{statement::PsqlpyStatement, statement_builder::StatementBuilder}, + value_converter::to_python::postgres_to_py, }; #[allow(clippy::module_name_repetitions)] @@ -26,13 +23,39 @@ impl PsqlpyConnection { /// /// # Errors /// May return Err if cannot prepare statement. - pub async fn prepare_cached(&self, query: &str) -> RustPSQLDriverPyResult { + pub async fn prepare(&self, query: &str) -> PSQLPyResult { match self { PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.prepare_cached(query).await?), PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.prepare(query).await?), } } + /// Delete prepared statement. + /// + /// # Errors + /// May return Err if cannot prepare statement. + pub async fn drop_prepared(&self, stmt: &Statement) -> PSQLPyResult<()> { + let query = format!("DEALLOCATE PREPARE {}", stmt.name()); + match self { + PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(&query).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(&query).await?), + } + } + + /// Prepare and delete statement. + /// + /// # Errors + /// Can return Err if cannot prepare statement. + pub async fn prepare_then_drop(&self, query: &str) -> PSQLPyResult> { + let types: Vec; + + let stmt = self.prepare(query).await?; + types = stmt.params().to_vec(); + self.drop_prepared(&stmt).await?; + + Ok(types) + } + /// Prepare cached statement. /// /// # Errors @@ -41,7 +64,7 @@ impl PsqlpyConnection { &self, statement: &T, params: &[&(dyn ToSql + Sync)], - ) -> RustPSQLDriverPyResult> + ) -> PSQLPyResult> where T: ?Sized + ToStatement, { @@ -57,7 +80,7 @@ impl PsqlpyConnection { /// /// # Errors /// May return Err if cannot execute statement. - pub async fn batch_execute(&self, query: &str) -> RustPSQLDriverPyResult<()> { + pub async fn batch_execute(&self, query: &str) -> PSQLPyResult<()> { match self { PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(query).await?), PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), @@ -72,7 +95,7 @@ impl PsqlpyConnection { &self, statement: &T, params: &[&(dyn ToSql + Sync)], - ) -> RustPSQLDriverPyResult + ) -> PSQLPyResult where T: ?Sized + ToStatement, { @@ -91,38 +114,28 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { - let prepared = prepared.unwrap_or(true); - - let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; + ) -> PSQLPyResult { + let statement = StatementBuilder::new(querystring, parameters, self, prepared) + .build() + .await?; - let boxed_params = ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); let result = if prepared { self.query( - &self.prepare_cached(&qs).await.map_err(|err| { + &self.prepare(&statement.sql_stmt()).await.map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement, error - {err}" )) })?, - boxed_params, + &statement.params(), ) .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? } else { - self.query(&qs, boxed_params).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + self.query(statement.sql_stmt(), &statement.params()) + .await + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? }; Ok(PSQLDriverPyQueryResult::new(result)) @@ -133,38 +146,28 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { - let prepared = prepared.unwrap_or(true); - - let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; + ) -> PSQLPyResult { + let statement = StatementBuilder::new(querystring, parameters, self, prepared) + .build() + .await?; - let boxed_params = ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); let result = if prepared { self.query( - &self.prepare_cached(&qs).await.map_err(|err| { + &self.prepare(statement.sql_stmt()).await.map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement, error - {err}" )) })?, - boxed_params, + &statement.params(), ) .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? } else { - self.query(&qs, boxed_params).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + self.query(statement.sql_stmt(), &statement.params()) + .await + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? }; Ok(PSQLDriverPyQueryResult::new(result)) @@ -172,41 +175,40 @@ impl PsqlpyConnection { pub async fn execute_many( &self, - mut querystring: String, + querystring: String, parameters: Option>>, prepared: Option, - ) -> RustPSQLDriverPyResult<()> { - let prepared = prepared.unwrap_or(true); - - let mut params: Vec> = vec![]; + ) -> PSQLPyResult<()> { + let mut statements: Vec = vec![]; if let Some(parameters) = parameters { for vec_of_py_any in parameters { // TODO: Fix multiple qs creation - let (qs, parsed_params) = - convert_parameters_and_qs(querystring.clone(), Some(vec_of_py_any))?; - querystring = qs; - params.push(parsed_params); + let statement = + StatementBuilder::new(querystring.clone(), Some(vec_of_py_any), self, prepared) + .build() + .await?; + + statements.push(statement); } } - for param in params { - let boxed_params = ¶m - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); + for statement in statements { let querystring_result = if prepared { - let prepared_stmt = &self.prepare_cached(&querystring).await; + let prepared_stmt = &self.prepare(&statement.sql_stmt()).await; if let Err(error) = prepared_stmt { return Err(RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement in execute_many, operation rolled back {error}", ))); } - self.query(&self.prepare_cached(&querystring).await?, boxed_params) - .await + self.query( + &self.prepare(&statement.sql_stmt()).await?, + &statement.params(), + ) + .await } else { - self.query(&querystring, boxed_params).await + self.query(statement.sql_stmt(), &statement.params()).await }; if let Err(error) = querystring_result { @@ -224,38 +226,28 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { - let prepared = prepared.unwrap_or(true); - - let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; + ) -> PSQLPyResult { + let statement = StatementBuilder::new(querystring, parameters, self, prepared) + .build() + .await?; - let boxed_params = ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); let result = if prepared { self.query_one( - &self.prepare_cached(&qs).await.map_err(|err| { + &self.prepare(&statement.sql_stmt()).await.map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement, error - {err}" )) })?, - boxed_params, + &statement.params(), ) .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? } else { - self.query_one(&qs, boxed_params).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + self.query_one(statement.sql_stmt(), &statement.params()) + .await + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? }; return Ok(result); @@ -266,7 +258,7 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let result = self .fetch_row_raw(querystring, parameters, prepared) .await?; @@ -279,7 +271,7 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let result = self .fetch_row_raw(querystring, parameters, prepared) .await?; @@ -294,7 +286,7 @@ impl PsqlpyConnection { /// /// # Errors /// May return Err if cannot execute copy data. - pub async fn copy_in(&self, statement: &T) -> RustPSQLDriverPyResult> + pub async fn copy_in(&self, statement: &T) -> PSQLPyResult> where T: ?Sized + ToStatement, U: Buf + 'static + Send, diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 83aa9b3e..16b323d8 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -18,7 +18,7 @@ use crate::{ inner_connection::PsqlpyConnection, utils::{build_tls, is_coroutine_function, ConfiguredTLS}, }, - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, runtime::{rustdriver_future, tokio_runtime}, }; @@ -89,7 +89,7 @@ impl Listener { } #[allow(clippy::unused_async)] - async fn __aenter__<'a>(slf: Py) -> RustPSQLDriverPyResult> { + async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { Ok(slf) } @@ -99,7 +99,7 @@ impl Listener { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (client, is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { let self_ = slf.borrow(gil); ( @@ -126,7 +126,7 @@ impl Listener { Err(RustPSQLDriverError::ListenerClosedError) } - fn __anext__(&self) -> RustPSQLDriverPyResult>> { + fn __anext__(&self) -> PSQLPyResult>> { let Some(client) = self.connection.db_client() else { return Err(RustPSQLDriverError::ListenerStartError( "Listener doesn't have underlying client, please call startup".into(), @@ -167,7 +167,7 @@ impl Listener { } #[getter] - fn connection(&self) -> RustPSQLDriverPyResult { + fn connection(&self) -> PSQLPyResult { if !self.is_started { return Err(RustPSQLDriverError::ListenerStartError( "Listener isn't started up".into(), @@ -177,7 +177,7 @@ impl Listener { Ok(self.connection.clone()) } - async fn startup(&mut self) -> RustPSQLDriverPyResult<()> { + async fn startup(&mut self) -> PSQLPyResult<()> { if self.is_started { return Err(RustPSQLDriverError::ListenerStartError( "Listener is already started".into(), @@ -238,11 +238,7 @@ impl Listener { } #[pyo3(signature = (channel, callback))] - async fn add_callback( - &mut self, - channel: String, - callback: Py, - ) -> RustPSQLDriverPyResult<()> { + async fn add_callback(&mut self, channel: String, callback: Py) -> PSQLPyResult<()> { if !is_coroutine_function(callback.clone())? { return Err(RustPSQLDriverError::ListenerCallbackError); } @@ -279,7 +275,7 @@ impl Listener { self.update_listen_query().await; } - fn listen(&mut self) -> RustPSQLDriverPyResult<()> { + fn listen(&mut self) -> PSQLPyResult<()> { let Some(client) = self.connection.db_client() else { return Err(RustPSQLDriverError::ListenerStartError( "Cannot start listening, underlying connection doesn't exist".into(), @@ -343,7 +339,7 @@ async fn dispatch_callback( listener_callback: &ListenerCallback, listener_notification: ListenerNotification, connection: Connection, -) -> RustPSQLDriverPyResult<()> { +) -> PSQLPyResult<()> { listener_callback .call(listener_notification.clone(), connection) .await?; @@ -355,7 +351,7 @@ async fn execute_listen( is_listened: &Arc>, listen_query: &Arc>, client: &Arc, -) -> RustPSQLDriverPyResult<()> { +) -> PSQLPyResult<()> { let mut write_is_listened = is_listened.write().await; if !write_is_listened.eq(&true) { @@ -371,7 +367,7 @@ async fn execute_listen( Ok(()) } -fn process_message(message: Option) -> RustPSQLDriverPyResult { +fn process_message(message: Option) -> PSQLPyResult { let Some(async_message) = message else { return Err(RustPSQLDriverError::ListenerError("Wow".into())); }; diff --git a/src/driver/listener/structs.rs b/src/driver/listener/structs.rs index 4d53a408..6236547e 100644 --- a/src/driver/listener/structs.rs +++ b/src/driver/listener/structs.rs @@ -6,7 +6,7 @@ use tokio_postgres::Notification; use crate::{ driver::connection::Connection, - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, runtime::tokio_runtime, }; @@ -126,7 +126,7 @@ impl ListenerCallback { &self, lister_notification: ListenerNotification, connection: Connection, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (callback, task_locals) = Python::with_gil(|py| (self.callback.clone(), self.task_locals.clone_ref(py))); diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index 2fa38ba5..60f054b7 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -9,7 +9,7 @@ use pyo3::{ use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, format_helpers::quote_ident, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, }; @@ -29,9 +29,9 @@ pub trait TransactionObjectTrait { read_variant: Option, defferable: Option, synchronous_commit: Option, - ) -> impl std::future::Future> + Send; - fn commit(&self) -> impl std::future::Future> + Send; - fn rollback(&self) -> impl std::future::Future> + Send; + ) -> impl std::future::Future> + Send; + fn commit(&self) -> impl std::future::Future> + Send; + fn rollback(&self) -> impl std::future::Future> + Send; } impl TransactionObjectTrait for PsqlpyConnection { @@ -41,7 +41,7 @@ impl TransactionObjectTrait for PsqlpyConnection { read_variant: Option, deferrable: Option, synchronous_commit: Option, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let mut querystring = "START TRANSACTION".to_string(); if let Some(level) = isolation_level { @@ -84,7 +84,7 @@ impl TransactionObjectTrait for PsqlpyConnection { Ok(()) } - async fn commit(&self) -> RustPSQLDriverPyResult<()> { + async fn commit(&self) -> PSQLPyResult<()> { self.batch_execute("COMMIT;").await.map_err(|err| { RustPSQLDriverError::TransactionCommitError(format!( "Cannot execute COMMIT statement, error - {err}" @@ -92,7 +92,7 @@ impl TransactionObjectTrait for PsqlpyConnection { })?; Ok(()) } - async fn rollback(&self) -> RustPSQLDriverPyResult<()> { + async fn rollback(&self) -> PSQLPyResult<()> { self.batch_execute("ROLLBACK;").await.map_err(|err| { RustPSQLDriverError::TransactionRollbackError(format!( "Cannot execute ROLLBACK statement, error - {err}" @@ -144,7 +144,7 @@ impl Transaction { } } - fn check_is_transaction_ready(&self) -> RustPSQLDriverPyResult<()> { + fn check_is_transaction_ready(&self) -> PSQLPyResult<()> { if !self.is_started { return Err(RustPSQLDriverError::TransactionBeginError( "Transaction is not started, please call begin() on transaction".into(), @@ -242,7 +242,7 @@ impl Transaction { self_ } - async fn __aenter__<'a>(self_: Py) -> RustPSQLDriverPyResult> { + async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { let ( is_started, is_done, @@ -302,7 +302,7 @@ impl Transaction { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (is_transaction_ready, is_exception_none, py_err, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -345,7 +345,7 @@ impl Transaction { /// 1) Transaction is not started /// 2) Transaction is done /// 3) Cannot execute `COMMIT` command - pub async fn commit(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn commit(&mut self) -> PSQLPyResult<()> { self.check_is_transaction_ready()?; if let Some(db_client) = &self.db_client { db_client.commit().await?; @@ -366,7 +366,7 @@ impl Transaction { /// 1) Transaction is not started /// 2) Transaction is done /// 3) Can not execute ROLLBACK command - pub async fn rollback(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn rollback(&mut self) -> PSQLPyResult<()> { self.check_is_transaction_ready()?; if let Some(db_client) = &self.db_client { db_client.rollback().await?; @@ -394,7 +394,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -419,7 +419,7 @@ impl Transaction { /// May return Err Result if: /// 1) Transaction is closed. /// 2) Cannot execute querystring. - pub async fn execute_batch(self_: Py, querystring: String) -> RustPSQLDriverPyResult<()> { + pub async fn execute_batch(self_: Py, querystring: String) -> PSQLPyResult<()> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -448,7 +448,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -481,7 +481,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -511,7 +511,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -539,7 +539,7 @@ impl Transaction { querystring: String, parameters: Option>>, prepared: Option, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -564,7 +564,7 @@ impl Transaction { /// 1) Transaction is already started. /// 2) Transaction is done. /// 3) Cannot execute `BEGIN` command. - pub async fn begin(self_: Py) -> RustPSQLDriverPyResult<()> { + pub async fn begin(self_: Py) -> PSQLPyResult<()> { let ( is_started, is_done, @@ -629,10 +629,7 @@ impl Transaction { /// 2) Transaction is done /// 3) Specified savepoint name is exists /// 4) Can not execute SAVEPOINT command - pub async fn create_savepoint( - self_: Py, - savepoint_name: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn create_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { let (is_transaction_ready, is_savepoint_name_exists, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -673,10 +670,7 @@ impl Transaction { /// 2) Transaction is done /// 3) Specified savepoint name doesn't exists /// 4) Can not execute RELEASE SAVEPOINT command - pub async fn release_savepoint( - self_: Py, - savepoint_name: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn release_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { let (is_transaction_ready, is_savepoint_name_exists, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -717,10 +711,7 @@ impl Transaction { /// 2) Transaction is done /// 3) Specified savepoint name doesn't exist /// 4) Can not execute ROLLBACK TO SAVEPOINT command - pub async fn rollback_savepoint( - self_: Py, - savepoint_name: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn rollback_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { let (is_transaction_ready, is_savepoint_name_exists, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -765,7 +756,7 @@ impl Transaction { self_: Py, queries: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -827,7 +818,7 @@ impl Transaction { fetch_number: Option, scroll: Option, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { if let Some(db_client) = &self.db_client { return Ok(Cursor::new( db_client.clone(), @@ -857,7 +848,7 @@ impl Transaction { table_name: String, columns: Option>, schema_name: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); let mut table_name = quote_ident(&table_name); if let Some(schema_name) = schema_name { diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 3d0d59e3..15ca4123 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -6,7 +6,7 @@ use postgres_openssl::MakeTlsConnector; use pyo3::{types::PyAnyMethods, Py, PyAny, Python}; use tokio_postgres::{Config, NoTls}; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::common_options::{self, LoadBalanceHosts, SslMode, TargetSessionAttrs}; @@ -40,7 +40,7 @@ pub fn build_connection_config( keepalives_retries: Option, load_balance_hosts: Option, ssl_mode: Option, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if tcp_user_timeout_nanosec.is_some() && tcp_user_timeout_sec.is_none() { return Err(RustPSQLDriverError::ConnectionPoolConfigurationError( "tcp_user_timeout_nanosec must be used with tcp_user_timeout_sec param.".into(), @@ -182,7 +182,7 @@ pub enum ConfiguredTLS { pub fn build_tls( ca_file: &Option, ssl_mode: &Option, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if let Some(ca_file) = ca_file { let mut builder = SslConnector::builder(SslMethod::tls())?; builder.set_ca_file(ca_file)?; @@ -224,7 +224,7 @@ pub fn build_manager( /// May return Err Result if cannot /// 1) import inspect /// 2) extract boolean -pub fn is_coroutine_function(function: Py) -> RustPSQLDriverPyResult { +pub fn is_coroutine_function(function: Py) -> PSQLPyResult { let is_coroutine_function: bool = Python::with_gil(|py| { let inspect = py.import("inspect")?; diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index 48af50cb..b6694da1 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -14,7 +14,7 @@ use super::python_errors::{ TransactionRollbackError, TransactionSavepointError, UUIDValueConvertError, }; -pub type RustPSQLDriverPyResult = Result; +pub type PSQLPyResult = Result; #[derive(Error, Debug)] pub enum RustPSQLDriverError { @@ -29,9 +29,9 @@ pub enum RustPSQLDriverError { ConnectionPoolExecuteError(String), // Connection Errors - #[error("Connection error: {0}.")] + #[error("{0}")] BaseConnectionError(String), - #[error("Connection execute error: {0}.")] + #[error("{0}")] ConnectionExecuteError(String), #[error("Underlying connection is returned to the pool")] ConnectionClosedError, @@ -81,7 +81,7 @@ pub enum RustPSQLDriverError { #[error("Python exception: {0}.")] RustPyError(#[from] pyo3::PyErr), - #[error("Database engine exception: {0}.")] + #[error("{0}")] RustDriverError(#[from] deadpool_postgres::tokio_postgres::Error), #[error("Database engine pool exception: {0}")] RustConnectionPoolError(#[from] deadpool_postgres::PoolError), diff --git a/src/extra_types.rs b/src/extra_types.rs index 48058ff4..c3b2d832 100644 --- a/src/extra_types.rs +++ b/src/extra_types.rs @@ -3,6 +3,7 @@ use std::str::FromStr; use geo_types::{Line as RustLineSegment, LineString, Point as RustPoint, Rect as RustRect}; use macaddr::{MacAddr6 as RustMacAddr6, MacAddr8 as RustMacAddr8}; use pyo3::{ + conversion::FromPyObjectBound, pyclass, pymethods, types::{PyModule, PyModuleMethods}, Bound, Py, PyAny, PyResult, Python, @@ -10,13 +11,11 @@ use pyo3::{ use serde_json::Value; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, value_converter::{ additional_types::{Circle as RustCircle, Line as RustLine}, dto::enums::PythonDTO, - funcs::from_python::{ - build_flat_geo_coords, build_geo_coords, py_sequence_into_postgres_array, - }, + from_python::{build_flat_geo_coords, build_geo_coords, py_sequence_into_postgres_array}, models::serde_value::build_serde_value, }, }; @@ -155,7 +154,7 @@ macro_rules! build_json_py_type { impl $st_name { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_class(value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + pub fn new_class(value: &Bound<'_, PyAny>) -> PSQLPyResult { Ok(Self { inner: build_serde_value(value)?, }) @@ -191,7 +190,7 @@ macro_rules! build_macaddr_type { impl $st_name { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_class(value: &str) -> RustPSQLDriverPyResult { + pub fn new_class(value: &str) -> PSQLPyResult { Ok(Self { inner: <$rust_type>::from_str(value)?, }) @@ -252,7 +251,7 @@ build_geo_type!(Circle, RustCircle); impl Point { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_point(value: Py) -> RustPSQLDriverPyResult { + pub fn new_point(value: Py) -> PSQLPyResult { let point_coords = build_geo_coords(value, Some(1))?; Ok(Self { @@ -265,7 +264,7 @@ impl Point { impl Box { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_box(value: Py) -> RustPSQLDriverPyResult { + pub fn new_box(value: Py) -> PSQLPyResult { let box_coords = build_geo_coords(value, Some(2))?; Ok(Self { @@ -278,7 +277,7 @@ impl Box { impl Path { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_path(value: Py) -> RustPSQLDriverPyResult { + pub fn new_path(value: Py) -> PSQLPyResult { let path_coords = build_geo_coords(value, None)?; Ok(Self { @@ -291,7 +290,7 @@ impl Path { impl Line { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_line(value: Py) -> RustPSQLDriverPyResult { + pub fn new_line(value: Py) -> PSQLPyResult { let line_coords = build_flat_geo_coords(value, Some(3))?; Ok(Self { @@ -304,7 +303,7 @@ impl Line { impl LineSegment { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_line_segment(value: Py) -> RustPSQLDriverPyResult { + pub fn new_line_segment(value: Py) -> PSQLPyResult { let line_segment_coords = build_geo_coords(value, Some(2))?; Ok(Self { @@ -317,7 +316,7 @@ impl LineSegment { impl Circle { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_circle(value: Py) -> RustPSQLDriverPyResult { + pub fn new_circle(value: Py) -> PSQLPyResult { let circle_coords = build_flat_geo_coords(value, Some(3))?; Ok(Self { inner: RustCircle::new(circle_coords[0], circle_coords[1], circle_coords[2]), @@ -352,7 +351,7 @@ macro_rules! build_array_type { /// /// # Errors /// May return Err Result if cannot convert sequence to array. - pub fn _convert_to_python_dto(&self) -> RustPSQLDriverPyResult { + pub fn _convert_to_python_dto(&self) -> PSQLPyResult { return Python::with_gil(|gil| { let binding = &self.inner; let bound_inner = Ok::<&pyo3::Bound<'_, pyo3::PyAny>, RustPSQLDriverError>( diff --git a/src/lib.rs b/src/lib.rs index e0e1fe11..6be59c75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub mod format_helpers; pub mod query_result; pub mod row_factories; pub mod runtime; +pub mod statement; pub mod value_converter; use common::add_module; diff --git a/src/query_result.rs b/src/query_result.rs index da393f89..cda02a8b 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -1,10 +1,7 @@ use pyo3::{prelude::*, pyclass, pymethods, types::PyDict, Py, PyAny, Python, ToPyObject}; use tokio_postgres::Row; -use crate::{ - exceptions::rust_errors::RustPSQLDriverPyResult, - value_converter::funcs::to_python::postgres_to_py, -}; +use crate::{exceptions::rust_errors::PSQLPyResult, value_converter::to_python::postgres_to_py}; /// Convert postgres `Row` into Python Dict. /// @@ -18,7 +15,7 @@ fn row_to_dict<'a>( py: Python<'a>, postgres_row: &'a Row, custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let python_dict = PyDict::new(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; @@ -30,7 +27,7 @@ fn row_to_dict<'a>( #[pyclass(name = "QueryResult")] #[allow(clippy::module_name_repetitions)] pub struct PSQLDriverPyQueryResult { - inner: Vec, + pub inner: Vec, } impl PSQLDriverPyQueryResult { @@ -65,7 +62,7 @@ impl PSQLDriverPyQueryResult { &self, py: Python<'_>, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let mut result: Vec> = vec![]; for row in &self.inner { result.push(row_to_dict(py, row, &custom_decoders)?); @@ -80,11 +77,7 @@ impl PSQLDriverPyQueryResult { /// May return Err Result if can not convert /// postgres type to python or create new Python class. #[allow(clippy::needless_pass_by_value)] - pub fn as_class<'a>( - &'a self, - py: Python<'a>, - as_class: Py, - ) -> RustPSQLDriverPyResult> { + pub fn as_class<'a>(&'a self, py: Python<'a>, as_class: Py) -> PSQLPyResult> { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &None)?; @@ -108,7 +101,7 @@ impl PSQLDriverPyQueryResult { py: Python<'a>, row_factory: Py, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &custom_decoders)?; @@ -155,7 +148,7 @@ impl PSQLDriverSinglePyQueryResult { &self, py: Python<'_>, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { Ok(row_to_dict(py, &self.inner, &custom_decoders)?.to_object(py)) } @@ -167,11 +160,7 @@ impl PSQLDriverSinglePyQueryResult { /// postgres type to python, can not create new Python class /// or there are no results. #[allow(clippy::needless_pass_by_value)] - pub fn as_class<'a>( - &'a self, - py: Python<'a>, - as_class: Py, - ) -> RustPSQLDriverPyResult> { + pub fn as_class<'a>(&'a self, py: Python<'a>, as_class: Py) -> PSQLPyResult> { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner, &None)?; Ok(as_class.call(py, (), Some(&pydict))?) } @@ -189,7 +178,7 @@ impl PSQLDriverSinglePyQueryResult { py: Python<'a>, row_factory: Py, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let pydict = row_to_dict(py, &self.inner, &custom_decoders)?.to_object(py); Ok(row_factory.call(py, (pydict,), None)?) } diff --git a/src/row_factories.rs b/src/row_factories.rs index 3a2d2de8..e867df0a 100644 --- a/src/row_factories.rs +++ b/src/row_factories.rs @@ -4,11 +4,11 @@ use pyo3::{ wrap_pyfunction, Bound, Py, PyAny, PyResult, Python, ToPyObject, }; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; #[pyfunction] #[allow(clippy::needless_pass_by_value)] -fn tuple_row(py: Python<'_>, dict_: Py) -> RustPSQLDriverPyResult> { +fn tuple_row(py: Python<'_>, dict_: Py) -> PSQLPyResult> { let dict_ = dict_.downcast_bound::(py).map_err(|_| { RustPSQLDriverError::RustToPyValueConversionError( "as_tuple accepts only dict as a parameter".into(), @@ -29,7 +29,7 @@ impl class_row { } #[allow(clippy::needless_pass_by_value)] - fn __call__(&self, py: Python<'_>, dict_: Py) -> RustPSQLDriverPyResult> { + fn __call__(&self, py: Python<'_>, dict_: Py) -> PSQLPyResult> { let dict_ = dict_.downcast_bound::(py).map_err(|_| { RustPSQLDriverError::RustToPyValueConversionError( "as_tuple accepts only dict as a parameter".into(), diff --git a/src/runtime.rs b/src/runtime.rs index 05889d99..ee6281de 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,7 +1,7 @@ use futures_util::Future; use pyo3::{IntoPyObject, Py, PyAny, Python}; -use crate::exceptions::rust_errors::RustPSQLDriverPyResult; +use crate::exceptions::rust_errors::PSQLPyResult; #[allow(clippy::missing_panics_doc)] #[allow(clippy::module_name_repetitions)] @@ -18,9 +18,9 @@ pub fn tokio_runtime() -> &'static tokio::runtime::Runtime { /// # Errors /// /// May return Err Result if future acts incorrect. -pub fn rustdriver_future(py: Python<'_>, future: F) -> RustPSQLDriverPyResult> +pub fn rustdriver_future(py: Python<'_>, future: F) -> PSQLPyResult> where - F: Future> + Send + 'static, + F: Future> + Send + 'static, T: for<'py> IntoPyObject<'py>, { let res = diff --git a/src/statement/cache.rs b/src/statement/cache.rs new file mode 100644 index 00000000..a6fbc131 --- /dev/null +++ b/src/statement/cache.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; + +use once_cell::sync::Lazy; +use postgres_types::Type; +use tokio::sync::RwLock; +use tokio_postgres::Statement; + +use super::{query::QueryString, traits::hash_str}; + +#[derive(Default)] +pub(crate) struct StatementsCache(HashMap); + +impl StatementsCache { + pub fn add_cache(&mut self, query: &QueryString, inner_stmt: &Statement) { + self.0 + .insert(query.hash(), StatementCacheInfo::new(query, inner_stmt)); + } + + pub fn get_cache(&self, querystring: &String) -> Option { + let qs_hash = hash_str(&querystring); + + if let Some(cache_info) = self.0.get(&qs_hash) { + return Some(cache_info.clone()); + } + + None + } +} + +#[derive(Clone)] +pub(crate) struct StatementCacheInfo { + pub(crate) query: QueryString, + pub(crate) inner_stmt: Statement, +} + +impl StatementCacheInfo { + fn new(query: &QueryString, inner_stmt: &Statement) -> Self { + return Self { + query: query.clone(), + inner_stmt: inner_stmt.clone(), + }; + } + + pub(crate) fn types(&self) -> Vec { + self.inner_stmt.params().to_vec() + } +} + +pub(crate) static STMTS_CACHE: Lazy> = + Lazy::new(|| RwLock::new(Default::default())); diff --git a/src/statement/mod.rs b/src/statement/mod.rs new file mode 100644 index 00000000..e027eaea --- /dev/null +++ b/src/statement/mod.rs @@ -0,0 +1,7 @@ +pub mod cache; +pub mod parameters; +pub mod query; +pub mod statement; +pub mod statement_builder; +pub mod traits; +pub mod utils; diff --git a/src/statement/parameters.rs b/src/statement/parameters.rs new file mode 100644 index 00000000..baeded5d --- /dev/null +++ b/src/statement/parameters.rs @@ -0,0 +1,255 @@ +use std::iter::zip; + +use postgres_types::{ToSql, Type}; +use pyo3::{ + conversion::FromPyObjectBound, + types::{PyAnyMethods, PyMapping}, + Py, PyObject, PyTypeCheck, Python, +}; + +use crate::{ + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + value_converter::{dto::enums::PythonDTO, from_python::py_to_rust}, +}; + +pub type QueryParameter = (dyn ToSql + Sync); + +pub(crate) struct ParametersBuilder { + parameters: Option, + types: Option>, +} + +impl ParametersBuilder { + pub fn new(parameters: &Option, types: Option>) -> Self { + Self { + parameters: parameters.clone(), + types, + } + } + + pub fn prepare( + self, + parameters_names: Option>, + ) -> PSQLPyResult { + let prepared_parameters = + Python::with_gil(|gil| self.prepare_parameters(gil, parameters_names))?; + + Ok(prepared_parameters) + } + + fn prepare_parameters( + self, + gil: Python<'_>, + parameters_names: Option>, + ) -> PSQLPyResult { + if self.parameters.is_none() { + return Ok(PreparedParameters::default()); + } + + let sequence_typed = self.as_type::>(gil); + let mapping_typed = self.downcast_as::(gil); + let mut prepared_parameters: Option = None; + + match (sequence_typed, mapping_typed) { + (Some(sequence), None) => { + prepared_parameters = + Some(SequenceParametersBuilder::new(sequence, self.types).prepare(gil)?); + } + (None, Some(mapping)) => { + if let Some(parameters_names) = parameters_names { + prepared_parameters = Some( + MappingParametersBuilder::new(mapping, self.types) + .prepare(gil, parameters_names)?, + ) + } + } + _ => {} + } + + if let Some(prepared_parameters) = prepared_parameters { + return Ok(prepared_parameters); + } + + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Parameters must be sequence or mapping".into(), + )); + } + + fn as_type FromPyObjectBound<'a, 'py>>(&self, gil: Python<'_>) -> Option { + if let Some(parameters) = &self.parameters { + let extracted_param = parameters.extract::(gil); + + if let Ok(extracted_param) = extracted_param { + return Some(extracted_param); + } + + return None; + } + + None + } + + fn downcast_as(&self, gil: Python<'_>) -> Option> { + if let Some(parameters) = &self.parameters { + let extracted_param = parameters.downcast_bound::(gil); + + if let Ok(extracted_param) = extracted_param { + return Some(extracted_param.clone().unbind()); + } + + return None; + } + + None + } +} + +pub(crate) struct MappingParametersBuilder { + map_parameters: Py, + types: Option>, +} + +impl MappingParametersBuilder { + fn new(map_parameters: Py, types: Option>) -> Self { + Self { + map_parameters, + types, + } + } + + fn prepare( + self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult { + if self.types.is_some() { + return self.prepare_typed(gil, parameters_names); + } + + self.prepare_not_typed(gil, parameters_names) + } + + fn prepare_typed( + self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult { + let converted_parameters = self + .extract_parameters(gil, parameters_names)? + .iter() + .map(|parameter| py_to_rust(parameter.bind(gil))) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. + } + + fn prepare_not_typed( + self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult { + let converted_parameters = self + .extract_parameters(gil, parameters_names)? + .iter() + .map(|parameter| py_to_rust(parameter.bind(gil))) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. + } + + fn extract_parameters( + &self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult> { + let mut params_as_pyobject: Vec = vec![]; + + for param_name in parameters_names { + match self.map_parameters.bind(gil).get_item(¶m_name) { + Ok(param_value) => params_as_pyobject.push(param_value.unbind()), + Err(_) => { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + format!("Cannot find parameter with name <{}>", param_name).into(), + )) + } + } + } + + Ok(params_as_pyobject) + } +} + +pub(crate) struct SequenceParametersBuilder { + seq_parameters: Vec, + types: Option>, +} + +impl SequenceParametersBuilder { + fn new(seq_parameters: Vec, types: Option>) -> Self { + Self { + seq_parameters: seq_parameters, + types, + } + } + + fn prepare(self, gil: Python<'_>) -> PSQLPyResult { + if self.types.is_some() { + return self.prepare_typed(gil); + } + + self.prepare_not_typed(gil) + } + + fn prepare_typed(self, gil: Python<'_>) -> PSQLPyResult { + let converted_parameters = self + .seq_parameters + .iter() + .map(|parameter| py_to_rust(parameter.bind(gil))) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. + + // Ok(prepared_parameters) // TODO: put there normal convert with types + } + + fn prepare_not_typed(self, gil: Python<'_>) -> PSQLPyResult { + let converted_parameters = self + .seq_parameters + .iter() + .map(|parameter| py_to_rust(parameter.bind(gil))) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. + } +} + +#[derive(Default, Clone, Debug)] +pub struct PreparedParameters { + parameters: Vec, + types: Vec, +} + +impl PreparedParameters { + pub fn new(parameters: Vec, types: Vec) -> Self { + Self { parameters, types } + } + + pub fn params(&self) -> Box<[&(dyn ToSql + Sync)]> { + let params_ref = &self.parameters; + params_ref + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice() + } + + pub fn params_typed(&self) -> Box<[(&(dyn ToSql + Sync), Type)]> { + let params_ref = &self.parameters; + let types = self.types.clone(); + let params_types = zip(params_ref, types); + params_types + .map(|(param, type_)| (param as &QueryParameter, type_)) + .collect::>() + .into_boxed_slice() + } +} diff --git a/src/statement/query.rs b/src/statement/query.rs new file mode 100644 index 00000000..7f87cede --- /dev/null +++ b/src/statement/query.rs @@ -0,0 +1,92 @@ +use std::fmt::Display; + +use regex::Regex; + +use crate::value_converter::consts::KWARGS_PARAMS_REGEXP; + +use super::traits::hash_str; + +#[derive(Clone)] +pub struct QueryString { + pub(crate) initial_qs: String, + // This field are used when kwargs passed + // from python side as parameters. + pub(crate) converted_qs: Option, +} + +impl Display for QueryString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.query()) + } +} + +impl QueryString { + pub fn new(initial_qs: &String) -> Self { + return Self { + initial_qs: initial_qs.clone(), + converted_qs: None, + }; + } + + pub(crate) fn query(&self) -> &str { + if let Some(converted_qs) = &self.converted_qs { + return converted_qs.query(); + } + + return &self.initial_qs; + } + + pub(crate) fn hash(&self) -> u64 { + hash_str(&self.initial_qs) + } + + pub(crate) fn process_qs(&mut self) { + if !self.is_kwargs_parametrized() { + return (); + } + + let mut counter = 0; + let mut parameters_names = Vec::new(); + + let re = Regex::new(KWARGS_PARAMS_REGEXP).unwrap(); + let result = re.replace_all(&self.initial_qs, |caps: ®ex::Captures| { + let parameter_idx = caps[1].to_string(); + + parameters_names.push(parameter_idx.clone()); + counter += 1; + + format!("${}", &counter) + }); + + self.converted_qs = Some(ConvertedQueryString::new(result.into(), parameters_names)); + } + + fn is_kwargs_parametrized(&self) -> bool { + Regex::new(KWARGS_PARAMS_REGEXP) + .unwrap() + .is_match(&self.initial_qs) + } +} + +#[derive(Clone)] +pub(crate) struct ConvertedQueryString { + converted_qs: String, + params_names: Vec, +} + +impl ConvertedQueryString { + fn new(converted_qs: String, params_names: Vec) -> Self { + Self { + converted_qs, + params_names, + } + } + + fn query(&self) -> &str { + &self.converted_qs + } + + pub(crate) fn params_names(&self) -> &Vec { + &self.params_names + } +} diff --git a/src/statement/statement.rs b/src/statement/statement.rs new file mode 100644 index 00000000..4c3a6e9b --- /dev/null +++ b/src/statement/statement.rs @@ -0,0 +1,30 @@ +use postgres_types::{ToSql, Type}; + +use super::{parameters::PreparedParameters, query::QueryString}; + +#[derive(Clone)] +pub struct PsqlpyStatement { + query: QueryString, + prepared_parameters: PreparedParameters, +} + +impl PsqlpyStatement { + pub fn new(query: QueryString, prepared_parameters: PreparedParameters) -> Self { + Self { + query, + prepared_parameters, + } + } + + pub fn sql_stmt(&self) -> &str { + self.query.query() + } + + pub fn params(&self) -> Box<[&(dyn ToSql + Sync)]> { + self.prepared_parameters.params() + } + + pub fn params_typed(&self) -> Box<[(&(dyn ToSql + Sync), Type)]> { + self.prepared_parameters.params_typed() + } +} diff --git a/src/statement/statement_builder.rs b/src/statement/statement_builder.rs new file mode 100644 index 00000000..07e003da --- /dev/null +++ b/src/statement/statement_builder.rs @@ -0,0 +1,100 @@ +use pyo3::PyObject; +use tokio_postgres::Statement; + +use crate::{driver::inner_connection::PsqlpyConnection, exceptions::rust_errors::PSQLPyResult}; + +use super::{ + cache::{StatementCacheInfo, STMTS_CACHE}, + parameters::ParametersBuilder, + query::QueryString, + statement::PsqlpyStatement, +}; + +pub struct StatementBuilder<'a> { + querystring: String, + parameters: Option, + inner_conn: &'a PsqlpyConnection, + prepared: bool, +} + +impl<'a> StatementBuilder<'a> { + pub fn new( + querystring: String, + parameters: Option, + inner_conn: &'a PsqlpyConnection, + prepared: Option, + ) -> Self { + Self { + querystring, + parameters, + inner_conn, + prepared: prepared.unwrap_or(true), + } + } + + pub async fn build(self) -> PSQLPyResult { + { + let stmt_cache_guard = STMTS_CACHE.read().await; + if let Some(cached) = stmt_cache_guard.get_cache(&self.querystring) { + return self.build_with_cached(cached); + } + } + + self.build_no_cached().await + } + + fn build_with_cached(self, cached: StatementCacheInfo) -> PSQLPyResult { + let raw_parameters = ParametersBuilder::new(&self.parameters, Some(cached.types())); + + let parameters_names = if let Some(converted_qs) = &cached.query.converted_qs { + Some(converted_qs.params_names().clone()) + } else { + None + }; + + let prepared_parameters = raw_parameters.prepare(parameters_names)?; + + return Ok(PsqlpyStatement::new(cached.query, prepared_parameters)); + } + + async fn build_no_cached(self) -> PSQLPyResult { + let mut querystring = QueryString::new(&self.querystring); + querystring.process_qs(); + + let prepared_stmt = self.prepare_query(&querystring).await?; + let parameters_builder = + ParametersBuilder::new(&self.parameters, Some(prepared_stmt.params().to_vec())); + + if !self.prepared { + Self::drop_prepared(self.inner_conn, &prepared_stmt).await?; + } + + let parameters_names = if let Some(converted_qs) = &querystring.converted_qs { + Some(converted_qs.params_names().clone()) + } else { + None + }; + + let prepared_parameters = parameters_builder.prepare(parameters_names)?; + + { + self.write_to_cache(&querystring, &prepared_stmt).await; + } + let statement = PsqlpyStatement::new(querystring, prepared_parameters); + + return Ok(statement); + } + + async fn write_to_cache(&self, query: &QueryString, inner_stmt: &Statement) { + let mut stmt_cache_guard = STMTS_CACHE.write().await; + stmt_cache_guard.add_cache(query, inner_stmt); + } + + async fn prepare_query(&self, query: &QueryString) -> PSQLPyResult { + self.inner_conn.prepare(query.query()).await + } + + async fn drop_prepared(inner_conn: &PsqlpyConnection, stmt: &Statement) -> PSQLPyResult<()> { + inner_conn.drop_prepared(stmt).await + } +} diff --git a/src/statement/traits.rs b/src/statement/traits.rs new file mode 100644 index 00000000..a79f8bdd --- /dev/null +++ b/src/statement/traits.rs @@ -0,0 +1,8 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +pub(crate) fn hash_str(string: &String) -> u64 { + let mut hasher = DefaultHasher::new(); + string.hash(&mut hasher); + + hasher.finish() +} diff --git a/src/statement/utils.rs b/src/statement/utils.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/statement/utils.rs @@ -0,0 +1 @@ + diff --git a/src/value_converter/consts.rs b/src/value_converter/consts.rs index 40fa932b..e5ff56e4 100644 --- a/src/value_converter/consts.rs +++ b/src/value_converter/consts.rs @@ -8,6 +8,8 @@ use pyo3::{ Bound, Py, PyResult, Python, }; +pub static KWARGS_PARAMS_REGEXP: &str = r"\$\(([^)]+)\)p"; + pub static DECIMAL_CLS: GILOnceCell> = GILOnceCell::new(); pub static TIMEDELTA_CLS: GILOnceCell> = GILOnceCell::new(); pub static KWARGS_QUERYSTRINGS: Lazy)>>> = diff --git a/src/value_converter/dto/converter_impls.rs b/src/value_converter/dto/converter_impls.rs index 97675af3..64948f29 100644 --- a/src/value_converter/dto/converter_impls.rs +++ b/src/value_converter/dto/converter_impls.rs @@ -10,30 +10,28 @@ use rust_decimal::Decimal; use uuid::Uuid; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, extra_types::{self, PythonDecimal, PythonUUID}, value_converter::{ additional_types::NonePyType, - funcs::from_python::{ - extract_datetime_from_python_object_attrs, py_sequence_into_postgres_array, - }, + from_python::{extract_datetime_from_python_object_attrs, py_sequence_into_postgres_array}, models::serde_value::build_serde_value, - traits::PythonToDTO, + traits::ToPythonDTO, }, }; use super::enums::PythonDTO; -impl PythonToDTO for NonePyType { - fn to_python_dto(_python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for NonePyType { + fn to_python_dto(_python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { Ok(PythonDTO::PyNone) } } macro_rules! construct_simple_type_matcher { ($match_type:ty, $kind:path) => { - impl PythonToDTO for $match_type { - fn to_python_dto(python_param: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + impl ToPythonDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { Ok($kind(python_param.extract::<$match_type>()?)) } } @@ -51,8 +49,8 @@ construct_simple_type_matcher!(i64, PythonDTO::PyIntI64); construct_simple_type_matcher!(NaiveDate, PythonDTO::PyDate); construct_simple_type_matcher!(NaiveTime, PythonDTO::PyTime); -impl PythonToDTO for PyDateTime { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for PyDateTime { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { let timestamp_tz = python_param.extract::>(); if let Ok(pydatetime_tz) = timestamp_tz { return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); @@ -74,8 +72,8 @@ impl PythonToDTO for PyDateTime { } } -impl PythonToDTO for PyDelta { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for PyDelta { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { let duration = python_param.extract::()?; if let Some(interval) = Interval::from_duration(duration) { return Ok(PythonDTO::PyInterval(interval)); @@ -86,8 +84,8 @@ impl PythonToDTO for PyDelta { } } -impl PythonToDTO for PyDict { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for PyDict { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { let serde_value = build_serde_value(python_param)?; return Ok(PythonDTO::PyJsonb(serde_value)); @@ -96,14 +94,22 @@ impl PythonToDTO for PyDict { macro_rules! construct_extra_type_matcher { ($match_type:ty, $kind:path) => { - impl PythonToDTO for $match_type { - fn to_python_dto(python_param: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + impl ToPythonDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { Ok($kind(python_param.extract::<$match_type>()?.inner())) } } }; } +construct_extra_type_matcher!(extra_types::Text, PythonDTO::PyText); +construct_extra_type_matcher!(extra_types::VarChar, PythonDTO::PyVarChar); +construct_extra_type_matcher!(extra_types::SmallInt, PythonDTO::PyIntI16); +construct_extra_type_matcher!(extra_types::Integer, PythonDTO::PyIntI32); +construct_extra_type_matcher!(extra_types::BigInt, PythonDTO::PyIntI64); +construct_extra_type_matcher!(extra_types::Float32, PythonDTO::PyFloat32); +construct_extra_type_matcher!(extra_types::Float64, PythonDTO::PyFloat64); +construct_extra_type_matcher!(extra_types::Money, PythonDTO::PyMoney); construct_extra_type_matcher!(extra_types::JSONB, PythonDTO::PyJsonb); construct_extra_type_matcher!(extra_types::JSON, PythonDTO::PyJson); construct_extra_type_matcher!(extra_types::MacAddr6, PythonDTO::PyMacAddr6); @@ -114,33 +120,34 @@ construct_extra_type_matcher!(extra_types::Path, PythonDTO::PyPath); construct_extra_type_matcher!(extra_types::Line, PythonDTO::PyLine); construct_extra_type_matcher!(extra_types::LineSegment, PythonDTO::PyLineSegment); construct_extra_type_matcher!(extra_types::Circle, PythonDTO::PyCircle); +construct_extra_type_matcher!(extra_types::PgVector, PythonDTO::PyPgVector); -impl PythonToDTO for PythonDecimal { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for PythonDecimal { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { Ok(PythonDTO::PyDecimal(Decimal::from_str_exact( python_param.str()?.extract::<&str>()?, )?)) } } -impl PythonToDTO for PythonUUID { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for PythonUUID { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { Ok(PythonDTO::PyUUID(Uuid::parse_str( python_param.str()?.extract::<&str>()?, )?)) } } -impl PythonToDTO for extra_types::PythonArray { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for extra_types::PythonArray { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { Ok(PythonDTO::PyArray(py_sequence_into_postgres_array( python_param, )?)) } } -impl PythonToDTO for IpAddr { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for IpAddr { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { if let Ok(id_address) = python_param.extract::() { return Ok(PythonDTO::PyIpAddress(id_address)); } @@ -151,8 +158,8 @@ impl PythonToDTO for IpAddr { } } -impl PythonToDTO for extra_types::PythonEnum { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +impl ToPythonDTO for extra_types::PythonEnum { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { let string = python_param.extract::()?; return Ok(PythonDTO::PyString(string)); } @@ -160,8 +167,8 @@ impl PythonToDTO for extra_types::PythonEnum { macro_rules! construct_array_type_matcher { ($match_type:ty) => { - impl PythonToDTO for $match_type { - fn to_python_dto(python_param: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { + impl ToPythonDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { python_param .extract::<$match_type>()? ._convert_to_python_dto() diff --git a/src/value_converter/dto/impls.rs b/src/value_converter/dto/impls.rs index b634d8b8..58debfdc 100644 --- a/src/value_converter/dto/impls.rs +++ b/src/value_converter/dto/impls.rs @@ -12,7 +12,7 @@ use pyo3::{Bound, IntoPyObject, PyAny, Python}; use tokio_postgres::types::{to_sql_checked, Type}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, value_converter::{ additional_types::{Circle, Line, RustLineSegment, RustLineString, RustPoint, RustRect}, models::serde_value::pythondto_array_to_serde, @@ -53,7 +53,7 @@ impl PythonDTO { /// /// # Errors /// May return Err Result if there is no support for passed python type. - pub fn array_type(&self) -> RustPSQLDriverPyResult { + pub fn array_type(&self) -> PSQLPyResult { match self { PythonDTO::PyBool(_) => Ok(tokio_postgres::types::Type::BOOL_ARRAY), PythonDTO::PyUUID(_) => Ok(tokio_postgres::types::Type::UUID_ARRAY), @@ -96,7 +96,7 @@ impl PythonDTO { /// /// # Errors /// May return Err Result if cannot convert python type into rust. - pub fn to_serde_value(&self) -> RustPSQLDriverPyResult { + pub fn to_serde_value(&self) -> PSQLPyResult { match self { PythonDTO::PyNone => Ok(Value::Null), PythonDTO::PyBool(pybool) => Ok(json!(pybool)), diff --git a/src/value_converter/funcs/from_python.rs b/src/value_converter/from_python.rs similarity index 64% rename from src/value_converter/funcs/from_python.rs rename to src/value_converter/from_python.rs index adad8879..b104c993 100644 --- a/src/value_converter/funcs/from_python.rs +++ b/src/value_converter/from_python.rs @@ -18,7 +18,7 @@ use pyo3::{ }; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, extra_types::{self}, value_converter::{ consts::KWARGS_QUERYSTRINGS, dto::enums::PythonDTO, @@ -26,6 +26,8 @@ use crate::{ }, }; +use super::{additional_types::NonePyType, consts::KWARGS_PARAMS_REGEXP, traits::ToPythonDTO}; + /// Convert single python parameter to `PythonDTO` enum. /// /// # Errors @@ -33,9 +35,9 @@ use crate::{ /// May return Err Result if python type doesn't have support yet /// or value of the type is incorrect. #[allow(clippy::too_many_lines)] -pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { if parameter.is_none() { - return Ok(PythonDTO::PyNone); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { @@ -45,387 +47,251 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyBool(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyBytes(parameter.extract::>()?)); + return as ToPythonDTO>::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyText( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyVarChar( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyString(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyFloat64(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyFloat32( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyFloat64( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI16( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI32( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI64( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyMoney( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI32(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - let timestamp_tz = parameter.extract::>(); - if let Ok(pydatetime_tz) = timestamp_tz { - return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); - } - - let timestamp_no_tz = parameter.extract::(); - if let Ok(pydatetime_no_tz) = timestamp_no_tz { - return Ok(PythonDTO::PyDateTime(pydatetime_no_tz)); - } - - let timestamp_tz = extract_datetime_from_python_object_attrs(parameter); - if let Ok(pydatetime_tz) = timestamp_tz { - return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); - } - - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "Can not convert you datetime to rust type".into(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyDate(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyTime(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - let duration = parameter.extract::()?; - if let Some(interval) = Interval::from_duration(duration) { - return Ok(PythonDTO::PyInterval(interval)); - } - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "Cannot convert timedelta from Python to inner Rust type.".to_string(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() | parameter.is_instance_of::() { - return Ok(PythonDTO::PyArray(py_sequence_into_postgres_array( - parameter, - )?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - let dict = parameter.downcast::().map_err(|error| { - RustPSQLDriverError::PyToRustValueConversionError(format!( - "Can't cast to inner dict: {error}" - )) - })?; - - let mut serde_map: Map = Map::new(); - - for dict_item in dict.items() { - let py_list = dict_item.downcast::().map_err(|error| { - RustPSQLDriverError::PyToRustValueConversionError(format!( - "Cannot cast to list: {error}" - )) - })?; - - let key = py_list.get_item(0)?.extract::()?; - let value = py_to_rust(&py_list.get_item(1)?)?; - - serde_map.insert(key, value.to_serde_value()?); - } - - return Ok(PythonDTO::PyJsonb(Value::Object(serde_map))); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyJsonb( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); + // return Ok(PythonDTO::PyJsonb( + // parameter.extract::()?.inner(), + // )); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyJson( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); + // return Ok(PythonDTO::PyJson( + // parameter.extract::()?.inner(), + // )); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyMacAddr6( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyMacAddr8( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyPoint( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyBox( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyPath( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyLine( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyLineSegment( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyCircle( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.get_type().name()? == "UUID" { - return Ok(PythonDTO::PyUUID(Uuid::parse_str( - parameter.str()?.extract::<&str>()?, - )?)); + return ::to_python_dto(parameter); } if parameter.get_type().name()? == "decimal.Decimal" || parameter.get_type().name()? == "Decimal" { - return Ok(PythonDTO::PyDecimal(Decimal::from_str_exact( - parameter.str()?.extract::<&str>()?, - )?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyPgVector( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } - if let Ok(id_address) = parameter.extract::() { - return Ok(PythonDTO::PyIpAddress(id_address)); + if let Ok(_) = parameter.extract::() { + return ::to_python_dto(parameter); } // It's used for Enum. @@ -502,9 +368,7 @@ pub fn extract_datetime_from_python_object_attrs( /// May return Err Result if cannot convert at least one element. #[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_wrap)] -pub fn py_sequence_into_postgres_array( - parameter: &Bound, -) -> RustPSQLDriverPyResult> { +pub fn py_sequence_into_postgres_array(parameter: &Bound) -> PSQLPyResult> { let mut py_seq = parameter .downcast::() .map_err(|_| { @@ -562,9 +426,7 @@ pub fn py_sequence_into_postgres_array( /// /// # Errors /// May return Err Result if cannot convert element into Rust one. -pub fn py_sequence_into_flat_vec( - parameter: &Bound, -) -> RustPSQLDriverPyResult> { +pub fn py_sequence_into_flat_vec(parameter: &Bound) -> PSQLPyResult> { let py_seq = parameter.downcast::().map_err(|_| { RustPSQLDriverError::PyToRustValueConversionError( "PostgreSQL ARRAY type can be made only from python Sequence".into(), @@ -597,85 +459,6 @@ pub fn py_sequence_into_flat_vec( Ok(final_vec) } -/// Convert parameters come from python. -/// -/// Parameters for `execute()` method can be either -/// a list or a tuple or a set. -/// -/// We parse every parameter from python object and return -/// Vector of out `PythonDTO`. -/// -/// # Errors -/// -/// May return Err Result if can't convert python object. -#[allow(clippy::needless_pass_by_value)] -pub fn convert_parameters_and_qs( - querystring: String, - parameters: Option>, -) -> RustPSQLDriverPyResult<(String, Vec)> { - let Some(parameters) = parameters else { - return Ok((querystring, vec![])); - }; - - let res = Python::with_gil(|gil| { - let params = parameters.extract::>>(gil).map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - "Cannot convert you parameters argument into Rust type, please use List/Tuple" - .into(), - ) - }); - if let Ok(params) = params { - return Ok((querystring, convert_seq_parameters(params)?)); - } - - let kw_params = parameters.downcast_bound::(gil); - if let Ok(kw_params) = kw_params { - return convert_kwargs_parameters(kw_params, &querystring); - } - - Err(RustPSQLDriverError::PyToRustValueConversionError( - "Parameters must be sequence or mapping".into(), - )) - })?; - - Ok(res) -} - -pub fn convert_kwargs_parameters<'a>( - kw_params: &Bound<'_, PyMapping>, - querystring: &'a str, -) -> RustPSQLDriverPyResult<(String, Vec)> { - let mut result_vec: Vec = vec![]; - let (changed_string, params_names) = parse_kwargs_qs(querystring); - - for param_name in params_names { - match kw_params.get_item(¶m_name) { - Ok(param) => result_vec.push(py_to_rust(¶m)?), - Err(_) => { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - format!("Cannot find parameter with name <{param_name}> in parameters").into(), - )) - } - } - } - - Ok((changed_string, result_vec)) -} - -pub fn convert_seq_parameters( - seq_params: Vec>, -) -> RustPSQLDriverPyResult> { - let mut result_vec: Vec = vec![]; - Python::with_gil(|gil| { - for parameter in seq_params { - result_vec.push(py_to_rust(parameter.bind(gil))?); - } - Ok::<(), RustPSQLDriverError>(()) - })?; - - Ok(result_vec) -} - /// Convert two python parameters(x and y) to Coord from `geo_type`. /// Also it checks that passed values is int or float. /// @@ -683,7 +466,7 @@ pub fn convert_seq_parameters( /// /// May return error if cannot convert Python type into Rust one. /// May return error if parameters type isn't correct. -fn convert_py_to_rust_coord_values(parameters: Vec>) -> RustPSQLDriverPyResult> { +fn convert_py_to_rust_coord_values(parameters: Vec>) -> PSQLPyResult> { Python::with_gil(|gil| { let mut coord_values_vec: Vec = vec![]; @@ -737,7 +520,7 @@ fn convert_py_to_rust_coord_values(parameters: Vec>) -> RustPSQLDriver pub fn build_geo_coords( py_parameters: Py, allowed_length_option: Option, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let mut result_vec: Vec = vec![]; result_vec = Python::with_gil(|gil| { @@ -811,7 +594,7 @@ pub fn build_geo_coords( pub fn build_flat_geo_coords( py_parameters: Py, allowed_length_option: Option, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { Python::with_gil(|gil| { let allowed_length = allowed_length_option.unwrap_or_default(); @@ -845,7 +628,7 @@ pub fn build_flat_geo_coords( /// /// May return error if cannot convert Python type into Rust one. /// May return error if parameters type isn't correct. -fn py_sequence_to_rust(bind_parameters: &Bound) -> RustPSQLDriverPyResult>> { +fn py_sequence_to_rust(bind_parameters: &Bound) -> PSQLPyResult>> { let mut coord_values_sequence_vec: Vec> = vec![]; if bind_parameters.is_instance_of::() { @@ -877,7 +660,7 @@ fn py_sequence_to_rust(bind_parameters: &Bound) -> RustPSQLDriverPyResult } fn parse_kwargs_qs(querystring: &str) -> (String, Vec) { - let re = regex::Regex::new(r"\$\(([^)]+)\)p").unwrap(); + let re = regex::Regex::new(KWARGS_PARAMS_REGEXP).unwrap(); { let kq_read = KWARGS_QUERYSTRINGS.read().unwrap(); diff --git a/src/value_converter/funcs/mod.rs b/src/value_converter/funcs/mod.rs deleted file mode 100644 index 4db4cd38..00000000 --- a/src/value_converter/funcs/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod from_python; -pub mod to_python; diff --git a/src/value_converter/mod.rs b/src/value_converter/mod.rs index 7d08bf3f..41c42284 100644 --- a/src/value_converter/mod.rs +++ b/src/value_converter/mod.rs @@ -1,7 +1,8 @@ pub mod additional_types; pub mod consts; pub mod dto; -pub mod funcs; +pub mod from_python; pub mod models; +pub mod to_python; pub mod traits; pub mod utils; diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index 0bf6652f..71239c2b 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -1,7 +1,6 @@ use postgres_array::{Array, Dimension}; use postgres_types::FromSql; use serde_json::{json, Map, Value}; -use uuid::Uuid; use pyo3::{ types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyTuple}, @@ -10,10 +9,9 @@ use pyo3::{ use tokio_postgres::types::Type; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, value_converter::{ - dto::enums::PythonDTO, - funcs::{from_python::py_to_rust, to_python::build_python_from_serde_value}, + dto::enums::PythonDTO, from_python::py_to_rust, to_python::build_python_from_serde_value, }, }; @@ -54,10 +52,7 @@ impl<'a> FromSql<'a> for InternalSerdeValue { } } -fn serde_value_from_list( - gil: Python<'_>, - bind_value: &Bound<'_, PyAny>, -) -> RustPSQLDriverPyResult { +fn serde_value_from_list(gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { let mut result_vec: Vec = vec![]; let params = bind_value.extract::>>()?; @@ -79,7 +74,7 @@ fn serde_value_from_list( Ok(json!(result_vec)) } -fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { let dict = bind_value.downcast::().map_err(|error| { RustPSQLDriverError::PyToRustValueConversionError(format!( "Can't cast to inner dict: {error}" @@ -109,7 +104,7 @@ fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResul /// # Errors /// May return error if cannot convert Python type into Rust one. #[allow(clippy::needless_pass_by_value)] -pub fn build_serde_value(value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +pub fn build_serde_value(value: &Bound<'_, PyAny>) -> PSQLPyResult { Python::with_gil(|gil| { if value.is_instance_of::() { return serde_value_from_list(gil, value); @@ -126,7 +121,7 @@ pub fn build_serde_value(value: &Bound<'_, PyAny>) -> RustPSQLDriverPyResult>) -> RustPSQLDriverPyResult { +pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult { match array { Some(array) => inner_pythondto_array_to_serde( array.dimensions(), @@ -145,7 +140,7 @@ fn inner_pythondto_array_to_serde( data: &[&PythonDTO], dimension_index: usize, mut lower_bound: usize, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { let current_dimension = dimensions.get(dimension_index); if let Some(current_dimension) = current_dimension { diff --git a/src/value_converter/params_converters.rs b/src/value_converter/params_converters.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/value_converter/funcs/to_python.rs b/src/value_converter/to_python.rs similarity index 97% rename from src/value_converter/funcs/to_python.rs rename to src/value_converter/to_python.rs index e65a0085..5dbfd7ce 100644 --- a/src/value_converter/funcs/to_python.rs +++ b/src/value_converter/to_python.rs @@ -17,7 +17,7 @@ use pyo3::{ }; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, value_converter::{ additional_types::{ Circle, Line, RustLineSegment, RustLineString, RustMacAddr6, RustMacAddr8, RustPoint, @@ -35,10 +35,7 @@ use pgvector::Vector as PgVector; /// Convert serde `Value` into Python object. /// # Errors /// May return Err Result if cannot add new value to Python Dict. -pub fn build_python_from_serde_value( - py: Python<'_>, - value: Value, -) -> RustPSQLDriverPyResult> { +pub fn build_python_from_serde_value(py: Python<'_>, value: Value) -> PSQLPyResult> { match value { Value::Array(massive) => { let mut result_vec: Vec> = vec![]; @@ -112,7 +109,7 @@ fn composite_field_postgres_to_py<'a, T: FromSql<'a>>( type_: &Type, buf: &mut &'a [u8], is_simple: bool, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if is_simple { return T::from_sql_nullable(type_, Some(buf)).map_err(|err| { RustPSQLDriverError::RustToPyValueConversionError(format!( @@ -196,7 +193,7 @@ fn postgres_bytes_to_py( type_: &Type, buf: &mut &[u8], is_simple: bool, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { match *type_ { // ---------- Bytes Types ---------- // Convert BYTEA type into Vector, then into PyBytes @@ -524,7 +521,7 @@ pub fn other_postgres_bytes_to_py( type_: &Type, buf: &mut &[u8], is_simple: bool, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { if type_.name() == "vector" { let vector = composite_field_postgres_to_py::>(type_, buf, is_simple)?; match vector { @@ -550,7 +547,7 @@ pub fn composite_postgres_to_py( fields: &Vec, buf: &mut &[u8], custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let result_py_dict: Bound<'_, PyDict> = PyDict::new_bound(py); let num_fields = postgres_types::private::read_be_i32(buf).map_err(|err| { @@ -619,7 +616,7 @@ pub fn raw_bytes_data_process( column_name: &str, column_type: &Type, custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { if let Some(custom_decoders) = custom_decoders { let py_encoder_func = custom_decoders .bind(py) @@ -658,7 +655,7 @@ pub fn postgres_to_py( column: &Column, column_i: usize, custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let raw_bytes_data = row.col_buffer(column_i); if let Some(mut raw_bytes_data) = raw_bytes_data { return raw_bytes_data_process( @@ -679,7 +676,7 @@ pub fn postgres_to_py( /// /// May return error if cannot convert Python type into Rust one. /// May return error if parameters type isn't correct. -fn py_sequence_to_rust(bind_parameters: &Bound) -> RustPSQLDriverPyResult>> { +fn py_sequence_to_rust(bind_parameters: &Bound) -> PSQLPyResult>> { let mut coord_values_sequence_vec: Vec> = vec![]; if bind_parameters.is_instance_of::() { diff --git a/src/value_converter/traits.rs b/src/value_converter/traits.rs index ca44a7d0..261ee16d 100644 --- a/src/value_converter/traits.rs +++ b/src/value_converter/traits.rs @@ -1,9 +1,9 @@ use pyo3::PyAny; -use crate::exceptions::rust_errors::RustPSQLDriverPyResult; +use crate::exceptions::rust_errors::PSQLPyResult; use super::dto::enums::PythonDTO; -pub trait PythonToDTO { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult; +pub trait ToPythonDTO { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult; } From af51f438c5e9e0f7685bd37ab38453153b137da8 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 3 May 2025 18:12:59 +0200 Subject: [PATCH 03/15] Full value converter refactor --- python/tests/test_value_converter.py | 42 +++- src/exceptions/rust_errors.rs | 2 +- src/extra_types.rs | 69 +++--- src/statement/parameters.rs | 30 +-- src/statement/statement.rs | 2 +- src/value_converter/consts.rs | 3 - src/value_converter/dto/converter_impls.rs | 153 ++++++------ src/value_converter/dto/enums.rs | 7 +- src/value_converter/dto/funcs.rs | 33 +++ src/value_converter/dto/impls.rs | 60 ++--- src/value_converter/dto/mod.rs | 1 + src/value_converter/from_python.rs | 262 ++++++++++++++------- src/value_converter/models/serde_value.rs | 7 +- src/value_converter/params_converters.rs | 0 src/value_converter/to_python.rs | 70 ------ src/value_converter/traits.rs | 8 + 16 files changed, 435 insertions(+), 314 deletions(-) create mode 100644 src/value_converter/dto/funcs.rs delete mode 100644 src/value_converter/params_converters.rs diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 34361b22..c35baec1 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -141,7 +141,6 @@ async def test_as_class( ("INT2", SmallInt(12), 12), ("INT4", Integer(121231231), 121231231), ("INT8", BigInt(99999999999999999), 99999999999999999), - ("MONEY", BigInt(99999999999999999), 99999999999999999), ("MONEY", Money(99999999999999999), 99999999999999999), ("NUMERIC(5, 2)", Decimal("120.12"), Decimal("120.12")), ("FLOAT8", 32.12329864501953, 32.12329864501953), @@ -270,11 +269,6 @@ async def test_as_class( [Money(99999999999999999), Money(99999999999999999)], [99999999999999999, 99999999999999999], ), - ( - "MONEY ARRAY", - [[Money(99999999999999999)], [Money(99999999999999999)]], - [[99999999999999999], [99999999999999999]], - ), ( "NUMERIC(5, 2) ARRAY", [Decimal("121.23"), Decimal("188.99")], @@ -666,6 +660,37 @@ async def test_deserialization_simple_into_python( postgres_type: str, py_value: Any, expected_deserialized: Any, +) -> None: + """Test how types can cast from Python and to Python.""" + connection = await psql_pool.connection() + table_name = f"for_test{uuid.uuid4().hex}" + await connection.execute(f"DROP TABLE IF EXISTS {table_name}") + create_table_query = f""" + CREATE TABLE {table_name} (test_field {postgres_type}) + """ + insert_data_query = f""" + INSERT INTO {table_name} VALUES ($1) + """ + await connection.execute(querystring=create_table_query) + await connection.execute( + querystring=insert_data_query, + parameters=[py_value], + ) + + raw_result = await connection.execute( + querystring=f"SELECT test_field FROM {table_name}", + ) + + assert raw_result.result()[0]["test_field"] == expected_deserialized + + await connection.execute(f"DROP TABLE IF EXISTS {table_name}") + + +async def test_aboba( + psql_pool: ConnectionPool, + postgres_type: str = "INT2", + py_value: Any = 2, + expected_deserialized: Any = 2, ) -> None: """Test how types can cast from Python and to Python.""" connection = await psql_pool.connection() @@ -1175,11 +1200,6 @@ async def test_empty_array( MoneyArray([Money(99999999999999999), Money(99999999999999999)]), [99999999999999999, 99999999999999999], ), - ( - "MONEY ARRAY", - MoneyArray([[Money(99999999999999999)], [Money(99999999999999999)]]), - [[99999999999999999], [99999999999999999]], - ), ( "NUMERIC(5, 2) ARRAY", NumericArray([Decimal("121.23"), Decimal("188.99")]), diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index b6694da1..94b89fa0 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -76,7 +76,7 @@ pub enum RustPSQLDriverError { #[error("Can't convert value from driver to python type: {0}")] RustToPyValueConversionError(String), - #[error("Can't convert value from python to rust type: {0}")] + #[error("{0}")] PyToRustValueConversionError(String), #[error("Python exception: {0}.")] diff --git a/src/extra_types.rs b/src/extra_types.rs index c3b2d832..b3411eae 100644 --- a/src/extra_types.rs +++ b/src/extra_types.rs @@ -2,8 +2,8 @@ use std::str::FromStr; use geo_types::{Line as RustLineSegment, LineString, Point as RustPoint, Rect as RustRect}; use macaddr::{MacAddr6 as RustMacAddr6, MacAddr8 as RustMacAddr8}; +use postgres_types::Type; use pyo3::{ - conversion::FromPyObjectBound, pyclass, pymethods, types::{PyModule, PyModuleMethods}, Bound, Py, PyAny, PyResult, Python, @@ -325,7 +325,7 @@ impl Circle { } macro_rules! build_array_type { - ($st_name:ident, $kind:path) => { + ($st_name:ident, $kind:path, $elem_kind:path) => { #[pyclass] #[derive(Clone)] pub struct $st_name { @@ -347,11 +347,15 @@ macro_rules! build_array_type { self.inner.clone() } + pub fn element_type() -> Type { + $elem_kind + } + /// Convert incoming sequence from python to internal `PythonDTO`. /// /// # Errors /// May return Err Result if cannot convert sequence to array. - pub fn _convert_to_python_dto(&self) -> PSQLPyResult { + pub fn _convert_to_python_dto(&self, elem_type: &Type) -> PSQLPyResult { return Python::with_gil(|gil| { let binding = &self.inner; let bound_inner = Ok::<&pyo3::Bound<'_, pyo3::PyAny>, RustPSQLDriverError>( @@ -359,6 +363,7 @@ macro_rules! build_array_type { )?; Ok::($kind(py_sequence_into_postgres_array( bound_inner, + elem_type, )?)) }); } @@ -366,33 +371,37 @@ macro_rules! build_array_type { }; } -build_array_type!(BoolArray, PythonDTO::PyBoolArray); -build_array_type!(UUIDArray, PythonDTO::PyUuidArray); -build_array_type!(VarCharArray, PythonDTO::PyVarCharArray); -build_array_type!(TextArray, PythonDTO::PyTextArray); -build_array_type!(Int16Array, PythonDTO::PyInt16Array); -build_array_type!(Int32Array, PythonDTO::PyInt32Array); -build_array_type!(Int64Array, PythonDTO::PyInt64Array); -build_array_type!(Float32Array, PythonDTO::PyFloat32Array); -build_array_type!(Float64Array, PythonDTO::PyFloat64Array); -build_array_type!(MoneyArray, PythonDTO::PyMoneyArray); -build_array_type!(IpAddressArray, PythonDTO::PyIpAddressArray); -build_array_type!(JSONBArray, PythonDTO::PyJSONBArray); -build_array_type!(JSONArray, PythonDTO::PyJSONArray); -build_array_type!(DateArray, PythonDTO::PyDateArray); -build_array_type!(TimeArray, PythonDTO::PyTimeArray); -build_array_type!(DateTimeArray, PythonDTO::PyDateTimeArray); -build_array_type!(DateTimeTZArray, PythonDTO::PyDateTimeTZArray); -build_array_type!(MacAddr6Array, PythonDTO::PyMacAddr6Array); -build_array_type!(MacAddr8Array, PythonDTO::PyMacAddr8Array); -build_array_type!(NumericArray, PythonDTO::PyNumericArray); -build_array_type!(PointArray, PythonDTO::PyPointArray); -build_array_type!(BoxArray, PythonDTO::PyBoxArray); -build_array_type!(PathArray, PythonDTO::PyPathArray); -build_array_type!(LineArray, PythonDTO::PyLineArray); -build_array_type!(LsegArray, PythonDTO::PyLsegArray); -build_array_type!(CircleArray, PythonDTO::PyCircleArray); -build_array_type!(IntervalArray, PythonDTO::PyIntervalArray); +build_array_type!(BoolArray, PythonDTO::PyBoolArray, Type::BOOL); +build_array_type!(UUIDArray, PythonDTO::PyUuidArray, Type::UUID); +build_array_type!(VarCharArray, PythonDTO::PyVarCharArray, Type::VARCHAR); +build_array_type!(TextArray, PythonDTO::PyTextArray, Type::TEXT); +build_array_type!(Int16Array, PythonDTO::PyInt16Array, Type::INT2); +build_array_type!(Int32Array, PythonDTO::PyInt32Array, Type::INT4); +build_array_type!(Int64Array, PythonDTO::PyInt64Array, Type::INT8); +build_array_type!(Float32Array, PythonDTO::PyFloat32Array, Type::FLOAT4); +build_array_type!(Float64Array, PythonDTO::PyFloat64Array, Type::FLOAT8); +build_array_type!(MoneyArray, PythonDTO::PyMoneyArray, Type::MONEY); +build_array_type!(IpAddressArray, PythonDTO::PyIpAddressArray, Type::INET); +build_array_type!(JSONBArray, PythonDTO::PyJSONBArray, Type::JSONB); +build_array_type!(JSONArray, PythonDTO::PyJSONArray, Type::JSON); +build_array_type!(DateArray, PythonDTO::PyDateArray, Type::DATE); +build_array_type!(TimeArray, PythonDTO::PyTimeArray, Type::TIME); +build_array_type!(DateTimeArray, PythonDTO::PyDateTimeArray, Type::TIMESTAMP); +build_array_type!( + DateTimeTZArray, + PythonDTO::PyDateTimeTZArray, + Type::TIMESTAMPTZ +); +build_array_type!(MacAddr6Array, PythonDTO::PyMacAddr6Array, Type::MACADDR); +build_array_type!(MacAddr8Array, PythonDTO::PyMacAddr8Array, Type::MACADDR8); +build_array_type!(NumericArray, PythonDTO::PyNumericArray, Type::NUMERIC); +build_array_type!(PointArray, PythonDTO::PyPointArray, Type::POINT); +build_array_type!(BoxArray, PythonDTO::PyBoxArray, Type::BOX); +build_array_type!(PathArray, PythonDTO::PyPathArray, Type::PATH); +build_array_type!(LineArray, PythonDTO::PyLineArray, Type::LINE); +build_array_type!(LsegArray, PythonDTO::PyLsegArray, Type::LSEG); +build_array_type!(CircleArray, PythonDTO::PyCircleArray, Type::CIRCLE); +build_array_type!(IntervalArray, PythonDTO::PyIntervalArray, Type::INTERVAL); #[allow(clippy::module_name_repetitions)] #[allow(clippy::missing_errors_doc)] diff --git a/src/statement/parameters.rs b/src/statement/parameters.rs index baeded5d..0a2d9105 100644 --- a/src/statement/parameters.rs +++ b/src/statement/parameters.rs @@ -9,7 +9,10 @@ use pyo3::{ use crate::{ exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, - value_converter::{dto::enums::PythonDTO, from_python::py_to_rust}, + value_converter::{ + dto::enums::PythonDTO, + from_python::{from_python_typed, from_python_untyped}, + }, }; pub type QueryParameter = (dyn ToSql + Sync); @@ -137,7 +140,7 @@ impl MappingParametersBuilder { let converted_parameters = self .extract_parameters(gil, parameters_names)? .iter() - .map(|parameter| py_to_rust(parameter.bind(gil))) + .map(|parameter| from_python_untyped(parameter.bind(gil))) .collect::>>()?; Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. @@ -151,7 +154,7 @@ impl MappingParametersBuilder { let converted_parameters = self .extract_parameters(gil, parameters_names)? .iter() - .map(|parameter| py_to_rust(parameter.bind(gil))) + .map(|parameter| from_python_untyped(parameter.bind(gil))) .collect::>>()?; Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. @@ -193,30 +196,29 @@ impl SequenceParametersBuilder { } fn prepare(self, gil: Python<'_>) -> PSQLPyResult { - if self.types.is_some() { - return self.prepare_typed(gil); + let types = self.types.clone(); + + if types.is_some() { + return self.prepare_typed(gil, types.clone().unwrap()); } self.prepare_not_typed(gil) } - fn prepare_typed(self, gil: Python<'_>) -> PSQLPyResult { - let converted_parameters = self - .seq_parameters - .iter() - .map(|parameter| py_to_rust(parameter.bind(gil))) + fn prepare_typed(self, gil: Python<'_>, types: Vec) -> PSQLPyResult { + let zipped_params_types = zip(self.seq_parameters, &types); + let converted_parameters = zipped_params_types + .map(|(parameter, type_)| from_python_typed(parameter.bind(gil), &type_)) .collect::>>()?; - Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. - - // Ok(prepared_parameters) // TODO: put there normal convert with types + Ok(PreparedParameters::new(converted_parameters, types)) } fn prepare_not_typed(self, gil: Python<'_>) -> PSQLPyResult { let converted_parameters = self .seq_parameters .iter() - .map(|parameter| py_to_rust(parameter.bind(gil))) + .map(|parameter| from_python_untyped(parameter.bind(gil))) .collect::>>()?; Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. diff --git a/src/statement/statement.rs b/src/statement/statement.rs index 4c3a6e9b..4cfdc09c 100644 --- a/src/statement/statement.rs +++ b/src/statement/statement.rs @@ -9,7 +9,7 @@ pub struct PsqlpyStatement { } impl PsqlpyStatement { - pub fn new(query: QueryString, prepared_parameters: PreparedParameters) -> Self { + pub(crate) fn new(query: QueryString, prepared_parameters: PreparedParameters) -> Self { Self { query, prepared_parameters, diff --git a/src/value_converter/consts.rs b/src/value_converter/consts.rs index e5ff56e4..82a34f0f 100644 --- a/src/value_converter/consts.rs +++ b/src/value_converter/consts.rs @@ -1,5 +1,4 @@ use once_cell::sync::Lazy; -use postgres_types::ToSql; use std::{collections::HashMap, sync::RwLock}; use pyo3::{ @@ -35,5 +34,3 @@ pub fn get_timedelta_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> { }) .map(|ty| ty.bind(py)) } - -pub type QueryParameter = (dyn ToSql + Sync); diff --git a/src/value_converter/dto/converter_impls.rs b/src/value_converter/dto/converter_impls.rs index 64948f29..1e6fa7be 100644 --- a/src/value_converter/dto/converter_impls.rs +++ b/src/value_converter/dto/converter_impls.rs @@ -2,11 +2,13 @@ use std::net::IpAddr; use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; use pg_interval::Interval; +use postgres_types::Type; use pyo3::{ types::{PyAnyMethods, PyDateTime, PyDelta, PyDict}, Bound, PyAny, }; use rust_decimal::Decimal; +use serde::de::IntoDeserializer; use uuid::Uuid; use crate::{ @@ -16,11 +18,11 @@ use crate::{ additional_types::NonePyType, from_python::{extract_datetime_from_python_object_attrs, py_sequence_into_postgres_array}, models::serde_value::build_serde_value, - traits::ToPythonDTO, + traits::{ToPythonDTO, ToPythonDTOArray}, }, }; -use super::enums::PythonDTO; +use super::{enums::PythonDTO, funcs::array_type_to_single_type}; impl ToPythonDTO for NonePyType { fn to_python_dto(_python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { @@ -28,7 +30,7 @@ impl ToPythonDTO for NonePyType { } } -macro_rules! construct_simple_type_matcher { +macro_rules! construct_simple_type_converter { ($match_type:ty, $kind:path) => { impl ToPythonDTO for $match_type { fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { @@ -38,16 +40,16 @@ macro_rules! construct_simple_type_matcher { }; } -construct_simple_type_matcher!(bool, PythonDTO::PyBool); -construct_simple_type_matcher!(Vec, PythonDTO::PyBytes); -construct_simple_type_matcher!(String, PythonDTO::PyString); -construct_simple_type_matcher!(f32, PythonDTO::PyFloat32); -construct_simple_type_matcher!(f64, PythonDTO::PyFloat64); -construct_simple_type_matcher!(i16, PythonDTO::PyIntI16); -construct_simple_type_matcher!(i32, PythonDTO::PyIntI32); -construct_simple_type_matcher!(i64, PythonDTO::PyIntI64); -construct_simple_type_matcher!(NaiveDate, PythonDTO::PyDate); -construct_simple_type_matcher!(NaiveTime, PythonDTO::PyTime); +construct_simple_type_converter!(bool, PythonDTO::PyBool); +construct_simple_type_converter!(Vec, PythonDTO::PyBytes); +construct_simple_type_converter!(String, PythonDTO::PyString); +construct_simple_type_converter!(f32, PythonDTO::PyFloat32); +construct_simple_type_converter!(f64, PythonDTO::PyFloat64); +construct_simple_type_converter!(i16, PythonDTO::PyIntI16); +construct_simple_type_converter!(i32, PythonDTO::PyIntI32); +construct_simple_type_converter!(i64, PythonDTO::PyIntI64); +construct_simple_type_converter!(NaiveDate, PythonDTO::PyDate); +construct_simple_type_converter!(NaiveTime, PythonDTO::PyTime); impl ToPythonDTO for PyDateTime { fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { @@ -92,7 +94,7 @@ impl ToPythonDTO for PyDict { } } -macro_rules! construct_extra_type_matcher { +macro_rules! construct_extra_type_converter { ($match_type:ty, $kind:path) => { impl ToPythonDTO for $match_type { fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { @@ -102,25 +104,26 @@ macro_rules! construct_extra_type_matcher { }; } -construct_extra_type_matcher!(extra_types::Text, PythonDTO::PyText); -construct_extra_type_matcher!(extra_types::VarChar, PythonDTO::PyVarChar); -construct_extra_type_matcher!(extra_types::SmallInt, PythonDTO::PyIntI16); -construct_extra_type_matcher!(extra_types::Integer, PythonDTO::PyIntI32); -construct_extra_type_matcher!(extra_types::BigInt, PythonDTO::PyIntI64); -construct_extra_type_matcher!(extra_types::Float32, PythonDTO::PyFloat32); -construct_extra_type_matcher!(extra_types::Float64, PythonDTO::PyFloat64); -construct_extra_type_matcher!(extra_types::Money, PythonDTO::PyMoney); -construct_extra_type_matcher!(extra_types::JSONB, PythonDTO::PyJsonb); -construct_extra_type_matcher!(extra_types::JSON, PythonDTO::PyJson); -construct_extra_type_matcher!(extra_types::MacAddr6, PythonDTO::PyMacAddr6); -construct_extra_type_matcher!(extra_types::MacAddr8, PythonDTO::PyMacAddr8); -construct_extra_type_matcher!(extra_types::Point, PythonDTO::PyPoint); -construct_extra_type_matcher!(extra_types::Box, PythonDTO::PyBox); -construct_extra_type_matcher!(extra_types::Path, PythonDTO::PyPath); -construct_extra_type_matcher!(extra_types::Line, PythonDTO::PyLine); -construct_extra_type_matcher!(extra_types::LineSegment, PythonDTO::PyLineSegment); -construct_extra_type_matcher!(extra_types::Circle, PythonDTO::PyCircle); -construct_extra_type_matcher!(extra_types::PgVector, PythonDTO::PyPgVector); +construct_extra_type_converter!(extra_types::Text, PythonDTO::PyText); +construct_extra_type_converter!(extra_types::VarChar, PythonDTO::PyVarChar); +construct_extra_type_converter!(extra_types::SmallInt, PythonDTO::PyIntI16); +construct_extra_type_converter!(extra_types::Integer, PythonDTO::PyIntI32); +construct_extra_type_converter!(extra_types::BigInt, PythonDTO::PyIntI64); +construct_extra_type_converter!(extra_types::Float32, PythonDTO::PyFloat32); +construct_extra_type_converter!(extra_types::Float64, PythonDTO::PyFloat64); +construct_extra_type_converter!(extra_types::Money, PythonDTO::PyMoney); +construct_extra_type_converter!(extra_types::JSONB, PythonDTO::PyJsonb); +construct_extra_type_converter!(extra_types::JSON, PythonDTO::PyJson); +construct_extra_type_converter!(extra_types::MacAddr6, PythonDTO::PyMacAddr6); +construct_extra_type_converter!(extra_types::MacAddr8, PythonDTO::PyMacAddr8); +construct_extra_type_converter!(extra_types::Point, PythonDTO::PyPoint); +construct_extra_type_converter!(extra_types::Box, PythonDTO::PyBox); +construct_extra_type_converter!(extra_types::Path, PythonDTO::PyPath); +construct_extra_type_converter!(extra_types::Line, PythonDTO::PyLine); +construct_extra_type_converter!(extra_types::LineSegment, PythonDTO::PyLineSegment); +construct_extra_type_converter!(extra_types::Circle, PythonDTO::PyCircle); +construct_extra_type_converter!(extra_types::PgVector, PythonDTO::PyPgVector); +construct_extra_type_converter!(extra_types::CustomType, PythonDTO::PyCustomType); impl ToPythonDTO for PythonDecimal { fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { @@ -138,11 +141,16 @@ impl ToPythonDTO for PythonUUID { } } -impl ToPythonDTO for extra_types::PythonArray { - fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { - Ok(PythonDTO::PyArray(py_sequence_into_postgres_array( - python_param, - )?)) +impl ToPythonDTOArray for extra_types::PythonArray { + fn to_python_dto( + python_param: &pyo3::Bound<'_, PyAny>, + array_type: Type, + ) -> PSQLPyResult { + let elem_type = array_type_to_single_type(&array_type); + Ok(PythonDTO::PyArray( + py_sequence_into_postgres_array(python_param, &elem_type)?, + array_type, + )) } } @@ -160,47 +168,54 @@ impl ToPythonDTO for IpAddr { impl ToPythonDTO for extra_types::PythonEnum { fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { - let string = python_param.extract::()?; - return Ok(PythonDTO::PyString(string)); + if let Ok(value_attr) = python_param.getattr("value") { + if let Ok(possible_string) = value_attr.extract::() { + return Ok(PythonDTO::PyString(possible_string)); + } + } + + Err(RustPSQLDriverError::PyToRustValueConversionError( + "Cannot convert Enum to inner type".into(), + )) } } -macro_rules! construct_array_type_matcher { +macro_rules! construct_array_type_converter { ($match_type:ty) => { impl ToPythonDTO for $match_type { fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { python_param .extract::<$match_type>()? - ._convert_to_python_dto() + ._convert_to_python_dto(&Self::element_type()) } } }; } -construct_array_type_matcher!(extra_types::BoolArray); -construct_array_type_matcher!(extra_types::UUIDArray); -construct_array_type_matcher!(extra_types::VarCharArray); -construct_array_type_matcher!(extra_types::TextArray); -construct_array_type_matcher!(extra_types::Int16Array); -construct_array_type_matcher!(extra_types::Int32Array); -construct_array_type_matcher!(extra_types::Int64Array); -construct_array_type_matcher!(extra_types::Float32Array); -construct_array_type_matcher!(extra_types::Float64Array); -construct_array_type_matcher!(extra_types::MoneyArray); -construct_array_type_matcher!(extra_types::IpAddressArray); -construct_array_type_matcher!(extra_types::JSONBArray); -construct_array_type_matcher!(extra_types::JSONArray); -construct_array_type_matcher!(extra_types::DateArray); -construct_array_type_matcher!(extra_types::TimeArray); -construct_array_type_matcher!(extra_types::DateTimeArray); -construct_array_type_matcher!(extra_types::DateTimeTZArray); -construct_array_type_matcher!(extra_types::MacAddr6Array); -construct_array_type_matcher!(extra_types::MacAddr8Array); -construct_array_type_matcher!(extra_types::NumericArray); -construct_array_type_matcher!(extra_types::PointArray); -construct_array_type_matcher!(extra_types::BoxArray); -construct_array_type_matcher!(extra_types::PathArray); -construct_array_type_matcher!(extra_types::LineArray); -construct_array_type_matcher!(extra_types::LsegArray); -construct_array_type_matcher!(extra_types::CircleArray); -construct_array_type_matcher!(extra_types::IntervalArray); +construct_array_type_converter!(extra_types::BoolArray); +construct_array_type_converter!(extra_types::UUIDArray); +construct_array_type_converter!(extra_types::VarCharArray); +construct_array_type_converter!(extra_types::TextArray); +construct_array_type_converter!(extra_types::Int16Array); +construct_array_type_converter!(extra_types::Int32Array); +construct_array_type_converter!(extra_types::Int64Array); +construct_array_type_converter!(extra_types::Float32Array); +construct_array_type_converter!(extra_types::Float64Array); +construct_array_type_converter!(extra_types::MoneyArray); +construct_array_type_converter!(extra_types::IpAddressArray); +construct_array_type_converter!(extra_types::JSONBArray); +construct_array_type_converter!(extra_types::JSONArray); +construct_array_type_converter!(extra_types::DateArray); +construct_array_type_converter!(extra_types::TimeArray); +construct_array_type_converter!(extra_types::DateTimeArray); +construct_array_type_converter!(extra_types::DateTimeTZArray); +construct_array_type_converter!(extra_types::MacAddr6Array); +construct_array_type_converter!(extra_types::MacAddr8Array); +construct_array_type_converter!(extra_types::NumericArray); +construct_array_type_converter!(extra_types::PointArray); +construct_array_type_converter!(extra_types::BoxArray); +construct_array_type_converter!(extra_types::PathArray); +construct_array_type_converter!(extra_types::LineArray); +construct_array_type_converter!(extra_types::LsegArray); +construct_array_type_converter!(extra_types::CircleArray); +construct_array_type_converter!(extra_types::IntervalArray); diff --git a/src/value_converter/dto/enums.rs b/src/value_converter/dto/enums.rs index 00e88a10..a90f1527 100644 --- a/src/value_converter/dto/enums.rs +++ b/src/value_converter/dto/enums.rs @@ -2,6 +2,7 @@ use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; use geo_types::{Line as LineSegment, LineString, Point, Rect}; use macaddr::{MacAddr6, MacAddr8}; use pg_interval::Interval; +use postgres_types::Type; use rust_decimal::Decimal; use serde_json::Value; use std::{fmt::Debug, net::IpAddr}; @@ -34,9 +35,9 @@ pub enum PythonDTO { PyDateTimeTz(DateTime), PyInterval(Interval), PyIpAddress(IpAddr), - PyList(Vec), - PyArray(Array), - PyTuple(Vec), + PyList(Vec, Type), + PyArray(Array, Type), + PyTuple(Vec, Type), PyJsonb(Value), PyJson(Value), PyMacAddr6(MacAddr6), diff --git a/src/value_converter/dto/funcs.rs b/src/value_converter/dto/funcs.rs new file mode 100644 index 00000000..116db7d0 --- /dev/null +++ b/src/value_converter/dto/funcs.rs @@ -0,0 +1,33 @@ +use postgres_types::Type; + +pub fn array_type_to_single_type(array_type: &Type) -> Type { + match *array_type { + Type::BOOL_ARRAY => Type::BOOL, + Type::UUID_ARRAY => Type::UUID_ARRAY, + Type::VARCHAR_ARRAY => Type::VARCHAR, + Type::TEXT_ARRAY => Type::TEXT, + Type::INT2_ARRAY => Type::INT2, + Type::INT4_ARRAY => Type::INT4, + Type::INT8_ARRAY => Type::INT8, + Type::FLOAT4_ARRAY => Type::FLOAT4, + Type::FLOAT8_ARRAY => Type::FLOAT8, + Type::MONEY_ARRAY => Type::MONEY, + Type::INET_ARRAY => Type::INET, + Type::JSON_ARRAY => Type::JSON, + Type::JSONB_ARRAY => Type::JSONB, + Type::DATE_ARRAY => Type::DATE, + Type::TIME_ARRAY => Type::TIME, + Type::TIMESTAMP_ARRAY => Type::TIMESTAMP, + Type::TIMESTAMPTZ_ARRAY => Type::TIMESTAMPTZ, + Type::INTERVAL_ARRAY => Type::INTERVAL, + Type::MACADDR_ARRAY => Type::MACADDR, + Type::MACADDR8_ARRAY => Type::MACADDR8, + Type::POINT_ARRAY => Type::POINT, + Type::BOX_ARRAY => Type::BOX, + Type::PATH_ARRAY => Type::PATH, + Type::LINE_ARRAY => Type::LINE, + Type::LSEG_ARRAY => Type::LSEG, + Type::CIRCLE_ARRAY => Type::CIRCLE, + _ => Type::ANY, + } +} diff --git a/src/value_converter/dto/impls.rs b/src/value_converter/dto/impls.rs index 58debfdc..bd48ddb3 100644 --- a/src/value_converter/dto/impls.rs +++ b/src/value_converter/dto/impls.rs @@ -39,7 +39,9 @@ impl<'py> IntoPyObject<'py> for PythonDTO { PythonDTO::PyIntU64(pyint) => Ok(pyint.into_pyobject(py)?.into_any()), PythonDTO::PyFloat32(pyfloat) => Ok(pyfloat.into_pyobject(py)?.into_any()), PythonDTO::PyFloat64(pyfloat) => Ok(pyfloat.into_pyobject(py)?.into_any()), - _ => unreachable!(), + _ => { + unreachable!() + } } } } @@ -108,7 +110,7 @@ impl PythonDTO { PythonDTO::PyIntU64(pyint) => Ok(json!(pyint)), PythonDTO::PyFloat32(pyfloat) => Ok(json!(pyfloat)), PythonDTO::PyFloat64(pyfloat) => Ok(json!(pyfloat)), - PythonDTO::PyList(pylist) => { + PythonDTO::PyList(pylist, _) => { let mut vec_serde_values: Vec = vec![]; for py_object in pylist { @@ -117,7 +119,9 @@ impl PythonDTO { Ok(json!(vec_serde_values)) } - PythonDTO::PyArray(array) => Ok(json!(pythondto_array_to_serde(Some(array.clone()))?)), + PythonDTO::PyArray(array, _) => { + Ok(json!(pythondto_array_to_serde(Some(array.clone()))?)) + } PythonDTO::PyJsonb(py_dict) | PythonDTO::PyJson(py_dict) => Ok(py_dict.clone()), _ => Err(RustPSQLDriverError::PyToRustValueConversionError( "Cannot convert your type into Rust type".into(), @@ -238,30 +242,32 @@ impl ToSql for PythonDTO { PythonDTO::PyCircle(pycircle) => { <&Circle as ToSql>::to_sql(&pycircle, ty, out)?; } - PythonDTO::PyList(py_iterable) | PythonDTO::PyTuple(py_iterable) => { - let mut items = Vec::new(); - for inner in py_iterable { - items.push(inner); - } - if items.is_empty() { - return_is_null_true = true; - } else { - items.to_sql(&items[0].array_type()?, out)?; - } - } - PythonDTO::PyArray(array) => { - if let Some(first_elem) = array.iter().nth(0) { - match first_elem.array_type() { - Ok(ok_type) => { - array.to_sql(&ok_type, out)?; - } - Err(_) => { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "Cannot define array type.".into(), - ))? - } - } - } + PythonDTO::PyList(py_iterable, type_) | PythonDTO::PyTuple(py_iterable, type_) => { + return py_iterable.to_sql(type_, out); + // let mut items = Vec::new(); + // for inner in py_iterable { + // items.push(inner); + // } + // if items.is_empty() { + // return_is_null_true = true; + // } else { + // items.to_sql(&items[0].array_type()?, out)?; + // } + } + PythonDTO::PyArray(array, type_) => { + return array.to_sql(type_, out); + // if let Some(first_elem) = array.iter().nth(0) { + // match first_elem.array_type() { + // Ok(ok_type) => { + // array.to_sql(&ok_type, out)?; + // } + // Err(_) => { + // return Err(RustPSQLDriverError::PyToRustValueConversionError( + // "Cannot define array type.".into(), + // ))? + // } + // } + // } } PythonDTO::PyJsonb(py_dict) | PythonDTO::PyJson(py_dict) => { <&Value as ToSql>::to_sql(&py_dict, ty, out)?; diff --git a/src/value_converter/dto/mod.rs b/src/value_converter/dto/mod.rs index 5be9ae5b..49985cf1 100644 --- a/src/value_converter/dto/mod.rs +++ b/src/value_converter/dto/mod.rs @@ -1,3 +1,4 @@ pub mod converter_impls; pub mod enums; +pub mod funcs; pub mod impls; diff --git a/src/value_converter/from_python.rs b/src/value_converter/from_python.rs index b104c993..57307f29 100644 --- a/src/value_converter/from_python.rs +++ b/src/value_converter/from_python.rs @@ -2,17 +2,14 @@ use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, T use chrono_tz::Tz; use geo_types::{coord, Coord}; use itertools::Itertools; -use pg_interval::Interval; use postgres_array::{Array, Dimension}; -use rust_decimal::Decimal; -use serde_json::{Map, Value}; +use postgres_types::Type; use std::net::IpAddr; -use uuid::Uuid; use pyo3::{ types::{ - PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyDictMethods, PyFloat, - PyInt, PyList, PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyTypeMethods, + PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyInt, PyList, + PySequence, PySet, PyString, PyTime, PyTuple, PyTypeMethods, }, Bound, Py, PyAny, Python, }; @@ -20,13 +17,13 @@ use pyo3::{ use crate::{ exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, extra_types::{self}, - value_converter::{ - consts::KWARGS_QUERYSTRINGS, dto::enums::PythonDTO, - utils::extract_value_from_python_object_or_raise, - }, + value_converter::{dto::enums::PythonDTO, utils::extract_value_from_python_object_or_raise}, }; -use super::{additional_types::NonePyType, consts::KWARGS_PARAMS_REGEXP, traits::ToPythonDTO}; +use super::{ + additional_types::NonePyType, + traits::{ToPythonDTO, ToPythonDTOArray}, +}; /// Convert single python parameter to `PythonDTO` enum. /// @@ -35,17 +32,11 @@ use super::{additional_types::NonePyType, consts::KWARGS_PARAMS_REGEXP, traits:: /// May return Err Result if python type doesn't have support yet /// or value of the type is incorrect. #[allow(clippy::too_many_lines)] -pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { +pub fn from_python_untyped(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { if parameter.is_none() { return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyCustomType( - parameter.extract::()?.inner(), - )); - } - if parameter.is_instance_of::() { return ::to_python_dto(parameter); } @@ -115,7 +106,7 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult } if parameter.is_instance_of::() | parameter.is_instance_of::() { - return ::to_python_dto(parameter); + return ::to_python_dto(parameter, Type::ANY); } if parameter.is_instance_of::() { @@ -124,16 +115,10 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult if parameter.is_instance_of::() { return ::to_python_dto(parameter); - // return Ok(PythonDTO::PyJsonb( - // parameter.extract::()?.inner(), - // )); } if parameter.is_instance_of::() { return ::to_python_dto(parameter); - // return Ok(PythonDTO::PyJson( - // parameter.extract::()?.inner(), - // )); } if parameter.is_instance_of::() { @@ -178,6 +163,162 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult return ::to_python_dto(parameter); } + if let Ok(converted_array) = from_python_array_typed(parameter) { + return Ok(converted_array); + } + + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + + if parameter.extract::().is_ok() { + return ::to_python_dto(parameter); + } + + if parameter.getattr("value").is_ok() { + return ::to_python_dto(parameter); + } + + Err(RustPSQLDriverError::PyToRustValueConversionError(format!( + "Can not covert you type {parameter} into inner one", + ))) +} + +/// Convert single python parameter to `PythonDTO` enum. +/// +/// # Errors +/// +/// May return Err Result if python type doesn't have support yet +/// or value of the type is incorrect. +#[allow(clippy::too_many_lines)] +pub fn from_python_typed( + parameter: &pyo3::Bound<'_, PyAny>, + type_: &Type, +) -> PSQLPyResult { + println!("{:?} {:?}", type_, parameter); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + + if parameter.is_none() { + return ::to_python_dto(parameter); + } + + if parameter.get_type().name()? == "UUID" { + return ::to_python_dto(parameter); + } + + if parameter.get_type().name()? == "decimal.Decimal" + || parameter.get_type().name()? == "Decimal" + { + return ::to_python_dto(parameter); + } + + if parameter.is_instance_of::() | parameter.is_instance_of::() { + return ::to_python_dto( + parameter, + type_.clone(), + ); + } + + if let Ok(converted_array) = from_python_array_typed(parameter) { + return Ok(converted_array); + } + + match *type_ { + Type::BYTEA => return as ToPythonDTO>::to_python_dto(parameter), + Type::TEXT => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::VARCHAR => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::XML => return ::to_python_dto(parameter), + Type::BOOL => return ::to_python_dto(parameter), + Type::INT2 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::INT4 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::INT8 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::MONEY => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::FLOAT4 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::FLOAT8 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::INET => return ::to_python_dto(parameter), + Type::DATE => return ::to_python_dto(parameter), + Type::TIME => return ::to_python_dto(parameter), + Type::TIMESTAMP | Type::TIMESTAMPTZ => { + return ::to_python_dto(parameter) + } + Type::INTERVAL => return ::to_python_dto(parameter), + Type::JSONB => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + + return ::to_python_dto(parameter); + } + Type::JSON => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + + return ::to_python_dto(parameter); + } + Type::MACADDR => return ::to_python_dto(parameter), + Type::MACADDR8 => return ::to_python_dto(parameter), + Type::POINT => return ::to_python_dto(parameter), + Type::BOX => return ::to_python_dto(parameter), + Type::PATH => return ::to_python_dto(parameter), + Type::LINE => return ::to_python_dto(parameter), + Type::LSEG => return ::to_python_dto(parameter), + Type::CIRCLE => return ::to_python_dto(parameter), + _ => {} + } + + if let Ok(converted_value) = from_python_untyped(parameter) { + return Ok(converted_value); + } + + Err(RustPSQLDriverError::PyToRustValueConversionError(format!( + "Can not covert you type {parameter} into {type_}", + ))) +} + +fn from_python_array_typed(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { if parameter.is_instance_of::() { return ::to_python_dto(parameter); } @@ -286,25 +427,8 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return ::to_python_dto(parameter); - } - - if let Ok(_) = parameter.extract::() { - return ::to_python_dto(parameter); - } - - // It's used for Enum. - // If StrEnum is used on Python side, - // we simply stop at the `is_instance_of::``. - if let Ok(value_attr) = parameter.getattr("value") { - if let Ok(possible_string) = value_attr.extract::() { - return Ok(PythonDTO::PyString(possible_string)); - } - } - Err(RustPSQLDriverError::PyToRustValueConversionError(format!( - "Can not covert you type {parameter} into inner one", + "Cannot convert parameter in extra types Array", ))) } @@ -368,7 +492,10 @@ pub fn extract_datetime_from_python_object_attrs( /// May return Err Result if cannot convert at least one element. #[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_wrap)] -pub fn py_sequence_into_postgres_array(parameter: &Bound) -> PSQLPyResult> { +pub fn py_sequence_into_postgres_array( + parameter: &Bound, + type_: &Type, +) -> PSQLPyResult> { let mut py_seq = parameter .downcast::() .map_err(|_| { @@ -413,7 +540,7 @@ pub fn py_sequence_into_postgres_array(parameter: &Bound) -> PSQLPyResult } } - let array_data = py_sequence_into_flat_vec(parameter)?; + let array_data = py_sequence_into_flat_vec(parameter, type_)?; match postgres_array::Array::from_parts_no_panic(array_data, dimensions) { Ok(result_array) => Ok(result_array), Err(err) => Err(RustPSQLDriverError::PyToRustValueConversionError(format!( @@ -426,7 +553,10 @@ pub fn py_sequence_into_postgres_array(parameter: &Bound) -> PSQLPyResult /// /// # Errors /// May return Err Result if cannot convert element into Rust one. -pub fn py_sequence_into_flat_vec(parameter: &Bound) -> PSQLPyResult> { +pub fn py_sequence_into_flat_vec( + parameter: &Bound, + type_: &Type, +) -> PSQLPyResult> { let py_seq = parameter.downcast::().map_err(|_| { RustPSQLDriverError::PyToRustValueConversionError( "PostgreSQL ARRAY type can be made only from python Sequence".into(), @@ -441,17 +571,17 @@ pub fn py_sequence_into_flat_vec(parameter: &Bound) -> PSQLPyResult() { - final_vec.push(py_to_rust(&ok_seq_elem)?); + final_vec.push(from_python_typed(&ok_seq_elem, type_)?); continue; } let possible_next_seq = ok_seq_elem.downcast::(); if let Ok(next_seq) = possible_next_seq { - let mut next_vec = py_sequence_into_flat_vec(next_seq)?; + let mut next_vec = py_sequence_into_flat_vec(next_seq, type_)?; final_vec.append(&mut next_vec); } else { - final_vec.push(py_to_rust(&ok_seq_elem)?); + final_vec.push(from_python_typed(&ok_seq_elem, type_)?); continue; } } @@ -481,7 +611,7 @@ fn convert_py_to_rust_coord_values(parameters: Vec>) -> PSQLPyResult coord_values_vec.push(f64::from(pyint)), PythonDTO::PyIntI32(pyint) => coord_values_vec.push(f64::from(pyint)), @@ -658,35 +788,3 @@ fn py_sequence_to_rust(bind_parameters: &Bound) -> PSQLPyResult>, RustPSQLDriverError>(coord_values_sequence_vec) } - -fn parse_kwargs_qs(querystring: &str) -> (String, Vec) { - let re = regex::Regex::new(KWARGS_PARAMS_REGEXP).unwrap(); - - { - let kq_read = KWARGS_QUERYSTRINGS.read().unwrap(); - let qs = kq_read.get(querystring); - - if let Some(qs) = qs { - return qs.clone(); - } - }; - - let mut counter = 0; - let mut sequence = Vec::new(); - - let result = re.replace_all(querystring, |caps: ®ex::Captures| { - let account_id = caps[1].to_string(); - - sequence.push(account_id.clone()); - counter += 1; - - format!("${}", &counter) - }); - - let mut kq_write = KWARGS_QUERYSTRINGS.write().unwrap(); - kq_write.insert( - querystring.to_string(), - (result.clone().into(), sequence.clone()), - ); - (result.into(), sequence) -} diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index 71239c2b..392e3fd0 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -11,7 +11,8 @@ use tokio_postgres::types::Type; use crate::{ exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, value_converter::{ - dto::enums::PythonDTO, from_python::py_to_rust, to_python::build_python_from_serde_value, + dto::enums::PythonDTO, from_python::from_python_untyped, + to_python::build_python_from_serde_value, }, }; @@ -60,7 +61,7 @@ fn serde_value_from_list(gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQL for inner in params { let inner_bind = inner.bind(gil); if inner_bind.is_instance_of::() { - let python_dto = py_to_rust(inner_bind)?; + let python_dto = from_python_untyped(inner_bind)?; result_vec.push(python_dto.to_serde_value()?); } else if inner_bind.is_instance_of::() { let serde_value = build_serde_value(inner.bind(gil))?; @@ -91,7 +92,7 @@ fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { })?; let key = py_list.get_item(0)?.extract::()?; - let value = py_to_rust(&py_list.get_item(1)?)?; + let value = from_python_untyped(&py_list.get_item(1)?)?; serde_map.insert(key, value.to_serde_value()?); } diff --git a/src/value_converter/params_converters.rs b/src/value_converter/params_converters.rs deleted file mode 100644 index e69de29b..00000000 diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index 5dbfd7ce..b3bf2af5 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -73,38 +73,6 @@ pub fn build_python_from_serde_value(py: Python<'_>, value: Value) -> PSQLPyResu } } -fn parse_kwargs_qs(querystring: &str) -> (String, Vec) { - let re = regex::Regex::new(r"\$\(([^)]+)\)p").unwrap(); - - { - let kq_read = KWARGS_QUERYSTRINGS.read().unwrap(); - let qs = kq_read.get(querystring); - - if let Some(qs) = qs { - return qs.clone(); - } - }; - - let mut counter = 0; - let mut sequence = Vec::new(); - - let result = re.replace_all(querystring, |caps: ®ex::Captures| { - let account_id = caps[1].to_string(); - - sequence.push(account_id.clone()); - counter += 1; - - format!("${}", &counter) - }); - - let mut kq_write = KWARGS_QUERYSTRINGS.write().unwrap(); - kq_write.insert( - querystring.to_string(), - (result.clone().into(), sequence.clone()), - ); - (result.into(), sequence) -} - fn composite_field_postgres_to_py<'a, T: FromSql<'a>>( type_: &Type, buf: &mut &'a [u8], @@ -668,41 +636,3 @@ pub fn postgres_to_py( } Ok(py.None()) } - -/// Convert Python sequence to Rust vector. -/// Also it checks that sequence has set/list/tuple type. -/// -/// # Errors -/// -/// May return error if cannot convert Python type into Rust one. -/// May return error if parameters type isn't correct. -fn py_sequence_to_rust(bind_parameters: &Bound) -> PSQLPyResult>> { - let mut coord_values_sequence_vec: Vec> = vec![]; - - if bind_parameters.is_instance_of::() { - let bind_pyset_parameters = bind_parameters.downcast::().unwrap(); - - for one_parameter in bind_pyset_parameters { - let extracted_parameter = one_parameter.extract::>().map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") - ) - })?; - coord_values_sequence_vec.push(extracted_parameter); - } - } else if bind_parameters.is_instance_of::() - | bind_parameters.is_instance_of::() - { - coord_values_sequence_vec = bind_parameters.extract::>>().map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") - ) - })?; - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError(format!( - "Invalid sequence type, please use list/tuple/set, {bind_parameters}" - ))); - }; - - Ok::>, RustPSQLDriverError>(coord_values_sequence_vec) -} diff --git a/src/value_converter/traits.rs b/src/value_converter/traits.rs index 261ee16d..d9d3512e 100644 --- a/src/value_converter/traits.rs +++ b/src/value_converter/traits.rs @@ -1,3 +1,4 @@ +use postgres_types::Type; use pyo3::PyAny; use crate::exceptions::rust_errors::PSQLPyResult; @@ -7,3 +8,10 @@ use super::dto::enums::PythonDTO; pub trait ToPythonDTO { fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult; } + +pub trait ToPythonDTOArray { + fn to_python_dto( + python_param: &pyo3::Bound<'_, PyAny>, + array_type_: Type, + ) -> PSQLPyResult; +} From 602f00e727b2ef2cb272cff9eb75b0126ca89ffc Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 4 May 2025 18:44:55 +0200 Subject: [PATCH 04/15] Full value converter refactor --- python/tests/test_value_converter.py | 30 +++- src/driver/connection.rs | 23 ++- src/driver/connection_pool.rs | 38 ++++- src/driver/connection_pool_builder.rs | 12 ++ src/driver/inner_connection.rs | 158 ++++++++++++--------- src/driver/listener/core.rs | 10 +- src/statement/cache.rs | 2 +- src/statement/mod.rs | 1 - src/statement/query.rs | 2 +- src/statement/statement.rs | 20 ++- src/statement/statement_builder.rs | 71 +++++---- src/statement/traits.rs | 8 -- src/statement/utils.rs | 7 + src/value_converter/dto/converter_impls.rs | 1 - src/value_converter/dto/impls.rs | 21 --- src/value_converter/from_python.rs | 1 - src/value_converter/models/decimal.rs | 2 +- src/value_converter/to_python.rs | 38 +++++ 18 files changed, 301 insertions(+), 144 deletions(-) delete mode 100644 src/statement/traits.rs diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index c35baec1..b0ec5c8d 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -139,13 +139,18 @@ async def test_as_class( ), ("BOOL", True, True), ("INT2", SmallInt(12), 12), + ("INT2", 12, 12), ("INT4", Integer(121231231), 121231231), + ("INT4", 121231231, 121231231), ("INT8", BigInt(99999999999999999), 99999999999999999), + ("INT8", 99999999999999999, 99999999999999999), ("MONEY", Money(99999999999999999), 99999999999999999), + ("MONEY", 99999999999999999, 99999999999999999), ("NUMERIC(5, 2)", Decimal("120.12"), Decimal("120.12")), - ("FLOAT8", 32.12329864501953, 32.12329864501953), ("FLOAT4", Float32(32.12329864501953), 32.12329864501953), + ("FLOAT4", 32.12329864501953, 32.12329864501953), ("FLOAT8", Float64(32.12329864501953), 32.12329864501953), + ("FLOAT8", 32.12329864501953, 32.12329864501953), ("DATE", now_datetime.date(), now_datetime.date()), ("TIME", now_datetime.time(), now_datetime.time()), ("TIMESTAMP", now_datetime, now_datetime), @@ -426,6 +431,29 @@ async def test_as_class( [[{"array": "json"}], [{"one more": "test"}]], ], ), + ( + "JSON ARRAY", + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), ( "JSON ARRAY", [ diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 8f2a4b40..469ece0b 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -25,6 +25,7 @@ pub struct Connection { db_client: Option>, db_pool: Option, pg_config: Arc, + prepare: bool, } impl Connection { @@ -33,11 +34,13 @@ impl Connection { db_client: Option>, db_pool: Option, pg_config: Arc, + prepare: bool, ) -> Self { Connection { db_client, db_pool, pg_config, + prepare, } } @@ -54,7 +57,7 @@ impl Connection { impl Default for Connection { fn default() -> Self { - Connection::new(None, None, Arc::new(Config::default())) + Connection::new(None, None, Arc::new(Config::default()), true) } } @@ -138,11 +141,16 @@ impl Connection { } async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { - let (db_client, db_pool) = pyo3::Python::with_gil(|gil| { + let (db_client, db_pool, prepare) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); - (self_.db_client.clone(), self_.db_pool.clone()) + ( + self_.db_client.clone(), + self_.db_pool.clone(), + self_.prepare, + ) }); + let db_pool_2 = db_pool.clone(); if db_client.is_some() { return Ok(self_); } @@ -155,7 +163,11 @@ impl Connection { .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))); + self_.db_client = Some(Arc::new(PsqlpyConnection::PoolConn( + db_connection, + db_pool_2.unwrap(), + prepare, + ))); }); return Ok(self_); } @@ -209,7 +221,8 @@ impl Connection { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return db_client.execute(querystring, parameters, prepared).await; + let res = db_client.execute(querystring, parameters, prepared).await; + return res; } Err(RustPSQLDriverError::ConnectionClosedError) diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 0c52c256..1ef2d8f9 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,5 +1,6 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; +use postgres_types::Type; use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; use std::sync::Arc; use tokio_postgres::Config; @@ -46,6 +47,7 @@ use super::{ ca_file=None, max_db_pool_size=None, conn_recycling_method=None, + prepare=None, ))] #[allow(clippy::too_many_arguments)] pub fn connect( @@ -75,6 +77,7 @@ pub fn connect( ca_file: Option, max_db_pool_size: Option, conn_recycling_method: Option, + prepare: Option, ) -> PSQLPyResult { if let Some(max_db_pool_size) = max_db_pool_size { if max_db_pool_size < 2 { @@ -139,6 +142,7 @@ pub fn connect( pg_config: Arc::new(pg_config), ca_file: ca_file, ssl_mode: ssl_mode, + prepare: prepare.unwrap_or(true), }) } @@ -207,6 +211,7 @@ pub struct ConnectionPool { pg_config: Arc, ca_file: Option, ssl_mode: Option, + prepare: bool, } impl ConnectionPool { @@ -216,14 +221,20 @@ impl ConnectionPool { pg_config: Config, ca_file: Option, ssl_mode: Option, + prepare: Option, ) -> Self { ConnectionPool { pool: pool, pg_config: Arc::new(pg_config), ca_file: ca_file, ssl_mode: ssl_mode, + prepare: prepare.unwrap_or(true), } } + + pub fn remove_prepared_stmt(&mut self, query: &str, types: &[Type]) { + self.pool.manager().statement_caches.remove(query, types); + } } #[pymethods] @@ -260,6 +271,7 @@ impl ConnectionPool { conn_recycling_method=None, ssl_mode=None, ca_file=None, + prepare=None, ))] #[allow(clippy::too_many_arguments)] pub fn new( @@ -289,6 +301,7 @@ impl ConnectionPool { conn_recycling_method: Option, ssl_mode: Option, ca_file: Option, + prepare: Option, ) -> PSQLPyResult { connect( dsn, @@ -317,6 +330,7 @@ impl ConnectionPool { ca_file, max_db_pool_size, conn_recycling_method, + prepare, ) } @@ -360,22 +374,28 @@ impl ConnectionPool { #[must_use] pub fn acquire(&self) -> Connection { - Connection::new(None, Some(self.pool.clone()), self.pg_config.clone()) + Connection::new( + None, + Some(self.pool.clone()), + self.pg_config.clone(), + self.prepare, + ) } #[must_use] #[allow(clippy::needless_pass_by_value)] pub fn listener(self_: pyo3::Py) -> Listener { - let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| { + let (pg_config, ca_file, ssl_mode, prepare) = pyo3::Python::with_gil(|gil| { let b_gil = self_.borrow(gil); ( b_gil.pg_config.clone(), b_gil.ca_file.clone(), b_gil.ssl_mode, + b_gil.prepare, ) }); - Listener::new(pg_config, ca_file, ssl_mode) + Listener::new(pg_config, ca_file, ssl_mode, prepare) } /// Return new single connection. @@ -383,10 +403,11 @@ impl ConnectionPool { /// # Errors /// May return Err Result if cannot get new connection from the pool. pub async fn connection(self_: pyo3::Py) -> PSQLPyResult { - let (db_pool, pg_config) = pyo3::Python::with_gil(|gil| { + let (db_pool, pg_config, prepare) = pyo3::Python::with_gil(|gil| { let slf = self_.borrow(gil); - (slf.pool.clone(), slf.pg_config.clone()) + (slf.pool.clone(), slf.pg_config.clone(), slf.prepare) }); + let db_pool_2 = db_pool.clone(); let db_connection = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) @@ -394,9 +415,14 @@ impl ConnectionPool { .await??; Ok(Connection::new( - Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))), + Some(Arc::new(PsqlpyConnection::PoolConn( + db_connection, + db_pool_2.clone(), + prepare, + ))), None, pg_config, + prepare, )) } diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index 42cdd641..ea311642 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -18,6 +18,7 @@ pub struct ConnectionPoolBuilder { conn_recycling_method: Option, ca_file: Option, ssl_mode: Option, + prepare: Option, } #[pymethods] @@ -31,6 +32,7 @@ impl ConnectionPoolBuilder { conn_recycling_method: None, ca_file: None, ssl_mode: None, + prepare: None, } } @@ -68,6 +70,7 @@ impl ConnectionPoolBuilder { self.config.clone(), self.ca_file.clone(), self.ssl_mode, + self.prepare, )) } @@ -80,6 +83,15 @@ impl ConnectionPoolBuilder { self_ } + /// Set ca_file for ssl_mode in PostgreSQL. + fn prepare(self_: Py, prepare: bool) -> Py { + Python::with_gil(|gil| { + let mut self_ = self_.borrow_mut(gil); + self_.prepare = Some(prepare); + }); + self_ + } + /// Set size to the connection pool. /// /// # Error diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index a7e9d233..c463be64 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -1,5 +1,5 @@ use bytes::Buf; -use deadpool_postgres::Object; +use deadpool_postgres::{Object, Pool}; use postgres_types::{ToSql, Type}; use pyo3::{Py, PyAny, Python}; use std::vec; @@ -14,7 +14,7 @@ use crate::{ #[allow(clippy::module_name_repetitions)] pub enum PsqlpyConnection { - PoolConn(Object), + PoolConn(Object, Pool, bool), SingleConn(Client), } @@ -23,9 +23,18 @@ impl PsqlpyConnection { /// /// # Errors /// May return Err if cannot prepare statement. - pub async fn prepare(&self, query: &str) -> PSQLPyResult { + pub async fn prepare(&self, query: &str, prepared: bool) -> PSQLPyResult { match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.prepare_cached(query).await?), + PsqlpyConnection::PoolConn(pconn, _, _) => { + if prepared { + return Ok(pconn.prepare_cached(query).await?); + } else { + println!("999999"); + let prepared = pconn.prepare(query).await?; + self.drop_prepared(&prepared).await?; + return Ok(prepared); + } + } PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.prepare(query).await?), } } @@ -35,28 +44,18 @@ impl PsqlpyConnection { /// # Errors /// May return Err if cannot prepare statement. pub async fn drop_prepared(&self, stmt: &Statement) -> PSQLPyResult<()> { - let query = format!("DEALLOCATE PREPARE {}", stmt.name()); + let deallocate_query = format!("DEALLOCATE PREPARE {}", stmt.name()); match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(&query).await?), - PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(&query).await?), + PsqlpyConnection::PoolConn(pconn, _, _) => { + return Ok(pconn.batch_execute(&deallocate_query).await?) + } + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.batch_execute(&deallocate_query).await?) + } } } - /// Prepare and delete statement. - /// - /// # Errors - /// Can return Err if cannot prepare statement. - pub async fn prepare_then_drop(&self, query: &str) -> PSQLPyResult> { - let types: Vec; - - let stmt = self.prepare(query).await?; - types = stmt.params().to_vec(); - self.drop_prepared(&stmt).await?; - - Ok(types) - } - - /// Prepare cached statement. + /// Execute statement with parameters. /// /// # Errors /// May return Err if cannot execute statement. @@ -69,20 +68,43 @@ impl PsqlpyConnection { T: ?Sized + ToStatement, { match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.query(statement, params).await?), + PsqlpyConnection::PoolConn(pconn, _, _) => { + return Ok(pconn.query(statement, params).await?) + } PsqlpyConnection::SingleConn(sconn) => { return Ok(sconn.query(statement, params).await?) } } } - /// Prepare cached statement. + /// Execute statement with parameters. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> PSQLPyResult> { + match self { + PsqlpyConnection::PoolConn(pconn, _, _) => { + return Ok(pconn.query_typed(statement, params).await?) + } + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.query_typed(statement, params).await?) + } + } + } + + /// Batch execute statement. /// /// # Errors /// May return Err if cannot execute statement. pub async fn batch_execute(&self, query: &str) -> PSQLPyResult<()> { match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(query).await?), + PsqlpyConnection::PoolConn(pconn, _, _) => { + return Ok(pconn.batch_execute(query).await?) + } PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), } } @@ -90,6 +112,21 @@ impl PsqlpyConnection { /// Prepare cached statement. /// /// # Errors + /// May return Err if cannot execute copy data. + pub async fn copy_in(&self, statement: &T) -> PSQLPyResult> + where + T: ?Sized + ToStatement, + U: Buf + 'static + Send, + { + match self { + PsqlpyConnection::PoolConn(pconn, _, _) => return Ok(pconn.copy_in(statement).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), + } + } + + /// Executes a statement which returns a single row, returning it. + /// + /// # Errors /// May return Err if cannot execute statement. pub async fn query_one( &self, @@ -100,7 +137,7 @@ impl PsqlpyConnection { T: ?Sized + ToStatement, { match self { - PsqlpyConnection::PoolConn(pconn) => { + PsqlpyConnection::PoolConn(pconn, _, _) => { return Ok(pconn.query_one(statement, params).await?) } PsqlpyConnection::SingleConn(sconn) => { @@ -123,17 +160,20 @@ impl PsqlpyConnection { let result = if prepared { self.query( - &self.prepare(&statement.sql_stmt()).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, + &self + .prepare(&statement.raw_query(), true) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement, error - {err}" + )) + })?, &statement.params(), ) .await .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? } else { - self.query(statement.sql_stmt(), &statement.params()) + self.query(statement.raw_query(), &statement.params()) .await .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? }; @@ -153,21 +193,19 @@ impl PsqlpyConnection { let prepared = prepared.unwrap_or(true); - let result = if prepared { - self.query( - &self.prepare(statement.sql_stmt()).await.map_err(|err| { + let result = match prepared { + true => self + .query(statement.statement_query()?, &statement.params()) + .await + .map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement, error - {err}" )) })?, - &statement.params(), - ) - .await - .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? - } else { - self.query(statement.sql_stmt(), &statement.params()) + false => self + .query_typed(statement.raw_query(), &statement.params_typed()) .await - .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))?, }; Ok(PSQLDriverPyQueryResult::new(result)) @@ -196,19 +234,19 @@ impl PsqlpyConnection { for statement in statements { let querystring_result = if prepared { - let prepared_stmt = &self.prepare(&statement.sql_stmt()).await; + let prepared_stmt = &self.prepare(&statement.raw_query(), true).await; if let Err(error) = prepared_stmt { return Err(RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement in execute_many, operation rolled back {error}", ))); } self.query( - &self.prepare(&statement.sql_stmt()).await?, + &self.prepare(&statement.raw_query(), true).await?, &statement.params(), ) .await } else { - self.query(statement.sql_stmt(), &statement.params()).await + self.query(statement.raw_query(), &statement.params()).await }; if let Err(error) = querystring_result { @@ -235,17 +273,20 @@ impl PsqlpyConnection { let result = if prepared { self.query_one( - &self.prepare(&statement.sql_stmt()).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, + &self + .prepare(&statement.raw_query(), true) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement, error - {err}" + )) + })?, &statement.params(), ) .await .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? } else { - self.query_one(statement.sql_stmt(), &statement.params()) + self.query_one(statement.raw_query(), &statement.params()) .await .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? }; @@ -281,19 +322,4 @@ impl PsqlpyConnection { None => Ok(gil.None()), }); } - - /// Prepare cached statement. - /// - /// # Errors - /// May return Err if cannot execute copy data. - pub async fn copy_in(&self, statement: &T) -> PSQLPyResult> - where - T: ?Sized + ToStatement, - U: Buf + 'static + Send, - { - match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.copy_in(statement).await?), - PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), - } - } } diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 16b323d8..4a9580af 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -42,14 +42,19 @@ pub struct Listener { impl Listener { #[must_use] - pub fn new(pg_config: Arc, ca_file: Option, ssl_mode: Option) -> Self { + pub fn new( + pg_config: Arc, + ca_file: Option, + ssl_mode: Option, + prepare: bool, + ) -> Self { Listener { pg_config: pg_config.clone(), ca_file, ssl_mode, channel_callbacks: Arc::default(), listen_abort_handler: Option::default(), - connection: Connection::new(None, None, pg_config.clone()), + connection: Connection::new(None, None, pg_config.clone(), prepare), receiver: Option::default(), listen_query: Arc::default(), is_listened: Arc::new(RwLock::new(false)), @@ -222,6 +227,7 @@ impl Listener { Some(Arc::new(PsqlpyConnection::SingleConn(client))), None, self.pg_config.clone(), + false, ); self.is_started = true; diff --git a/src/statement/cache.rs b/src/statement/cache.rs index a6fbc131..7d78898d 100644 --- a/src/statement/cache.rs +++ b/src/statement/cache.rs @@ -5,7 +5,7 @@ use postgres_types::Type; use tokio::sync::RwLock; use tokio_postgres::Statement; -use super::{query::QueryString, traits::hash_str}; +use super::{query::QueryString, utils::hash_str}; #[derive(Default)] pub(crate) struct StatementsCache(HashMap); diff --git a/src/statement/mod.rs b/src/statement/mod.rs index e027eaea..c894b9a8 100644 --- a/src/statement/mod.rs +++ b/src/statement/mod.rs @@ -3,5 +3,4 @@ pub mod parameters; pub mod query; pub mod statement; pub mod statement_builder; -pub mod traits; pub mod utils; diff --git a/src/statement/query.rs b/src/statement/query.rs index 7f87cede..2b08aa62 100644 --- a/src/statement/query.rs +++ b/src/statement/query.rs @@ -4,7 +4,7 @@ use regex::Regex; use crate::value_converter::consts::KWARGS_PARAMS_REGEXP; -use super::traits::hash_str; +use super::utils::hash_str; #[derive(Clone)] pub struct QueryString { diff --git a/src/statement/statement.rs b/src/statement/statement.rs index 4cfdc09c..a93d9cd5 100644 --- a/src/statement/statement.rs +++ b/src/statement/statement.rs @@ -1,4 +1,7 @@ use postgres_types::{ToSql, Type}; +use tokio_postgres::Statement; + +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::{parameters::PreparedParameters, query::QueryString}; @@ -6,20 +9,33 @@ use super::{parameters::PreparedParameters, query::QueryString}; pub struct PsqlpyStatement { query: QueryString, prepared_parameters: PreparedParameters, + prepared_statement: Option, } impl PsqlpyStatement { - pub(crate) fn new(query: QueryString, prepared_parameters: PreparedParameters) -> Self { + pub(crate) fn new( + query: QueryString, + prepared_parameters: PreparedParameters, + prepared_statement: Option, + ) -> Self { Self { query, prepared_parameters, + prepared_statement, } } - pub fn sql_stmt(&self) -> &str { + pub fn raw_query(&self) -> &str { self.query.query() } + pub fn statement_query(&self) -> PSQLPyResult<&Statement> { + match &self.prepared_statement { + Some(prepared_stmt) => return Ok(prepared_stmt), + None => return Err(RustPSQLDriverError::ConnectionExecuteError("No".into())), + } + } + pub fn params(&self) -> Box<[&(dyn ToSql + Sync)]> { self.prepared_parameters.params() } diff --git a/src/statement/statement_builder.rs b/src/statement/statement_builder.rs index 07e003da..863ba400 100644 --- a/src/statement/statement_builder.rs +++ b/src/statement/statement_builder.rs @@ -1,10 +1,11 @@ use pyo3::PyObject; +use tokio::sync::RwLockWriteGuard; use tokio_postgres::Statement; use crate::{driver::inner_connection::PsqlpyConnection, exceptions::rust_errors::PSQLPyResult}; use super::{ - cache::{StatementCacheInfo, STMTS_CACHE}, + cache::{StatementCacheInfo, StatementsCache, STMTS_CACHE}, parameters::ParametersBuilder, query::QueryString, statement::PsqlpyStatement, @@ -33,14 +34,17 @@ impl<'a> StatementBuilder<'a> { } pub async fn build(self) -> PSQLPyResult { - { - let stmt_cache_guard = STMTS_CACHE.read().await; - if let Some(cached) = stmt_cache_guard.get_cache(&self.querystring) { - return self.build_with_cached(cached); + if !self.prepared { + { + let stmt_cache_guard = STMTS_CACHE.read().await; + if let Some(cached) = stmt_cache_guard.get_cache(&self.querystring) { + return self.build_with_cached(cached); + } } } - self.build_no_cached().await + let stmt_cache_guard = STMTS_CACHE.write().await; + self.build_no_cached(stmt_cache_guard).await } fn build_with_cached(self, cached: StatementCacheInfo) -> PSQLPyResult { @@ -54,21 +58,24 @@ impl<'a> StatementBuilder<'a> { let prepared_parameters = raw_parameters.prepare(parameters_names)?; - return Ok(PsqlpyStatement::new(cached.query, prepared_parameters)); + return Ok(PsqlpyStatement::new( + cached.query, + prepared_parameters, + None, + )); } - async fn build_no_cached(self) -> PSQLPyResult { + async fn build_no_cached( + self, + cache_guard: RwLockWriteGuard<'_, StatementsCache>, + ) -> PSQLPyResult { let mut querystring = QueryString::new(&self.querystring); querystring.process_qs(); - let prepared_stmt = self.prepare_query(&querystring).await?; + let prepared_stmt = self.prepare_query(&querystring, self.prepared).await?; let parameters_builder = ParametersBuilder::new(&self.parameters, Some(prepared_stmt.params().to_vec())); - if !self.prepared { - Self::drop_prepared(self.inner_conn, &prepared_stmt).await?; - } - let parameters_names = if let Some(converted_qs) = &querystring.converted_qs { Some(converted_qs.params_names().clone()) } else { @@ -77,24 +84,34 @@ impl<'a> StatementBuilder<'a> { let prepared_parameters = parameters_builder.prepare(parameters_names)?; - { - self.write_to_cache(&querystring, &prepared_stmt).await; + match self.prepared { + true => { + return Ok(PsqlpyStatement::new( + querystring, + prepared_parameters, + Some(prepared_stmt), + )) + } + false => { + { + self.write_to_cache(cache_guard, &querystring, &prepared_stmt) + .await; + } + return Ok(PsqlpyStatement::new(querystring, prepared_parameters, None)); + } } - let statement = PsqlpyStatement::new(querystring, prepared_parameters); - - return Ok(statement); - } - - async fn write_to_cache(&self, query: &QueryString, inner_stmt: &Statement) { - let mut stmt_cache_guard = STMTS_CACHE.write().await; - stmt_cache_guard.add_cache(query, inner_stmt); } - async fn prepare_query(&self, query: &QueryString) -> PSQLPyResult { - self.inner_conn.prepare(query.query()).await + async fn write_to_cache( + &self, + mut cache_guard: RwLockWriteGuard<'_, StatementsCache>, + query: &QueryString, + inner_stmt: &Statement, + ) { + cache_guard.add_cache(query, inner_stmt); } - async fn drop_prepared(inner_conn: &PsqlpyConnection, stmt: &Statement) -> PSQLPyResult<()> { - inner_conn.drop_prepared(stmt).await + async fn prepare_query(&self, query: &QueryString, prepared: bool) -> PSQLPyResult { + self.inner_conn.prepare(query.query(), prepared).await } } diff --git a/src/statement/traits.rs b/src/statement/traits.rs deleted file mode 100644 index a79f8bdd..00000000 --- a/src/statement/traits.rs +++ /dev/null @@ -1,8 +0,0 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; - -pub(crate) fn hash_str(string: &String) -> u64 { - let mut hasher = DefaultHasher::new(); - string.hash(&mut hasher); - - hasher.finish() -} diff --git a/src/statement/utils.rs b/src/statement/utils.rs index 8b137891..a79f8bdd 100644 --- a/src/statement/utils.rs +++ b/src/statement/utils.rs @@ -1 +1,8 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; +pub(crate) fn hash_str(string: &String) -> u64 { + let mut hasher = DefaultHasher::new(); + string.hash(&mut hasher); + + hasher.finish() +} diff --git a/src/value_converter/dto/converter_impls.rs b/src/value_converter/dto/converter_impls.rs index 1e6fa7be..f50529bc 100644 --- a/src/value_converter/dto/converter_impls.rs +++ b/src/value_converter/dto/converter_impls.rs @@ -8,7 +8,6 @@ use pyo3::{ Bound, PyAny, }; use rust_decimal::Decimal; -use serde::de::IntoDeserializer; use uuid::Uuid; use crate::{ diff --git a/src/value_converter/dto/impls.rs b/src/value_converter/dto/impls.rs index bd48ddb3..3450dfd0 100644 --- a/src/value_converter/dto/impls.rs +++ b/src/value_converter/dto/impls.rs @@ -244,30 +244,9 @@ impl ToSql for PythonDTO { } PythonDTO::PyList(py_iterable, type_) | PythonDTO::PyTuple(py_iterable, type_) => { return py_iterable.to_sql(type_, out); - // let mut items = Vec::new(); - // for inner in py_iterable { - // items.push(inner); - // } - // if items.is_empty() { - // return_is_null_true = true; - // } else { - // items.to_sql(&items[0].array_type()?, out)?; - // } } PythonDTO::PyArray(array, type_) => { return array.to_sql(type_, out); - // if let Some(first_elem) = array.iter().nth(0) { - // match first_elem.array_type() { - // Ok(ok_type) => { - // array.to_sql(&ok_type, out)?; - // } - // Err(_) => { - // return Err(RustPSQLDriverError::PyToRustValueConversionError( - // "Cannot define array type.".into(), - // ))? - // } - // } - // } } PythonDTO::PyJsonb(py_dict) | PythonDTO::PyJson(py_dict) => { <&Value as ToSql>::to_sql(&py_dict, ty, out)?; diff --git a/src/value_converter/from_python.rs b/src/value_converter/from_python.rs index 57307f29..fa1d5c60 100644 --- a/src/value_converter/from_python.rs +++ b/src/value_converter/from_python.rs @@ -195,7 +195,6 @@ pub fn from_python_typed( parameter: &pyo3::Bound<'_, PyAny>, type_: &Type, ) -> PSQLPyResult { - println!("{:?} {:?}", type_, parameter); if parameter.is_instance_of::() { return ::to_python_dto(parameter); } diff --git a/src/value_converter/models/decimal.rs b/src/value_converter/models/decimal.rs index 13d009cc..44a898a1 100644 --- a/src/value_converter/models/decimal.rs +++ b/src/value_converter/models/decimal.rs @@ -1,5 +1,5 @@ use postgres_types::{FromSql, Type}; -use pyo3::{types::PyAnyMethods, PyObject, Python, ToPyObject}; +use pyo3::{types::PyAnyMethods, Bound, IntoPyObject, PyAny, PyObject, Python, ToPyObject}; use rust_decimal::Decimal; use crate::value_converter::consts::get_decimal_cls; diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index b3bf2af5..047cd4c4 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -636,3 +636,41 @@ pub fn postgres_to_py( } Ok(py.None()) } + +/// Convert Python sequence to Rust vector. +/// Also it checks that sequence has set/list/tuple type. +/// +/// # Errors +/// +/// May return error if cannot convert Python type into Rust one. +/// May return error if parameters type isn't correct. +fn py_sequence_to_rust(bind_parameters: &Bound) -> PSQLPyResult>> { + let mut coord_values_sequence_vec: Vec> = vec![]; + + if bind_parameters.is_instance_of::() { + let bind_pyset_parameters = bind_parameters.downcast::().unwrap(); + + for one_parameter in bind_pyset_parameters { + let extracted_parameter = one_parameter.extract::>().map_err(|_| { + RustPSQLDriverError::PyToRustValueConversionError( + format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") + ) + })?; + coord_values_sequence_vec.push(extracted_parameter); + } + } else if bind_parameters.is_instance_of::() + | bind_parameters.is_instance_of::() + { + coord_values_sequence_vec = bind_parameters.extract::>>().map_err(|_| { + RustPSQLDriverError::PyToRustValueConversionError( + format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") + ) + })?; + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError(format!( + "Invalid sequence type, please use list/tuple/set, {bind_parameters}" + ))); + }; + + Ok::>, RustPSQLDriverError>(coord_values_sequence_vec) +} From 19fe58d52ffa03761e01fd207b4a2790d9e7f607 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 4 May 2025 20:36:53 +0200 Subject: [PATCH 05/15] Full value converter refactor --- src/driver/connection.rs | 8 ++------ src/driver/connection_pool.rs | 7 +------ src/driver/inner_connection.rs | 32 ++++++++++++++++++------------ src/statement/query.rs | 4 ++-- src/statement/statement.rs | 2 +- src/statement/statement_builder.rs | 6 ++---- src/value_converter/to_python.rs | 1 - 7 files changed, 27 insertions(+), 33 deletions(-) diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 469ece0b..d38b71f9 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -150,7 +150,6 @@ impl Connection { ) }); - let db_pool_2 = db_pool.clone(); if db_client.is_some() { return Ok(self_); } @@ -163,11 +162,8 @@ impl Connection { .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(PsqlpyConnection::PoolConn( - db_connection, - db_pool_2.unwrap(), - prepare, - ))); + self_.db_client = + Some(Arc::new(PsqlpyConnection::PoolConn(db_connection, prepare))); }); return Ok(self_); } diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 1ef2d8f9..16454de0 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -407,7 +407,6 @@ impl ConnectionPool { let slf = self_.borrow(gil); (slf.pool.clone(), slf.pg_config.clone(), slf.prepare) }); - let db_pool_2 = db_pool.clone(); let db_connection = tokio_runtime() .spawn(async move { Ok::(db_pool.get().await?) @@ -415,11 +414,7 @@ impl ConnectionPool { .await??; Ok(Connection::new( - Some(Arc::new(PsqlpyConnection::PoolConn( - db_connection, - db_pool_2.clone(), - prepare, - ))), + Some(Arc::new(PsqlpyConnection::PoolConn(db_connection, prepare))), None, pg_config, prepare, diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index c463be64..797c9749 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -14,7 +14,7 @@ use crate::{ #[allow(clippy::module_name_repetitions)] pub enum PsqlpyConnection { - PoolConn(Object, Pool, bool), + PoolConn(Object, bool), SingleConn(Client), } @@ -25,13 +25,14 @@ impl PsqlpyConnection { /// May return Err if cannot prepare statement. pub async fn prepare(&self, query: &str, prepared: bool) -> PSQLPyResult { match self { - PsqlpyConnection::PoolConn(pconn, _, _) => { + PsqlpyConnection::PoolConn(pconn, _) => { if prepared { return Ok(pconn.prepare_cached(query).await?); } else { - println!("999999"); + pconn.batch_execute("BEGIN").await?; let prepared = pconn.prepare(query).await?; self.drop_prepared(&prepared).await?; + pconn.batch_execute("COMMIT").await?; return Ok(prepared); } } @@ -46,8 +47,9 @@ impl PsqlpyConnection { pub async fn drop_prepared(&self, stmt: &Statement) -> PSQLPyResult<()> { let deallocate_query = format!("DEALLOCATE PREPARE {}", stmt.name()); match self { - PsqlpyConnection::PoolConn(pconn, _, _) => { - return Ok(pconn.batch_execute(&deallocate_query).await?) + PsqlpyConnection::PoolConn(pconn, _) => { + let res = Ok(pconn.batch_execute(&deallocate_query).await?); + res } PsqlpyConnection::SingleConn(sconn) => { return Ok(sconn.batch_execute(&deallocate_query).await?) @@ -68,7 +70,7 @@ impl PsqlpyConnection { T: ?Sized + ToStatement, { match self { - PsqlpyConnection::PoolConn(pconn, _, _) => { + PsqlpyConnection::PoolConn(pconn, _) => { return Ok(pconn.query(statement, params).await?) } PsqlpyConnection::SingleConn(sconn) => { @@ -87,7 +89,7 @@ impl PsqlpyConnection { params: &[(&(dyn ToSql + Sync), Type)], ) -> PSQLPyResult> { match self { - PsqlpyConnection::PoolConn(pconn, _, _) => { + PsqlpyConnection::PoolConn(pconn, _) => { return Ok(pconn.query_typed(statement, params).await?) } PsqlpyConnection::SingleConn(sconn) => { @@ -102,9 +104,7 @@ impl PsqlpyConnection { /// May return Err if cannot execute statement. pub async fn batch_execute(&self, query: &str) -> PSQLPyResult<()> { match self { - PsqlpyConnection::PoolConn(pconn, _, _) => { - return Ok(pconn.batch_execute(query).await?) - } + PsqlpyConnection::PoolConn(pconn, _) => return Ok(pconn.batch_execute(query).await?), PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), } } @@ -119,7 +119,7 @@ impl PsqlpyConnection { U: Buf + 'static + Send, { match self { - PsqlpyConnection::PoolConn(pconn, _, _) => return Ok(pconn.copy_in(statement).await?), + PsqlpyConnection::PoolConn(pconn, _) => return Ok(pconn.copy_in(statement).await?), PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), } } @@ -137,7 +137,7 @@ impl PsqlpyConnection { T: ?Sized + ToStatement, { match self { - PsqlpyConnection::PoolConn(pconn, _, _) => { + PsqlpyConnection::PoolConn(pconn, _) => { return Ok(pconn.query_one(statement, params).await?) } PsqlpyConnection::SingleConn(sconn) => { @@ -202,8 +202,14 @@ impl PsqlpyConnection { "Cannot prepare statement, error - {err}" )) })?, + // false => { + // self + // .query_typed(statement.raw_query(), &statement.params_typed()) + // .await + // .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? + // }, false => self - .query_typed(statement.raw_query(), &statement.params_typed()) + .query_typed("SELECT * FROM users", &[]) .await .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))?, }; diff --git a/src/statement/query.rs b/src/statement/query.rs index 2b08aa62..108fe756 100644 --- a/src/statement/query.rs +++ b/src/statement/query.rs @@ -6,7 +6,7 @@ use crate::value_converter::consts::KWARGS_PARAMS_REGEXP; use super::utils::hash_str; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct QueryString { pub(crate) initial_qs: String, // This field are used when kwargs passed @@ -68,7 +68,7 @@ impl QueryString { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub(crate) struct ConvertedQueryString { converted_qs: String, params_names: Vec, diff --git a/src/statement/statement.rs b/src/statement/statement.rs index a93d9cd5..addaae89 100644 --- a/src/statement/statement.rs +++ b/src/statement/statement.rs @@ -5,7 +5,7 @@ use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::{parameters::PreparedParameters, query::QueryString}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PsqlpyStatement { query: QueryString, prepared_parameters: PreparedParameters, diff --git a/src/statement/statement_builder.rs b/src/statement/statement_builder.rs index 863ba400..5954f88c 100644 --- a/src/statement/statement_builder.rs +++ b/src/statement/statement_builder.rs @@ -93,10 +93,8 @@ impl<'a> StatementBuilder<'a> { )) } false => { - { - self.write_to_cache(cache_guard, &querystring, &prepared_stmt) - .await; - } + self.write_to_cache(cache_guard, &querystring, &prepared_stmt) + .await; return Ok(PsqlpyStatement::new(querystring, prepared_parameters, None)); } } diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index 047cd4c4..3d65565b 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -23,7 +23,6 @@ use crate::{ Circle, Line, RustLineSegment, RustLineString, RustMacAddr6, RustMacAddr8, RustPoint, RustRect, }, - consts::KWARGS_QUERYSTRINGS, models::{ decimal::InnerDecimal, interval::InnerInterval, serde_value::InternalSerdeValue, uuid::InternalUuid, From 3e89ffd50bab16355228e4d8ae26db3fcf9c0715 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 4 May 2025 20:38:12 +0200 Subject: [PATCH 06/15] Full value converter refactor --- src/driver/inner_connection.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index 797c9749..5b28d12b 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -202,14 +202,8 @@ impl PsqlpyConnection { "Cannot prepare statement, error - {err}" )) })?, - // false => { - // self - // .query_typed(statement.raw_query(), &statement.params_typed()) - // .await - // .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? - // }, false => self - .query_typed("SELECT * FROM users", &[]) + .query_typed(statement.raw_query(), &statement.params_typed()) .await .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))?, }; From 8453ab7128d4d61a27bc941bdd9921b9fed4a704 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 4 May 2025 21:06:13 +0200 Subject: [PATCH 07/15] Full value converter refactor --- python/tests/test_value_converter.py | 5 +++++ src/driver/inner_connection.rs | 2 -- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index b0ec5c8d..022afea2 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -1187,6 +1187,11 @@ async def test_empty_array( VarCharArray([]), [], ), + ( + "VARCHAR ARRAY", + [], + [], + ), ( "TEXT ARRAY", TextArray([]), diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index 5b28d12b..c671229a 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -29,10 +29,8 @@ impl PsqlpyConnection { if prepared { return Ok(pconn.prepare_cached(query).await?); } else { - pconn.batch_execute("BEGIN").await?; let prepared = pconn.prepare(query).await?; self.drop_prepared(&prepared).await?; - pconn.batch_execute("COMMIT").await?; return Ok(prepared); } } From 284cc344c278a70d0e892c5ae87ffc1c51b2497c Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 4 May 2025 22:58:57 +0200 Subject: [PATCH 08/15] Full value converter refactor --- python/tests/test_value_converter.py | 685 +++++++++----------------- src/driver/common_options.rs | 2 +- src/driver/connection_pool.rs | 73 +-- src/driver/connection_pool_builder.rs | 9 - 4 files changed, 266 insertions(+), 503 deletions(-) diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 022afea2..ce2f05ed 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -239,448 +239,6 @@ async def test_as_class( datetime.timedelta(days=100, microseconds=100), datetime.timedelta(days=100, microseconds=100), ), - ( - "VARCHAR ARRAY", - ["Some String", "Some String"], - ["Some String", "Some String"], - ), - ( - "TEXT ARRAY", - [Text("Some String"), Text("Some String")], - ["Some String", "Some String"], - ), - ("BOOL ARRAY", [True, False], [True, False]), - ("BOOL ARRAY", [[True], [False]], [[True], [False]]), - ("INT2 ARRAY", [SmallInt(12), SmallInt(100)], [12, 100]), - ("INT2 ARRAY", [[SmallInt(12)], [SmallInt(100)]], [[12], [100]]), - ("INT4 ARRAY", [Integer(121231231), Integer(121231231)], [121231231, 121231231]), - ( - "INT4 ARRAY", - [[Integer(121231231)], [Integer(121231231)]], - [[121231231], [121231231]], - ), - ( - "INT8 ARRAY", - [BigInt(99999999999999999), BigInt(99999999999999999)], - [99999999999999999, 99999999999999999], - ), - ( - "INT8 ARRAY", - [[BigInt(99999999999999999)], [BigInt(99999999999999999)]], - [[99999999999999999], [99999999999999999]], - ), - ( - "MONEY ARRAY", - [Money(99999999999999999), Money(99999999999999999)], - [99999999999999999, 99999999999999999], - ), - ( - "NUMERIC(5, 2) ARRAY", - [Decimal("121.23"), Decimal("188.99")], - [Decimal("121.23"), Decimal("188.99")], - ), - ( - "NUMERIC(5, 2) ARRAY", - [[Decimal("121.23")], [Decimal("188.99")]], - [[Decimal("121.23")], [Decimal("188.99")]], - ), - ( - "FLOAT8 ARRAY", - [32.12329864501953, 32.12329864501953], - [32.12329864501953, 32.12329864501953], - ), - ( - "FLOAT8 ARRAY", - [[32.12329864501953], [32.12329864501953]], - [[32.12329864501953], [32.12329864501953]], - ), - ( - "DATE ARRAY", - [now_datetime.date(), now_datetime.date()], - [now_datetime.date(), now_datetime.date()], - ), - ( - "DATE ARRAY", - [[now_datetime.date()], [now_datetime.date()]], - [[now_datetime.date()], [now_datetime.date()]], - ), - ( - "TIME ARRAY", - [now_datetime.time(), now_datetime.time()], - [now_datetime.time(), now_datetime.time()], - ), - ( - "TIME ARRAY", - [[now_datetime.time()], [now_datetime.time()]], - [[now_datetime.time()], [now_datetime.time()]], - ), - ("TIMESTAMP ARRAY", [now_datetime, now_datetime], [now_datetime, now_datetime]), - ( - "TIMESTAMP ARRAY", - [[now_datetime], [now_datetime]], - [[now_datetime], [now_datetime]], - ), - ( - "TIMESTAMPTZ ARRAY", - [now_datetime_with_tz, now_datetime_with_tz], - [now_datetime_with_tz, now_datetime_with_tz], - ), - ( - "TIMESTAMPTZ ARRAY", - [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], - [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], - ), - ( - "TIMESTAMPTZ ARRAY", - [[now_datetime_with_tz], [now_datetime_with_tz]], - [[now_datetime_with_tz], [now_datetime_with_tz]], - ), - ( - "UUID ARRAY", - [uuid_, uuid_], - [str(uuid_), str(uuid_)], - ), - ( - "UUID ARRAY", - [[uuid_], [uuid_]], - [[str(uuid_)], [str(uuid_)]], - ), - ( - "INET ARRAY", - [IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")], - [IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")], - ), - ( - "INET ARRAY", - [[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]], - [[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]], - ), - ( - "JSONB ARRAY", - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ), - ( - "JSONB ARRAY", - [ - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ], - [ - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ], - ), - ( - "JSONB ARRAY", - [ - JSONB([{"array": "json"}, {"one more": "test"}]), - JSONB([{"array": "json"}, {"one more": "test"}]), - ], - [ - [{"array": "json"}, {"one more": "test"}], - [{"array": "json"}, {"one more": "test"}], - ], - ), - ( - "JSONB ARRAY", - [ - JSONB([[{"array": "json"}], [{"one more": "test"}]]), - JSONB([[{"array": "json"}], [{"one more": "test"}]]), - ], - [ - [[{"array": "json"}], [{"one more": "test"}]], - [[{"array": "json"}], [{"one more": "test"}]], - ], - ), - ( - "JSON ARRAY", - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ), - ( - "JSON ARRAY", - [ - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ), - ( - "JSON ARRAY", - [ - [ - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - ], - [ - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - ], - ], - [ - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ], - ), - ( - "JSON ARRAY", - [ - JSON([{"array": "json"}, {"one more": "test"}]), - JSON([{"array": "json"}, {"one more": "test"}]), - ], - [ - [{"array": "json"}, {"one more": "test"}], - [{"array": "json"}, {"one more": "test"}], - ], - ), - ( - "JSON ARRAY", - [ - JSON([[{"array": "json"}], [{"one more": "test"}]]), - JSON([[{"array": "json"}], [{"one more": "test"}]]), - ], - [ - [[{"array": "json"}], [{"one more": "test"}]], - [[{"array": "json"}], [{"one more": "test"}]], - ], - ), - ( - "POINT ARRAY", - [ - Point([1.5, 2]), - Point([2, 3]), - ], - [ - (1.5, 2.0), - (2.0, 3.0), - ], - ), - ( - "POINT ARRAY", - [ - [Point([1.5, 2])], - [Point([2, 3])], - ], - [ - [(1.5, 2.0)], - [(2.0, 3.0)], - ], - ), - ( - "BOX ARRAY", - [ - Box([3.5, 3, 9, 9]), - Box([8.5, 8, 9, 9]), - ], - [ - ((9.0, 9.0), (3.5, 3.0)), - ((9.0, 9.0), (8.5, 8.0)), - ], - ), - ( - "BOX ARRAY", - [ - [Box([3.5, 3, 9, 9])], - [Box([8.5, 8, 9, 9])], - ], - [ - [((9.0, 9.0), (3.5, 3.0))], - [((9.0, 9.0), (8.5, 8.0))], - ], - ), - ( - "PATH ARRAY", - [ - Path([(3.5, 3), (9, 9), (8, 8)]), - Path([(3.5, 3), (6, 6), (3.5, 3)]), - ], - [ - [(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)], - ((3.5, 3.0), (6.0, 6.0), (3.5, 3.0)), - ], - ), - ( - "PATH ARRAY", - [ - [Path([(3.5, 3), (9, 9), (8, 8)])], - [Path([(3.5, 3), (6, 6), (3.5, 3)])], - ], - [ - [[(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)]], - [((3.5, 3.0), (6.0, 6.0), (3.5, 3.0))], - ], - ), - ( - "LINE ARRAY", - [ - Line([-2, 1, 2]), - Line([1, -2, 3]), - ], - [ - (-2.0, 1.0, 2.0), - (1.0, -2.0, 3.0), - ], - ), - ( - "LINE ARRAY", - [ - [Line([-2, 1, 2])], - [Line([1, -2, 3])], - ], - [ - [(-2.0, 1.0, 2.0)], - [(1.0, -2.0, 3.0)], - ], - ), - ( - "LSEG ARRAY", - [ - LineSegment({(1, 2), (9, 9)}), - LineSegment([(5.6, 3.1), (4, 5)]), - ], - [ - [(1.0, 2.0), (9.0, 9.0)], - [(5.6, 3.1), (4.0, 5.0)], - ], - ), - ( - "LSEG ARRAY", - [ - [LineSegment({(1, 2), (9, 9)})], - [LineSegment([(5.6, 3.1), (4, 5)])], - ], - [ - [[(1.0, 2.0), (9.0, 9.0)]], - [[(5.6, 3.1), (4.0, 5.0)]], - ], - ), - ( - "CIRCLE ARRAY", - [ - Circle([1.7, 2.8, 3]), - Circle([5, 1.8, 10]), - ], - [ - ((1.7, 2.8), 3.0), - ((5.0, 1.8), 10.0), - ], - ), - ( - "CIRCLE ARRAY", - [ - [Circle([1.7, 2.8, 3])], - [Circle([5, 1.8, 10])], - ], - [ - [((1.7, 2.8), 3.0)], - [((5.0, 1.8), 10.0)], - ], - ), - ( - "INTERVAL ARRAY", - [ - datetime.timedelta(days=100, microseconds=100), - datetime.timedelta(days=100, microseconds=100), - ], - [ - datetime.timedelta(days=100, microseconds=100), - datetime.timedelta(days=100, microseconds=100), - ], - ), ], ) async def test_deserialization_simple_into_python( @@ -1177,37 +735,29 @@ async def test_empty_array( @pytest.mark.parametrize( ("postgres_type", "py_value", "expected_deserialized"), [ + ("VARCHAR ARRAY", [], []), ( "VARCHAR ARRAY", VarCharArray(["Some String", "Some String"]), ["Some String", "Some String"], ), - ( - "VARCHAR ARRAY", - VarCharArray([]), - [], - ), - ( - "VARCHAR ARRAY", - [], - [], - ), - ( - "TEXT ARRAY", - TextArray([]), - [], - ), + ("VARCHAR ARRAY", VarCharArray([]), []), + ("TEXT ARRAY", [], []), + ("TEXT ARRAY", TextArray([]), []), ( "TEXT ARRAY", TextArray([Text("Some String"), Text("Some String")]), ["Some String", "Some String"], ), + ("BOOL ARRAY", [], []), ("BOOL ARRAY", BoolArray([]), []), ("BOOL ARRAY", BoolArray([True, False]), [True, False]), ("BOOL ARRAY", BoolArray([[True], [False]]), [[True], [False]]), + ("INT2 ARRAY", [], []), ("INT2 ARRAY", Int16Array([]), []), ("INT2 ARRAY", Int16Array([SmallInt(12), SmallInt(100)]), [12, 100]), ("INT2 ARRAY", Int16Array([[SmallInt(12)], [SmallInt(100)]]), [[12], [100]]), + ("INT4 ARRAY", [], []), ( "INT4 ARRAY", Int32Array([Integer(121231231), Integer(121231231)]), @@ -1218,6 +768,7 @@ async def test_empty_array( Int32Array([[Integer(121231231)], [Integer(121231231)]]), [[121231231], [121231231]], ), + ("INT8 ARRAY", [], []), ( "INT8 ARRAY", Int64Array([BigInt(99999999999999999), BigInt(99999999999999999)]), @@ -1228,11 +779,13 @@ async def test_empty_array( Int64Array([[BigInt(99999999999999999)], [BigInt(99999999999999999)]]), [[99999999999999999], [99999999999999999]], ), + ("MONEY ARRAY", [], []), ( "MONEY ARRAY", MoneyArray([Money(99999999999999999), Money(99999999999999999)]), [99999999999999999, 99999999999999999], ), + ("NUMERIC(5, 2) ARRAY", [], []), ( "NUMERIC(5, 2) ARRAY", NumericArray([Decimal("121.23"), Decimal("188.99")]), @@ -1243,6 +796,13 @@ async def test_empty_array( NumericArray([[Decimal("121.23")], [Decimal("188.99")]]), [[Decimal("121.23")], [Decimal("188.99")]], ), + ("FLOAT4 ARRAY", [], []), + ( + "FLOAT4 ARRAY", + [32.12329864501953, 32.12329864501953], + [32.12329864501953, 32.12329864501953], + ), + ("FLOAT8 ARRAY", [], []), ( "FLOAT8 ARRAY", Float64Array([32.12329864501953, 32.12329864501953]), @@ -1253,6 +813,7 @@ async def test_empty_array( Float64Array([[32.12329864501953], [32.12329864501953]]), [[32.12329864501953], [32.12329864501953]], ), + ("DATE ARRAY", [], []), ( "DATE ARRAY", DateArray([now_datetime.date(), now_datetime.date()]), @@ -1263,6 +824,7 @@ async def test_empty_array( DateArray([[now_datetime.date()], [now_datetime.date()]]), [[now_datetime.date()], [now_datetime.date()]], ), + ("TIME ARRAY", [], []), ( "TIME ARRAY", TimeArray([now_datetime.time(), now_datetime.time()]), @@ -1273,6 +835,7 @@ async def test_empty_array( TimeArray([[now_datetime.time()], [now_datetime.time()]]), [[now_datetime.time()], [now_datetime.time()]], ), + ("TIMESTAMP ARRAY", [], []), ( "TIMESTAMP ARRAY", DateTimeArray([now_datetime, now_datetime]), @@ -1283,6 +846,7 @@ async def test_empty_array( DateTimeArray([[now_datetime], [now_datetime]]), [[now_datetime], [now_datetime]], ), + ("TIMESTAMPTZ ARRAY", [], []), ( "TIMESTAMPTZ ARRAY", DateTimeTZArray([now_datetime_with_tz, now_datetime_with_tz]), @@ -1293,16 +857,13 @@ async def test_empty_array( DateTimeTZArray([[now_datetime_with_tz], [now_datetime_with_tz]]), [[now_datetime_with_tz], [now_datetime_with_tz]], ), - ( - "UUID ARRAY", - UUIDArray([uuid_, uuid_]), - [str(uuid_), str(uuid_)], - ), + ("UUID ARRAY", [], []), ( "UUID ARRAY", UUIDArray([[uuid_], [uuid_]]), [[str(uuid_)], [str(uuid_)]], ), + ("INET ARRAY", [], []), ( "INET ARRAY", IpAddressArray([IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")]), @@ -1313,6 +874,30 @@ async def test_empty_array( IpAddressArray([[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]]), [[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]], ), + ("JSONB ARRAY", [], []), + ( + "JSONB ARRAY", + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), ( "JSONB ARRAY", JSONBArray( @@ -1397,6 +982,55 @@ async def test_empty_array( [[{"array": "json"}], [{"one more": "test"}]], ], ), + ("JSON ARRAY", [], []), + ( + "JSON ARRAY", + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), + ( + "JSON ARRAY", + JSONArray( + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), ( "JSON ARRAY", JSONArray( @@ -1489,6 +1123,17 @@ async def test_empty_array( [[{"array": "json"}], [{"one more": "test"}]], ], ), + ( + "POINT ARRAY", + [ + Point([1.5, 2]), + Point([2, 3]), + ], + [ + (1.5, 2.0), + (2.0, 3.0), + ], + ), ( "POINT ARRAY", PointArray( @@ -1502,6 +1147,17 @@ async def test_empty_array( (2.0, 3.0), ], ), + ( + "POINT ARRAY", + [ + [Point([1.5, 2])], + [Point([2, 3])], + ], + [ + [(1.5, 2.0)], + [(2.0, 3.0)], + ], + ), ( "POINT ARRAY", PointArray( @@ -1515,6 +1171,18 @@ async def test_empty_array( [(2.0, 3.0)], ], ), + ("BOX ARRAY", [], []), + ( + "BOX ARRAY", + [ + Box([3.5, 3, 9, 9]), + Box([8.5, 8, 9, 9]), + ], + [ + ((9.0, 9.0), (3.5, 3.0)), + ((9.0, 9.0), (8.5, 8.0)), + ], + ), ( "BOX ARRAY", BoxArray( @@ -1541,6 +1209,18 @@ async def test_empty_array( [((9.0, 9.0), (8.5, 8.0))], ], ), + ("PATH ARRAY", [], []), + ( + "PATH ARRAY", + [ + Path([(3.5, 3), (9, 9), (8, 8)]), + Path([(3.5, 3), (6, 6), (3.5, 3)]), + ], + [ + [(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)], + ((3.5, 3.0), (6.0, 6.0), (3.5, 3.0)), + ], + ), ( "PATH ARRAY", PathArray( @@ -1554,6 +1234,17 @@ async def test_empty_array( ((3.5, 3.0), (6.0, 6.0), (3.5, 3.0)), ], ), + ( + "PATH ARRAY", + [ + [Path([(3.5, 3), (9, 9), (8, 8)])], + [Path([(3.5, 3), (6, 6), (3.5, 3)])], + ], + [ + [[(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)]], + [((3.5, 3.0), (6.0, 6.0), (3.5, 3.0))], + ], + ), ( "PATH ARRAY", PathArray( @@ -1567,6 +1258,18 @@ async def test_empty_array( [((3.5, 3.0), (6.0, 6.0), (3.5, 3.0))], ], ), + ("LINE ARRAY", [], []), + ( + "LINE ARRAY", + [ + Line([-2, 1, 2]), + Line([1, -2, 3]), + ], + [ + (-2.0, 1.0, 2.0), + (1.0, -2.0, 3.0), + ], + ), ( "LINE ARRAY", LineArray( @@ -1580,6 +1283,17 @@ async def test_empty_array( (1.0, -2.0, 3.0), ], ), + ( + "LINE ARRAY", + [ + [Line([-2, 1, 2])], + [Line([1, -2, 3])], + ], + [ + [(-2.0, 1.0, 2.0)], + [(1.0, -2.0, 3.0)], + ], + ), ( "LINE ARRAY", LineArray( @@ -1593,6 +1307,18 @@ async def test_empty_array( [(1.0, -2.0, 3.0)], ], ), + ("LSEG ARRAY", [], []), + ( + "LSEG ARRAY", + [ + LineSegment({(1, 2), (9, 9)}), + LineSegment([(5.6, 3.1), (4, 5)]), + ], + [ + [(1.0, 2.0), (9.0, 9.0)], + [(5.6, 3.1), (4.0, 5.0)], + ], + ), ( "LSEG ARRAY", LsegArray( @@ -1606,6 +1332,17 @@ async def test_empty_array( [(5.6, 3.1), (4.0, 5.0)], ], ), + ( + "LSEG ARRAY", + [ + [LineSegment({(1, 2), (9, 9)})], + [LineSegment([(5.6, 3.1), (4, 5)])], + ], + [ + [[(1.0, 2.0), (9.0, 9.0)]], + [[(5.6, 3.1), (4.0, 5.0)]], + ], + ), ( "LSEG ARRAY", LsegArray( @@ -1619,6 +1356,18 @@ async def test_empty_array( [[(5.6, 3.1), (4.0, 5.0)]], ], ), + ("CIRCLE ARRAY", [], []), + ( + "CIRCLE ARRAY", + [ + Circle([1.7, 2.8, 3]), + Circle([5, 1.8, 10]), + ], + [ + ((1.7, 2.8), 3.0), + ((5.0, 1.8), 10.0), + ], + ), ( "CIRCLE ARRAY", CircleArray( @@ -1645,6 +1394,18 @@ async def test_empty_array( [((5.0, 1.8), 10.0)], ], ), + ("INTERVAL ARRAY", [], []), + ( + "INTERVAL ARRAY", + [ + [datetime.timedelta(days=100, microseconds=100)], + [datetime.timedelta(days=100, microseconds=100)], + ], + [ + [datetime.timedelta(days=100, microseconds=100)], + [datetime.timedelta(days=100, microseconds=100)], + ], + ), ( "INTERVAL ARRAY", IntervalArray( diff --git a/src/driver/common_options.rs b/src/driver/common_options.rs index aebc5837..a76d37dd 100644 --- a/src/driver/common_options.rs +++ b/src/driver/common_options.rs @@ -64,7 +64,7 @@ impl TargetSessionAttrs { } #[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] +#[derive(Clone, Copy, PartialEq, Debug)] pub enum SslMode { /// Do not use TLS. Disable, diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 16454de0..aa897012 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -15,6 +15,23 @@ use super::{ utils::{build_connection_config, build_manager, build_tls}, }; +#[derive(Debug, Clone)] +pub struct ConnectionPoolConf { + pub ca_file: Option, + pub ssl_mode: Option, + pub prepare: bool, +} + +impl ConnectionPoolConf { + fn new(ca_file: Option, ssl_mode: Option, prepare: bool) -> Self { + Self { + ca_file, + ssl_mode, + prepare, + } + } +} + /// Make new connection pool. /// /// # Errors @@ -47,7 +64,6 @@ use super::{ ca_file=None, max_db_pool_size=None, conn_recycling_method=None, - prepare=None, ))] #[allow(clippy::too_many_arguments)] pub fn connect( @@ -77,7 +93,6 @@ pub fn connect( ca_file: Option, max_db_pool_size: Option, conn_recycling_method: Option, - prepare: Option, ) -> PSQLPyResult { if let Some(max_db_pool_size) = max_db_pool_size { if max_db_pool_size < 2 { @@ -137,13 +152,9 @@ pub fn connect( let pool = db_pool_builder.build()?; - Ok(ConnectionPool { - pool: pool, - pg_config: Arc::new(pg_config), - ca_file: ca_file, - ssl_mode: ssl_mode, - prepare: prepare.unwrap_or(true), - }) + Ok(ConnectionPool::build( + pool, pg_config, ca_file, ssl_mode, None, + )) } #[pyclass] @@ -209,9 +220,7 @@ impl ConnectionPoolStatus { pub struct ConnectionPool { pool: Pool, pg_config: Arc, - ca_file: Option, - ssl_mode: Option, - prepare: bool, + pool_conf: ConnectionPoolConf, } impl ConnectionPool { @@ -226,9 +235,7 @@ impl ConnectionPool { ConnectionPool { pool: pool, pg_config: Arc::new(pg_config), - ca_file: ca_file, - ssl_mode: ssl_mode, - prepare: prepare.unwrap_or(true), + pool_conf: ConnectionPoolConf::new(ca_file, ssl_mode, prepare.unwrap_or(true)), } } @@ -271,7 +278,6 @@ impl ConnectionPool { conn_recycling_method=None, ssl_mode=None, ca_file=None, - prepare=None, ))] #[allow(clippy::too_many_arguments)] pub fn new( @@ -301,7 +307,6 @@ impl ConnectionPool { conn_recycling_method: Option, ssl_mode: Option, ca_file: Option, - prepare: Option, ) -> PSQLPyResult { connect( dsn, @@ -330,7 +335,6 @@ impl ConnectionPool { ca_file, max_db_pool_size, conn_recycling_method, - prepare, ) } @@ -378,24 +382,24 @@ impl ConnectionPool { None, Some(self.pool.clone()), self.pg_config.clone(), - self.prepare, + self.pool_conf.prepare, ) } #[must_use] #[allow(clippy::needless_pass_by_value)] pub fn listener(self_: pyo3::Py) -> Listener { - let (pg_config, ca_file, ssl_mode, prepare) = pyo3::Python::with_gil(|gil| { + let (pg_config, pool_conf) = pyo3::Python::with_gil(|gil| { let b_gil = self_.borrow(gil); - ( - b_gil.pg_config.clone(), - b_gil.ca_file.clone(), - b_gil.ssl_mode, - b_gil.prepare, - ) + (b_gil.pg_config.clone(), b_gil.pool_conf.clone()) }); - Listener::new(pg_config, ca_file, ssl_mode, prepare) + Listener::new( + pg_config, + pool_conf.ca_file, + pool_conf.ssl_mode, + pool_conf.prepare, + ) } /// Return new single connection. @@ -403,9 +407,13 @@ impl ConnectionPool { /// # Errors /// May return Err Result if cannot get new connection from the pool. pub async fn connection(self_: pyo3::Py) -> PSQLPyResult { - let (db_pool, pg_config, prepare) = pyo3::Python::with_gil(|gil| { + let (db_pool, pg_config, pool_conf) = pyo3::Python::with_gil(|gil| { let slf = self_.borrow(gil); - (slf.pool.clone(), slf.pg_config.clone(), slf.prepare) + ( + slf.pool.clone(), + slf.pg_config.clone(), + slf.pool_conf.clone(), + ) }); let db_connection = tokio_runtime() .spawn(async move { @@ -414,10 +422,13 @@ impl ConnectionPool { .await??; Ok(Connection::new( - Some(Arc::new(PsqlpyConnection::PoolConn(db_connection, prepare))), + Some(Arc::new(PsqlpyConnection::PoolConn( + db_connection, + pool_conf.prepare, + ))), None, pg_config, - prepare, + pool_conf.prepare, )) } diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index ea311642..0cd7432b 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -83,15 +83,6 @@ impl ConnectionPoolBuilder { self_ } - /// Set ca_file for ssl_mode in PostgreSQL. - fn prepare(self_: Py, prepare: bool) -> Py { - Python::with_gil(|gil| { - let mut self_ = self_.borrow_mut(gil); - self_.prepare = Some(prepare); - }); - self_ - } - /// Set size to the connection pool. /// /// # Error From b5bfec2aac8addb161b80644bc856c73f325dedf Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 4 May 2025 23:04:20 +0200 Subject: [PATCH 09/15] Full value converter refactor --- src/driver/inner_connection.rs | 2 +- src/value_converter/to_python.rs | 38 -------------------------------- 2 files changed, 1 insertion(+), 39 deletions(-) diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index c671229a..d8acc4d8 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -1,5 +1,5 @@ use bytes::Buf; -use deadpool_postgres::{Object, Pool}; +use deadpool_postgres::Object; use postgres_types::{ToSql, Type}; use pyo3::{Py, PyAny, Python}; use std::vec; diff --git a/src/value_converter/to_python.rs b/src/value_converter/to_python.rs index 3d65565b..c0801bac 100644 --- a/src/value_converter/to_python.rs +++ b/src/value_converter/to_python.rs @@ -635,41 +635,3 @@ pub fn postgres_to_py( } Ok(py.None()) } - -/// Convert Python sequence to Rust vector. -/// Also it checks that sequence has set/list/tuple type. -/// -/// # Errors -/// -/// May return error if cannot convert Python type into Rust one. -/// May return error if parameters type isn't correct. -fn py_sequence_to_rust(bind_parameters: &Bound) -> PSQLPyResult>> { - let mut coord_values_sequence_vec: Vec> = vec![]; - - if bind_parameters.is_instance_of::() { - let bind_pyset_parameters = bind_parameters.downcast::().unwrap(); - - for one_parameter in bind_pyset_parameters { - let extracted_parameter = one_parameter.extract::>().map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") - ) - })?; - coord_values_sequence_vec.push(extracted_parameter); - } - } else if bind_parameters.is_instance_of::() - | bind_parameters.is_instance_of::() - { - coord_values_sequence_vec = bind_parameters.extract::>>().map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") - ) - })?; - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError(format!( - "Invalid sequence type, please use list/tuple/set, {bind_parameters}" - ))); - }; - - Ok::>, RustPSQLDriverError>(coord_values_sequence_vec) -} From 0464162b0c296b8701f42bce529cef1bf48807b7 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Mon, 5 May 2025 00:28:58 +0200 Subject: [PATCH 10/15] Added 14, 15, 16, 17 version of PostgreSQL to tests --- .github/workflows/test.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7d81fbd3..76dbdda8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -36,6 +36,7 @@ jobs: strategy: matrix: py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + postgres_version: ["14", "15", "16", "17"] job: - os: ubuntu-latest ssl_cmd: sudo apt-get update && sudo apt-get install libssl-dev openssl @@ -43,12 +44,13 @@ jobs: steps: - uses: actions/checkout@v1 - name: Setup Postgres - uses: ./.github/actions/setup_postgres/ + uses: ikalnytskyi/action-setup-postgres@v7 with: username: postgres password: postgres database: psqlpy_test - ssl_on: "on" + ssl: true + postgres-version: ${{ matrix.postgres_version }} id: postgres - uses: actions-rs/toolchain@v1 with: From 7c6d373e1c8e72698891e768751ee9fdab534430 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Mon, 5 May 2025 00:35:11 +0200 Subject: [PATCH 11/15] Added 14, 15, 16, 17 version of PostgreSQL to tests --- .github/workflows/test.yaml | 2 +- python/tests/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 76dbdda8..f23ba46c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -32,7 +32,7 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} args: -p psqlpy --all-features -- -W clippy::all -W clippy::pedantic pytest: - name: ${{matrix.job.os}}-${{matrix.py_version}} + name: ${{matrix.job.os}}-${{matrix.py_version}}-${{ matrix.postgres_version }} strategy: matrix: py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 30426e5f..1ee7e9b4 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -85,7 +85,7 @@ def number_database_records() -> int: @pytest.fixture def ssl_cert_file() -> str: - return os.environ.get("POSTGRES_CERT_FILE", "./root.crt") + return os.environ.get("POSTGRES_CERT_FILE", "./server.crt") @pytest.fixture From 47a441c37cc89aa446ef3a0984f6fe166d161635 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Mon, 5 May 2025 00:39:38 +0200 Subject: [PATCH 12/15] Added 14, 15, 16, 17 version of PostgreSQL to tests --- .github/workflows/test.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f23ba46c..aee295cd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -44,6 +44,7 @@ jobs: steps: - uses: actions/checkout@v1 - name: Setup Postgres + id: postgres uses: ikalnytskyi/action-setup-postgres@v7 with: username: postgres @@ -66,4 +67,6 @@ jobs: - name: Install tox run: pip install "tox-gh>=1.2,<2" - name: Run pytest + env: + POSTGRES_CERT_FILE: "${{ steps.postgres.outputs.certificate-path }}" run: tox -v -c tox.ini From 28c6d4ee4eba2eca5e4ebf5ac726bcbdcc3397f9 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Mon, 5 May 2025 00:40:22 +0200 Subject: [PATCH 13/15] Added 14, 15, 16, 17 version of PostgreSQL to tests --- .github/workflows/test.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index aee295cd..eb8f6197 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -52,7 +52,6 @@ jobs: database: psqlpy_test ssl: true postgres-version: ${{ matrix.postgres_version }} - id: postgres - uses: actions-rs/toolchain@v1 with: toolchain: stable From 7a5362b9055450318e0cc512d51cc8bd9e1abe17 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Mon, 5 May 2025 00:46:04 +0200 Subject: [PATCH 14/15] Added 14, 15, 16, 17 version of PostgreSQL to tests --- .github/workflows/test.yaml | 6 ++++-- python/tests/conftest.py | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index eb8f6197..144e3bfd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -35,8 +35,10 @@ jobs: name: ${{matrix.job.os}}-${{matrix.py_version}}-${{ matrix.postgres_version }} strategy: matrix: - py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - postgres_version: ["14", "15", "16", "17"] + # py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + py_version: ["3.9"] + # postgres_version: ["14", "15", "16", "17"] + postgres_version: ["14"] job: - os: ubuntu-latest ssl_cmd: sudo apt-get update && sudo apt-get install libssl-dev openssl diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 1ee7e9b4..a9bfc4d3 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -85,7 +85,10 @@ def number_database_records() -> int: @pytest.fixture def ssl_cert_file() -> str: - return os.environ.get("POSTGRES_CERT_FILE", "./server.crt") + return os.environ.get( + "POSTGRES_CERT_FILE", + "/home/runner/work/_temp/pgdata/server.crt", + ) @pytest.fixture From 936529ae75e97120c57819b2bbeacb0ff82321ec Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Mon, 5 May 2025 00:50:28 +0200 Subject: [PATCH 15/15] Added 14, 15, 16, 17 version of PostgreSQL to tests --- .github/workflows/test.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 144e3bfd..eb8f6197 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -35,10 +35,8 @@ jobs: name: ${{matrix.job.os}}-${{matrix.py_version}}-${{ matrix.postgres_version }} strategy: matrix: - # py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - py_version: ["3.9"] - # postgres_version: ["14", "15", "16", "17"] - postgres_version: ["14"] + py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + postgres_version: ["14", "15", "16", "17"] job: - os: ubuntu-latest ssl_cmd: sudo apt-get update && sudo apt-get install libssl-dev openssl