Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions java-runtime/src/main/scala/Readiness.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.lyranthe.fs2_grpc
package java_runtime
package shared

import cats.Applicative
import cats.effect.Concurrent
import cats.effect.concurrent.{Deferred, Ref}
import cats.implicits._

// Readiness implements respect for GRPC's backpressure on the sender
// - i.e. it delays sending into a channel if that channel
// is full.
private [java_runtime] class ReadinessImpl[F[_]](
waiting: Ref[F, Option[Deferred[F, Unit]]]
)(implicit F: Concurrent[F]) extends Readiness[F] {
def signal: F[Unit] = {
waiting.getAndSet(None).flatMap {
case None => F.unit
case Some(wake) => wake.complete(())
}
}

def whenReady(isReady: F[Boolean], action: F[Unit]): F[Unit] = {
isReady.ifM(action, {
Deferred[F, Unit].flatMap { wakeup =>
waiting.set(wakeup.some) *>
isReady.ifM(signal, F.unit) *> // trigger manually in case onReady was invoked before we installed wakeup
wakeup.get *>
action
}
})
}
}

private [java_runtime] trait Readiness[F[_]] {
def signal: F[Unit]

def whenReady(isReady: F[Boolean], action: F[Unit]): F[Unit]
}

private [java_runtime] object Readiness {
def apply[F[_]](implicit F: Concurrent[F]): F[Readiness[F]] =
Ref[F].of(Option.empty[Deferred[F, Unit]]).map(new ReadinessImpl(_))

def noop[F[_]](implicit F: Applicative[F]): Readiness[F] = new Readiness[F] {
override def signal: F[Unit] = F.unit

override def whenReady(isReady: F[Boolean], action: F[Unit]): F[Unit] = action
}
}
46 changes: 29 additions & 17 deletions java-runtime/src/main/scala/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ package client

import cats.effect._
import cats.implicits._
import io.grpc.{Metadata, _}
import fs2._
import io.grpc._
import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness

final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus])
final case class GrpcStatus(status: Status, trailers: Metadata)
Expand All @@ -15,7 +16,6 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
errorAdapter: StatusRuntimeException => Option[Exception],
prefetchN: Int
) {

private val ea: PartialFunction[Throwable, Throwable] = { case e: StatusRuntimeException =>
errorAdapter(e).getOrElse(e)
}
Expand All @@ -29,7 +29,13 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
private def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] =
F.delay(call.request(numMessages))

private def sendMessage(message: Request)(implicit F: Sync[F]): F[Unit] =
private def isReady(implicit F: Sync[F]): F[Boolean] = F.delay(call.isReady)

private def sendMessageWhenReady(readiness: Readiness[F])(implicit F: Concurrent[F]): Request => F[Unit] = {
message => readiness.whenReady(isReady, sendMessageImmediately(message))
}

private def sendMessageImmediately(message: Request)(implicit F: Sync[F]): F[Unit] =
F.delay(call.sendMessage(message))

private def start(listener: ClientCall.Listener[Response], metadata: Metadata)(implicit F: Sync[F]): F[Unit] =
Expand All @@ -40,10 +46,10 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
): F[A] = createListener.flatTap(start(_, headers))

def sendSingleMessage(message: Request)(implicit F: Sync[F]): F[Unit] =
sendMessage(message) *> halfClose
sendMessageImmediately(message) *> halfClose

def sendStream(stream: Stream[F, Request])(implicit F: Sync[F]): Stream[F, Unit] =
stream.evalMap(sendMessage) ++ Stream.eval(halfClose)
private def sendStream(readiness: Readiness[F], stream: Stream[F, Request])(implicit F: Concurrent[F]): Stream[F, Unit] =
stream.evalMap(sendMessageWhenReady(readiness)) ++ Stream.eval(halfClose)

private def handleExitCase(
cancelComplete: Boolean
Expand All @@ -54,32 +60,38 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
}

def unaryToUnaryCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): F[Response] =
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers) <* request(1))(l =>
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response](F.unit), headers) <* request(1))(l =>
sendSingleMessage(message) *> l.getValue.adaptError(ea)
)(handleExitCase(cancelComplete = false))

def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata)(implicit
F: ConcurrentEffect[F]
): F[Response] =
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response], headers) <* request(1))(l =>
Stream.eval(l.getValue.adaptError(ea)).concurrently(sendStream(messages)).compile.lastOrError
)(handleExitCase(cancelComplete = false))
): F[Response] = {
Readiness[F].flatMap { readiness =>
F.bracketCase(startListener(Fs2UnaryClientCallListener[F, Response](readiness.signal), headers) <* request(1))(l =>
Stream.eval(l.getValue.adaptError(ea)).concurrently(sendStream(readiness, messages)).compile.lastOrError
)(handleExitCase(cancelComplete = false))
}
}

