@@ -70,7 +70,7 @@ pub struct TableSchema {
7070 ///
7171 /// These columns are NOT present in the data files but are appended to each
7272 /// row during query execution based on the file's location.
73- table_partition_cols : Vec < FieldRef > ,
73+ table_partition_cols : Arc < Vec < FieldRef > > ,
7474
7575 /// The complete table schema: file_schema columns followed by partition columns.
7676 ///
@@ -121,7 +121,7 @@ impl TableSchema {
121121 builder. extend ( table_partition_cols. iter ( ) . cloned ( ) ) ;
122122 Self {
123123 file_schema,
124- table_partition_cols,
124+ table_partition_cols : Arc :: new ( table_partition_cols ) ,
125125 table_schema : Arc :: new ( builder. finish ( ) ) ,
126126 }
127127 }
@@ -140,7 +140,18 @@ impl TableSchema {
140140 /// into [`TableSchema::with_table_partition_cols`] if you have partition columns at construction time
141141 /// since it avoids re-computing the table schema.
142142 pub fn with_table_partition_cols ( mut self , partition_cols : Vec < FieldRef > ) -> Self {
143- self . table_partition_cols = partition_cols;
143+ if self . table_partition_cols . is_empty ( ) {
144+ self . table_partition_cols = Arc :: new ( partition_cols) ;
145+ } else {
146+ // Append to existing partition columns
147+ self . table_partition_cols = Arc :: new (
148+ self . table_partition_cols
149+ . iter ( )
150+ . cloned ( )
151+ . chain ( partition_cols)
152+ . collect ( ) ,
153+ ) ;
154+ }
144155 let mut builder = SchemaBuilder :: from ( self . file_schema . as_ref ( ) ) ;
145156 builder. extend ( self . table_partition_cols . iter ( ) . cloned ( ) ) ;
146157 self . table_schema = Arc :: new ( builder. finish ( ) ) ;
@@ -176,3 +187,96 @@ impl From<SchemaRef> for TableSchema {
176187 Self :: from_file_schema ( schema)
177188 }
178189}
190+
191+ #[ cfg( test) ]
192+ mod tests {
193+ use super :: TableSchema ;
194+ use arrow:: datatypes:: { DataType , Field , Schema } ;
195+ use std:: sync:: Arc ;
196+
197+ #[ test]
198+ fn test_table_schema_creation ( ) {
199+ let file_schema = Arc :: new ( Schema :: new ( vec ! [
200+ Field :: new( "user_id" , DataType :: Int64 , false ) ,
201+ Field :: new( "amount" , DataType :: Float64 , false ) ,
202+ ] ) ) ;
203+
204+ let partition_cols = vec ! [
205+ Arc :: new( Field :: new( "date" , DataType :: Utf8 , false ) ) ,
206+ Arc :: new( Field :: new( "region" , DataType :: Utf8 , false ) ) ,
207+ ] ;
208+
209+ let table_schema = TableSchema :: new ( file_schema. clone ( ) , partition_cols. clone ( ) ) ;
210+
211+ // Verify file schema
212+ assert_eq ! ( table_schema. file_schema( ) . as_ref( ) , file_schema. as_ref( ) ) ;
213+
214+ // Verify partition columns
215+ assert_eq ! ( table_schema. table_partition_cols( ) . len( ) , 2 ) ;
216+ assert_eq ! ( table_schema. table_partition_cols( ) [ 0 ] , partition_cols[ 0 ] ) ;
217+ assert_eq ! ( table_schema. table_partition_cols( ) [ 1 ] , partition_cols[ 1 ] ) ;
218+
219+ // Verify full table schema
220+ let expected_fields = vec ! [
221+ Field :: new( "user_id" , DataType :: Int64 , false ) ,
222+ Field :: new( "amount" , DataType :: Float64 , false ) ,
223+ Field :: new( "date" , DataType :: Utf8 , false ) ,
224+ Field :: new( "region" , DataType :: Utf8 , false ) ,
225+ ] ;
226+ let expected_schema = Schema :: new ( expected_fields) ;
227+ assert_eq ! ( table_schema. table_schema( ) . as_ref( ) , & expected_schema) ;
228+ }
229+
230+ #[ test]
231+ fn test_add_multiple_partition_columns ( ) {
232+ let file_schema =
233+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( "id" , DataType :: Int32 , false ) ] ) ) ;
234+
235+ let initial_partition_cols =
236+ vec ! [ Arc :: new( Field :: new( "country" , DataType :: Utf8 , false ) ) ] ;
237+
238+ let table_schema = TableSchema :: new ( file_schema. clone ( ) , initial_partition_cols) ;
239+
240+ let additional_partition_cols = vec ! [
241+ Arc :: new( Field :: new( "city" , DataType :: Utf8 , false ) ) ,
242+ Arc :: new( Field :: new( "year" , DataType :: Int32 , false ) ) ,
243+ ] ;
244+
245+ let updated_table_schema =
246+ table_schema. with_table_partition_cols ( additional_partition_cols) ;
247+
248+ // Verify file schema remains unchanged
249+ assert_eq ! (
250+ updated_table_schema. file_schema( ) . as_ref( ) ,
251+ file_schema. as_ref( )
252+ ) ;
253+
254+ // Verify partition columns
255+ assert_eq ! ( updated_table_schema. table_partition_cols( ) . len( ) , 3 ) ;
256+ assert_eq ! (
257+ updated_table_schema. table_partition_cols( ) [ 0 ] . name( ) ,
258+ "country"
259+ ) ;
260+ assert_eq ! (
261+ updated_table_schema. table_partition_cols( ) [ 1 ] . name( ) ,
262+ "city"
263+ ) ;
264+ assert_eq ! (
265+ updated_table_schema. table_partition_cols( ) [ 2 ] . name( ) ,
266+ "year"
267+ ) ;
268+
269+ // Verify full table schema
270+ let expected_fields = vec ! [
271+ Field :: new( "id" , DataType :: Int32 , false ) ,
272+ Field :: new( "country" , DataType :: Utf8 , false ) ,
273+ Field :: new( "city" , DataType :: Utf8 , false ) ,
274+ Field :: new( "year" , DataType :: Int32 , false ) ,
275+ ] ;
276+ let expected_schema = Schema :: new ( expected_fields) ;
277+ assert_eq ! (
278+ updated_table_schema. table_schema( ) . as_ref( ) ,
279+ & expected_schema
280+ ) ;
281+ }
282+ }
0 commit comments