From fe817fb1bf9461b52626b17ccd724294516c0008 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 17 Jun 2026 16:01:45 +0200 Subject: [PATCH 1/2] Add signals API. Also for Java API make sure we run Serdes in the user handler thread, and not on the state machine thread. --- .../dev/restate/sdk/kotlin/ContextImpl.kt | 12 +++- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 72 +++++++++++++++++++ .../kotlin/dev/restate/sdk/kotlin/futures.kt | 28 +++++++- .../main/java/dev/restate/sdk/Context.java | 30 ++++++++ .../java/dev/restate/sdk/ContextImpl.java | 38 +++++++++- .../dev/restate/sdk/InvocationHandle.java | 8 +++ .../java/dev/restate/sdk/SignalHandle.java | 48 +++++++++++++ .../endpoint/definition/HandlerContext.java | 11 +++ .../restate/sdk/core/HandlerContextImpl.java | 22 ++++++ .../dev/restate/sdk/fake/FakeContext.java | 5 ++ .../restate/sdk/fake/FakeHandlerContext.java | 19 +++++ 11 files changed, 287 insertions(+), 6 deletions(-) create mode 100644 sdk-api/src/main/java/dev/restate/sdk/SignalHandle.java diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index a74a86a7d..09ae5cd38 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -125,7 +125,7 @@ internal constructor( ) .await() - object : BaseInvocationHandle(handlerContext, responseSerde) { + object : BaseInvocationHandle(this, responseSerde) { override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await() } } @@ -136,7 +136,7 @@ internal constructor( responseTypeTag: TypeTag, ): InvocationHandle = resolveSerde(responseTypeTag).let { responseSerde -> - object : BaseInvocationHandle(handlerContext, responseSerde) { + object : BaseInvocationHandle(this, responseSerde) { override suspend fun invocationId(): String = invocationId } } @@ -200,6 +200,14 @@ internal constructor( return AwakeableHandleImpl(this, id) } + override suspend fun signal(name: String, typeTag: TypeTag): DurableFuture { + checkNotInsideRun() + val serde: Serde = resolveSerde(typeTag) + return SingleDurableFutureImpl(handlerContext.signal(name).await()).simpleMap { + serde.deserialize(it) + } + } + override fun random(): RestateRandom { return this.random } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index 45b8830f3..020f38559 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -201,6 +201,21 @@ sealed interface Context { */ fun awakeableHandle(id: String): AwakeableHandle + /** + * Create a [DurableFuture] waiting on a named signal targeting the current invocation. + * + * Signals are identified by `(invocationId, name)`. The resolution can arrive before or after the + * handler starts waiting on the signal — there's no need to pre-register. + * + * Another invocation can resolve or reject the signal using [signalHandle]. + * + * @param name the signal name. + * @param typeTag the response type tag to use for deserializing the signal result. + * @return a [DurableFuture] that resolves to the signal value (or rejects with a + * [dev.restate.sdk.common.TerminalException]). + */ + suspend fun signal(name: String, typeTag: TypeTag): DurableFuture + /** * Create a [RestateRandom] instance inherently predictable, seeded on the * [dev.restate.sdk.common.InvocationId], which is not secret. @@ -336,6 +351,15 @@ suspend inline fun Context.awakeable(): Awakeable { return this.awakeable(typeTag()) } +/** + * Create a [DurableFuture] waiting on a named signal targeting the current invocation. + * + * @see Context.signal + */ +suspend inline fun Context.signal(name: String): DurableFuture { + return this.signal(name, typeTag()) +} + /** * This interface can be used only within shared handlers of virtual objects. It extends [Context] * adding access to the virtual object instance key-value state storage. @@ -629,6 +653,14 @@ sealed interface InvocationHandle { /** @return the output of this invocation, if present. */ suspend fun output(): Output + + /** + * Get a [SignalHandle] for resolving or rejecting a named signal on this invocation. The + * receiving handler can await on the signal using [Context.signal]. + * + * @param name the signal name. + */ + suspend fun signal(name: String): SignalHandle } /** @@ -677,6 +709,35 @@ suspend inline fun AwakeableHandle.resolve(payload: T) { return this.resolve(typeTag(), payload) } +/** + * Handle to resolve or reject a named signal on a target invocation. + * + * Unlike awakeables, signals are identified by `(invocationId, name)` and do not need to be + * pre-registered: the resolution can arrive before or after the handler starts waiting. + */ +sealed interface SignalHandle { + /** + * Resolve the signal with the given value. + * + * @param typeTag used to serialize the result payload. + * @param payload the result payload. + */ + suspend fun resolve(typeTag: TypeTag, payload: T) + + /** + * Reject the signal with the given reason. The handler awaiting the signal will receive a + * terminal error with [reason] as the message. + * + * @param reason the rejection reason. + */ + suspend fun reject(reason: String) +} + +/** Resolve the signal with the given value. */ +suspend inline fun SignalHandle.resolve(payload: T) { + return this.resolve(typeTag(), payload) +} + /** * A [DurablePromise] is a durable, distributed version of a Kotlin's Deferred, or more commonly of * a future/promise. Restate keeps track of the [DurablePromise] across restarts/failures. @@ -965,6 +1026,17 @@ suspend fun awakeableHandle(id: String): AwakeableHandle { return context().awakeableHandle(id) } +/** + * Create a [DurableFuture] waiting on a named signal targeting the current invocation. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.signal + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun signal(name: String): DurableFuture { + return context().signal(name, typeTag()) +} + /** * Get an [InvocationHandle] for an already existing invocation. * diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt index 2d05e48c1..32496189e 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt @@ -191,9 +191,12 @@ internal constructor( internal abstract class BaseInvocationHandle internal constructor( - private val handlerContext: HandlerContext, + private val contextImpl: ContextImpl, private val responseSerde: Serde, ) : InvocationHandle { + private val handlerContext: HandlerContext + get() = contextImpl.handlerContext + override suspend fun cancel() { checkNotInsideRun() val ignored = handlerContext.cancelInvocation(invocationId()).await() @@ -214,6 +217,11 @@ internal constructor( .simpleMap { it.map { responseSerde.deserialize(it) } } .await() } + + override suspend fun signal(name: String): SignalHandle { + val resolvedId = invocationId() + return SignalHandleImpl(contextImpl, resolvedId, name) + } } internal class AwakeableImpl @@ -237,6 +245,24 @@ internal class AwakeableHandleImpl(val contextImpl: ContextImpl, val id: String) } } +internal class SignalHandleImpl( + val contextImpl: ContextImpl, + val invocationId: String, + val name: String, +) : SignalHandle { + override suspend fun resolve(typeTag: TypeTag, payload: T) { + checkNotInsideRun() + contextImpl.handlerContext + .resolveSignal(invocationId, name, contextImpl.resolveAndSerialize(typeTag, payload)) + .await() + } + + override suspend fun reject(reason: String) { + checkNotInsideRun() + contextImpl.handlerContext.rejectSignal(invocationId, name, TerminalException(reason)).await() + } +} + internal class SelectClauseImpl(override val durableFuture: DurableFuture) : SelectClause @PublishedApi diff --git a/sdk-api/src/main/java/dev/restate/sdk/Context.java b/sdk-api/src/main/java/dev/restate/sdk/Context.java index a31338e3a..1ff5fc1b6 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Context.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Context.java @@ -478,6 +478,36 @@ default Awakeable awakeable(Class clazz) { */ AwakeableHandle awakeableHandle(String id); + /** + * Create a {@link DurableFuture} waiting on a named signal targeting the current invocation. + * + *

Signals are identified by {@code (invocationId, name)}. The resolution can arrive before or + * after the handler starts waiting on the signal — there's no need to pre-register. + * + *

Another invocation can resolve or reject the signal using {@link + * SignalHandle#resolve(TypeTag, Object)} / {@link SignalHandle#reject(String)}. + * + * @param name the signal name. + * @param clazz the response type to use for deserializing the signal result. When using generic + * types, use {@link #signal(String, TypeTag)} instead. + * @return a {@link DurableFuture} that resolves to the signal value (or rejects with a {@link + * TerminalException}). + */ + default DurableFuture signal(String name, Class clazz) { + return signal(name, TypeTag.of(clazz)); + } + + /** + * Create a {@link DurableFuture} waiting on a named signal targeting the current invocation. + * + * @param name the signal name. + * @param typeTag the response type tag to use for deserializing the signal result. + * @return a {@link DurableFuture} that resolves to the signal value (or rejects with a {@link + * TerminalException}). + * @see #signal(String, Class) + */ + DurableFuture signal(String name, TypeTag typeTag); + /** * Returns a deterministic random. * diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 9d1b5df5d..9a0259afa 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -82,7 +82,7 @@ public Optional get(StateKey key) { checkNotInsideRun(); return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.get(key.name())), serviceExecutor) - .mapWithoutExecutor(opt -> opt.map(serdeFactory.create(key.serdeInfo())::deserialize)) + .map(opt -> opt.map(serdeFactory.create(key.serdeInfo())::deserialize)) .await(); } @@ -227,6 +227,30 @@ public Output getOutput() { serviceExecutor) .await(); } + + @Override + public SignalHandle signal(String name) { + String invocationId = invocationId(); + return new SignalHandle() { + @Override + public void resolve(TypeTag typeTag, T payload) { + checkNotInsideRun(); + Util.awaitCompletableFuture( + handlerContext.resolveSignal( + invocationId, + name, + Util.executeOrFail( + handlerContext, serdeFactory.create(typeTag)::serialize, payload))); + } + + @Override + public void reject(String reason) { + checkNotInsideRun(); + Util.awaitCompletableFuture( + handlerContext.rejectSignal(invocationId, name, new TerminalException(reason))); + } + }; + } } @Override @@ -249,7 +273,7 @@ public DurableFuture runAsync( return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.submitRun(name, runClosure)), serviceExecutor) - .mapWithoutExecutor(serde::deserialize); + .map(serde::deserialize); } private void executeRunAction( @@ -325,6 +349,14 @@ public void reject(String reason) { }; } + @Override + public DurableFuture signal(String name, TypeTag typeTag) throws TerminalException { + checkNotInsideRun(); + Serde serde = serdeFactory.create(typeTag); + AsyncResult result = Util.awaitCompletableFuture(handlerContext.signal(name)); + return DurableFuture.fromAsyncResult(result, serviceExecutor).map(serde::deserialize); + } + @Override public RestateRandom random() { return this.random; @@ -338,7 +370,7 @@ public DurableFuture future() { checkNotInsideRun(); AsyncResult result = Util.awaitCompletableFuture(handlerContext.promise(key.name())); return DurableFuture.fromAsyncResult(result, serviceExecutor) - .mapWithoutExecutor(serdeFactory.create(key.serdeInfo())::deserialize); + .map(serdeFactory.create(key.serdeInfo())::deserialize); } @Override diff --git a/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java index 4bb4636f1..4a68d0a93 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java +++ b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java @@ -31,4 +31,12 @@ public interface InvocationHandle { * @return the output of this invocation, if present. */ Output getOutput(); + + /** + * Get a {@link SignalHandle} for resolving or rejecting a named signal on this invocation. The + * receiving handler can await on the signal using {@link Context#signal(String, Class)}. + * + * @param name the signal name. + */ + SignalHandle signal(String name); } diff --git a/sdk-api/src/main/java/dev/restate/sdk/SignalHandle.java b/sdk-api/src/main/java/dev/restate/sdk/SignalHandle.java new file mode 100644 index 000000000..fb5886729 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/SignalHandle.java @@ -0,0 +1,48 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.serde.TypeTag; + +/** + * Handle to resolve or reject a named signal on a target invocation. Acquired via {@link + * InvocationHandle#signal(String)}. + * + *

Unlike awakeables, signals are identified by {@code (invocationId, name)} and do not need to + * be pre-registered: the resolution can arrive before or after the handler starts waiting on the + * signal. + */ +public interface SignalHandle { + + /** + * Resolve the signal with the given value. + * + * @param typeTag used to serialize the result payload. + * @param payload the result payload. MUST NOT be null. + */ + void resolve(TypeTag typeTag, T payload); + + /** + * Resolve the signal with the given value. + * + * @param clazz used to serialize the result payload. + * @param payload the result payload. MUST NOT be null. + */ + default void resolve(Class clazz, T payload) { + resolve(TypeTag.of(clazz), payload); + } + + /** + * Reject the signal with the given reason. The handler awaiting the signal will receive a + * terminal error with {@code reason} as the message. + * + * @param reason the rejection reason. MUST NOT be null. + */ + void reject(String reason); +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java index 8fd10b5c4..61443a919 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java @@ -106,6 +106,17 @@ record Awakeable(String id, AsyncResult asyncResult) {} CompletableFuture> rejectPromise(String key, TerminalException reason); + // ----- Named signals + // + // Signals are identified by (invocationId, name). Unlike awakeables, signals do not need to be + // pre-registered: the resolution can arrive before or after the handler starts waiting. + + CompletableFuture> signal(String name); + + CompletableFuture resolveSignal(String invocationId, String name, Slice payload); + + CompletableFuture rejectSignal(String invocationId, String name, TerminalException reason); + CompletableFuture cancelInvocation(String invocationId); CompletableFuture> attachInvocation(String invocationId); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java index a2938b55f..980b64d78 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -346,6 +346,28 @@ public CompletableFuture> rejectPromise(String key, TerminalEx HandlerContextImpl::parseEmptyOrFailure)); } + @Override + public CompletableFuture> signal(String name) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.createSignalHandle(name), + HandlerContextImpl::parseSuccessOrFailure)); + } + + @Override + public CompletableFuture resolveSignal(String invocationId, String name, Slice payload) { + return this.catchExceptions( + () -> this.stateMachine.completeSignal(invocationId, name, payload)); + } + + @Override + public CompletableFuture rejectSignal( + String invocationId, String name, TerminalException reason) { + return this.catchExceptions(() -> this.stateMachine.completeSignal(invocationId, name, reason)); + } + @Override public CompletableFuture cancelInvocation(String invocationId) { return this.catchExceptions(() -> this.stateMachine.cancelInvocation(invocationId)); diff --git a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java index 9ca13e61c..751a9155a 100644 --- a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java +++ b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java @@ -105,6 +105,11 @@ public AwakeableHandle awakeableHandle(String s) { return inner.awakeableHandle(s); } + @Override + public DurableFuture signal(String name, TypeTag typeTag) { + return inner.signal(name, typeTag); + } + @Override public RestateRandom random() { return inner.random(); diff --git a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java index 8459d10c8..ee525e39f 100644 --- a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java +++ b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java @@ -255,6 +255,25 @@ public CompletableFuture> rejectPromise(String s, TerminalExce "FakeHandlerContext doesn't currently support mocking this operation"); } + @Override + public CompletableFuture> signal(String name) { + throw new UnsupportedOperationException( + "FakeHandlerContext doesn't currently support mocking this operation"); + } + + @Override + public CompletableFuture resolveSignal(String invocationId, String name, Slice payload) { + throw new UnsupportedOperationException( + "FakeHandlerContext doesn't currently support mocking this operation"); + } + + @Override + public CompletableFuture rejectSignal( + String invocationId, String name, TerminalException reason) { + throw new UnsupportedOperationException( + "FakeHandlerContext doesn't currently support mocking this operation"); + } + @Override public CompletableFuture cancelInvocation(String s) { throw new UnsupportedOperationException( From 9ce0aa55b344a9c7bd42bcb8f14fa85b4438bed7 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Thu, 18 Jun 2026 11:58:43 +0200 Subject: [PATCH 2/2] Fix all the story around thread trampolining, issues and co. Now the mapped promise will try to resolve in whatever thread they are, and just trampoline back to the core executor when needed (for fail) --- .../dev/restate/sdk/core/AsyncResults.java | 18 ++-- .../dev/restate/sdk/core/ExceptionUtils.java | 20 +++- .../restate/sdk/core/HandlerContextImpl.java | 6 -- .../sdk/core/HandlerContextInternal.java | 3 - .../sdk/core/RequestProcessorImpl.java | 2 +- .../restate/sdk/core/MockRequestResponse.java | 9 +- .../sdk/core/javaapi/JavaAPITests.java | 3 +- .../javaapi/SerdeThreadTrampoliningTest.java | 97 +++++++++++++++++++ 8 files changed, 135 insertions(+), 23 deletions(-) create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SerdeThreadTrampoliningTest.java diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java index 4f128cca6..a24db3b24 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java @@ -190,24 +190,23 @@ private static CompletableFuture compose( try { failureMapper .apply((TerminalException) throwable) - .whenCompleteAsync( + .whenComplete( (u, mapperT) -> { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else if (mapperT != null) { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally( AbortedExecutionException.INSTANCE); } else { downstreamFuture.complete(u); } - }, - ctx.stateMachineExecutor()); + }); } catch (Throwable mapperT) { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally(AbortedExecutionException.INSTANCE); } } @@ -223,24 +222,23 @@ private static CompletableFuture compose( try { successMapper .apply(t) - .whenCompleteAsync( + .whenComplete( (u, mapperT) -> { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else if (mapperT != null) { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally( AbortedExecutionException.INSTANCE); } else { downstreamFuture.complete(u); } - }, - ctx.stateMachineExecutor()); + }); } catch (Throwable mapperT) { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally(AbortedExecutionException.INSTANCE); } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java index 92a3593b0..4bcdd7566 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java @@ -11,11 +11,29 @@ import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.TerminalException; import java.util.Optional; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; import java.util.function.Predicate; public final class ExceptionUtils { private ExceptionUtils() {} + /** + * Unwrap the {@link CompletionException}/{@link ExecutionException} wrappers introduced by the + * {@link java.util.concurrent.CompletableFuture} machinery, returning the underlying cause. The + * reported error message and stacktrace should reflect the user-thrown exception, not the + * executor plumbing. + */ + public static Throwable unwrapCompletionException(Throwable throwable) { + Throwable current = throwable; + while ((current instanceof CompletionException || current instanceof ExecutionException) + && current.getCause() != null + && current.getCause() != current) { + current = current.getCause(); + } + return current; + } + @SuppressWarnings("unchecked") public static void sneakyThrow(Throwable e) throws E { throw (E) e; @@ -53,7 +71,7 @@ public static Optional findProtocolException(Throwable throwa return findCause(throwable, t -> t instanceof ProtocolException); } - public static boolean containsSuspendedException(Throwable throwable) { + public static boolean containsAbortedExecutionException(Throwable throwable) { return findCause(throwable, t -> t == AbortedExecutionException.INSTANCE).isPresent(); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java index 980b64d78..32a407a44 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -27,7 +27,6 @@ import java.time.Instant; import java.util.*; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.stream.Stream; import org.apache.logging.log4j.LogManager; @@ -154,11 +153,6 @@ public InvocationState getInvocationState() { return this.stateMachine.state(); } - @Override - public Executor stateMachineExecutor() { - return Runnable::run; - } - @Override public CompletableFuture>> get(String name) { return catchExceptions( diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java index b89b18c9c..d6a9bfe8c 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java @@ -16,7 +16,6 @@ import dev.restate.sdk.endpoint.definition.HandlerContext; import java.time.Duration; import java.util.List; -import java.util.concurrent.Executor; import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; @@ -59,7 +58,5 @@ void proposeRunFailure( InvocationState getInvocationState(); - Executor stateMachineExecutor(); - void failWithoutContextSwitch(Throwable throwable); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java index 15f540894..a35a60475 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java @@ -200,7 +200,7 @@ private CompletableFuture writeOutputAndEnd( private CompletableFuture end( HandlerContextInternal contextInternal, @Nullable Throwable exception) { - if (exception == null || ExceptionUtils.containsSuspendedException(exception)) { + if (exception == null || ExceptionUtils.containsAbortedExecutionException(exception)) { contextInternal.close(); } else if (contextInternal.getInvocationState() != InvocationState.CLOSED) { if (ExceptionUtils.isTerminalException(exception)) { diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java index a5488af1f..250a0a58f 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java @@ -41,7 +41,14 @@ public boolean buffered() { @Override public void executeTest(TestDefinition definition) { - Executor syscallsExecutor = Executors.newSingleThreadExecutor(); + Executor syscallsExecutor = + Executors.newSingleThreadExecutor( + runnable -> { + Thread t = new Thread(runnable, "coreExecutor"); + if (t.isDaemon()) t.setDaemon(false); + if (t.getPriority() != Thread.NORM_PRIORITY) t.setPriority(Thread.NORM_PRIORITY); + return t; + }); ServiceDefinition serviceDefinition = definition.getServiceDefinition(); diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java index 1d6ef41c0..f326f70da 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java @@ -51,7 +51,8 @@ public Stream definitions() { new UserFailuresTest(), new RandomTest(), new CodegenTest(), - new ReflectionTest()); + new ReflectionTest(), + new SerdeThreadTrampoliningTest()); } public static TestInvocationBuilder testDefinitionForService( diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SerdeThreadTrampoliningTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SerdeThreadTrampoliningTest.java new file mode 100644 index 000000000..6e98ec1f8 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SerdeThreadTrampoliningTest.java @@ -0,0 +1,97 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi; + +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.END_MESSAGE; +import static org.assertj.core.api.Assertions.assertThat; + +import dev.restate.common.Slice; +import dev.restate.sdk.HandlerRunner; +import dev.restate.sdk.ObjectContext; +import dev.restate.sdk.common.StateKey; +import dev.restate.sdk.core.TestDefinitions; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.endpoint.definition.HandlerType; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import dev.restate.sdk.endpoint.definition.ServiceType; +import dev.restate.serde.Serde; +import dev.restate.serde.jackson.JacksonSerdeFactory; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.stream.Stream; +import org.jspecify.annotations.NonNull; + +public class SerdeThreadTrampoliningTest implements TestDefinitions.TestSuite { + + private static final ThreadLocal THREAD_LOCAL = new ThreadLocal<>(); + + private static final StateKey STATE = + StateKey.of( + "STATE", + new Serde() { + @Override + public Slice serialize(String value) { + throw new IllegalStateException("Unexpected call to serialize"); + } + + @Override + public String deserialize(@NonNull Slice value) { + assertThreadLocal(); + return TestSerdes.STRING.deserialize(value); + } + }); + + private static void assertThreadLocal() { + assertThat(THREAD_LOCAL.get()).isEqualTo("UserThread"); + } + + @Override + public Stream definitions() { + Executor executor = Executors.newCachedThreadPool(); + Executor wrappedExecutor = + runnable -> + executor.execute( + () -> { + THREAD_LOCAL.set("UserThread"); + try { + runnable.run(); + } finally { + THREAD_LOCAL.remove(); + } + }); + + return Stream.of( + TestDefinitions.testInvocation( + ServiceDefinition.of( + "SerdeThreadTrampolining", + ServiceType.VIRTUAL_OBJECT, + List.of( + HandlerDefinition.of( + "run", + HandlerType.EXCLUSIVE, + Serde.VOID, + TestSerdes.STRING, + HandlerRunner.of( + (ctx, unused) -> { + assertThreadLocal(); + String result = ((ObjectContext) ctx).get(STATE).get(); + assertThreadLocal(); + return result; + }, + new JacksonSerdeFactory(), + new HandlerRunner.Options().setExecutor(wrappedExecutor))))), + "run") + .withInput( + startMessage(2), inputCmd("Francesco"), getEagerStateCmd(STATE.name(), "Value")) + .expectingOutput(outputCmd("Value"), END_MESSAGE)); + } +}