From b27789d86a4eeb205e256998fb460a0d164a7399 Mon Sep 17 00:00:00 2001 From: Parth Chandra Date: Fri, 10 Apr 2026 15:31:19 -0700 Subject: [PATCH] fix: checkSparkAnswer displays incorrect labels --- .../apache/comet/CometExpressionSuite.scala | 7 +++ .../org/apache/spark/sql/CometTestBase.scala | 46 ++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index b64bce6a79..2f2b25cc7d 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -58,6 +58,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val DIVIDE_BY_ZERO_EXCEPTION_MSG = """Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead""" + // Temporary test to verify checkSparkAnswer failure output labels Comet/Spark correctly. + ignore("check output labels on mismatch") { + val cometDf = Seq((1, "apple"), (2, "banana"), (3, "cherry")).toDF("id", "fruit") + val sparkAnswer = Seq(Row(1, "apple"), Row(2, "BANANA"), Row(3, "cherry")) + checkCometAnswer(cometDf, sparkAnswer) + } + test("sort floating point with negative zero") { val schema = StructType( Seq( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index afee83aa63..c91099c9e0 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -37,6 +37,8 @@ import org.apache.parquet.hadoop.example.{ExampleParquetWriter, GroupWriteSuppor import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.comet.CometPlanChecker import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution._ @@ -128,7 +130,7 @@ abstract class CometTestBase if (withTol.isDefined) { checkAnswerWithTolerance(dfComet, expected, withTol.get) } else { - checkAnswer(dfComet, expected) + checkCometAnswer(dfComet, expected) } if (assertCometNative) { @@ -358,6 +360,48 @@ abstract class CometTestBase } } + /** + * Compares the Comet DataFrame result against the expected Spark answer, using labels that + * correctly identify which side is Comet and which is Spark. This avoids the misleading "Spark + * Answer" label that Spark's built-in `checkAnswer` would apply to the Comet result. + */ + protected def checkCometAnswer(cometDf: DataFrame, sparkAnswer: Seq[Row]): Unit = { + val isSorted = cometDf.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + val cometAnswer = + try cometDf.collect().toSeq + catch { + case e: Exception => + fail(s"""Exception thrown while executing query in Comet: + |${cometDf.queryExecution} + |== Exception == + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin) + } + if (!QueryTest.compare( + QueryTest.prepareAnswer(sparkAnswer, isSorted), + QueryTest.prepareAnswer(cometAnswer, isSorted))) { + val getRowType: Option[Row] => String = row => + row + .map(r => if (r.schema == null) "struct<>" else r.schema.catalogString) + .getOrElse("struct<>") + fail(s"""Results do not match for query: + |Timezone: ${java.util.TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${cometDf.queryExecution} + |== Results == + |${sideBySide( + s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + QueryTest.prepareAnswer(sparkAnswer, isSorted).map(_.toString()), + s"== Comet Answer - ${cometAnswer.size} ==" +: + getRowType(cometAnswer.headOption) +: + QueryTest.prepareAnswer(cometAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin) + } + } + /** * A helper function for comparing Comet DataFrame with Spark result using absolute tolerance. */