diff --git a/java-runtime/src/main/scala/Readiness.scala b/java-runtime/src/main/scala/Readiness.scala new file mode 100644 index 00000000..9ad6bcd4 --- /dev/null +++ b/java-runtime/src/main/scala/Readiness.scala @@ -0,0 +1,50 @@ +package org.lyranthe.fs2_grpc +package java_runtime +package shared + +import cats.Applicative +import cats.effect.Concurrent +import cats.effect.concurrent.{Deferred, Ref} +import cats.implicits._ + +// Readiness implements respect for GRPC's backpressure on the sender +// - i.e. it delays sending into a channel if that channel +// is full. +private [java_runtime] class ReadinessImpl[F[_]]( + waiting: Ref[F, Option[Deferred[F, Unit]]] +)(implicit F: Concurrent[F]) extends Readiness[F] { + def signal: F[Unit] = { + waiting.getAndSet(None).flatMap { + case None => F.unit + case Some(wake) => wake.complete(()) + } + } + + def whenReady(isReady: F[Boolean], action: F[Unit]): F[Unit] = { + isReady.ifM(action, { + Deferred[F, Unit].flatMap { wakeup => + waiting.set(wakeup.some) *> + isReady.ifM(signal, F.unit) *> // trigger manually in case onReady was invoked before we installed wakeup + wakeup.get *> + action + } + }) + } +} + +private [java_runtime] trait Readiness[F[_]] { + def signal: F[Unit] + + def whenReady(isReady: F[Boolean], action: F[Unit]): F[Unit] +} + +private [java_runtime] object Readiness { + def apply[F[_]](implicit F: Concurrent[F]): F[Readiness[F]] = + Ref[F].of(Option.empty[Deferred[F, Unit]]).map(new ReadinessImpl(_)) + + def noop[F[_]](implicit F: Applicative[F]): Readiness[F] = new Readiness[F] { + override def signal: F[Unit] = F.unit + + override def whenReady(isReady: F[Boolean], action: F[Unit]): F[Unit] = action + } +} diff --git a/java-runtime/src/main/scala/client/Fs2ClientCall.scala b/java-runtime/src/main/scala/client/Fs2ClientCall.scala index 1650aedf..221b2fbd 100644 --- a/java-runtime/src/main/scala/client/Fs2ClientCall.scala +++ b/java-runtime/src/main/scala/client/Fs2ClientCall.scala @@ -4,8 +4,9 @@ package client import cats.effect._ import cats.implicits._ -import io.grpc.{Metadata, _} import fs2._ +import io.grpc._ +import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus]) final case class GrpcStatus(status: Status, trailers: Metadata) @@ -15,7 +16,6 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( errorAdapter: StatusRuntimeException => Option[Exception], prefetchN: Int ) { - private val ea: PartialFunction[Throwable, Throwable] = { case e: StatusRuntimeException => errorAdapter(e).getOrElse(e) } @@ -29,7 +29,13 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( private def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] = F.delay(call.request(numMessages)) - private def sendMessage(message: Request)(implicit F: Sync[F]): F[Unit] = + private def isReady(implicit F: Sync[F]): F[Boolean] = F.delay(call.isReady) + + private def sendMessageWhenReady(readiness: Readiness[F])(implicit F: Concurrent[F]): Request => F[Unit] = { + message => readiness.whenReady(isReady, sendMessageImmediately(message)) + } + + private def sendMessageImmediately(message: Request)(implicit F: Sync[F]): F[Unit] = F.delay(call.sendMessage(message)) private def start(listener: ClientCall.Listener[Response], metadata: Metadata)(implicit F: Sync[F]): F[Unit] = @@ -40,10 +46,10 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( ): F[A] = createListener.flatTap(start(_, headers)) def sendSingleMessage(message: Request)(implicit F: Sync[F]): F[Unit] = - sendMessage(message) *> halfClose + sendMessageImmediately(message) *> halfClose - def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] = - stream.evalMap(sendMessage) ++ Stream.eval(halfClose) + private def sendStream(readiness: Readiness[F], stream: Stream[F, Request])(implicit F: Concurrent[F]): Stream[F, Unit] = + stream.evalMap(sendMessageWhenReady(readiness)) ++ Stream.eval(halfClose) private def handleExitCase( cancelComplete: Boolean @@ -54,32 +60,38 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( } def unaryToUnaryCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] = - F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers) <* request(1))(l => + F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response](F.unit), headers) <* request(1))(l => sendSingleMessage(message) *> l.getValue.adaptError(ea) )(handleExitCase(cancelComplete = false)) def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata)(implicit F: ConcurrentEffect[F] - ): F[Response] = - F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers) <* request(1))(l => - Stream.eval(l.getValue.adaptError(ea)).concurrently(sendStream(messages)).compile.lastOrError - )(handleExitCase(cancelComplete = false)) + ): F[Response] = { + Readiness[F].flatMap { readiness => + F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response](readiness.signal), headers) <* request(1))(l => + Stream.eval(l.getValue.adaptError(ea)).concurrently(sendStream(readiness, messages)).compile.lastOrError + )(handleExitCase(cancelComplete = false)) + } + } def unaryToStreamingCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): Stream[F, Response] = Stream - .bracketCase(startListener(Fs2StreamClientCallListener[F, Response](request, prefetchN), headers) <* request(1))( + .bracketCase(startListener(Fs2StreamClientCallListener[F, Response](request, F.unit, prefetchN), headers) <* request(1))( handleExitCase(cancelComplete = true) ) .flatMap(Stream.eval_(sendSingleMessage(message)) ++ _.stream.adaptError(ea)) def streamingToStreamingCall(messages: Stream[F, Request], headers: Metadata)(implicit F: ConcurrentEffect[F] - ): Stream[F, Response] = - Stream - .bracketCase(startListener(Fs2StreamClientCallListener[F, Response](request, prefetchN), headers) <* request(1))( + ): Stream[F, Response] = { + Stream.eval(Readiness[F]).flatMap { readiness => + Stream.bracketCase(startListener( + Fs2StreamClientCallListener[F, Response](request, readiness.signal, prefetchN), headers + ) <* request(1))( handleExitCase(cancelComplete = true) - ) - .flatMap(_.stream.adaptError(ea).concurrently(sendStream(messages))) + ).flatMap(_.stream.adaptError(ea).concurrently(sendStream(readiness, messages))) + } + } } object Fs2ClientCall { diff --git a/java-runtime/src/main/scala/client/Fs2StreamClientCallListener.scala b/java-runtime/src/main/scala/client/Fs2StreamClientCallListener.scala index 7d3d73de..29ea8e54 100644 --- a/java-runtime/src/main/scala/client/Fs2StreamClientCallListener.scala +++ b/java-runtime/src/main/scala/client/Fs2StreamClientCallListener.scala @@ -8,8 +8,10 @@ import fs2.Stream import io.grpc.{ClientCall, Metadata, Status} private[client] class Fs2StreamClientCallListener[F[_]: Effect, Response]( - ingest: StreamIngest[F, Response] + ingest: StreamIngest[F, Response], + signalReadiness: F[Unit], ) extends ClientCall.Listener[Response] { + override def onReady(): Unit = signalReadiness.unsafeRun() override def onMessage(message: Response): Unit = ingest.onMessage(message).unsafeRun() @@ -25,10 +27,11 @@ private[client] object Fs2StreamClientCallListener { def apply[F[_], Response]( request: Int => F[Unit], + signalReadiness: F[Unit], prefetchN: Int )(implicit F: ConcurrentEffect[F]): F[Fs2StreamClientCallListener[F, Response]] = - StreamIngest[F, Response](request, prefetchN).map( - new Fs2StreamClientCallListener[F, Response](_) + StreamIngest[F, Response](request, prefetchN).map(streamIngest => + new Fs2StreamClientCallListener[F, Response](streamIngest, signalReadiness) ) } diff --git a/java-runtime/src/main/scala/client/Fs2UnaryClientCallListener.scala b/java-runtime/src/main/scala/client/Fs2UnaryClientCallListener.scala index de01883f..3269abf1 100644 --- a/java-runtime/src/main/scala/client/Fs2UnaryClientCallListener.scala +++ b/java-runtime/src/main/scala/client/Fs2UnaryClientCallListener.scala @@ -9,11 +9,14 @@ import io.grpc._ private[client] class Fs2UnaryClientCallListener[F[_], Response]( grpcStatus: Deferred[F, GrpcStatus], - value: Ref[F, Option[Response]] + value: Ref[F, Option[Response]], + signalReadiness: F[Unit] )(implicit F: Effect[F] ) extends ClientCall.Listener[Response] { + override def onReady(): Unit = signalReadiness.unsafeRun() + override def onClose(status: Status, trailers: Metadata): Unit = grpcStatus.complete(GrpcStatus(status, trailers)).unsafeRun() @@ -46,9 +49,9 @@ private[client] class Fs2UnaryClientCallListener[F[_], Response]( private[client] object Fs2UnaryClientCallListener { - def apply[F[_]: ConcurrentEffect, Response]: F[Fs2UnaryClientCallListener[F, Response]] = { + def apply[F[_]: ConcurrentEffect, Response](signalReadiness: F[Unit]): F[Fs2UnaryClientCallListener[F, Response]] = { (Deferred[F, GrpcStatus], Ref.of[F, Option[Response]](none)).mapN((response, value) => - new Fs2UnaryClientCallListener[F, Response](response, value) + new Fs2UnaryClientCallListener[F, Response](response, value, signalReadiness) ) } diff --git a/java-runtime/src/main/scala/client/StreamIngest.scala b/java-runtime/src/main/scala/client/StreamIngest.scala index 524f4964..4bd633ee 100644 --- a/java-runtime/src/main/scala/client/StreamIngest.scala +++ b/java-runtime/src/main/scala/client/StreamIngest.scala @@ -1,19 +1,19 @@ package org.lyranthe.fs2_grpc package java_runtime -package client +package client // TODO shared? import cats.syntax.all._ import cats.effect._ import fs2.Stream import fs2.concurrent.InspectableQueue -private[client] trait StreamIngest[F[_], T] { +private[java_runtime] trait StreamIngest[F[_], T] { def onMessage(msg: T): F[Unit] def onClose(status: GrpcStatus): F[Unit] def messages: Stream[F, T] } -private[client] object StreamIngest { +private[java_runtime] object StreamIngest { def apply[F[_]: ConcurrentEffect, T]( request: Int => F[Unit], diff --git a/java-runtime/src/main/scala/server/Fs2ServerCall.scala b/java-runtime/src/main/scala/server/Fs2ServerCall.scala index 598e86bd..b9a67e20 100644 --- a/java-runtime/src/main/scala/server/Fs2ServerCall.scala +++ b/java-runtime/src/main/scala/server/Fs2ServerCall.scala @@ -5,6 +5,7 @@ package server import cats.effect._ import cats.syntax.functor._ import io.grpc._ +import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness // TODO: Add attributes, compression, message compression. private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCall[Request, Response]) extends AnyVal { @@ -14,7 +15,12 @@ private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCal def closeStream(status: Status, trailers: Metadata)(implicit F: Sync[F]): F[Unit] = F.delay(call.close(status, trailers)) - def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] = + private def isReady(implicit F: Sync[F]): F[Boolean] = F.delay(call.isReady) + + def sendMessageWhenReady(readiness: Readiness[F])(implicit F: Sync[F]): Response => F[Unit] = + message => readiness.whenReady(isReady, sendMessageImmediately(message)) + + def sendMessageImmediately(message: Response)(implicit F: Sync[F]): F[Unit] = F.delay(call.sendMessage(message)) def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] = diff --git a/java-runtime/src/main/scala/server/Fs2ServerCallHandler.scala b/java-runtime/src/main/scala/server/Fs2ServerCallHandler.scala index 9e394ebf..ac6a1e03 100644 --- a/java-runtime/src/main/scala/server/Fs2ServerCallHandler.scala +++ b/java-runtime/src/main/scala/server/Fs2ServerCallHandler.scala @@ -6,6 +6,7 @@ import cats.effect._ import cats.implicits._ import fs2._ import io.grpc._ +import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal { @@ -15,7 +16,7 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal { )(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = Fs2UnaryServerCallListener[F](call, options).unsafeRun() + val listener = Fs2UnaryServerCallListener[F](call, F.unit, options).unsafeRun() listener.unsafeUnaryResponse(headers, _ flatMap { request => implementation(request, headers) }) listener } @@ -27,12 +28,16 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal { )(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = Fs2UnaryServerCallListener[F](call, options).unsafeRun() - listener.unsafeStreamResponse( - new Metadata(), - v => Stream.eval(v) flatMap { request => implementation(request, headers) } - ) - listener + Readiness[F].flatMap { readiness => + Fs2UnaryServerCallListener[F](call, readiness.signal, options).map { listener => + listener.unsafeStreamResponse( + readiness, + new Metadata(), + v => Stream.eval(v) flatMap { request => implementation(request, headers) } + ) + listener + } + }.unsafeRun() } } @@ -42,7 +47,7 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal { )(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = Fs2StreamServerCallListener[F](call, options).unsafeRun() + val listener = Fs2StreamServerCallListener[F](call, F.unit, options).unsafeRun() listener.unsafeUnaryResponse(headers, implementation(_, headers)) listener } @@ -54,9 +59,13 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal { )(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = Fs2StreamServerCallListener[F](call, options).unsafeRun() - listener.unsafeStreamResponse(headers, implementation(_, headers)) - listener + + Readiness[F].flatMap { readiness => + Fs2StreamServerCallListener[F](call, readiness.signal, options).map { listener=> + listener.unsafeStreamResponse(readiness, headers, implementation(_, headers)) + listener + } + }.unsafeRun() } } } diff --git a/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala b/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala index f4e3ce59..26918c27 100644 --- a/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala +++ b/java-runtime/src/main/scala/server/Fs2ServerCallListener.scala @@ -7,6 +7,7 @@ import cats.effect.concurrent.Deferred import cats.implicits._ import fs2.Stream import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException} +import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { def source: G[Request] @@ -26,10 +27,10 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { } private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessage + call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessageImmediately - private def handleStreamResponse(headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessage).compile.drain + private def handleStreamResponse(readiness: Readiness[F], headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] = + call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessageWhenReady(readiness)).compile.drain private def unsafeRun(f: F[Unit])(implicit F: ConcurrentEffect[F]): Unit = { val bracketed = F.guaranteeCase(f) { @@ -47,8 +48,8 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { ): Unit = unsafeRun(handleUnaryResponse(headers, implementation(source))) - def unsafeStreamResponse(headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit + def unsafeStreamResponse(readiness: Readiness[F], headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit F: ConcurrentEffect[F] ): Unit = - unsafeRun(handleStreamResponse(headers, implementation(source))) + unsafeRun(handleStreamResponse(readiness, headers, implementation(source))) } diff --git a/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala b/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala index 50daf5ba..55f0c781 100644 --- a/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala +++ b/java-runtime/src/main/scala/server/Fs2StreamServerCallListener.scala @@ -5,31 +5,33 @@ package server import cats.effect.concurrent.Deferred import cats.effect.{ConcurrentEffect, Effect} import cats.implicits._ -import io.grpc.ServerCall -import fs2.concurrent.Queue import fs2._ +import io.grpc.ServerCall +import org.lyranthe.fs2_grpc.java_runtime.client.StreamIngest class Fs2StreamServerCallListener[F[_], Request, Response] private ( - requestQ: Queue[F, Option[Request]], + ingest: StreamIngest[F, Option[Request]], + signalReadiness: F[Unit], val isCancelled: Deferred[F, Unit], val call: Fs2ServerCall[F, Request, Response] )(implicit F: Effect[F]) extends ServerCall.Listener[Request] with Fs2ServerCallListener[F, Stream[F, *], Request, Response] { + override def onReady(): Unit = signalReadiness.unsafeRun() + override def onCancel(): Unit = { isCancelled.complete(()).unsafeRun() } override def onMessage(message: Request): Unit = { - call.call.request(1) - requestQ.enqueue1(message.some).unsafeRun() + ingest.onMessage(Some(message)).unsafeRun() } - override def onHalfClose(): Unit = requestQ.enqueue1(none).unsafeRun() + override def onHalfClose(): Unit = ingest.onMessage(none).unsafeRun() override def source: Stream[F, Request] = - requestQ.dequeue.unNoneTerminate + ingest.messages.unNoneTerminate } object Fs2StreamServerCallListener { @@ -37,15 +39,17 @@ object Fs2StreamServerCallListener { def apply[Request, Response]( call: ServerCall[Request, Response], + signalReadiness: F[Unit], options: ServerCallOptions = ServerCallOptions.default )(implicit F: ConcurrentEffect[F] - ): F[Fs2StreamServerCallListener[F, Request, Response]] = + ): F[Fs2StreamServerCallListener[F, Request, Response]] = { for { - inputQ <- Queue.unbounded[F, Option[Request]] isCancelled <- Deferred[F, Unit] serverCall <- Fs2ServerCall[F, Request, Response](call, options) - } yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, isCancelled, serverCall) + ingest <- StreamIngest[F, Option[Request]](serverCall.request, options.prefetchN) + } yield new Fs2StreamServerCallListener[F, Request, Response](ingest, signalReadiness, isCancelled, serverCall) + } } def apply[F[_]] = new PartialFs2StreamServerCallListener[F] diff --git a/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala b/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala index 965c24b5..98715dad 100644 --- a/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala +++ b/java-runtime/src/main/scala/server/Fs2UnaryServerCallListener.scala @@ -9,6 +9,7 @@ import io.grpc._ class Fs2UnaryServerCallListener[F[_], Request, Response] private ( request: Ref[F, Option[Request]], + signalReadiness: F[Unit], isComplete: Deferred[F, Unit], val isCancelled: Deferred[F, Unit], val call: Fs2ServerCall[F, Request, Response] @@ -18,6 +19,8 @@ class Fs2UnaryServerCallListener[F[_], Request, Response] private ( import Fs2UnaryServerCallListener._ + override def onReady(): Unit = signalReadiness.unsafeRun() + override def onCancel(): Unit = { isCancelled.complete(()).unsafeRun() } @@ -57,6 +60,7 @@ object Fs2UnaryServerCallListener { def apply[Request, Response]( call: ServerCall[Request, Response], + signalReadiness: F[Unit], options: ServerCallOptions = ServerCallOptions.default )(implicit F: ConcurrentEffect[F] @@ -66,7 +70,7 @@ object Fs2UnaryServerCallListener { isComplete <- Deferred[F, Unit] isCancelled <- Deferred[F, Unit] serverCall <- Fs2ServerCall[F, Request, Response](call, options) - } yield new Fs2UnaryServerCallListener[F, Request, Response](request, isComplete, isCancelled, serverCall) + } yield new Fs2UnaryServerCallListener[F, Request, Response](request, signalReadiness, isComplete, isCancelled, serverCall) } def apply[F[_]] = new PartialFs2UnaryServerCallListener[F] diff --git a/java-runtime/src/main/scala/server/models.scala b/java-runtime/src/main/scala/server/models.scala index 8ac2a33f..0ad7f157 100644 --- a/java-runtime/src/main/scala/server/models.scala +++ b/java-runtime/src/main/scala/server/models.scala @@ -5,14 +5,17 @@ package server sealed abstract class ServerCompressor(val name: String) extends Product with Serializable case object GzipCompressor extends ServerCompressor("gzip") -abstract class ServerCallOptions private (val compressor: Option[ServerCompressor]) { - def copy(compressor: Option[ServerCompressor] = this.compressor): ServerCallOptions = - new ServerCallOptions(compressor) {} +abstract class ServerCallOptions private (val compressor: Option[ServerCompressor], val prefetchN: Int) { + def copy(compressor: Option[ServerCompressor] = this.compressor, prefetchN: Int = this.prefetchN): ServerCallOptions = + new ServerCallOptions(compressor, prefetchN) {} def withServerCompressor(compressor: Option[ServerCompressor]): ServerCallOptions = - copy(compressor) + copy(compressor = compressor) + + def withPrefetchN(prefetchN: Int): ServerCallOptions = + copy(prefetchN = prefetchN) } object ServerCallOptions { - val default: ServerCallOptions = new ServerCallOptions(None) {} + val default: ServerCallOptions = new ServerCallOptions(None, 1) {} } diff --git a/java-runtime/src/test/scala/client/ClientSuite.scala b/java-runtime/src/test/scala/client/ClientSuite.scala index e5b2c478..308a530f 100644 --- a/java-runtime/src/test/scala/client/ClientSuite.scala +++ b/java-runtime/src/test/scala/client/ClientSuite.scala @@ -138,6 +138,39 @@ object ClientSuite extends SimpleTestSuite { assertEquals(dummy.requested, 1) } + test("stream to streamingToUnary - send respects readiness") { + implicit val ec: TestContext = TestContext() + implicit val cs: ContextShift[IO] = IO.contextShift(ec) + + val dummy = new DummyClientCall() + val client = fs2ClientCall(dummy) + val requests = Stream.emits(List("a", "b", "c", "d", "e")) + .unchunk + .map { value => + if (value == "c") dummy.setIsReady(false) + value + } + + val result = client + .streamingToUnaryCall(requests, new Metadata()) + .unsafeToFuture() + + ec.tick() + + assertEquals(dummy.messagesSent.size, 2) + assertEquals(result.value, None) + + dummy.setIsReady(true) + ec.tick() + + dummy.listener.get.onMessage(1) + dummy.listener.get.onClose(Status.OK, new Metadata()) + ec.tick() + + assertEquals(result.value, Some(Success(1))) + assertEquals(dummy.messagesSent.size, 5) + } + test("0-length to streamingToUnary") { implicit val ec: TestContext = TestContext() @@ -247,6 +280,41 @@ object ClientSuite extends SimpleTestSuite { assertEquals(dummy.requested, 2) } + test("stream to streamingToStreaming respects readiness") { + implicit val ec: TestContext = TestContext() + implicit val cs: ContextShift[IO] = IO.contextShift(ec) + + val dummy = new DummyClientCall() + val client = fs2ClientCall(dummy) + val requests = Stream.emits(List("a", "b", "c", "d", "e")) + .unchunk + .map { value => + if (value == "c") dummy.setIsReady(false) + value + } + + val result = client + .streamingToStreamingCall(requests, new Metadata()) + .compile + .toList + .unsafeToFuture() + + ec.tick() + + assertEquals(dummy.messagesSent.size, 2) + assertEquals(result.value, None) + + dummy.setIsReady(true) + ec.tick() + + assertEquals(dummy.messagesSent.size, 5) + + dummy.listener.get.onClose(Status.OK, new Metadata()) + ec.tick() + + assertEquals(result.value, Some(Success(Nil))) + } + test("cancellation for streamingToStreaming") { implicit val ec: TestContext = TestContext() diff --git a/java-runtime/src/test/scala/client/DummyClientCall.scala b/java-runtime/src/test/scala/client/DummyClientCall.scala index 9791ce48..6bd0972b 100644 --- a/java-runtime/src/test/scala/client/DummyClientCall.scala +++ b/java-runtime/src/test/scala/client/DummyClientCall.scala @@ -6,6 +6,7 @@ import scala.collection.mutable.ArrayBuffer import io.grpc._ class DummyClientCall extends ClientCall[String, Int] { + var ready = true var requested: Int = 0 val messagesSent: ArrayBuffer[String] = ArrayBuffer[String]() var listener: Option[ClientCall.Listener[Int]] = None @@ -25,4 +26,13 @@ class DummyClientCall extends ClientCall[String, Int] { messagesSent += message () } + + override def isReady: Boolean = ready + + def setIsReady(value: Boolean): Unit = { + ready = value + if (value) { + listener.foreach(_.onReady()) + } + } } diff --git a/java-runtime/src/test/scala/server/DummyServerCall.scala b/java-runtime/src/test/scala/server/DummyServerCall.scala index c30d1ea4..761b592c 100644 --- a/java-runtime/src/test/scala/server/DummyServerCall.scala +++ b/java-runtime/src/test/scala/server/DummyServerCall.scala @@ -7,8 +7,13 @@ import scala.collection.mutable.ArrayBuffer class DummyServerCall extends ServerCall[String, Int] { val messages: ArrayBuffer[Int] = ArrayBuffer[Int]() var currentStatus: Option[Status] = None + var requested: Int = 0 + private var ready = true + + override def request(numMessages: Int): Unit = { + requested += numMessages + } - override def request(numMessages: Int): Unit = () override def sendMessage(message: Int): Unit = { messages += message () @@ -21,4 +26,13 @@ class DummyServerCall extends ServerCall[String, Int] { currentStatus = Some(status) } override def isCancelled: Boolean = false + + override def isReady: Boolean = ready + + def setIsReady(value: Boolean, listener: ServerCall.Listener[_]): Unit = { + ready = value + if (ready) { + listener.onReady() + } + } } diff --git a/java-runtime/src/test/scala/server/ServerSuite.scala b/java-runtime/src/test/scala/server/ServerSuite.scala index 84b64c6a..654b2076 100644 --- a/java-runtime/src/test/scala/server/ServerSuite.scala +++ b/java-runtime/src/test/scala/server/ServerSuite.scala @@ -2,12 +2,14 @@ package org.lyranthe.fs2_grpc package java_runtime package server -import cats.effect.{ContextShift, IO} +import cats.effect.concurrent.Deferred import cats.effect.laws.util.TestContext +import cats.effect.{ContextShift, IO} import cats.implicits._ import fs2._ import io.grpc._ import minitest._ +import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness object ServerSuite extends SimpleTestSuite { @@ -23,7 +25,7 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2UnaryServerCallListener[IO](dummy, options).unsafeRunSync() + val listener = Fs2UnaryServerCallListener[IO](dummy, IO.unit, options).unsafeRunSync() listener.unsafeUnaryResponse(new Metadata(), _.map(_.length)) listener.onMessage("123") @@ -43,7 +45,7 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2UnaryServerCallListener[IO](dummy).unsafeRunSync() + val listener = Fs2UnaryServerCallListener[IO](dummy, IO.unit).unsafeRunSync() listener.unsafeUnaryResponse(new Metadata(), _.map(_.length)) listener.onCancel() @@ -65,7 +67,7 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2UnaryServerCallListener[IO](dummy, options).unsafeRunSync() + val listener = Fs2UnaryServerCallListener[IO](dummy, IO.unit, options).unsafeRunSync() listener.unsafeUnaryResponse(new Metadata(), _.map(_.length)) listener.onMessage("123") @@ -103,9 +105,9 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2UnaryServerCallListener[IO].apply[String, Int](dummy, options).unsafeRunSync() + val listener = Fs2UnaryServerCallListener[IO].apply[String, Int](dummy, IO.unit, options).unsafeRunSync() - listener.unsafeStreamResponse(new Metadata(), s => Stream.eval(s).map(_.length).repeat.take(5)) + listener.unsafeStreamResponse(Readiness.noop, new Metadata(), s => Stream.eval(s).map(_.length).repeat.take(5)) listener.onMessage("123") listener.onHalfClose() @@ -123,9 +125,9 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy).unsafeRunSync() + val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, IO.unit).unsafeRunSync() - listener.unsafeStreamResponse(new Metadata(), _ => Stream.emit(3).repeat.take(5)) + listener.unsafeStreamResponse(Readiness.noop, new Metadata(), _ => Stream.emit(3).repeat.take(5)) listener.onHalfClose() ec.tick() @@ -142,9 +144,9 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy).unsafeRunSync() + val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, IO.unit).unsafeRunSync() - listener.unsafeStreamResponse(new Metadata(), _ => Stream.emit(3).repeat.take(5)) + listener.unsafeStreamResponse(Readiness.noop, new Metadata(), _ => Stream.emit(3).repeat.take(5)) listener.onCancel() @@ -165,9 +167,9 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, options).unsafeRunSync() + val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, IO.unit, options).unsafeRunSync() - listener.unsafeStreamResponse(new Metadata(), _.map(_.length).intersperse(0)) + listener.unsafeStreamResponse(Readiness.noop, new Metadata(), _.map(_.length).intersperse(0)) listener.onMessage("a") listener.onMessage("ab") listener.onHalfClose() @@ -186,9 +188,10 @@ object ServerSuite extends SimpleTestSuite { implicit val cs: ContextShift[IO] = IO.contextShift(ec) val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy).unsafeRunSync() + val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, IO.unit).unsafeRunSync() listener.unsafeStreamResponse( + Readiness.noop, new Metadata(), _.map(_.length) ++ Stream.emit(0) ++ Stream.raiseError[IO](new RuntimeException("hello")) ) @@ -205,6 +208,66 @@ object ServerSuite extends SimpleTestSuite { assertEquals(dummy.currentStatus.get.isOk, false) } + test("streamingToStreaming send respects isReady") { + implicit val ec: TestContext = TestContext() + implicit val cs: ContextShift[IO] = IO.contextShift(ec) + + val dummy = new DummyServerCall + val readiness = Readiness[IO].unsafeRunSync() + val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, readiness.signal).unsafeRunSync() + + listener.unsafeStreamResponse( + readiness, + new Metadata(), + requests => unreadyAfterTwoEmissions(dummy, listener).concurrently(requests) + ) + + ec.tick() + + assertEquals(dummy.messages.toList, List(1, 2)) + + dummy.setIsReady(true, listener) + ec.tick() + + assertEquals(dummy.messages.length, 5) + assertEquals(dummy.messages.toList, List(1, 2, 3, 4, 5)) + } + + test("unaryToStreaming send respects isReady") { + implicit val ec: TestContext = TestContext() + implicit val cs: ContextShift[IO] = IO.contextShift(ec) + + val dummy = new DummyServerCall + val readiness = Readiness[IO].unsafeRunSync() + val listener = Fs2UnaryServerCallListener[IO].apply[String, Int](dummy, readiness.signal).unsafeRunSync() + + listener.unsafeStreamResponse( + readiness, + new Metadata(), + _ => unreadyAfterTwoEmissions(dummy, listener) + ) + + listener.onMessage("a") + ec.tick() + + assertEquals(dummy.messages.toList, List(1, 2)) + + dummy.setIsReady(true, listener) + ec.tick() + + assertEquals(dummy.messages.length, 5) + assertEquals(dummy.messages.toList, List(1, 2, 3, 4, 5)) + } + + private def unreadyAfterTwoEmissions(dummy: DummyServerCall, listener: ServerCall.Listener[_]) = { + Stream.emits(List(1, 2, 3, 4, 5)) + .unchunk + .map { value => + if (value == 3) dummy.setIsReady(false, listener) + value + } + } + test("streaming to unary")(streamingToUnary()) test("streaming to unary with compression")(streamingToUnary(compressionOps)) @@ -218,7 +281,7 @@ object ServerSuite extends SimpleTestSuite { _.compile.foldMonoid.map(_.length) val dummy = new DummyServerCall - val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, options).unsafeRunSync() + val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, IO.unit, options).unsafeRunSync() listener.unsafeUnaryResponse(new Metadata(), implementation) listener.onMessage("ab") @@ -233,4 +296,47 @@ object ServerSuite extends SimpleTestSuite { assertEquals(dummy.currentStatus.get.isOk, true) } + test("streamingToUnary back pressure") { + implicit val ec: TestContext = TestContext() + implicit val cs: ContextShift[IO] = IO.contextShift(ec) + + val deferred = Deferred[IO, Unit].unsafeRunSync() + val implementation: Stream[IO, String] => IO[Int] = { + requests => requests.evalMap(_ => deferred.get).compile.drain.as(1) } + + val dummy = new DummyServerCall + val listener = Fs2StreamServerCallListener[IO].apply[String, Int](dummy, IO.unit).unsafeRunSync() + + listener.unsafeUnaryResponse(new Metadata(), implementation) + ec.tick() + + assertEquals(dummy.requested, 1) + + listener.onMessage("1") + ec.tick() + + listener.onMessage("2") + listener.onMessage("3") + ec.tick() + + // requested should ideally be 2, however StreamIngest can double-request in some execution + // orderings if the push() is followed by pop() before the push checks the queue length. + val initialRequested = dummy.requested + assert(initialRequested == 2 || initialRequested == 3, s"expected requested to be 2 or 3, got ${initialRequested}") + + // don't request any more messages while downstream is blocked + listener.onMessage("4") + listener.onMessage("5") + listener.onMessage("6") + ec.tick() + + assertEquals(dummy.requested - initialRequested, 0) + + // allow all messages through, the final pop() will trigger a new request + deferred.complete(()).unsafeToFuture() + ec.tick() + + assertEquals(dummy.requested - initialRequested, 1) + } + }