def unaryToStreamingCall(message: Request, headers: Metadata)(implicit F: ConcurrentEffect[F]): Stream[F, Response] =
Stream
.bracketCase(startListener(Fs2StreamClientCallListener[F, Response](request, prefetchN), headers) <* request(1))(
.bracketCase(startListener(Fs2StreamClientCallListener[F, Response](request, F.unit, prefetchN), headers) <* request(1))(
handleExitCase(cancelComplete = true)
)
.flatMap(Stream.eval_(sendSingleMessage(message)) ++ _.stream.adaptError(ea))

def streamingToStreamingCall(messages: Stream[F, Request], headers: Metadata)(implicit
F: ConcurrentEffect[F]
): Stream[F, Response] =
Stream
.bracketCase(startListener(Fs2StreamClientCallListener[F, Response](request, prefetchN), headers) <* request(1))(
): Stream[F, Response] = {
Stream.eval(Readiness[F]).flatMap { readiness =>
Stream.bracketCase(startListener(
Fs2StreamClientCallListener[F, Response](request, readiness.signal, prefetchN), headers
) <* request(1))(
handleExitCase(cancelComplete = true)
)
.flatMap(_.stream.adaptError(ea).concurrently(sendStream(messages)))
).flatMap(_.stream.adaptError(ea).concurrently(sendStream(readiness, messages)))
}
}
}

