diff --git a/java-runtime/src/main/scala/client/Fs2ClientCall.scala b/java-runtime/src/main/scala/client/Fs2ClientCall.scala index 280e60e5..6390602e 100644 --- a/java-runtime/src/main/scala/client/Fs2ClientCall.scala +++ b/java-runtime/src/main/scala/client/Fs2ClientCall.scala @@ -3,6 +3,7 @@ package java_runtime package client import cats.effect._ +import cats.effect.concurrent.{Deferred, Ref} import cats.implicits._ import io.grpc.{Metadata, _} import fs2._ @@ -10,7 +11,19 @@ import fs2._ final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus]) final case class GrpcStatus(status: Status, trailers: Metadata) -class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCall[Request, Response]) extends AnyVal { +class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCall[Request, Response], + val wakeOnReady: Ref[F, Option[Deferred[F, Unit]]]) { + def onReady()(implicit F: Sync[F]): F[Unit] = { + wakeOnReady + .modify({ + case None => (None, F.unit) + case Some(wake) => (None, wake.complete(())) + }) + .flatten + } + + private def isReady(implicit F: Sync[F]): F[Boolean] = + F.delay(call.isReady) private def cancel(message: Option[String], cause: Option[Throwable])(implicit F: Sync[F]): F[Unit] = F.delay(call.cancel(message.orNull, cause.orNull)) @@ -21,13 +34,26 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCa private def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] = F.delay(call.request(numMessages)) - private def sendMessage(message: Request)(implicit F: Sync[F]): F[Unit] = + private def sendMessage(message: Request)(implicit F: Sync[F]): F[Unit] = { F.delay(call.sendMessage(message)) + } + + private def sendMessageOrDelay(message: Request)(implicit F: Concurrent[F]): F[Unit] = { + isReady.ifM( + sendMessage(message), { + Deferred[F, Unit].flatMap { wakeup => + wakeOnReady.set(wakeup.some) *> + isReady.ifM(sendMessage(message), wakeup.get *> sendMessage(message)) + } + } + ) + } private def start(listener: ClientCall.Listener[Response], metadata: Metadata)(implicit F: Sync[F]): F[Unit] = F.delay(call.start(listener, metadata)) - def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)(implicit F: Sync[F]): F[A] = { + def startListener[A <: ClientCall.Listener[Response]](createListener: F[A], headers: Metadata)( + implicit F: Sync[F]): F[A] = { createListener.flatTap(start(_, headers)) <* request(1) } @@ -35,8 +61,8 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (val call: ClientCa sendMessage(message) *> halfClose } - def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] = { - stream.evalMap(sendMessage) ++ Stream.eval(halfClose) + def sendStream(stream: Stream[F, Request])(implicit F: Concurrent[F]): Stream[F, Unit] = { + stream.evalMap(sendMessageOrDelay) ++ Stream.eval(halfClose) } def handleCallError( @@ -82,8 +108,13 @@ object Fs2ClientCall { channel: Channel, methodDescriptor: MethodDescriptor[Request, Response], callOptions: CallOptions)(implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] = - F.delay(new Fs2ClientCall(channel.newCall[Request, Response](methodDescriptor, callOptions))) + apply(channel.newCall[Request, Response](methodDescriptor, callOptions)) + def apply[Request, Response](call: ClientCall[Request, Response])( + implicit F: Sync[F]): F[Fs2ClientCall[F, Request, Response]] = + for { + wakeOnReady <- Ref[F].of(none[Deferred[F, Unit]]) + } yield new Fs2ClientCall(call, wakeOnReady) } def apply[F[_]]: PartiallyAppliedClientCall[F] = diff --git a/java-runtime/src/main/scala/server/Fs2ServerCall.scala b/java-runtime/src/main/scala/server/Fs2ServerCall.scala index b986b04e..846619af 100644 --- a/java-runtime/src/main/scala/server/Fs2ServerCall.scala +++ b/java-runtime/src/main/scala/server/Fs2ServerCall.scala @@ -3,10 +3,25 @@ package java_runtime package server import cats.effect._ +import cats.effect.concurrent.{Deferred, Ref} +import cats.implicits._ import io.grpc._ // TODO: Add attributes, compression, message compression. -private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCall[Request, Response]) extends AnyVal { +private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCall[Request, Response], + val wakeOnReady: Ref[F, Option[Deferred[F, Unit]]]) { + def onReady()(implicit F: Sync[F]): F[Unit] = { + wakeOnReady + .modify({ + case None => (None, F.unit) + case Some(wake) => (None, wake.complete(())) + }) + .flatten + } + + def isReady(implicit F: Sync[F]): F[Boolean] = + F.delay(call.isReady) + def sendHeaders(headers: Metadata)(implicit F: Sync[F]): F[Unit] = F.delay(call.sendHeaders(headers)) @@ -16,11 +31,24 @@ private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCal def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] = F.delay(call.sendMessage(message)) + def sendMessageOrDelay(message: Response)(implicit F: Concurrent[F]): F[Unit] = + isReady.ifM( + sendMessage(message), { + Deferred[F, Unit].flatMap { wakeup => + wakeOnReady.set(wakeup.some) *> + isReady.ifM(sendMessage(message), wakeup.get *> sendMessage(message)) + } + } + ) + def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] = F.delay(call.request(numMessages)) } private[server] object Fs2ServerCall { - def apply[F[_], Request, Response](call: ServerCall[Request, Response]): Fs2ServerCall[F, Request, Response] = - new Fs2ServerCall[F, Request, Response](call) + def apply[F[_], Request, Response](call: ServerCall[Request, Response])( + implicit F: Concurrent[F]): F[Fs2ServerCall[F, Request, Response]] = + for { + wakeOnReady <- Ref[F].of(none[Deferred[F, Unit]]) + } yield new Fs2ServerCall[F, Request, Response](call, wakeOnReady) } diff --git a/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala b/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala index 324dfbe1..ddaebc38 100644 --- a/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala +++ b/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala @@ -15,6 +15,9 @@ class Fs2StreamServerCallListener[F[_], Request, Response] private ( val call: Fs2ServerCall[F, Request, Response])(implicit F: Effect[F]) extends ServerCall.Listener[Request] with Fs2ServerCallListener[F, Stream[F, ?], Request, Response] { + override def onReady(): Unit = { + call.onReady().unsafeRun() + } override def onCancel(): Unit = { isCancelled.complete(()).unsafeRun() @@ -40,10 +43,8 @@ object Fs2StreamServerCallListener { for { inputQ <- Queue.unbounded[F, Option[Request]] isCancelled <- Deferred[F, Unit] - } yield - new Fs2StreamServerCallListener[F, Request, Response](inputQ, - isCancelled, - Fs2ServerCall[F, Request, Response](call)) + serverCall <- Fs2ServerCall[F, Request, Response](call) + } yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, isCancelled, serverCall) } def apply[F[_]] = new PartialFs2StreamServerCallListener[F] diff --git a/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala b/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala index f403190d..aa772640 100644 --- a/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala +++ b/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala @@ -17,6 +17,10 @@ class Fs2UnaryServerCallListener[F[_], Request, Response] private ( import Fs2UnaryServerCallListener._ + override def onReady(): Unit = { + call.onReady().unsafeRun() + } + override def onCancel(): Unit = { isCancelled.complete(()).unsafeRun() } @@ -62,11 +66,8 @@ object Fs2UnaryServerCallListener { request <- Ref.of[F, Option[Request]](none) isComplete <- Deferred[F, Unit] isCancelled <- Deferred[F, Unit] - } yield - new Fs2UnaryServerCallListener[F, Request, Response](request, - isComplete, - isCancelled, - Fs2ServerCall[F, Request, Response](call)) + serverCall <- Fs2ServerCall[F, Request, Response](call) + } yield new Fs2UnaryServerCallListener[F, Request, Response](request, isComplete, isCancelled, serverCall) } def apply[F[_]] = new PartialFs2UnaryServerCallListener[F] diff --git a/java-runtime/src/test/scala/client/ClientSuite.scala b/java-runtime/src/test/scala/client/ClientSuite.scala index e337699f..58ebba60 100644 --- a/java-runtime/src/test/scala/client/ClientSuite.scala +++ b/java-runtime/src/test/scala/client/ClientSuite.scala @@ -20,7 +20,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture() dummy.listener.get.onMessage(5) @@ -44,7 +44,7 @@ object ClientSuite extends SimpleTestSuite { implicit val timer: Timer[IO] = ec.timer val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client.unaryToUnaryCall("hello", new Metadata()).timeout(1.second).unsafeToFuture() ec.tick() @@ -68,7 +68,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture() dummy.listener.get.onClose(Status.OK, new Metadata()) @@ -87,9 +87,8 @@ object ClientSuite extends SimpleTestSuite { implicit val ec: TestContext = TestContext() implicit val cs: ContextShift[IO] = IO.contextShift(ec) - val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client.unaryToUnaryCall("hello", new Metadata()).unsafeToFuture() dummy.listener.get.onMessage(5) @@ -113,9 +112,8 @@ object ClientSuite extends SimpleTestSuite { implicit val ec: TestContext = TestContext() implicit val cs: ContextShift[IO] = IO.contextShift(ec) - val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client .streamingToUnaryCall(Stream.emits(List("a", "b", "c")), new Metadata()) .unsafeToFuture() @@ -140,9 +138,8 @@ object ClientSuite extends SimpleTestSuite { implicit val ec: TestContext = TestContext() implicit val cs: ContextShift[IO] = IO.contextShift(ec) - val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client .streamingToUnaryCall(Stream.empty, new Metadata()) .unsafeToFuture() @@ -168,7 +165,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client.unaryToStreamingCall("hello", new Metadata()).compile.toList.unsafeToFuture() dummy.listener.get.onMessage(1) @@ -194,7 +191,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client .streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata()) @@ -225,7 +222,7 @@ object ClientSuite extends SimpleTestSuite { implicit val timer: Timer[IO] = ec.timer val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client .streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata()) @@ -255,7 +252,7 @@ object ClientSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyClientCall() - val client = new Fs2ClientCall[IO, String, Int](dummy) + val client = Fs2ClientCall[IO](dummy).unsafeRunSync() val result = client .streamingToStreamingCall(Stream.emits(List("a", "b", "c", "d", "e")), new Metadata()) @@ -287,7 +284,7 @@ object ClientSuite extends SimpleTestSuite { } test("resource awaits termination of managed channel") { - implicit val ec: TestContext = TestContext() + implicit val ec: TestContext = TestContext() import implicits._ val result = ManagedChannelBuilder.forAddress("127.0.0.1", 0).resource[IO].use(IO.pure).unsafeToFuture()