diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index b7c335c6cfcfe..9f606b698d30c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -193,6 +193,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends return } + // SPARK-53339: Post the Started event here, right after the CAS succeeds, to ensure that + // postStarted() is never called when interrupt() has already transitioned the state to + // interrupted. This eliminates the race between postStarted() and interrupt(). + executeHolder.eventsManager.postStarted() + // `withSession` ensures that session-specific artifacts (such as JARs and class files) are // available during processing. executeHolder.sessionHolder.withSession { session => diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala index fcf01d5d29ab3..286163e135d2a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -188,9 +188,7 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { request.getPlan.getOpTypeCase match { case proto.Plan.OpTypeCase.COMMAND => request.getPlan.getCommand case proto.Plan.OpTypeCase.ROOT => request.getPlan.getRoot - case _ => - throw new UnsupportedOperationException( - s"${request.getPlan.getOpTypeCase} not supported.") + case _ => request.getPlan } val event = SparkListenerConnectOperationStarted( @@ -248,8 +246,11 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationCanceled. */ def postCanceled(): Unit = { + // SPARK-53339: Pending is included to handle the case where interrupt() is called before + // postStarted() transitions the status from Pending to Started. assertStatus( List( + ExecuteStatus.Pending, ExecuteStatus.Started, ExecuteStatus.Analyzed, ExecuteStatus.ReadyForExecution, @@ -269,8 +270,11 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { * The message of the error thrown during the request. */ def postFailed(errorMessage: String): Unit = { + // SPARK-53339: Pending is included to handle the case where postStarted() itself throws + // an exception (e.g., session state check failure) before transitioning from Pending. assertStatus( List( + ExecuteStatus.Pending, ExecuteStatus.Started, ExecuteStatus.Analyzed, ExecuteStatus.ReadyForExecution, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index bb51438ce90f6..2a936b526d96a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.connect.IllegalStateErrors import org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE, CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT, CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL} import org.apache.spark.sql.connect.execution.ExecuteGrpcResponseSender +import org.apache.spark.sql.connect.planner.InvalidInputErrors import org.apache.spark.util.ThreadUtils // Unique key identifying execution by combination of user, session and operation id @@ -191,7 +192,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { responseObserver: StreamObserver[proto.ExecutePlanResponse]): ExecuteHolder = { val executeHolder = createExecuteHolder(executeKey, request, sessionHolder) try { - executeHolder.eventsManager.postStarted() + // SPARK-53339: Validate the plan before starting the execution thread. + // postStarted() was moved into executeInternal(), so invalid plans that previously + // caused postStarted() to throw (and thus triggered removeExecuteHolder in this + // catch block) now fail asynchronously inside the execution thread. This early + // validation ensures that invalid plans are still caught synchronously here. + request.getPlan.getOpTypeCase match { + case proto.Plan.OpTypeCase.ROOT | proto.Plan.OpTypeCase.COMMAND => // valid + case other => + throw InvalidInputErrors.invalidOneOfField(other, request.getPlan.getDescriptorForType) + } executeHolder.start() } catch { // Errors raised before the execution holder has finished spawning a thread are considered diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index a96d0ab977c5c..f5349a48330c3 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -138,6 +138,36 @@ class ExecuteEventsManagerSuite .isInstanceOf[SparkListenerConnectOperationCanceled]) } + test("SPARK-53339: post canceled from Pending state") { + val events = setupEvents(ExecuteStatus.Pending) + events.postCanceled() + assert(events.status == ExecuteStatus.Canceled) + assert(events.terminationReason.contains(TerminationReason.Canceled)) + } + + test("SPARK-53339: post failed from Pending state") { + val events = setupEvents(ExecuteStatus.Pending) + events.postFailed(DEFAULT_ERROR) + assert(events.status == ExecuteStatus.Failed) + assert(events.terminationReason.contains(TerminationReason.Failed)) + } + + test("SPARK-53339: Pending to Canceled to Closed transition") { + val events = setupEvents(ExecuteStatus.Pending) + events.postCanceled() + events.postClosed() + assert(events.status == ExecuteStatus.Closed) + assert(events.terminationReason.contains(TerminationReason.Canceled)) + } + + test("SPARK-53339: Pending to Failed to Closed transition") { + val events = setupEvents(ExecuteStatus.Pending) + events.postFailed(DEFAULT_ERROR) + events.postClosed() + assert(events.status == ExecuteStatus.Closed) + assert(events.terminationReason.contains(TerminationReason.Failed)) + } + test("SPARK-43923: post failed") { val events = setupEvents(ExecuteStatus.Started) events.postFailed(DEFAULT_ERROR)