object Fs2ClientCall {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import fs2.Stream
import io.grpc.{ClientCall, Metadata, Status}

private[client] class Fs2StreamClientCallListener[F[_]: Effect, Response](
ingest: StreamIngest[F, Response]
ingest: StreamIngest[F, Response],
signalReadiness: F[Unit],
) extends ClientCall.Listener[Response] {
override def onReady(): Unit = signalReadiness.unsafeRun()

override def onMessage(message: Response): Unit =
ingest.onMessage(message).unsafeRun()
Expand All @@ -25,10 +27,11 @@ private[client] object Fs2StreamClientCallListener {

def apply[F[_], Response](
request: Int => F[Unit],
signalReadiness: F[Unit],
prefetchN: Int
)(implicit F: ConcurrentEffect[F]): F[Fs2StreamClientCallListener[F, Response]] =
StreamIngest[F, Response](request, prefetchN).map(
new Fs2StreamClientCallListener[F, Response](_)
StreamIngest[F, Response](request, prefetchN).map(streamIngest =>
new Fs2StreamClientCallListener[F, Response](streamIngest, signalReadiness)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ import io.grpc._

private[client] class Fs2UnaryClientCallListener[F[_], Response](
grpcStatus: Deferred[F, GrpcStatus],
value: Ref[F, Option[Response]]
value: Ref[F, Option[Response]],
signalReadiness: F[Unit]
)(implicit
F: Effect[F]
) extends ClientCall.Listener[Response] {

override def onReady(): Unit = signalReadiness.unsafeRun()

override def onClose(status: Status, trailers: Metadata): Unit =
grpcStatus.complete(GrpcStatus(status, trailers)).unsafeRun()

Expand Down Expand Up @@ -46,9 +49,9 @@ private[client] class Fs2UnaryClientCallListener[F[_], Response](

private[client] object Fs2UnaryClientCallListener {

def apply[F[_]: ConcurrentEffect, Response]: F[Fs2UnaryClientCallListener[F, Response]] = {
def apply[F[_]: ConcurrentEffect, Response](signalReadiness: F[Unit]): F[Fs2UnaryClientCallListener[F, Response]] = {
(Deferred[F, GrpcStatus], Ref.of[F, Option[Response]](none)).mapN((response, value) =>
new Fs2UnaryClientCallListener[F, Response](response, value)
new Fs2UnaryClientCallListener[F, Response](response, value, signalReadiness)
)
}

Expand Down
6 changes: 3 additions & 3 deletions java-runtime/src/main/scala/client/StreamIngest.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package org.lyranthe.fs2_grpc
package java_runtime
package client
package client // TODO shared?

import cats.syntax.all._
import cats.effect._
import fs2.Stream
import fs2.concurrent.InspectableQueue

private[client] trait StreamIngest[F[_], T] {
private[java_runtime] trait StreamIngest[F[_], T] {
def onMessage(msg: T): F[Unit]
def onClose(status: GrpcStatus): F[Unit]
def messages: Stream[F, T]
}

private[client] object StreamIngest {
private[java_runtime] object StreamIngest {

def apply[F[_]: ConcurrentEffect, T](
request: Int => F[Unit],
Expand Down
8 changes: 7 additions & 1 deletion java-runtime/src/main/scala/server/Fs2ServerCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package server
import cats.effect._
import cats.syntax.functor._
import io.grpc._
import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness

// TODO: Add attributes, compression, message compression.
private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCall[Request, Response]) extends AnyVal {
Expand All @@ -14,7 +15,12 @@ private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCal
def closeStream(status: Status, trailers: Metadata)(implicit F: Sync[F]): F[Unit] =
F.delay(call.close(status, trailers))

def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] =
private def isReady(implicit F: Sync[F]): F[Boolean] = F.delay(call.isReady)

def sendMessageWhenReady(readiness: Readiness[F])(implicit F: Sync[F]): Response => F[Unit] =
message => readiness.whenReady(isReady, sendMessageImmediately(message))

def sendMessageImmediately(message: Response)(implicit F: Sync[F]): F[Unit] =
F.delay(call.sendMessage(message))

def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] =
Expand Down
31 changes: 20 additions & 11 deletions java-runtime/src/main/scala/server/Fs2ServerCallHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import cats.effect._
import cats.implicits._
import fs2._
import io.grpc._
import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness

class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal {

Expand All @@ -15,7 +16,7 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal {
)(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] =
new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = Fs2UnaryServerCallListener[F](call, options).unsafeRun()
val listener = Fs2UnaryServerCallListener[F](call, F.unit, options).unsafeRun()
listener.unsafeUnaryResponse(headers, _ flatMap { request => implementation(request, headers) })
listener
}
Expand All @@ -27,12 +28,16 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal {
)(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] =
new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = Fs2UnaryServerCallListener[F](call, options).unsafeRun()
listener.unsafeStreamResponse(
new Metadata(),
v => Stream.eval(v) flatMap { request => implementation(request, headers) }
)
listener
Readiness[F].flatMap { readiness =>
Fs2UnaryServerCallListener[F](call, readiness.signal, options).map { listener =>
listener.unsafeStreamResponse(
readiness,
new Metadata(),
v => Stream.eval(v) flatMap { request => implementation(request, headers) }
)
listener
}
}.unsafeRun()
}
}

Expand All @@ -42,7 +47,7 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal {
)(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] =
new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = Fs2StreamServerCallListener[F](call, options).unsafeRun()
val listener = Fs2StreamServerCallListener[F](call, F.unit, options).unsafeRun()
listener.unsafeUnaryResponse(headers, implementation(_, headers))
listener
}
Expand All @@ -54,9 +59,13 @@ class Fs2ServerCallHandler[F[_]](val dummy: Boolean = false) extends AnyVal {
)(implicit F: ConcurrentEffect[F]): ServerCallHandler[Request, Response] =
new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = Fs2StreamServerCallListener[F](call, options).unsafeRun()
listener.unsafeStreamResponse(headers, implementation(_, headers))
listener

Readiness[F].flatMap { readiness =>
Fs2StreamServerCallListener[F](call, readiness.signal, options).map { listener=>
listener.unsafeStreamResponse(readiness, headers, implementation(_, headers))
listener
}
}.unsafeRun()
}
}
}
Expand Down
11 changes: 6 additions & 5 deletions java-runtime/src/main/scala/server/Fs2ServerCallListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import cats.effect.concurrent.Deferred
import cats.implicits._
import fs2.Stream
import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException}
import org.lyranthe.fs2_grpc.java_runtime.shared.Readiness

private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
def source: G[Request]
Expand All @@ -26,10 +27,10 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
}

private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessage
call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessageImmediately

private def handleStreamResponse(headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessage).compile.drain
private def handleStreamResponse(readiness: Readiness[F], headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessageWhenReady(readiness)).compile.drain

private def unsafeRun(f: F[Unit])(implicit F: ConcurrentEffect[F]): Unit = {
val bracketed = F.guaranteeCase(f) {
Expand All @@ -47,8 +48,8 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
): Unit =
unsafeRun(handleUnaryResponse(headers, implementation(source)))

def unsafeStreamResponse(headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit
def unsafeStreamResponse(readiness: Readiness[F], headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit
F: ConcurrentEffect[F]
): Unit =
unsafeRun(handleStreamResponse(headers, implementation(source)))
unsafeRun(handleStreamResponse(readiness, headers, implementation(source)))
}
Loading