-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-55568][SQL] Separate schema construction from field stats collection #54343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.execution.datasources.parquet | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.SparkRuntimeException | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
@@ -93,74 +95,31 @@ class InferVariantShreddingSchema(val schema: StructType) { | |
|
|
||
| private val COUNT_METADATA_KEY = "COUNT" | ||
|
|
||
| /** | ||
| * Return an appropriate schema for shredding a Variant value. | ||
| * It is similar to the SchemaOfVariant expression, but the rules are somewhat different, because | ||
| * we want the types to be consistent with what will be allowed during shredding. E.g. | ||
| * SchemaOfVariant will consider the common type across Integer and Double to be double, but we | ||
| * consider it to be VariantType, since shredding will not allow those types to be written to | ||
| * the same typed_value. | ||
| * We also maintain metadata on struct fields to track how frequently they occur. Rare fields | ||
| * are dropped in the final schema. | ||
| */ | ||
| private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match { | ||
| case Type.OBJECT => | ||
| if (maxDepth <= 0) return VariantType | ||
| val size = v.objectSize() | ||
| val fields = new Array[StructField](size) | ||
| for (i <- 0 until size) { | ||
| val field = v.getFieldAtIndex(i) | ||
| fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1), | ||
| metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 1).build()) | ||
| } | ||
| // According to the variant spec, object fields must be sorted alphabetically. So we don't | ||
| // have to sort, but just need to validate they are sorted. | ||
| for (i <- 1 until size) { | ||
| if (fields(i - 1).name >= fields(i).name) { | ||
| throw new SparkRuntimeException( | ||
| errorClass = "MALFORMED_VARIANT", | ||
| messageParameters = Map.empty | ||
| ) | ||
| } | ||
| } | ||
| StructType(fields) | ||
| case Type.ARRAY => | ||
| if (maxDepth <= 0) return VariantType | ||
| var elementType: DataType = NullType | ||
| for (i <- 0 until v.arraySize()) { | ||
| elementType = mergeSchema(elementType, schemaOf(v.getElementAtIndex(i), maxDepth - 1)) | ||
| } | ||
| ArrayType(elementType) | ||
| case Type.NULL => NullType | ||
| case Type.BOOLEAN => BooleanType | ||
| case Type.LONG => | ||
| // Compute the smallest decimal that can contain this value. | ||
| // This will allow us to merge with decimal later without introducing excessive precision. | ||
| // If we only end up encountering integer values, we'll convert back to LongType when we | ||
| // finalize. | ||
| val d = BigDecimal(v.getLong()) | ||
| val precision = d.precision | ||
| if (precision <= Decimal.MAX_LONG_DIGITS) { | ||
| DecimalType(precision, 0) | ||
| } else { | ||
| // Value is too large for Decimal(18, 0), so record its type as long. | ||
| LongType | ||
| // Node for tree-based field tracking | ||
| private case class FieldNode( | ||
| var dataType: DataType, | ||
| var rowCount: Int = 0, // Count of distinct rows containing this field | ||
| var lastSeenRow: Int = -1, // Last row index that incremented rowCount | ||
| var arrayElementCount: Long = 0, // Total occurrences across all array elements | ||
| children: mutable.Map[String, FieldNode] = mutable.Map.empty, | ||
| var arrayElementNode: Option[FieldNode] = None | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a field to track Array, instead of relying on |
||
| ) { | ||
|
|
||
| def getOrCreateChild(fieldName: String): FieldNode = { | ||
| children.getOrElseUpdate(fieldName, FieldNode(NullType)) | ||
| } | ||
|
|
||
| def hasChildren: Boolean = children.nonEmpty | ||
|
|
||
| def getChildren: Seq[(String, FieldNode)] = children.toSeq | ||
|
|
||
| def getOrCreateArrayElement(): FieldNode = { | ||
| arrayElementNode.getOrElse { | ||
| val node = FieldNode(NullType) | ||
| arrayElementNode = Some(node) | ||
| node | ||
| } | ||
| case Type.STRING => StringType | ||
| case Type.DOUBLE => DoubleType | ||
| case Type.DECIMAL => | ||
| // Don't strip trailing zeros to determine scale. Even if we allow scale relaxation during | ||
| // shredding, it's useful to take trailing zeros as a hint that the extra digits may be used | ||
| // in later values, and use the larger scale. | ||
| val d = Decimal(v.getDecimalWithOriginalScale()) | ||
| DecimalType(d.precision, d.scale) | ||
| case Type.DATE => DateType | ||
| case Type.TIMESTAMP => TimestampType | ||
| case Type.TIMESTAMP_NTZ => TimestampNTZType | ||
| case Type.FLOAT => FloatType | ||
| case Type.BINARY => BinaryType | ||
| // Spark doesn't support UUID, so shred it as an untyped value. | ||
| case Type.UUID => VariantType | ||
| } | ||
| } | ||
|
|
||
| private def getFieldCount(field: StructField): Long = { | ||
|
|
@@ -351,36 +310,254 @@ class InferVariantShreddingSchema(val schema: StructType) { | |
| } | ||
|
|
||
| def inferSchema(rows: Seq[InternalRow]): StructType = { | ||
| // For each path to a Variant value, iterate over all rows and update the inferred schema. | ||
| // Add the result to a map, which we'll use to update the full schema. | ||
| // maxShreddedFieldsPerFile is a global max for all fields, so initialize it here. | ||
| // For each variant path, collect field statistics using a single pass | ||
| val maxFields = MaxFields(maxShreddedFieldsPerFile) | ||
|
|
||
| val inferredSchemas = pathsToVariant.map { path => | ||
| var numNonNullValues = 0 | ||
| val simpleSchema = rows.foldLeft(NullType: DataType) { | ||
| case (partialSchema, row) => | ||
| getValueAtPath(schema, row, path).map { variantVal => | ||
| numNonNullValues += 1 | ||
| val v = new Variant(variantVal.getValue, variantVal.getMetadata) | ||
| val schemaOfRow = schemaOf(v, maxShreddingDepth) | ||
| mergeSchema(partialSchema, schemaOfRow) | ||
| // If getValueAtPath returned None, the value is null in this row; just ignore. | ||
| } | ||
| .getOrElse(partialSchema) | ||
| // If we didn't find any non-null rows, use an unshredded schema. | ||
| } | ||
| val rootNode = FieldNode(NullType) | ||
| var numNonNullVariants = 0 | ||
|
|
||
| // Don't infer a schema for fields that appear in less than 10% of rows. | ||
| // Ensure that minCardinality is at least 1 if we have any rows. | ||
| val minCardinality = (numNonNullValues + 9) / 10 | ||
| // Single pass: process all rows for this variant path | ||
| rows.zipWithIndex.foreach { case (row, rowIdx) => | ||
| getValueAtPath(schema, row, path).foreach { variantVal => | ||
| numNonNullVariants += 1 | ||
| val v = new Variant(variantVal.getValue, variantVal.getMetadata) | ||
| rootNode.dataType = mergeSchema(rootNode.dataType, inferPrimitiveType(v, 0)) | ||
| // Traverse variant and update field stats tree | ||
| collectFieldStats(v, rootNode, rowIdx, 0, inArrayContext = false) | ||
| } | ||
| } | ||
|
|
||
| // Build final schema from collected statistics | ||
| val minCardinality = (numNonNullVariants + 9) / 10 | ||
| val simpleSchema = buildSchemaFromStats( | ||
| rootNode, | ||
| minCardinality, | ||
| numNonNullVariants, | ||
| inArrayContext = false, | ||
| isArray = rootNode.arrayElementNode.isDefined) | ||
| val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality, maxFields) | ||
| val shreddingSchema = SparkShreddingUtils.variantShreddingSchema(finalizedSchema) | ||
| val schemaWithMetadata = SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema) | ||
| (path, schemaWithMetadata) | ||
| }.toMap | ||
|
|
||
| // Insert each inferred schema into the full schema. | ||
| // Insert each inferred schema into the full schema | ||
| updateSchema(schema, inferredSchemas) | ||
| } | ||
|
|
||
| /** | ||
| * Recursively traverse a variant value and build field statistics tree. | ||
| * For each field encountered, record its type and track distinct row count. | ||
| * For fields inside arrays, also increment the occurrence count. | ||
| */ | ||
| private def collectFieldStats( | ||
| v: Variant, | ||
| currentNode: FieldNode, | ||
| rowIdx: Int, | ||
| depth: Int, | ||
| inArrayContext: Boolean): Unit = { | ||
|
|
||
| if (depth >= maxShreddingDepth) return | ||
|
|
||
| v.getType match { | ||
| case Type.OBJECT => | ||
| val size = v.objectSize() | ||
| // Validate fields are sorted (per variant spec) | ||
| for (i <- 1 until size) { | ||
| val prevKey = v.getFieldAtIndex(i - 1).key | ||
| val currKey = v.getFieldAtIndex(i).key | ||
| if (prevKey >= currKey) { | ||
| throw new SparkRuntimeException( | ||
| errorClass = "MALFORMED_VARIANT", | ||
| messageParameters = Map.empty | ||
| ) | ||
| } | ||
| } | ||
|
|
||
| // Process each field | ||
| for (i <- 0 until size) { | ||
| val field = v.getFieldAtIndex(i) | ||
| val fieldName = field.key | ||
|
|
||
| // Get or create child node (O(1) map access - no path string building!) | ||
| val childNode = currentNode.getOrCreateChild(fieldName) | ||
|
|
||
| // Track distinct row count (deduplicate using lastSeenRow) | ||
| if (childNode.lastSeenRow != rowIdx) { | ||
| childNode.rowCount += 1 | ||
| childNode.lastSeenRow = rowIdx | ||
| } | ||
|
|
||
| // Track occurrence count for array elements | ||
| if (inArrayContext) { | ||
| childNode.arrayElementCount += 1 | ||
| } | ||
|
|
||
| // Infer and merge type | ||
| val fieldType = inferPrimitiveType(field.value, depth) | ||
| childNode.dataType = mergeSchema(childNode.dataType, fieldType) | ||
|
|
||
| // Recurse into nested structures (pass child node, not path string) | ||
| collectFieldStats(field.value, childNode, rowIdx, depth + 1, inArrayContext) | ||
| } | ||
|
|
||
| case Type.ARRAY => | ||
| val arrayNode = currentNode.getOrCreateArrayElement() | ||
|
|
||
| // Track distinct row count for the array field itself | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this check for
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check for arrayNode is needed for a row with nested array like
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I think I wasn't quite asking the right question. I agree that we'll increment Stepping back, I'm not sure I understand why this distinction between row count and array element count is needed. Could we just have a single |
||
| if (arrayNode.lastSeenRow != rowIdx) { | ||
| arrayNode.rowCount += 1 | ||
| arrayNode.lastSeenRow = rowIdx | ||
| } | ||
|
|
||
| val arraySize = v.arraySize() | ||
| if (arraySize > 0) { | ||
| // Process array elements | ||
| for (i <- 0 until arraySize) { | ||
| val element = v.getElementAtIndex(i) | ||
| val elementTypeClass = element.getType | ||
|
|
||
| // For primitives, infer and merge type directly | ||
| // For objects/arrays, collectFieldStats handles type via field traversal | ||
| if (elementTypeClass != Type.OBJECT && elementTypeClass != Type.ARRAY) { | ||
| val primitiveType = inferPrimitiveType(element, depth) | ||
| arrayNode.dataType = mergeSchema(arrayNode.dataType, primitiveType) | ||
| } | ||
|
|
||
| // Recurse into element to collect nested fields, now IN array context | ||
| collectFieldStats(element, arrayNode, rowIdx, depth + 1, inArrayContext = true) | ||
| } | ||
| } | ||
|
|
||
| case _ => | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Infer the type of a variant value without recursive field collection. | ||
| * For objects and arrays, return a marker type; recursive collection is done separately. | ||
| */ | ||
| private def inferPrimitiveType(v: Variant, depth: Int): DataType = { | ||
| if (depth >= maxShreddingDepth) return VariantType | ||
|
|
||
| v.getType match { | ||
| case Type.OBJECT => | ||
| // Return empty struct as marker; fields collected separately | ||
| StructType(Seq.empty) | ||
| case Type.ARRAY => | ||
| // Return array with null element as marker; elements processed separately | ||
| ArrayType(NullType) | ||
| case Type.NULL => NullType | ||
| case Type.BOOLEAN => BooleanType | ||
| case Type.LONG => | ||
| val d = BigDecimal(v.getLong()) | ||
| val precision = d.precision | ||
| if (precision <= Decimal.MAX_LONG_DIGITS) { | ||
| DecimalType(precision, 0) | ||
| } else { | ||
| LongType | ||
| } | ||
| case Type.STRING => StringType | ||
| case Type.DOUBLE => DoubleType | ||
| case Type.DECIMAL => | ||
| val d = Decimal(v.getDecimalWithOriginalScale()) | ||
| DecimalType(d.precision, d.scale) | ||
| case Type.DATE => DateType | ||
| case Type.TIMESTAMP => TimestampType | ||
| case Type.TIMESTAMP_NTZ => TimestampNTZType | ||
| case Type.FLOAT => FloatType | ||
| case Type.BINARY => BinaryType | ||
| case Type.UUID => VariantType | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Build a schema from collected field statistics tree. | ||
| * For fields in array contexts, use arrayElementCount / total rows. | ||
| * For top-level fields, use distinct row count. | ||
| */ | ||
| private def buildSchemaFromStats( | ||
| currentNode: FieldNode, | ||
| minCardinality: Int, | ||
| numNonNullVariants: Int, | ||
| inArrayContext: Boolean, | ||
| isArray: Boolean): DataType = { | ||
|
|
||
| // If this node is an array, use array-element node to build element type. | ||
| if (isArray) { | ||
| val arrayElementNodeOpt = currentNode.arrayElementNode | ||
| if (arrayElementNodeOpt.isDefined && arrayElementNodeOpt.get.rowCount >= minCardinality) { | ||
| val arrayElementNode = arrayElementNodeOpt.get | ||
| val elementType = buildSchemaFromStats( | ||
| arrayElementNode, | ||
| minCardinality, | ||
| numNonNullVariants, | ||
| inArrayContext = true, | ||
| isArray = arrayElementNode.arrayElementNode.isDefined | ||
| ) | ||
| return ArrayType( | ||
| if (elementType == VariantType) arrayElementNode.dataType else elementType) | ||
| } | ||
| return currentNode.dataType | ||
| } | ||
|
|
||
| // Get all direct children, filter by cardinality, sort by cardinality descending, | ||
| // take top N, then sort alphabetically for determinism. | ||
| val maxStructSize = Math.min(1000, maxShreddedFieldsPerFile) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is to limit candidates per struct node; final file-level enforcement is in finalizeSimpleSchema. |
||
| val children = currentNode.getChildren | ||
| .filter { case (_, childNode) => | ||
| val cardinality = if (inArrayContext) { | ||
| childNode.arrayElementCount | ||
| } else { | ||
| childNode.rowCount | ||
| } | ||
| cardinality >= minCardinality | ||
| } | ||
| .sortBy { case (fieldName, childNode) => | ||
| val cardinality = if (inArrayContext) { | ||
| childNode.arrayElementCount | ||
| } else { | ||
| childNode.rowCount | ||
| } | ||
| // Sort by cardinality descending, then by name ascending for stability | ||
| (-cardinality, fieldName) | ||
| } | ||
| .take(maxStructSize) | ||
| .sortBy(_._1) // Sort alphabetically | ||
|
|
||
| if (children.isEmpty) { | ||
| return VariantType | ||
| } | ||
|
|
||
| // Build struct from children | ||
| val fields = children.map { case (fieldName, childNode) => | ||
| val fieldType = childNode.dataType match { | ||
| case StructType(_) => | ||
| buildSchemaFromStats( | ||
| childNode, minCardinality, numNonNullVariants, inArrayContext, isArray = false) | ||
|
|
||
| case ArrayType(_, _) => | ||
| buildSchemaFromStats( | ||
| childNode, | ||
| minCardinality, | ||
| numNonNullVariants, | ||
| inArrayContext = true, | ||
| isArray = childNode.arrayElementNode.isDefined) | ||
|
|
||
| case other => other | ||
| } | ||
|
|
||
| val cardinality = if (inArrayContext) { | ||
| childNode.arrayElementCount | ||
| } else { | ||
| childNode.rowCount | ||
| } | ||
|
|
||
| StructField(fieldName, fieldType, | ||
| metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, cardinality).build()) | ||
| } | ||
|
|
||
| StructType(fields.toSeq) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you comment on what
dataTypewill be for structs and arrays?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataTypeis the type summary:StructType(Seq.empty), actual schema come from childrenArrayType(NullType), actually schema comes from the new arrayElementNodedataTypeis the merged scala typeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I just meant that you might want to add a shorted version of this as a comment, to clarify that
DataTypewill not contain the full nested type.