diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index a74a86a7..09ae5cd3 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -125,7 +125,7 @@ internal constructor( ) .await() - object : BaseInvocationHandle(handlerContext, responseSerde) { + object : BaseInvocationHandle(this, responseSerde) { override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await() } } @@ -136,7 +136,7 @@ internal constructor( responseTypeTag: TypeTag, ): InvocationHandle = resolveSerde(responseTypeTag).let { responseSerde -> - object : BaseInvocationHandle(handlerContext, responseSerde) { + object : BaseInvocationHandle(this, responseSerde) { override suspend fun invocationId(): String = invocationId } } @@ -200,6 +200,14 @@ internal constructor( return AwakeableHandleImpl(this, id) } + override suspend fun signal(name: String, typeTag: TypeTag): DurableFuture { + checkNotInsideRun() + val serde: Serde = resolveSerde(typeTag) + return SingleDurableFutureImpl(handlerContext.signal(name).await()).simpleMap { + serde.deserialize(it) + } + } + override fun random(): RestateRandom { return this.random } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index 45b8830f..020f3855 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -201,6 +201,21 @@ sealed interface Context { */ fun awakeableHandle(id: String): AwakeableHandle + /** + * Create a [DurableFuture] waiting on a named signal targeting the current invocation. + * + * Signals are identified by `(invocationId, name)`. The resolution can arrive before or after the + * handler starts waiting on the signal — there's no need to pre-register. + * + * Another invocation can resolve or reject the signal using [signalHandle]. + * + * @param name the signal name. + * @param typeTag the response type tag to use for deserializing the signal result. + * @return a [DurableFuture] that resolves to the signal value (or rejects with a + * [dev.restate.sdk.common.TerminalException]). + */ + suspend fun signal(name: String, typeTag: TypeTag): DurableFuture + /** * Create a [RestateRandom] instance inherently predictable, seeded on the * [dev.restate.sdk.common.InvocationId], which is not secret. @@ -336,6 +351,15 @@ suspend inline fun Context.awakeable(): Awakeable { return this.awakeable(typeTag()) } +/** + * Create a [DurableFuture] waiting on a named signal targeting the current invocation. + * + * @see Context.signal + */ +suspend inline fun Context.signal(name: String): DurableFuture { + return this.signal(name, typeTag()) +} + /** * This interface can be used only within shared handlers of virtual objects. It extends [Context] * adding access to the virtual object instance key-value state storage. @@ -629,6 +653,14 @@ sealed interface InvocationHandle { /** @return the output of this invocation, if present. */ suspend fun output(): Output + + /** + * Get a [SignalHandle] for resolving or rejecting a named signal on this invocation. The + * receiving handler can await on the signal using [Context.signal]. + * + * @param name the signal name. + */ + suspend fun signal(name: String): SignalHandle } /** @@ -677,6 +709,35 @@ suspend inline fun AwakeableHandle.resolve(payload: T) { return this.resolve(typeTag(), payload) } +/** + * Handle to resolve or reject a named signal on a target invocation. + * + * Unlike awakeables, signals are identified by `(invocationId, name)` and do not need to be + * pre-registered: the resolution can arrive before or after the handler starts waiting. + */ +sealed interface SignalHandle { + /** + * Resolve the signal with the given value. + * + * @param typeTag used to serialize the result payload. + * @param payload the result payload. + */ + suspend fun resolve(typeTag: TypeTag, payload: T) + + /** + * Reject the signal with the given reason. The handler awaiting the signal will receive a + * terminal error with [reason] as the message. + * + * @param reason the rejection reason. + */ + suspend fun reject(reason: String) +} + +/** Resolve the signal with the given value. */ +suspend inline fun SignalHandle.resolve(payload: T) { + return this.resolve(typeTag(), payload) +} + /** * A [DurablePromise] is a durable, distributed version of a Kotlin's Deferred, or more commonly of * a future/promise. Restate keeps track of the [DurablePromise] across restarts/failures. @@ -965,6 +1026,17 @@ suspend fun awakeableHandle(id: String): AwakeableHandle { return context().awakeableHandle(id) } +/** + * Create a [DurableFuture] waiting on a named signal targeting the current invocation. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.signal + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun signal(name: String): DurableFuture { + return context().signal(name, typeTag()) +} + /** * Get an [InvocationHandle] for an already existing invocation. * diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt index 2d05e48c..32496189 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt @@ -191,9 +191,12 @@ internal constructor( internal abstract class BaseInvocationHandle internal constructor( - private val handlerContext: HandlerContext, + private val contextImpl: ContextImpl, private val responseSerde: Serde, ) : InvocationHandle { + private val handlerContext: HandlerContext + get() = contextImpl.handlerContext + override suspend fun cancel() { checkNotInsideRun() val ignored = handlerContext.cancelInvocation(invocationId()).await() @@ -214,6 +217,11 @@ internal constructor( .simpleMap { it.map { responseSerde.deserialize(it) } } .await() } + + override suspend fun signal(name: String): SignalHandle { + val resolvedId = invocationId() + return SignalHandleImpl(contextImpl, resolvedId, name) + } } internal class AwakeableImpl @@ -237,6 +245,24 @@ internal class AwakeableHandleImpl(val contextImpl: ContextImpl, val id: String) } } +internal class SignalHandleImpl( + val contextImpl: ContextImpl, + val invocationId: String, + val name: String, +) : SignalHandle { + override suspend fun resolve(typeTag: TypeTag, payload: T) { + checkNotInsideRun() + contextImpl.handlerContext + .resolveSignal(invocationId, name, contextImpl.resolveAndSerialize(typeTag, payload)) + .await() + } + + override suspend fun reject(reason: String) { + checkNotInsideRun() + contextImpl.handlerContext.rejectSignal(invocationId, name, TerminalException(reason)).await() + } +} + internal class SelectClauseImpl(override val durableFuture: DurableFuture) : SelectClause @PublishedApi diff --git a/sdk-api/src/main/java/dev/restate/sdk/Context.java b/sdk-api/src/main/java/dev/restate/sdk/Context.java index a31338e3..1ff5fc1b 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Context.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Context.java @@ -478,6 +478,36 @@ default Awakeable awakeable(Class clazz) { */ AwakeableHandle awakeableHandle(String id); + /** + * Create a {@link DurableFuture} waiting on a named signal targeting the current invocation. + * + *

Signals are identified by {@code (invocationId, name)}. The resolution can arrive before or + * after the handler starts waiting on the signal — there's no need to pre-register. + * + *

Another invocation can resolve or reject the signal using {@link + * SignalHandle#resolve(TypeTag, Object)} / {@link SignalHandle#reject(String)}. + * + * @param name the signal name. + * @param clazz the response type to use for deserializing the signal result. When using generic + * types, use {@link #signal(String, TypeTag)} instead. + * @return a {@link DurableFuture} that resolves to the signal value (or rejects with a {@link + * TerminalException}). + */ + default DurableFuture signal(String name, Class clazz) { + return signal(name, TypeTag.of(clazz)); + } + + /** + * Create a {@link DurableFuture} waiting on a named signal targeting the current invocation. + * + * @param name the signal name. + * @param typeTag the response type tag to use for deserializing the signal result. + * @return a {@link DurableFuture} that resolves to the signal value (or rejects with a {@link + * TerminalException}). + * @see #signal(String, Class) + */ + DurableFuture signal(String name, TypeTag typeTag); + /** * Returns a deterministic random. * diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 9d1b5df5..9a0259af 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -82,7 +82,7 @@ public Optional get(StateKey key) { checkNotInsideRun(); return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.get(key.name())), serviceExecutor) - .mapWithoutExecutor(opt -> opt.map(serdeFactory.create(key.serdeInfo())::deserialize)) + .map(opt -> opt.map(serdeFactory.create(key.serdeInfo())::deserialize)) .await(); } @@ -227,6 +227,30 @@ public Output getOutput() { serviceExecutor) .await(); } + + @Override + public SignalHandle signal(String name) { + String invocationId = invocationId(); + return new SignalHandle() { + @Override + public void resolve(TypeTag typeTag, T payload) { + checkNotInsideRun(); + Util.awaitCompletableFuture( + handlerContext.resolveSignal( + invocationId, + name, + Util.executeOrFail( + handlerContext, serdeFactory.create(typeTag)::serialize, payload))); + } + + @Override + public void reject(String reason) { + checkNotInsideRun(); + Util.awaitCompletableFuture( + handlerContext.rejectSignal(invocationId, name, new TerminalException(reason))); + } + }; + } } @Override @@ -249,7 +273,7 @@ public DurableFuture runAsync( return DurableFuture.fromAsyncResult( Util.awaitCompletableFuture(handlerContext.submitRun(name, runClosure)), serviceExecutor) - .mapWithoutExecutor(serde::deserialize); + .map(serde::deserialize); } private void executeRunAction( @@ -325,6 +349,14 @@ public void reject(String reason) { }; } + @Override + public DurableFuture signal(String name, TypeTag typeTag) throws TerminalException { + checkNotInsideRun(); + Serde serde = serdeFactory.create(typeTag); + AsyncResult result = Util.awaitCompletableFuture(handlerContext.signal(name)); + return DurableFuture.fromAsyncResult(result, serviceExecutor).map(serde::deserialize); + } + @Override public RestateRandom random() { return this.random; @@ -338,7 +370,7 @@ public DurableFuture future() { checkNotInsideRun(); AsyncResult result = Util.awaitCompletableFuture(handlerContext.promise(key.name())); return DurableFuture.fromAsyncResult(result, serviceExecutor) - .mapWithoutExecutor(serdeFactory.create(key.serdeInfo())::deserialize); + .map(serdeFactory.create(key.serdeInfo())::deserialize); } @Override diff --git a/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java index 4bb4636f..4a68d0a9 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java +++ b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java @@ -31,4 +31,12 @@ public interface InvocationHandle { * @return the output of this invocation, if present. */ Output getOutput(); + + /** + * Get a {@link SignalHandle} for resolving or rejecting a named signal on this invocation. The + * receiving handler can await on the signal using {@link Context#signal(String, Class)}. + * + * @param name the signal name. + */ + SignalHandle signal(String name); } diff --git a/sdk-api/src/main/java/dev/restate/sdk/SignalHandle.java b/sdk-api/src/main/java/dev/restate/sdk/SignalHandle.java new file mode 100644 index 00000000..fb588672 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/SignalHandle.java @@ -0,0 +1,48 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.serde.TypeTag; + +/** + * Handle to resolve or reject a named signal on a target invocation. Acquired via {@link + * InvocationHandle#signal(String)}. + * + *

Unlike awakeables, signals are identified by {@code (invocationId, name)} and do not need to + * be pre-registered: the resolution can arrive before or after the handler starts waiting on the + * signal. + */ +public interface SignalHandle { + + /** + * Resolve the signal with the given value. + * + * @param typeTag used to serialize the result payload. + * @param payload the result payload. MUST NOT be null. + */ + void resolve(TypeTag typeTag, T payload); + + /** + * Resolve the signal with the given value. + * + * @param clazz used to serialize the result payload. + * @param payload the result payload. MUST NOT be null. + */ + default void resolve(Class clazz, T payload) { + resolve(TypeTag.of(clazz), payload); + } + + /** + * Reject the signal with the given reason. The handler awaiting the signal will receive a + * terminal error with {@code reason} as the message. + * + * @param reason the rejection reason. MUST NOT be null. + */ + void reject(String reason); +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java index 8fd10b5c..61443a91 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java @@ -106,6 +106,17 @@ record Awakeable(String id, AsyncResult asyncResult) {} CompletableFuture> rejectPromise(String key, TerminalException reason); + // ----- Named signals + // + // Signals are identified by (invocationId, name). Unlike awakeables, signals do not need to be + // pre-registered: the resolution can arrive before or after the handler starts waiting. + + CompletableFuture> signal(String name); + + CompletableFuture resolveSignal(String invocationId, String name, Slice payload); + + CompletableFuture rejectSignal(String invocationId, String name, TerminalException reason); + CompletableFuture cancelInvocation(String invocationId); CompletableFuture> attachInvocation(String invocationId); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java index 4f128cca..a24db3b2 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java @@ -190,24 +190,23 @@ private static CompletableFuture compose( try { failureMapper .apply((TerminalException) throwable) - .whenCompleteAsync( + .whenComplete( (u, mapperT) -> { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else if (mapperT != null) { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally( AbortedExecutionException.INSTANCE); } else { downstreamFuture.complete(u); } - }, - ctx.stateMachineExecutor()); + }); } catch (Throwable mapperT) { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally(AbortedExecutionException.INSTANCE); } } @@ -223,24 +222,23 @@ private static CompletableFuture compose( try { successMapper .apply(t) - .whenCompleteAsync( + .whenComplete( (u, mapperT) -> { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else if (mapperT != null) { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally( AbortedExecutionException.INSTANCE); } else { downstreamFuture.complete(u); } - }, - ctx.stateMachineExecutor()); + }); } catch (Throwable mapperT) { if (ExceptionUtils.isTerminalException(mapperT)) { downstreamFuture.completeExceptionally(mapperT); } else { - ctx.failWithoutContextSwitch(mapperT); + ctx.fail(ExceptionUtils.unwrapCompletionException(mapperT)); downstreamFuture.completeExceptionally(AbortedExecutionException.INSTANCE); } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java index 92a3593b..4bcdd756 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java @@ -11,11 +11,29 @@ import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.TerminalException; import java.util.Optional; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; import java.util.function.Predicate; public final class ExceptionUtils { private ExceptionUtils() {} + /** + * Unwrap the {@link CompletionException}/{@link ExecutionException} wrappers introduced by the + * {@link java.util.concurrent.CompletableFuture} machinery, returning the underlying cause. The + * reported error message and stacktrace should reflect the user-thrown exception, not the + * executor plumbing. + */ + public static Throwable unwrapCompletionException(Throwable throwable) { + Throwable current = throwable; + while ((current instanceof CompletionException || current instanceof ExecutionException) + && current.getCause() != null + && current.getCause() != current) { + current = current.getCause(); + } + return current; + } + @SuppressWarnings("unchecked") public static void sneakyThrow(Throwable e) throws E { throw (E) e; @@ -53,7 +71,7 @@ public static Optional findProtocolException(Throwable throwa return findCause(throwable, t -> t instanceof ProtocolException); } - public static boolean containsSuspendedException(Throwable throwable) { + public static boolean containsAbortedExecutionException(Throwable throwable) { return findCause(throwable, t -> t == AbortedExecutionException.INSTANCE).isPresent(); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java index a2938b55..32a407a4 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -27,7 +27,6 @@ import java.time.Instant; import java.util.*; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.stream.Stream; import org.apache.logging.log4j.LogManager; @@ -154,11 +153,6 @@ public InvocationState getInvocationState() { return this.stateMachine.state(); } - @Override - public Executor stateMachineExecutor() { - return Runnable::run; - } - @Override public CompletableFuture>> get(String name) { return catchExceptions( @@ -346,6 +340,28 @@ public CompletableFuture> rejectPromise(String key, TerminalEx HandlerContextImpl::parseEmptyOrFailure)); } + @Override + public CompletableFuture> signal(String name) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.createSignalHandle(name), + HandlerContextImpl::parseSuccessOrFailure)); + } + + @Override + public CompletableFuture resolveSignal(String invocationId, String name, Slice payload) { + return this.catchExceptions( + () -> this.stateMachine.completeSignal(invocationId, name, payload)); + } + + @Override + public CompletableFuture rejectSignal( + String invocationId, String name, TerminalException reason) { + return this.catchExceptions(() -> this.stateMachine.completeSignal(invocationId, name, reason)); + } + @Override public CompletableFuture cancelInvocation(String invocationId) { return this.catchExceptions(() -> this.stateMachine.cancelInvocation(invocationId)); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java index b89b18c9..d6a9bfe8 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java @@ -16,7 +16,6 @@ import dev.restate.sdk.endpoint.definition.HandlerContext; import java.time.Duration; import java.util.List; -import java.util.concurrent.Executor; import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; @@ -59,7 +58,5 @@ void proposeRunFailure( InvocationState getInvocationState(); - Executor stateMachineExecutor(); - void failWithoutContextSwitch(Throwable throwable); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java index 15f54089..a35a6047 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java @@ -200,7 +200,7 @@ private CompletableFuture writeOutputAndEnd( private CompletableFuture end( HandlerContextInternal contextInternal, @Nullable Throwable exception) { - if (exception == null || ExceptionUtils.containsSuspendedException(exception)) { + if (exception == null || ExceptionUtils.containsAbortedExecutionException(exception)) { contextInternal.close(); } else if (contextInternal.getInvocationState() != InvocationState.CLOSED) { if (ExceptionUtils.isTerminalException(exception)) { diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java index a5488af1..250a0a58 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java @@ -41,7 +41,14 @@ public boolean buffered() { @Override public void executeTest(TestDefinition definition) { - Executor syscallsExecutor = Executors.newSingleThreadExecutor(); + Executor syscallsExecutor = + Executors.newSingleThreadExecutor( + runnable -> { + Thread t = new Thread(runnable, "coreExecutor"); + if (t.isDaemon()) t.setDaemon(false); + if (t.getPriority() != Thread.NORM_PRIORITY) t.setPriority(Thread.NORM_PRIORITY); + return t; + }); ServiceDefinition serviceDefinition = definition.getServiceDefinition(); diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java index 1d6ef41c..f326f70d 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java @@ -51,7 +51,8 @@ public Stream definitions() { new UserFailuresTest(), new RandomTest(), new CodegenTest(), - new ReflectionTest()); + new ReflectionTest(), + new SerdeThreadTrampoliningTest()); } public static TestInvocationBuilder testDefinitionForService( diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SerdeThreadTrampoliningTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SerdeThreadTrampoliningTest.java new file mode 100644 index 00000000..6e98ec1f --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SerdeThreadTrampoliningTest.java @@ -0,0 +1,97 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi; + +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.END_MESSAGE; +import static org.assertj.core.api.Assertions.assertThat; + +import dev.restate.common.Slice; +import dev.restate.sdk.HandlerRunner; +import dev.restate.sdk.ObjectContext; +import dev.restate.sdk.common.StateKey; +import dev.restate.sdk.core.TestDefinitions; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.endpoint.definition.HandlerType; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import dev.restate.sdk.endpoint.definition.ServiceType; +import dev.restate.serde.Serde; +import dev.restate.serde.jackson.JacksonSerdeFactory; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.stream.Stream; +import org.jspecify.annotations.NonNull; + +public class SerdeThreadTrampoliningTest implements TestDefinitions.TestSuite { + + private static final ThreadLocal THREAD_LOCAL = new ThreadLocal<>(); + + private static final StateKey STATE = + StateKey.of( + "STATE", + new Serde() { + @Override + public Slice serialize(String value) { + throw new IllegalStateException("Unexpected call to serialize"); + } + + @Override + public String deserialize(@NonNull Slice value) { + assertThreadLocal(); + return TestSerdes.STRING.deserialize(value); + } + }); + + private static void assertThreadLocal() { + assertThat(THREAD_LOCAL.get()).isEqualTo("UserThread"); + } + + @Override + public Stream definitions() { + Executor executor = Executors.newCachedThreadPool(); + Executor wrappedExecutor = + runnable -> + executor.execute( + () -> { + THREAD_LOCAL.set("UserThread"); + try { + runnable.run(); + } finally { + THREAD_LOCAL.remove(); + } + }); + + return Stream.of( + TestDefinitions.testInvocation( + ServiceDefinition.of( + "SerdeThreadTrampolining", + ServiceType.VIRTUAL_OBJECT, + List.of( + HandlerDefinition.of( + "run", + HandlerType.EXCLUSIVE, + Serde.VOID, + TestSerdes.STRING, + HandlerRunner.of( + (ctx, unused) -> { + assertThreadLocal(); + String result = ((ObjectContext) ctx).get(STATE).get(); + assertThreadLocal(); + return result; + }, + new JacksonSerdeFactory(), + new HandlerRunner.Options().setExecutor(wrappedExecutor))))), + "run") + .withInput( + startMessage(2), inputCmd("Francesco"), getEagerStateCmd(STATE.name(), "Value")) + .expectingOutput(outputCmd("Value"), END_MESSAGE)); + } +} diff --git a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java index 9ca13e61..751a9155 100644 --- a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java +++ b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeContext.java @@ -105,6 +105,11 @@ public AwakeableHandle awakeableHandle(String s) { return inner.awakeableHandle(s); } + @Override + public DurableFuture signal(String name, TypeTag typeTag) { + return inner.signal(name, typeTag); + } + @Override public RestateRandom random() { return inner.random(); diff --git a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java index 8459d10c..ee525e39 100644 --- a/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java +++ b/sdk-fake-api/src/main/java/dev/restate/sdk/fake/FakeHandlerContext.java @@ -255,6 +255,25 @@ public CompletableFuture> rejectPromise(String s, TerminalExce "FakeHandlerContext doesn't currently support mocking this operation"); } + @Override + public CompletableFuture> signal(String name) { + throw new UnsupportedOperationException( + "FakeHandlerContext doesn't currently support mocking this operation"); + } + + @Override + public CompletableFuture resolveSignal(String invocationId, String name, Slice payload) { + throw new UnsupportedOperationException( + "FakeHandlerContext doesn't currently support mocking this operation"); + } + + @Override + public CompletableFuture rejectSignal( + String invocationId, String name, TerminalException reason) { + throw new UnsupportedOperationException( + "FakeHandlerContext doesn't currently support mocking this operation"); + } + @Override public CompletableFuture cancelInvocation(String s) { throw new UnsupportedOperationException(