diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 1ed805513b..797b4d6022 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -141,6 +141,19 @@ Cast operations in Comet fall into three levels of support: Spark. - **N/A**: Spark does not support this cast. +### String to Decimal + +Comet's native `CAST(string AS DECIMAL)` implementation matches Apache Spark's behavior, +including: + +- Leading and trailing ASCII whitespace is trimmed before parsing. +- Null bytes (`\u0000`) at the start or end of a string are trimmed, matching Spark's + `UTF8String` behavior. Null bytes embedded in the middle of a string produce `NULL`. +- Fullwidth Unicode digits (U+FF10–U+FF19, e.g. `123.45`) are treated as their ASCII + equivalents, so `CAST('123.45' AS DECIMAL(10,2))` returns `123.45`. +- Scientific notation (e.g. `1.23E+5`) is supported. +- Special values (`inf`, `infinity`, `nan`) produce `NULL`. + ### String to Timestamp Comet's native `CAST(string AS TIMESTAMP)` implementation supports all timestamp formats accepted diff --git a/native/spark-expr/src/conversion_funcs/string.rs b/native/spark-expr/src/conversion_funcs/string.rs index fbad964ec3..13a2b8ba56 100644 --- a/native/spark-expr/src/conversion_funcs/string.rs +++ b/native/spark-expr/src/conversion_funcs/string.rs @@ -438,6 +438,40 @@ fn cast_string_to_decimal256_impl( )) } +/// Normalize fullwidth Unicode digits (U+FF10–U+FF19) to their ASCII equivalents. +/// +/// Spark's UTF8String parser treats fullwidth digits as numerically equivalent to +/// ASCII digits, e.g. "123.45" parses as 123.45. Each fullwidth digit encodes +/// to exactly three UTF-8 bytes: [0xEF, 0xBC, 0x90+n] for digit n. The ASCII +/// equivalent is 0x30+n, so the conversion is: third_byte - 0x60. +/// +/// All other bytes (ASCII or other multi-byte sequences) are passed through +/// unchanged, so the output is valid UTF-8 whenever the input is. +fn normalize_fullwidth_digits(s: &str) -> String { + let bytes = s.as_bytes(); + let mut out = Vec::with_capacity(s.len()); + let mut i = 0; + while i < bytes.len() { + if i + 2 < bytes.len() + && bytes[i] == 0xEF + && bytes[i + 1] == 0xBC + && bytes[i + 2] >= 0x90 + && bytes[i + 2] <= 0x99 + { + // e.g. 0x91 - 0x60 = 0x31 = b'1' + out.push(bytes[i + 2] - 0x60); + i += 3; + } else { + out.push(bytes[i]); + i += 1; + } + } + // SAFETY: we only replace valid 3-byte UTF-8 sequences [EF BC 9X] with a + // single ASCII byte; all other bytes are copied unchanged, preserving the + // UTF-8 invariant of the input. + unsafe { String::from_utf8_unchecked(out) } +} + /// Parse a decimal string into mantissa and scale /// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc /// Parse a string to decimal following Spark's behavior @@ -446,16 +480,30 @@ fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> SparkRe let mut start = 0; let mut end = string_bytes.len(); - // trim whitespaces - while start < end && string_bytes[start].is_ascii_whitespace() { + // Trim ASCII whitespace and null bytes from both ends. Spark's UTF8String + // trims null bytes the same way it trims whitespace: "123\u0000" and + // "\u0000123" both parse as 123. Null bytes in the middle are not trimmed + // and will fail the digit validation in parse_decimal_str, producing NULL. + while start < end && (string_bytes[start].is_ascii_whitespace() || string_bytes[start] == 0) { start += 1; } - while end > start && string_bytes[end - 1].is_ascii_whitespace() { + while end > start && (string_bytes[end - 1].is_ascii_whitespace() || string_bytes[end - 1] == 0) + { end -= 1; } let trimmed = &input_str[start..end]; + // Normalize fullwidth digits to ASCII. Fast path skips the allocation for + // pure-ASCII strings, which is the common case. + let normalized; + let trimmed = if trimmed.bytes().any(|b| b > 0x7F) { + normalized = normalize_fullwidth_digits(trimmed); + normalized.as_str() + } else { + trimmed + }; + if trimmed.is_empty() { return Ok(None); } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 3ebc5197cc..c9b9c433bf 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -211,9 +211,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case DataTypes.FloatType | DataTypes.DoubleType => Compatible() case _: DecimalType => - // https://github.com/apache/datafusion-comet/issues/325 - Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10) - |or strings containing null bytes (e.g \\u0000)""".stripMargin)) + Compatible() case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala index ba37f8c94b..eebe1351aa 100644 --- a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala +++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -327,7 +327,21 @@ trait ShimSparkErrorConverter { try { DataType.fromDDL(typeName) } catch { - case _: Exception => StringType + case _: Exception => + // fromDDL rejects types that are syntactically invalid in SQL DDL, such as + // DECIMAL(p,s) with a negative scale (valid when allowNegativeScaleOfDecimal=true). + // Parse those manually rather than silently falling back to StringType. + if (typeName.toUpperCase.startsWith("DECIMAL(") && typeName.endsWith(")")) { + val inner = typeName.substring("DECIMAL(".length, typeName.length - 1) + val parts = inner.split(",") + if (parts.length == 2) { + try { + DataTypes.createDecimalType(parts(0).trim.toInt, parts(1).trim.toInt) + } catch { + case _: Exception => StringType + } + } else StringType + } else StringType } } } diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala index 1d140e190f..005aa1548e 100644 --- a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -323,7 +323,21 @@ trait ShimSparkErrorConverter { try { DataType.fromDDL(typeName) } catch { - case _: Exception => StringType + case _: Exception => + // fromDDL rejects types that are syntactically invalid in SQL DDL, such as + // DECIMAL(p,s) with a negative scale (valid when allowNegativeScaleOfDecimal=true). + // Parse those manually rather than silently falling back to StringType. + if (typeName.toUpperCase.startsWith("DECIMAL(") && typeName.endsWith(")")) { + val inner = typeName.substring("DECIMAL(".length, typeName.length - 1) + val parts = inner.split(",") + if (parts.length == 2) { + try { + DataTypes.createDecimalType(parts(0).trim.toInt, parts(1).trim.toInt) + } catch { + case _: Exception => StringType + } + } else StringType + } else StringType } } } diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala index a787fb8014..fc7d8277ed 100644 --- a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -328,7 +328,21 @@ trait ShimSparkErrorConverter { try { DataType.fromDDL(typeName) } catch { - case _: Exception => StringType + case _: Exception => + // fromDDL rejects types that are syntactically invalid in SQL DDL, such as + // DECIMAL(p,s) with a negative scale (valid when allowNegativeScaleOfDecimal=true). + // Parse those manually rather than silently falling back to StringType. + if (typeName.toUpperCase.startsWith("DECIMAL(") && typeName.endsWith(")")) { + val inner = typeName.substring("DECIMAL(".length, typeName.length - 1) + val parts = inner.split(",") + if (parts.length == 2) { + try { + DataTypes.createDecimalType(parts(0).trim.toInt, parts(1).trim.toInt) + } catch { + case _: Exception => StringType + } + } else StringType + } else StringType } } } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 8a71c08eb8..92af20c433 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -767,102 +767,112 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } // This is to pass the first `all cast combinations are covered` - ignore("cast StringType to DecimalType(10,2)") { + test("cast StringType to DecimalType(10,2)") { val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false) } - test("cast StringType to DecimalType(10,2) (does not support fullwidth unicode digits)") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) - } + test("cast StringType to DecimalType(10,2) fuzz") { + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) } test("cast StringType to DecimalType(2,2)") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) - } + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) } test("cast StringType to DecimalType check if right exception message is thrown") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = Seq("d11307\n").toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) - } + val values = Seq("d11307\n").toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) } test("cast StringType to DecimalType(2,2) check if right exception is being thrown") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = gen.generateInts(10000).map(" " + _).toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) - } + val values = gen.generateInts(10000).map(" " + _).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) } test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = Seq("0e31", "000e3375", "0e40", "0E+695", "0e5887677").toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) - } + val values = Seq("0e31", "000e3375", "0e40", "0E+695", "0e5887677").toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) } test("cast StringType to DecimalType(38,10) high precision") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) - } + val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) + } + + test("cast StringType to DecimalType - null bytes and fullwidth digits") { + // Spark trims null bytes (\u0000) from both ends of a string before parsing, + // matching its whitespace-trim behavior. Null bytes in the middle produce NULL. + // Fullwidth digits (U+FF10-U+FF19) are treated as numeric equivalents to ASCII digits. + val values = Seq( + // null byte positions + "123\u0000", + "\u0000123", + "12\u00003", + "1\u00002\u00003", + "\u0000", + // null byte with decimal point + "12\u0000.45", + "12.\u000045", + // fullwidth digits (U+FF10-U+FF19) + "123.45", // "123.45" in fullwidth + "123", + "-123.45", + "+123.45", + "123.45E2", + // mixed fullwidth and ASCII + "123.45", + null).toDF("a") + castTest(values, DataTypes.createDecimalType(10, 2)) } test("cast StringType to DecimalType(10,2) basic values") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = Seq( - "123.45", - "-67.89", - "-67.89", - "-67.895", - "67.895", - "0.001", - "999.99", - "123.456", - "123.45D", - ".5", - "5.", - "+123.45", - " 123.45 ", - "inf", - "", - "abc", - null).toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) - } + val values = Seq( + "123.45", + "-67.89", + "-67.89", + "-67.895", + "67.895", + "0.001", + "999.99", + "123.456", + "123.45D", + ".5", + "5.", + "+123.45", + " 123.45 ", + "inf", + "", + "abc", + null).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) } test("cast StringType to Decimal type scientific notation") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - val values = Seq( - "1.23E-5", - "1.23e10", - "1.23E+10", - "-1.23e-5", - "1e5", - "1E-2", - "-1.5e3", - "1.23E0", - "0e0", - "1.23e", - "e5", - null).toDF("a") - Seq(true, false).foreach(ansiEnabled => - castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = ansiEnabled)) - } + val values = Seq( + "1.23E-5", + "1.23e10", + "1.23E+10", + "-1.23e-5", + "1e5", + "1E-2", + "-1.5e3", + "1.23E0", + "0e0", + "1.23e", + "e5", + null).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = ansiEnabled)) } test("cast StringType to BinaryType") { @@ -1310,6 +1320,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4)) } + test("cast StringType to DecimalType with negative scale (allowNegativeScaleOfDecimal)") { + // With allowNegativeScaleOfDecimal=true, Spark allows DECIMAL(p, s) where s < 0. + // The value is rounded to the nearest 10^|s| — e.g. DECIMAL(10,-4) rounds to + // the nearest 10000. This requires the legacy SQL parser config to be enabled. + withSQLConf("spark.sql.legacy.allowNegativeScaleOfDecimal" -> "true") { + val values = + Seq("12500", "15000", "99990000", "-12500", "0", "0.001", "abc", null).toDF("a") + // testTry=false: try_cast uses SQL string interpolation (toType.sql → "DECIMAL(10,-4)") + // which the SQL parser rejects regardless of allowNegativeScaleOfDecimal. + castTest(values, DataTypes.createDecimalType(10, -4), testTry = false) + } + } + test("cast between decimals with negative precision") { // cast to negative scale checkSparkAnswerMaybeThrows(