@@ -70,7 +70,7 @@ pub struct TableSchema {
7070 ///
7171 /// These columns are NOT present in the data files but are appended to each
7272 /// row during query execution based on the file's location.
73- table_partition_cols : Vec < FieldRef > ,
73+ table_partition_cols : Arc < Vec < FieldRef > > ,
7474
7575 /// The complete table schema: file_schema columns followed by partition columns.
7676 ///
@@ -121,7 +121,7 @@ impl TableSchema {
121121 builder. extend ( table_partition_cols. iter ( ) . cloned ( ) ) ;
122122 Self {
123123 file_schema,
124- table_partition_cols,
124+ table_partition_cols : Arc :: new ( table_partition_cols ) ,
125125 table_schema : Arc :: new ( builder. finish ( ) ) ,
126126 }
127127 }
@@ -140,7 +140,15 @@ impl TableSchema {
140140 /// into [`TableSchema::with_table_partition_cols`] if you have partition columns at construction time
141141 /// since it avoids re-computing the table schema.
142142 pub fn with_table_partition_cols ( mut self , partition_cols : Vec < FieldRef > ) -> Self {
143- self . table_partition_cols = partition_cols;
143+ if self . table_partition_cols . is_empty ( ) {
144+ self . table_partition_cols = Arc :: new ( partition_cols) ;
145+ } else {
146+ // Append to existing partition columns
147+ let table_partition_cols = Arc :: get_mut ( & mut self . table_partition_cols ) . expect (
148+ "Expected to be the sole owner of table_partition_cols since this function accepts mut self" ,
149+ ) ;
150+ table_partition_cols. extend ( partition_cols) ;
151+ }
144152 let mut builder = SchemaBuilder :: from ( self . file_schema . as_ref ( ) ) ;
145153 builder. extend ( self . table_partition_cols . iter ( ) . cloned ( ) ) ;
146154 self . table_schema = Arc :: new ( builder. finish ( ) ) ;
@@ -176,3 +184,96 @@ impl From<SchemaRef> for TableSchema {
176184 Self :: from_file_schema ( schema)
177185 }
178186}
187+
188+ #[ cfg( test) ]
189+ mod tests {
190+ use super :: TableSchema ;
191+ use arrow:: datatypes:: { DataType , Field , Schema } ;
192+ use std:: sync:: Arc ;
193+
194+ #[ test]
195+ fn test_table_schema_creation ( ) {
196+ let file_schema = Arc :: new ( Schema :: new ( vec ! [
197+ Field :: new( "user_id" , DataType :: Int64 , false ) ,
198+ Field :: new( "amount" , DataType :: Float64 , false ) ,
199+ ] ) ) ;
200+
201+ let partition_cols = vec ! [
202+ Arc :: new( Field :: new( "date" , DataType :: Utf8 , false ) ) ,
203+ Arc :: new( Field :: new( "region" , DataType :: Utf8 , false ) ) ,
204+ ] ;
205+
206+ let table_schema = TableSchema :: new ( file_schema. clone ( ) , partition_cols. clone ( ) ) ;
207+
208+ // Verify file schema
209+ assert_eq ! ( table_schema. file_schema( ) . as_ref( ) , file_schema. as_ref( ) ) ;
210+
211+ // Verify partition columns
212+ assert_eq ! ( table_schema. table_partition_cols( ) . len( ) , 2 ) ;
213+ assert_eq ! ( table_schema. table_partition_cols( ) [ 0 ] , partition_cols[ 0 ] ) ;
214+ assert_eq ! ( table_schema. table_partition_cols( ) [ 1 ] , partition_cols[ 1 ] ) ;
215+
216+ // Verify full table schema
217+ let expected_fields = vec ! [
218+ Field :: new( "user_id" , DataType :: Int64 , false ) ,
219+ Field :: new( "amount" , DataType :: Float64 , false ) ,
220+ Field :: new( "date" , DataType :: Utf8 , false ) ,
221+ Field :: new( "region" , DataType :: Utf8 , false ) ,
222+ ] ;
223+ let expected_schema = Schema :: new ( expected_fields) ;
224+ assert_eq ! ( table_schema. table_schema( ) . as_ref( ) , & expected_schema) ;
225+ }
226+
227+ #[ test]
228+ fn test_add_multiple_partition_columns ( ) {
229+ let file_schema =
230+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( "id" , DataType :: Int32 , false ) ] ) ) ;
231+
232+ let initial_partition_cols =
233+ vec ! [ Arc :: new( Field :: new( "country" , DataType :: Utf8 , false ) ) ] ;
234+
235+ let table_schema = TableSchema :: new ( file_schema. clone ( ) , initial_partition_cols) ;
236+
237+ let additional_partition_cols = vec ! [
238+ Arc :: new( Field :: new( "city" , DataType :: Utf8 , false ) ) ,
239+ Arc :: new( Field :: new( "year" , DataType :: Int32 , false ) ) ,
240+ ] ;
241+
242+ let updated_table_schema =
243+ table_schema. with_table_partition_cols ( additional_partition_cols) ;
244+
245+ // Verify file schema remains unchanged
246+ assert_eq ! (
247+ updated_table_schema. file_schema( ) . as_ref( ) ,
248+ file_schema. as_ref( )
249+ ) ;
250+
251+ // Verify partition columns
252+ assert_eq ! ( updated_table_schema. table_partition_cols( ) . len( ) , 3 ) ;
253+ assert_eq ! (
254+ updated_table_schema. table_partition_cols( ) [ 0 ] . name( ) ,
255+ "country"
256+ ) ;
257+ assert_eq ! (
258+ updated_table_schema. table_partition_cols( ) [ 1 ] . name( ) ,
259+ "city"
260+ ) ;
261+ assert_eq ! (
262+ updated_table_schema. table_partition_cols( ) [ 2 ] . name( ) ,
263+ "year"
264+ ) ;
265+
266+ // Verify full table schema
267+ let expected_fields = vec ! [
268+ Field :: new( "id" , DataType :: Int32 , false ) ,
269+ Field :: new( "country" , DataType :: Utf8 , false ) ,
270+ Field :: new( "city" , DataType :: Utf8 , false ) ,
271+ Field :: new( "year" , DataType :: Int32 , false ) ,
272+ ] ;
273+ let expected_schema = Schema :: new ( expected_fields) ;
274+ assert_eq ! (
275+ updated_table_schema. table_schema( ) . as_ref( ) ,
276+ & expected_schema
277+ ) ;
278+ }
279+ }
0 commit comments