Skip to content

Commit adaed42

Browse files
authored
Arc partition values in TableSchema (#19137)
This should avoid some clones. I also fixed a bug when `with_partition_values` is called multiple times on the same instance; it now appends the partition values.
1 parent e8384fb commit adaed42

File tree

1 file changed

+104
-3
lines changed

1 file changed

+104
-3
lines changed

datafusion/datasource/src/table_schema.rs

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ pub struct TableSchema {
7070
///
7171
/// These columns are NOT present in the data files but are appended to each
7272
/// row during query execution based on the file's location.
73-
table_partition_cols: Vec<FieldRef>,
73+
table_partition_cols: Arc<Vec<FieldRef>>,
7474

7575
/// The complete table schema: file_schema columns followed by partition columns.
7676
///
@@ -121,7 +121,7 @@ impl TableSchema {
121121
builder.extend(table_partition_cols.iter().cloned());
122122
Self {
123123
file_schema,
124-
table_partition_cols,
124+
table_partition_cols: Arc::new(table_partition_cols),
125125
table_schema: Arc::new(builder.finish()),
126126
}
127127
}
@@ -140,7 +140,15 @@ impl TableSchema {
140140
/// into [`TableSchema::with_table_partition_cols`] if you have partition columns at construction time
141141
/// since it avoids re-computing the table schema.
142142
pub fn with_table_partition_cols(mut self, partition_cols: Vec<FieldRef>) -> Self {
143-
self.table_partition_cols = partition_cols;
143+
if self.table_partition_cols.is_empty() {
144+
self.table_partition_cols = Arc::new(partition_cols);
145+
} else {
146+
// Append to existing partition columns
147+
let table_partition_cols = Arc::get_mut(&mut self.table_partition_cols).expect(
148+
"Expected to be the sole owner of table_partition_cols since this function accepts mut self",
149+
);
150+
table_partition_cols.extend(partition_cols);
151+
}
144152
let mut builder = SchemaBuilder::from(self.file_schema.as_ref());
145153
builder.extend(self.table_partition_cols.iter().cloned());
146154
self.table_schema = Arc::new(builder.finish());
@@ -176,3 +184,96 @@ impl From<SchemaRef> for TableSchema {
176184
Self::from_file_schema(schema)
177185
}
178186
}
187+
188+
#[cfg(test)]
189+
mod tests {
190+
use super::TableSchema;
191+
use arrow::datatypes::{DataType, Field, Schema};
192+
use std::sync::Arc;
193+
194+
#[test]
195+
fn test_table_schema_creation() {
196+
let file_schema = Arc::new(Schema::new(vec![
197+
Field::new("user_id", DataType::Int64, false),
198+
Field::new("amount", DataType::Float64, false),
199+
]));
200+
201+
let partition_cols = vec![
202+
Arc::new(Field::new("date", DataType::Utf8, false)),
203+
Arc::new(Field::new("region", DataType::Utf8, false)),
204+
];
205+
206+
let table_schema = TableSchema::new(file_schema.clone(), partition_cols.clone());
207+
208+
// Verify file schema
209+
assert_eq!(table_schema.file_schema().as_ref(), file_schema.as_ref());
210+
211+
// Verify partition columns
212+
assert_eq!(table_schema.table_partition_cols().len(), 2);
213+
assert_eq!(table_schema.table_partition_cols()[0], partition_cols[0]);
214+
assert_eq!(table_schema.table_partition_cols()[1], partition_cols[1]);
215+
216+
// Verify full table schema
217+
let expected_fields = vec![
218+
Field::new("user_id", DataType::Int64, false),
219+
Field::new("amount", DataType::Float64, false),
220+
Field::new("date", DataType::Utf8, false),
221+
Field::new("region", DataType::Utf8, false),
222+
];
223+
let expected_schema = Schema::new(expected_fields);
224+
assert_eq!(table_schema.table_schema().as_ref(), &expected_schema);
225+
}
226+
227+
#[test]
228+
fn test_add_multiple_partition_columns() {
229+
let file_schema =
230+
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
231+
232+
let initial_partition_cols =
233+
vec![Arc::new(Field::new("country", DataType::Utf8, false))];
234+
235+
let table_schema = TableSchema::new(file_schema.clone(), initial_partition_cols);
236+
237+
let additional_partition_cols = vec![
238+
Arc::new(Field::new("city", DataType::Utf8, false)),
239+
Arc::new(Field::new("year", DataType::Int32, false)),
240+
];
241+
242+
let updated_table_schema =
243+
table_schema.with_table_partition_cols(additional_partition_cols);
244+
245+
// Verify file schema remains unchanged
246+
assert_eq!(
247+
updated_table_schema.file_schema().as_ref(),
248+
file_schema.as_ref()
249+
);
250+
251+
// Verify partition columns
252+
assert_eq!(updated_table_schema.table_partition_cols().len(), 3);
253+
assert_eq!(
254+
updated_table_schema.table_partition_cols()[0].name(),
255+
"country"
256+
);
257+
assert_eq!(
258+
updated_table_schema.table_partition_cols()[1].name(),
259+
"city"
260+
);
261+
assert_eq!(
262+
updated_table_schema.table_partition_cols()[2].name(),
263+
"year"
264+
);
265+
266+
// Verify full table schema
267+
let expected_fields = vec![
268+
Field::new("id", DataType::Int32, false),
269+
Field::new("country", DataType::Utf8, false),
270+
Field::new("city", DataType::Utf8, false),
271+
Field::new("year", DataType::Int32, false),
272+
];
273+
let expected_schema = Schema::new(expected_fields);
274+
assert_eq!(
275+
updated_table_schema.table_schema().as_ref(),
276+
&expected_schema
277+
);
278+
}
279+
}

0 commit comments

Comments
 (0)