diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala index 1ebb61968150b..b8c3d93a10522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import scala.collection.mutable + import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.internal.SQLConf @@ -93,74 +95,31 @@ class InferVariantShreddingSchema(val schema: StructType) { private val COUNT_METADATA_KEY = "COUNT" - /** - * Return an appropriate schema for shredding a Variant value. - * It is similar to the SchemaOfVariant expression, but the rules are somewhat different, because - * we want the types to be consistent with what will be allowed during shredding. E.g. - * SchemaOfVariant will consider the common type across Integer and Double to be double, but we - * consider it to be VariantType, since shredding will not allow those types to be written to - * the same typed_value. - * We also maintain metadata on struct fields to track how frequently they occur. Rare fields - * are dropped in the final schema. - */ - private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match { - case Type.OBJECT => - if (maxDepth <= 0) return VariantType - val size = v.objectSize() - val fields = new Array[StructField](size) - for (i <- 0 until size) { - val field = v.getFieldAtIndex(i) - fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1), - metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 1).build()) - } - // According to the variant spec, object fields must be sorted alphabetically. So we don't - // have to sort, but just need to validate they are sorted. - for (i <- 1 until size) { - if (fields(i - 1).name >= fields(i).name) { - throw new SparkRuntimeException( - errorClass = "MALFORMED_VARIANT", - messageParameters = Map.empty - ) - } - } - StructType(fields) - case Type.ARRAY => - if (maxDepth <= 0) return VariantType - var elementType: DataType = NullType - for (i <- 0 until v.arraySize()) { - elementType = mergeSchema(elementType, schemaOf(v.getElementAtIndex(i), maxDepth - 1)) - } - ArrayType(elementType) - case Type.NULL => NullType - case Type.BOOLEAN => BooleanType - case Type.LONG => - // Compute the smallest decimal that can contain this value. - // This will allow us to merge with decimal later without introducing excessive precision. - // If we only end up encountering integer values, we'll convert back to LongType when we - // finalize. - val d = BigDecimal(v.getLong()) - val precision = d.precision - if (precision <= Decimal.MAX_LONG_DIGITS) { - DecimalType(precision, 0) - } else { - // Value is too large for Decimal(18, 0), so record its type as long. - LongType + // Node for tree-based field tracking + private case class FieldNode( + var dataType: DataType, // type summary of the field, not fully defined + var rowCount: Int = 0, // Count of distinct rows containing this field + var lastSeenRow: Int = -1, // Last row index that incremented rowCount + var arrayElementCount: Long = 0, // Total occurrences across all array elements + children: mutable.Map[String, FieldNode] = mutable.Map.empty, + var arrayElementNode: Option[FieldNode] = None + ) { + + def getOrCreateChild(fieldName: String): FieldNode = { + children.getOrElseUpdate(fieldName, FieldNode(NullType)) + } + + def hasChildren: Boolean = children.nonEmpty + + def getChildren: Seq[(String, FieldNode)] = children.toSeq + + def getOrCreateArrayElement(): FieldNode = { + arrayElementNode.getOrElse { + val node = FieldNode(NullType) + arrayElementNode = Some(node) + node } - case Type.STRING => StringType - case Type.DOUBLE => DoubleType - case Type.DECIMAL => - // Don't strip trailing zeros to determine scale. Even if we allow scale relaxation during - // shredding, it's useful to take trailing zeros as a hint that the extra digits may be used - // in later values, and use the larger scale. - val d = Decimal(v.getDecimalWithOriginalScale()) - DecimalType(d.precision, d.scale) - case Type.DATE => DateType - case Type.TIMESTAMP => TimestampType - case Type.TIMESTAMP_NTZ => TimestampNTZType - case Type.FLOAT => FloatType - case Type.BINARY => BinaryType - // Spark doesn't support UUID, so shred it as an untyped value. - case Type.UUID => VariantType + } } private def getFieldCount(field: StructField): Long = { @@ -351,36 +310,251 @@ class InferVariantShreddingSchema(val schema: StructType) { } def inferSchema(rows: Seq[InternalRow]): StructType = { - // For each path to a Variant value, iterate over all rows and update the inferred schema. - // Add the result to a map, which we'll use to update the full schema. - // maxShreddedFieldsPerFile is a global max for all fields, so initialize it here. + // For each variant path, collect field statistics using a single pass val maxFields = MaxFields(maxShreddedFieldsPerFile) + val inferredSchemas = pathsToVariant.map { path => - var numNonNullValues = 0 - val simpleSchema = rows.foldLeft(NullType: DataType) { - case (partialSchema, row) => - getValueAtPath(schema, row, path).map { variantVal => - numNonNullValues += 1 - val v = new Variant(variantVal.getValue, variantVal.getMetadata) - val schemaOfRow = schemaOf(v, maxShreddingDepth) - mergeSchema(partialSchema, schemaOfRow) - // If getValueAtPath returned None, the value is null in this row; just ignore. - } - .getOrElse(partialSchema) - // If we didn't find any non-null rows, use an unshredded schema. - } + val rootNode = FieldNode(NullType) + var numNonNullVariants = 0 - // Don't infer a schema for fields that appear in less than 10% of rows. - // Ensure that minCardinality is at least 1 if we have any rows. - val minCardinality = (numNonNullValues + 9) / 10 + // Single pass: process all rows for this variant path + rows.zipWithIndex.foreach { case (row, rowIdx) => + getValueAtPath(schema, row, path).foreach { variantVal => + numNonNullVariants += 1 + val v = new Variant(variantVal.getValue, variantVal.getMetadata) + rootNode.dataType = mergeSchema(rootNode.dataType, inferPrimitiveType(v, 0)) + // Traverse variant and update field stats tree + collectFieldStats(v, rootNode, rowIdx, 0, inArrayContext = false) + } + } + // Build final schema from collected statistics + val minCardinality = (numNonNullVariants + 9) / 10 + val simpleSchema = buildSchemaFromStats( + rootNode, + minCardinality, + numNonNullVariants, + inArrayContext = false, + isArray = rootNode.arrayElementNode.isDefined) val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality, maxFields) val shreddingSchema = SparkShreddingUtils.variantShreddingSchema(finalizedSchema) val schemaWithMetadata = SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema) (path, schemaWithMetadata) }.toMap - // Insert each inferred schema into the full schema. + // Insert each inferred schema into the full schema updateSchema(schema, inferredSchemas) } + + /** + * Recursively traverse a variant value and build field statistics tree. + * For each field encountered, record its type and track distinct row count. + * For fields inside arrays, also increment the occurrence count. + */ + private def collectFieldStats( + v: Variant, + currentNode: FieldNode, + rowIdx: Int, + depth: Int, + inArrayContext: Boolean): Unit = { + + if (depth >= maxShreddingDepth) return + + v.getType match { + case Type.OBJECT => + val size = v.objectSize() + // Validate fields are sorted (per variant spec) + for (i <- 1 until size) { + val prevKey = v.getFieldAtIndex(i - 1).key + val currKey = v.getFieldAtIndex(i).key + if (prevKey >= currKey) { + throw new SparkRuntimeException( + errorClass = "MALFORMED_VARIANT", + messageParameters = Map.empty + ) + } + } + + // Process each field + for (i <- 0 until size) { + val field = v.getFieldAtIndex(i) + val fieldName = field.key + + // Get or create child node (O(1) map access - no path string building!) + val childNode = currentNode.getOrCreateChild(fieldName) + + // Track row-level presence only outside array context. + if (inArrayContext) { + childNode.arrayElementCount += 1 + } else if (childNode.lastSeenRow != rowIdx) { + childNode.rowCount += 1 + childNode.lastSeenRow = rowIdx + } + + // Infer and merge type + val fieldType = inferPrimitiveType(field.value, depth) + childNode.dataType = mergeSchema(childNode.dataType, fieldType) + + // Recurse into nested structures (pass child node, not path string) + collectFieldStats(field.value, childNode, rowIdx, depth + 1, inArrayContext) + } + + case Type.ARRAY => + val arrayNode = currentNode.getOrCreateArrayElement() + + // Track distinct row count for the array field itself + if (arrayNode.lastSeenRow != rowIdx) { + arrayNode.rowCount += 1 + arrayNode.lastSeenRow = rowIdx + } + + val arraySize = v.arraySize() + if (arraySize > 0) { + // Process array elements + for (i <- 0 until arraySize) { + val element = v.getElementAtIndex(i) + val elementTypeClass = element.getType + + // For primitives, infer and merge type directly + // For objects/arrays, collectFieldStats handles type via field traversal + if (elementTypeClass != Type.OBJECT && elementTypeClass != Type.ARRAY) { + val primitiveType = inferPrimitiveType(element, depth) + arrayNode.dataType = mergeSchema(arrayNode.dataType, primitiveType) + } + + // Recurse into element to collect nested fields, now IN array context + collectFieldStats(element, arrayNode, rowIdx, depth + 1, inArrayContext = true) + } + } + + case _ => + } + } + + /** + * Infer the type of a variant value without recursive field collection. + * For objects and arrays, return a marker type; recursive collection is done separately. + */ + private def inferPrimitiveType(v: Variant, depth: Int): DataType = { + if (depth >= maxShreddingDepth) return VariantType + + v.getType match { + case Type.OBJECT => + // Return empty struct as marker; fields collected separately + StructType(Seq.empty) + case Type.ARRAY => + // Return array with null element as marker; elements processed separately + ArrayType(NullType) + case Type.NULL => NullType + case Type.BOOLEAN => BooleanType + case Type.LONG => + val d = BigDecimal(v.getLong()) + val precision = d.precision + if (precision <= Decimal.MAX_LONG_DIGITS) { + DecimalType(precision, 0) + } else { + LongType + } + case Type.STRING => StringType + case Type.DOUBLE => DoubleType + case Type.DECIMAL => + val d = Decimal(v.getDecimalWithOriginalScale()) + DecimalType(d.precision, d.scale) + case Type.DATE => DateType + case Type.TIMESTAMP => TimestampType + case Type.TIMESTAMP_NTZ => TimestampNTZType + case Type.FLOAT => FloatType + case Type.BINARY => BinaryType + case Type.UUID => VariantType + } + } + + /** + * Build a schema from collected field statistics tree. + * For fields in array contexts, use arrayElementCount / total rows. + * For top-level fields, use distinct row count. + */ + private def buildSchemaFromStats( + currentNode: FieldNode, + minCardinality: Int, + numNonNullVariants: Int, + inArrayContext: Boolean, + isArray: Boolean): DataType = { + + // If this node is an array, use array-element node to build element type. + if (isArray) { + val arrayElementNodeOpt = currentNode.arrayElementNode + if (arrayElementNodeOpt.isDefined && arrayElementNodeOpt.get.rowCount >= minCardinality) { + val arrayElementNode = arrayElementNodeOpt.get + val elementType = buildSchemaFromStats( + arrayElementNode, + minCardinality, + numNonNullVariants, + inArrayContext = true, + isArray = arrayElementNode.arrayElementNode.isDefined + ) + return ArrayType( + if (elementType == VariantType) arrayElementNode.dataType else elementType) + } + return currentNode.dataType + } + + // Get all direct children, filter by cardinality, sort by cardinality descending, + // take top N, then sort alphabetically for determinism. + val maxStructSize = Math.min(1000, maxShreddedFieldsPerFile) + val children = currentNode.getChildren + .filter { case (_, childNode) => + val cardinality = if (inArrayContext) { + childNode.arrayElementCount + } else { + childNode.rowCount + } + cardinality >= minCardinality + } + .sortBy { case (fieldName, childNode) => + val cardinality = if (inArrayContext) { + childNode.arrayElementCount + } else { + childNode.rowCount + } + // Sort by cardinality descending, then by name ascending for stability + (-cardinality, fieldName) + } + .take(maxStructSize) + .sortBy(_._1) // Sort alphabetically + + if (children.isEmpty) { + return VariantType + } + + // Build struct from children + val fields = children.map { case (fieldName, childNode) => + val fieldType = childNode.dataType match { + case StructType(_) => + buildSchemaFromStats( + childNode, minCardinality, numNonNullVariants, inArrayContext, isArray = false) + + case ArrayType(_, _) => + buildSchemaFromStats( + childNode, + minCardinality, + numNonNullVariants, + inArrayContext = true, + isArray = childNode.arrayElementNode.isDefined) + + case other => other + } + + val cardinality = if (inArrayContext) { + childNode.arrayElementCount + } else { + childNode.rowCount + } + + StructField(fieldName, fieldType, + metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, cardinality).build()) + } + + StructType(fields.toSeq) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala index 49a43fffafb34..56b398749f7e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala @@ -203,13 +203,10 @@ class VariantInferShreddingSuite extends QueryTest with SharedSparkSession with } testWithTempDir("infer shredding key as data") { dir => - // The first 10 fields in each object include the row ID in the field name, so they'll be - // unique. Because we impose a 1000-field limit when building up the schema, we'll end up - // dropping all but the first 1000, so we won't include the non-unique fields in the schema. - // Since the unique names are below the count threshold, we'll end up with an unshredded - // schema. - // In the future, we could consider trying to improve this by dropping the least-common fields - // when we hit the limit of 1000. + // The first 50 fields include the row ID in the field name, so they're + // unique (low cardinality). The last 50 fields are shared across all rows + // (high cardinality). With cardinality-based sorting, + // we now correctly shred the high-cardinality last_* fields val bigObject = (0 until 100).map { i => if (i < 50) { s""" "first_${i}_' || id || '": {"x": $i, "y": "${i + 1}"} """ @@ -226,15 +223,23 @@ class VariantInferShreddingSuite extends QueryTest with SharedSparkSession with val footers = getFooters(dir) assert(footers.size == 1) - // We can't call checkFileSchema, because it only handles the case of one Variant column in - // the file. - val largeExpected = SparkShreddingUtils.variantShreddingSchema(DataType.fromDDL("variant")) + // With cardinality-based sorting, v should now have a shredded schema + // for the high-cardinality last_* fields (not an unshredded schema like + // master would produce). v2 should still be shredded correctly. + val actual = getFileSchema(dir) + val v_schema = actual.fields(0).dataType.asInstanceOf[StructType] + val v2_schema = actual.fields(1).dataType.asInstanceOf[StructType] + + // v should have shredded typed_value (struct with nested last_* fields) + assert(v_schema.fieldNames.contains("typed_value")) + val v_typed = v_schema("typed_value").dataType.asInstanceOf[StructType] + assert(v_typed.fields.exists(_.name.startsWith("last_"))) + + // v2 should be fully shredded val smallExpected = SparkShreddingUtils.variantShreddingSchema( DataType.fromDDL("struct")) - val actual = getFileSchema(dir) - assert(actual == StructType(Seq( - StructField("v", largeExpected, nullable = false), - StructField("v2", smallExpected, nullable = false)))) + assert(v2_schema == smallExpected) + checkStringAndSchema(dir, df) } @@ -634,4 +639,111 @@ class VariantInferShreddingSuite extends QueryTest with SharedSparkSession with checkFileSchema(expected, dir) checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) } + + testWithTempDir("special characters in field names - dots") { dir => + val df = spark.sql( + """ + |select parse_json( + | '{"field.with.dots": ' || id || ', "another.dotted.field": "value"}' + |) as v + |from range(0, 100, 1, 1) + """.stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + + // Verify the schema contains fields with dots + val schema = getFileSchema(dir) + val vSchema = schema("v").dataType.asInstanceOf[StructType] + val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType] + assert(typedValue.fieldNames.contains("another.dotted.field")) + assert(typedValue.fieldNames.contains("field.with.dots")) + + // Verify we can read the data back + val result = spark.read.parquet(dir.getAbsolutePath) + assert(result.count() == 100) + } + + testWithTempDir("special characters in field names - brackets") { dir => + val df = spark.sql( + """ + |select parse_json( + | '{"field[0]": ' || id || ', "another[key]": "value"}' + |) as v + |from range(0, 100, 1, 1) + """.stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + + // Verify the schema contains fields with brackets + val schema = getFileSchema(dir) + val vSchema = schema("v").dataType.asInstanceOf[StructType] + val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType] + assert(typedValue.fieldNames.contains("another[key]")) + assert(typedValue.fieldNames.contains("field[0]")) + + // Verify we can read the data back + val result = spark.read.parquet(dir.getAbsolutePath) + assert(result.count() == 100) + } + + testWithTempDir("special characters in field names - mixed") { dir => + val df = spark.sql( + """ + |select parse_json( + | '{"a.b[0]": ' || id || ', "c[d].e": "value", "normal_field": 42}' + |) as v + |from range(0, 100, 1, 1) + """.stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + + // Verify the schema contains fields with mixed special characters + val schema = getFileSchema(dir) + val vSchema = schema("v").dataType.asInstanceOf[StructType] + val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType] + assert(typedValue.fieldNames.contains("a.b[0]")) + assert(typedValue.fieldNames.contains("c[d].e")) + assert(typedValue.fieldNames.contains("normal_field")) + + // Verify we can read the data back + val result = spark.read.parquet(dir.getAbsolutePath) + assert(result.count() == 100) + } + + testWithTempDir("special characters in field names - literal empty brackets") { dir => + val df = spark.sql( + """ + |select parse_json( + | '{"[]": ' || id || ', "normal_field": "value"}' + |) as v + |from range(0, 100, 1, 1) + """.stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + + val schema = getFileSchema(dir) + val vSchema = schema("v").dataType.asInstanceOf[StructType] + val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType] + assert(typedValue.fieldNames.contains("[]")) + assert(typedValue.fieldNames.contains("normal_field")) + + val result = spark.read.parquet(dir.getAbsolutePath) + assert(result.count() == 100) + } + + testWithTempDir("special characters in field names - literal empty brackets with array") { dir => + val df = spark.sql( + """ + |select parse_json( + | '{"[]": ' || id || ', "arr": [' || id || ', ' || (id + 1) || ']}' + |) as v + |from range(0, 100, 1, 1) + """.stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + + val schema = getFileSchema(dir) + val vSchema = schema("v").dataType.asInstanceOf[StructType] + val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType] + assert(typedValue.fieldNames.contains("[]")) + assert(typedValue.fieldNames.contains("arr")) + + val result = spark.read.parquet(dir.getAbsolutePath) + assert(result.count() == 100) + } }