Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.mutable

import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -93,74 +95,31 @@ class InferVariantShreddingSchema(val schema: StructType) {

private val COUNT_METADATA_KEY = "COUNT"

/**
* Return an appropriate schema for shredding a Variant value.
* It is similar to the SchemaOfVariant expression, but the rules are somewhat different, because
* we want the types to be consistent with what will be allowed during shredding. E.g.
* SchemaOfVariant will consider the common type across Integer and Double to be double, but we
* consider it to be VariantType, since shredding will not allow those types to be written to
* the same typed_value.
* We also maintain metadata on struct fields to track how frequently they occur. Rare fields
* are dropped in the final schema.
*/
private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match {
case Type.OBJECT =>
if (maxDepth <= 0) return VariantType
val size = v.objectSize()
val fields = new Array[StructField](size)
for (i <- 0 until size) {
val field = v.getFieldAtIndex(i)
fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1),
metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 1).build())
}
// According to the variant spec, object fields must be sorted alphabetically. So we don't
// have to sort, but just need to validate they are sorted.
for (i <- 1 until size) {
if (fields(i - 1).name >= fields(i).name) {
throw new SparkRuntimeException(
errorClass = "MALFORMED_VARIANT",
messageParameters = Map.empty
)
}
}
StructType(fields)
case Type.ARRAY =>
if (maxDepth <= 0) return VariantType
var elementType: DataType = NullType
for (i <- 0 until v.arraySize()) {
elementType = mergeSchema(elementType, schemaOf(v.getElementAtIndex(i), maxDepth - 1))
}
ArrayType(elementType)
case Type.NULL => NullType
case Type.BOOLEAN => BooleanType
case Type.LONG =>
// Compute the smallest decimal that can contain this value.
// This will allow us to merge with decimal later without introducing excessive precision.
// If we only end up encountering integer values, we'll convert back to LongType when we
// finalize.
val d = BigDecimal(v.getLong())
val precision = d.precision
if (precision <= Decimal.MAX_LONG_DIGITS) {
DecimalType(precision, 0)
} else {
// Value is too large for Decimal(18, 0), so record its type as long.
LongType
// Node for tree-based field tracking
private case class FieldNode(
var dataType: DataType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on what dataType will be for structs and arrays?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataType is the type summary:

  • For struct, it is StructType(Seq.empty), actual schema come from children
  • For array, it is ArrayType(NullType), actually schema comes from the new arrayElementNode
  • For primitives: dataType is the merged scala type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I just meant that you might want to add a shorted version of this as a comment, to clarify that DataType will not contain the full nested type.

var rowCount: Int = 0, // Count of distinct rows containing this field
var lastSeenRow: Int = -1, // Last row index that incremented rowCount
var arrayElementCount: Long = 0, // Total occurrences across all array elements
children: mutable.Map[String, FieldNode] = mutable.Map.empty,
var arrayElementNode: Option[FieldNode] = None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a field to track Array, instead of relying on [] marker .

) {

def getOrCreateChild(fieldName: String): FieldNode = {
children.getOrElseUpdate(fieldName, FieldNode(NullType))
}

def hasChildren: Boolean = children.nonEmpty

def getChildren: Seq[(String, FieldNode)] = children.toSeq

def getOrCreateArrayElement(): FieldNode = {
arrayElementNode.getOrElse {
val node = FieldNode(NullType)
arrayElementNode = Some(node)
node
}
case Type.STRING => StringType
case Type.DOUBLE => DoubleType
case Type.DECIMAL =>
// Don't strip trailing zeros to determine scale. Even if we allow scale relaxation during
// shredding, it's useful to take trailing zeros as a hint that the extra digits may be used
// in later values, and use the larger scale.
val d = Decimal(v.getDecimalWithOriginalScale())
DecimalType(d.precision, d.scale)
case Type.DATE => DateType
case Type.TIMESTAMP => TimestampType
case Type.TIMESTAMP_NTZ => TimestampNTZType
case Type.FLOAT => FloatType
case Type.BINARY => BinaryType
// Spark doesn't support UUID, so shred it as an untyped value.
case Type.UUID => VariantType
}
}

private def getFieldCount(field: StructField): Long = {
Expand Down Expand Up @@ -351,36 +310,254 @@ class InferVariantShreddingSchema(val schema: StructType) {
}

def inferSchema(rows: Seq[InternalRow]): StructType = {
// For each path to a Variant value, iterate over all rows and update the inferred schema.
// Add the result to a map, which we'll use to update the full schema.
// maxShreddedFieldsPerFile is a global max for all fields, so initialize it here.
// For each variant path, collect field statistics using a single pass
val maxFields = MaxFields(maxShreddedFieldsPerFile)

val inferredSchemas = pathsToVariant.map { path =>
var numNonNullValues = 0
val simpleSchema = rows.foldLeft(NullType: DataType) {
case (partialSchema, row) =>
getValueAtPath(schema, row, path).map { variantVal =>
numNonNullValues += 1
val v = new Variant(variantVal.getValue, variantVal.getMetadata)
val schemaOfRow = schemaOf(v, maxShreddingDepth)
mergeSchema(partialSchema, schemaOfRow)
// If getValueAtPath returned None, the value is null in this row; just ignore.
}
.getOrElse(partialSchema)
// If we didn't find any non-null rows, use an unshredded schema.
}
val rootNode = FieldNode(NullType)
var numNonNullVariants = 0

// Don't infer a schema for fields that appear in less than 10% of rows.
// Ensure that minCardinality is at least 1 if we have any rows.
val minCardinality = (numNonNullValues + 9) / 10
// Single pass: process all rows for this variant path
rows.zipWithIndex.foreach { case (row, rowIdx) =>
getValueAtPath(schema, row, path).foreach { variantVal =>
numNonNullVariants += 1
val v = new Variant(variantVal.getValue, variantVal.getMetadata)
rootNode.dataType = mergeSchema(rootNode.dataType, inferPrimitiveType(v, 0))
// Traverse variant and update field stats tree
collectFieldStats(v, rootNode, rowIdx, 0, inArrayContext = false)
}
}

// Build final schema from collected statistics
val minCardinality = (numNonNullVariants + 9) / 10
val simpleSchema = buildSchemaFromStats(
rootNode,
minCardinality,
numNonNullVariants,
inArrayContext = false,
isArray = rootNode.arrayElementNode.isDefined)
val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality, maxFields)
val shreddingSchema = SparkShreddingUtils.variantShreddingSchema(finalizedSchema)
val schemaWithMetadata = SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema)
(path, schemaWithMetadata)
}.toMap

// Insert each inferred schema into the full schema.
// Insert each inferred schema into the full schema
updateSchema(schema, inferredSchemas)
}

