diff --git a/datafusion/datasource/src/table_schema.rs b/datafusion/datasource/src/table_schema.rs index ff0e78801887..a45cdbaaea07 100644 --- a/datafusion/datasource/src/table_schema.rs +++ b/datafusion/datasource/src/table_schema.rs @@ -70,7 +70,7 @@ pub struct TableSchema { /// /// These columns are NOT present in the data files but are appended to each /// row during query execution based on the file's location. - table_partition_cols: Vec, + table_partition_cols: Arc>, /// The complete table schema: file_schema columns followed by partition columns. /// @@ -121,7 +121,7 @@ impl TableSchema { builder.extend(table_partition_cols.iter().cloned()); Self { file_schema, - table_partition_cols, + table_partition_cols: Arc::new(table_partition_cols), table_schema: Arc::new(builder.finish()), } } @@ -140,7 +140,15 @@ impl TableSchema { /// into [`TableSchema::with_table_partition_cols`] if you have partition columns at construction time /// since it avoids re-computing the table schema. pub fn with_table_partition_cols(mut self, partition_cols: Vec) -> Self { - self.table_partition_cols = partition_cols; + if self.table_partition_cols.is_empty() { + self.table_partition_cols = Arc::new(partition_cols); + } else { + // Append to existing partition columns + let table_partition_cols = Arc::get_mut(&mut self.table_partition_cols).expect( + "Expected to be the sole owner of table_partition_cols since this function accepts mut self", + ); + table_partition_cols.extend(partition_cols); + } let mut builder = SchemaBuilder::from(self.file_schema.as_ref()); builder.extend(self.table_partition_cols.iter().cloned()); self.table_schema = Arc::new(builder.finish()); @@ -176,3 +184,96 @@ impl From for TableSchema { Self::from_file_schema(schema) } } + +#[cfg(test)] +mod tests { + use super::TableSchema; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + #[test] + fn test_table_schema_creation() { + let file_schema = Arc::new(Schema::new(vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), + ])); + + let partition_cols = vec![ + Arc::new(Field::new("date", DataType::Utf8, false)), + Arc::new(Field::new("region", DataType::Utf8, false)), + ]; + + let table_schema = TableSchema::new(file_schema.clone(), partition_cols.clone()); + + // Verify file schema + assert_eq!(table_schema.file_schema().as_ref(), file_schema.as_ref()); + + // Verify partition columns + assert_eq!(table_schema.table_partition_cols().len(), 2); + assert_eq!(table_schema.table_partition_cols()[0], partition_cols[0]); + assert_eq!(table_schema.table_partition_cols()[1], partition_cols[1]); + + // Verify full table schema + let expected_fields = vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), + Field::new("date", DataType::Utf8, false), + Field::new("region", DataType::Utf8, false), + ]; + let expected_schema = Schema::new(expected_fields); + assert_eq!(table_schema.table_schema().as_ref(), &expected_schema); + } + + #[test] + fn test_add_multiple_partition_columns() { + let file_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let initial_partition_cols = + vec![Arc::new(Field::new("country", DataType::Utf8, false))]; + + let table_schema = TableSchema::new(file_schema.clone(), initial_partition_cols); + + let additional_partition_cols = vec![ + Arc::new(Field::new("city", DataType::Utf8, false)), + Arc::new(Field::new("year", DataType::Int32, false)), + ]; + + let updated_table_schema = + table_schema.with_table_partition_cols(additional_partition_cols); + + // Verify file schema remains unchanged + assert_eq!( + updated_table_schema.file_schema().as_ref(), + file_schema.as_ref() + ); + + // Verify partition columns + assert_eq!(updated_table_schema.table_partition_cols().len(), 3); + assert_eq!( + updated_table_schema.table_partition_cols()[0].name(), + "country" + ); + assert_eq!( + updated_table_schema.table_partition_cols()[1].name(), + "city" + ); + assert_eq!( + updated_table_schema.table_partition_cols()[2].name(), + "year" + ); + + // Verify full table schema + let expected_fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("country", DataType::Utf8, false), + Field::new("city", DataType::Utf8, false), + Field::new("year", DataType::Int32, false), + ]; + let expected_schema = Schema::new(expected_fields); + assert_eq!( + updated_table_schema.table_schema().as_ref(), + &expected_schema + ); + } +}