22# -*- coding: utf-8 -*-
33from typing import Any , Generic , Iterable , Sequence , Type
44
5- from sqlalchemy import Row , RowMapping , select
6- from sqlalchemy import delete as sa_delete
7- from sqlalchemy import update as sa_update
5+ from sqlalchemy import Row , RowMapping , Select , delete , select , update
86from sqlalchemy .ext .asyncio import AsyncSession
97
108from sqlalchemy_crud_plus .errors import MultipleResultsError
@@ -16,7 +14,13 @@ class CRUDPlus(Generic[Model]):
1614 def __init__ (self , model : Type [Model ]):
1715 self .model = model
1816
19- async def create_model (self , session : AsyncSession , obj : CreateSchema , commit : bool = False , ** kwargs ) -> Model :
17+ async def create_model (
18+ self ,
19+ session : AsyncSession ,
20+ obj : CreateSchema ,
21+ commit : bool = False ,
22+ ** kwargs ,
23+ ) -> Model :
2024 """
2125 Create a new instance of a model
2226
@@ -36,7 +40,10 @@ async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: b
3640 return ins
3741
3842 async def create_models (
39- self , session : AsyncSession , obj : Iterable [CreateSchema ], commit : bool = False
43+ self ,
44+ session : AsyncSession ,
45+ obj : Iterable [CreateSchema ],
46+ commit : bool = False ,
4047 ) -> list [Model ]:
4148 """
4249 Create new instances of a model
@@ -79,6 +86,35 @@ async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model
7986 query = await session .execute (stmt )
8087 return query .scalars ().first ()
8188
89+ async def select (self , ** kwargs ) -> Select :
90+ """
91+ Construct the SQLAlchemy selection
92+
93+ :param kwargs: Query expressions.
94+ :return:
95+ """
96+ filters = parse_filters (self .model , ** kwargs )
97+ stmt = select (self .model ).where (* filters )
98+ return stmt
99+
100+ async def select_order (
101+ self ,
102+ sort_columns : str | list [str ],
103+ sort_orders : str | list [str ] | None = None ,
104+ ** kwargs ,
105+ ) -> Select :
106+ """
107+ Constructing SQLAlchemy selection with sorting
108+
109+ :param kwargs: Query expressions.
110+ :param sort_columns: more details see apply_sorting
111+ :param sort_orders: more details see apply_sorting
112+ :return:
113+ """
114+ stmt = await self .select (** kwargs )
115+ sorted_stmt = apply_sorting (self .model , stmt , sort_columns , sort_orders )
116+ return sorted_stmt
117+
82118 async def select_models (self , session : AsyncSession , ** kwargs ) -> Sequence [Row [Any ] | RowMapping | Any ]:
83119 """
84120 Query all rows
@@ -87,13 +123,16 @@ async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[A
87123 :param kwargs: Query expressions.
88124 :return:
89125 """
90- filters = parse_filters (self .model , ** kwargs )
91- stmt = select (self .model ).where (* filters )
126+ stmt = await self .select (** kwargs )
92127 query = await session .execute (stmt )
93128 return query .scalars ().all ()
94129
95130 async def select_models_order (
96- self , session : AsyncSession , sort_columns : str | list [str ], sort_orders : str | list [str ] | None = None , ** kwargs
131+ self ,
132+ session : AsyncSession ,
133+ sort_columns : str | list [str ],
134+ sort_orders : str | list [str ] | None = None ,
135+ ** kwargs ,
97136 ) -> Sequence [Row | RowMapping | Any ] | None :
98137 """
99138 Query all rows and sort by columns
@@ -103,14 +142,16 @@ async def select_models_order(
103142 :param sort_orders: more details see apply_sorting
104143 :return:
105144 """
106- filters = parse_filters (self .model , ** kwargs )
107- stmt = select (self .model ).where (* filters )
108- stmt_sort = apply_sorting (self .model , stmt , sort_columns , sort_orders )
109- query = await session .execute (stmt_sort )
145+ stmt = await self .select_order (sort_columns , sort_orders , ** kwargs )
146+ query = await session .execute (stmt )
110147 return query .scalars ().all ()
111148
112149 async def update_model (
113- self , session : AsyncSession , pk : int , obj : UpdateSchema | dict [str , Any ], commit : bool = False
150+ self ,
151+ session : AsyncSession ,
152+ pk : int ,
153+ obj : UpdateSchema | dict [str , Any ],
154+ commit : bool = False ,
114155 ) -> int :
115156 """
116157 Update an instance by model's primary key
@@ -125,7 +166,7 @@ async def update_model(
125166 instance_data = obj
126167 else :
127168 instance_data = obj .model_dump (exclude_unset = True )
128- stmt = sa_update (self .model ).where (self .model .id == pk ).values (** instance_data )
169+ stmt = update (self .model ).where (self .model .id == pk ).values (** instance_data )
129170 result = await session .execute (stmt )
130171 if commit :
131172 await session .commit ()
@@ -157,13 +198,18 @@ async def update_model_by_column(
157198 instance_data = obj
158199 else :
159200 instance_data = obj .model_dump (exclude_unset = True )
160- stmt = sa_update (self .model ).where (* filters ).values (** instance_data ) # type: ignore
201+ stmt = update (self .model ).where (* filters ).values (** instance_data ) # type: ignore
161202 result = await session .execute (stmt )
162203 if commit :
163204 await session .commit ()
164205 return result .rowcount # type: ignore
165206
166- async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False ) -> int :
207+ async def delete_model (
208+ self ,
209+ session : AsyncSession ,
210+ pk : int ,
211+ commit : bool = False ,
212+ ) -> int :
167213 """
168214 Delete an instance by model's primary key
169215
@@ -172,7 +218,7 @@ async def delete_model(self, session: AsyncSession, pk: int, commit: bool = Fals
172218 :param commit: If `True`, commits the transaction immediately. Default is `False`.
173219 :return:
174220 """
175- stmt = sa_delete (self .model ).where (self .model .id == pk )
221+ stmt = delete (self .model ).where (self .model .id == pk )
176222 result = await session .execute (stmt )
177223 if commit :
178224 await session .commit ()
@@ -204,10 +250,10 @@ async def delete_model_by_column(
204250 raise MultipleResultsError (f'Only one record is expected to be delete, found { total_count } records.' )
205251 if logical_deletion :
206252 deleted_flag = {deleted_flag_column : True }
207- stmt = sa_update (self .model ).where (* filters ).values (** deleted_flag )
253+ stmt = update (self .model ).where (* filters ).values (** deleted_flag )
208254 else :
209- stmt = sa_delete (self .model ).where (* filters )
210- await session .execute (stmt )
255+ stmt = delete (self .model ).where (* filters )
256+ result = await session .execute (stmt )
211257 if commit :
212258 await session .commit ()
213- return total_count
259+ return result . rowcount # type: ignore
0 commit comments