/**
* Recursively traverse a variant value and build field statistics tree.
* For each field encountered, record its type and track distinct row count.
* For fields inside arrays, also increment the occurrence count.
*/
private def collectFieldStats(
v: Variant,
currentNode: FieldNode,
rowIdx: Int,
depth: Int,
inArrayContext: Boolean): Unit = {

if (depth >= maxShreddingDepth) return

v.getType match {
case Type.OBJECT =>
val size = v.objectSize()
// Validate fields are sorted (per variant spec)
for (i <- 1 until size) {
val prevKey = v.getFieldAtIndex(i - 1).key
val currKey = v.getFieldAtIndex(i).key
if (prevKey >= currKey) {
throw new SparkRuntimeException(
errorClass = "MALFORMED_VARIANT",
messageParameters = Map.empty
)
}
}

// Process each field
for (i <- 0 until size) {
val field = v.getFieldAtIndex(i)
val fieldName = field.key

// Get or create child node (O(1) map access - no path string building!)
val childNode = currentNode.getOrCreateChild(fieldName)

// Track distinct row count (deduplicate using lastSeenRow)
if (childNode.lastSeenRow != rowIdx) {
childNode.rowCount += 1
childNode.lastSeenRow = rowIdx
}

// Track occurrence count for array elements
if (inArrayContext) {
childNode.arrayElementCount += 1
}

// Infer and merge type
val fieldType = inferPrimitiveType(field.value, depth)
childNode.dataType = mergeSchema(childNode.dataType, fieldType)

// Recurse into nested structures (pass child node, not path string)
collectFieldStats(field.value, childNode, rowIdx, depth + 1, inArrayContext)
}

case Type.ARRAY =>
val arrayNode = currentNode.getOrCreateArrayElement()

// Track distinct row count for the array field itself
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this check for arrayNode.lastSeenRow != rowIdx needed? If we're not in an array context, what's the case where we'd see the same rowIdx twice?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for arrayNode is needed for a row with nested array like [[1], [2]]. The outer array node will be visited multiple times while we iterate elements at line 429. The check is to prevent inflating rowCount.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I wasn't quite asking the right question. I agree that we'll increment rowCount for the inner array too many times if we don't have this check. I'm unclear on why we care. inArrayContext should be true in that case, in which case buildSchemaFromStats will use array element count rather than this row count, right?

Stepping back, I'm not sure I understand why this distinction between row count and array element count is needed. Could we just have a single valueCount field that represents row count outside of an array context, and element count within an array context?

if (arrayNode.lastSeenRow != rowIdx) {
arrayNode.rowCount += 1
arrayNode.lastSeenRow = rowIdx
}

val arraySize = v.arraySize()
if (arraySize > 0) {
// Process array elements
for (i <- 0 until arraySize) {
val element = v.getElementAtIndex(i)
val elementTypeClass = element.getType

// For primitives, infer and merge type directly
// For objects/arrays, collectFieldStats handles type via field traversal
if (elementTypeClass != Type.OBJECT && elementTypeClass != Type.ARRAY) {
val primitiveType = inferPrimitiveType(element, depth)
arrayNode.dataType = mergeSchema(arrayNode.dataType, primitiveType)
}

// Recurse into element to collect nested fields, now IN array context
collectFieldStats(element, arrayNode, rowIdx, depth + 1, inArrayContext = true)
}
}

case _ =>
}
}

/**
* Infer the type of a variant value without recursive field collection.
* For objects and arrays, return a marker type; recursive collection is done separately.
*/
private def inferPrimitiveType(v: Variant, depth: Int): DataType = {
if (depth >= maxShreddingDepth) return VariantType

v.getType match {
case Type.OBJECT =>
// Return empty struct as marker; fields collected separately
StructType(Seq.empty)
case Type.ARRAY =>
// Return array with null element as marker; elements processed separately
ArrayType(NullType)
case Type.NULL => NullType
case Type.BOOLEAN => BooleanType
case Type.LONG =>
val d = BigDecimal(v.getLong())
val precision = d.precision
if (precision <= Decimal.MAX_LONG_DIGITS) {
DecimalType(precision, 0)
} else {
LongType
}
case Type.STRING => StringType
case Type.DOUBLE => DoubleType
case Type.DECIMAL =>
val d = Decimal(v.getDecimalWithOriginalScale())
DecimalType(d.precision, d.scale)
case Type.DATE => DateType
case Type.TIMESTAMP => TimestampType
case Type.TIMESTAMP_NTZ => TimestampNTZType
case Type.FLOAT => FloatType
case Type.BINARY => BinaryType
case Type.UUID => VariantType
}
}

/**
* Build a schema from collected field statistics tree.
* For fields in array contexts, use arrayElementCount / total rows.
* For top-level fields, use distinct row count.
*/
private def buildSchemaFromStats(
currentNode: FieldNode,
minCardinality: Int,
numNonNullVariants: Int,
inArrayContext: Boolean,
isArray: Boolean): DataType = {

// If this node is an array, use array-element node to build element type.
if (isArray) {
val arrayElementNodeOpt = currentNode.arrayElementNode
if (arrayElementNodeOpt.isDefined && arrayElementNodeOpt.get.rowCount >= minCardinality) {
val arrayElementNode = arrayElementNodeOpt.get
val elementType = buildSchemaFromStats(
arrayElementNode,
minCardinality,
numNonNullVariants,
inArrayContext = true,
isArray = arrayElementNode.arrayElementNode.isDefined
)
return ArrayType(
if (elementType == VariantType) arrayElementNode.dataType else elementType)
}
return currentNode.dataType
}

// Get all direct children, filter by cardinality, sort by cardinality descending,
// take top N, then sort alphabetically for determinism.
val maxStructSize = Math.min(1000, maxShreddedFieldsPerFile)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to limit candidates per struct node; final file-level enforcement is in finalizeSimpleSchema.

val children = currentNode.getChildren
.filter { case (_, childNode) =>
val cardinality = if (inArrayContext) {
childNode.arrayElementCount
} else {
childNode.rowCount
}
cardinality >= minCardinality
}
.sortBy { case (fieldName, childNode) =>
val cardinality = if (inArrayContext) {
childNode.arrayElementCount
} else {
childNode.rowCount
}
// Sort by cardinality descending, then by name ascending for stability
(-cardinality, fieldName)
}
.take(maxStructSize)
.sortBy(_._1) // Sort alphabetically

if (children.isEmpty) {
return VariantType
}

// Build struct from children
val fields = children.map { case (fieldName, childNode) =>
val fieldType = childNode.dataType match {
case StructType(_) =>
buildSchemaFromStats(
childNode, minCardinality, numNonNullVariants, inArrayContext, isArray = false)

case ArrayType(_, _) =>
buildSchemaFromStats(
childNode,
minCardinality,
numNonNullVariants,
inArrayContext = true,
isArray = childNode.arrayElementNode.isDefined)

case other => other
}

val cardinality = if (inArrayContext) {
childNode.arrayElementCount
} else {
childNode.rowCount
}

StructField(fieldName, fieldType,
metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, cardinality).build())
}

StructType(fields.toSeq)
}
}
Loading