Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ package fs2
package grpc
package client

import cats.syntax.all._
import cats.effect.{Async, Resource}
import cats.effect.std.Dispatcher
import cats.effect.{Async, Resource, SyncIO}
import cats.syntax.all._
import fs2.grpc.client.internal.Fs2UnaryCallHandler
import io.grpc.{Metadata, _}
import fs2.grpc.shared.StreamOutput
import io.grpc._

final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus])
final case class GrpcStatus(status: Status, trailers: Metadata)
Expand All @@ -51,35 +52,39 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
private def request(numMessages: Int): F[Unit] =
F.delay(call.request(numMessages))

private def sendMessage(message: Request): F[Unit] =
F.delay(call.sendMessage(message))

private def start[A <: ClientCall.Listener[Response]](createListener: F[A], md: Metadata): F[A] =
createListener.flatTap(l => F.delay(call.start(l, md)))

private def sendSingleMessage(message: Request): F[Unit] =
sendMessage(message) *> halfClose

private def sendStream(stream: Stream[F, Request]): Stream[F, Unit] =
stream.evalMap(sendMessage) ++ Stream.eval(halfClose)
F.delay(call.sendMessage(message)) *> halfClose

//

def unaryToUnaryCall(message: Request, headers: Metadata): F[Response] =
Fs2UnaryCallHandler.unary(call, options, message, headers)

def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] =
Fs2UnaryCallHandler.stream(call, options, messages, headers)
StreamOutput.client(call, dispatcher).flatMap { output =>
Fs2UnaryCallHandler.stream(call, options, messages, output, headers)
}

def unaryToStreamingCall(message: Request, md: Metadata): Stream[F, Response] =
Stream
.resource(mkStreamListenerR(md))
.resource(mkStreamListenerR(md, SyncIO.unit))
.flatMap(Stream.exec(sendSingleMessage(message)) ++ _.stream.adaptError(ea))

def streamingToStreamingCall(messages: Stream[F, Request], md: Metadata): Stream[F, Response] =
def streamingToStreamingCall(messages: Stream[F, Request], md: Metadata): Stream[F, Response] = {
val listenerAndOutput = Resource.eval(StreamOutput.client(call, dispatcher)).flatMap { output =>
mkStreamListenerR(md, output.onReady).map((_, output))
}

Stream
.resource(mkStreamListenerR(md))
.flatMap(_.stream.adaptError(ea).concurrently(sendStream(messages)))
.resource(listenerAndOutput)
.flatMap { case (listener, output) =>
listener.stream.adaptError(ea)
.concurrently(output.writeStream(messages) ++ Stream.eval(halfClose))
}
}

//

Expand All @@ -89,10 +94,9 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
case (_, Resource.ExitCase.Errored(t)) => cancel(t.getMessage.some, t.some)
}

private def mkStreamListenerR(md: Metadata): Resource[F, Fs2StreamClientCallListener[F, Response]] = {

private def mkStreamListenerR(md: Metadata, signalReadiness: SyncIO[Unit]): Resource[F, Fs2StreamClientCallListener[F, Response]] = {
val prefetchN = options.prefetchN.max(1)
val create = Fs2StreamClientCallListener.create[F, Response](request, dispatcher, prefetchN)
val create = Fs2StreamClientCallListener.create[F, Response](request, signalReadiness, dispatcher, prefetchN)
val acquire = start(create, md) <* request(prefetchN)
val release = handleExitCase(cancelSucceed = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ package fs2
package grpc
package client

import cats.effect.SyncIO
import cats.implicits._
import cats.effect.kernel.Concurrent
import cats.effect.std.Dispatcher
import io.grpc.{ClientCall, Metadata, Status}

class Fs2StreamClientCallListener[F[_], Response] private (
ingest: StreamIngest[F, Response],
signalReadiness: SyncIO[Unit],
dispatcher: Dispatcher[F]
) extends ClientCall.Listener[Response] {

Expand All @@ -39,18 +41,21 @@ class Fs2StreamClientCallListener[F[_], Response] private (
override def onClose(status: Status, trailers: Metadata): Unit =
dispatcher.unsafeRunSync(ingest.onClose(GrpcStatus(status, trailers)))

override def onReady(): Unit = signalReadiness.unsafeRunSync()

val stream: Stream[F, Response] = ingest.messages
}

object Fs2StreamClientCallListener {

private[client] def create[F[_]: Concurrent, Response](
request: Int => F[Unit],
signalReadiness: SyncIO[Unit],
dispatcher: Dispatcher[F],
prefetchN: Int
): F[Fs2StreamClientCallListener[F, Response]] =
StreamIngest[F, Response](request, prefetchN).map(
new Fs2StreamClientCallListener[F, Response](_, dispatcher)
new Fs2StreamClientCallListener[F, Response](_, signalReadiness, dispatcher)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@

package fs2.grpc.client.internal

import cats.effect.Sync
import cats.effect.SyncIO
import cats.effect.kernel.{Async, Outcome, Ref}
import cats.effect.syntax.all._
import cats.effect.kernel.Async
import cats.effect.kernel.Outcome
import cats.effect.kernel.Ref
import cats.syntax.functor._
import cats.effect.{Sync, SyncIO}
import cats.syntax.flatMap._
import cats.syntax.functor._
import fs2._
import fs2.grpc.client.ClientOptions
import fs2.grpc.shared.StreamOutput
import io.grpc._

private[client] object Fs2UnaryCallHandler {
Expand Down Expand Up @@ -65,7 +63,8 @@ private[client] object Fs2UnaryCallHandler {
class Done[R] extends ReceiveState[R]

private def mkListener[Response](
state: Ref[SyncIO, ReceiveState[Response]]
state: Ref[SyncIO, ReceiveState[Response]],
signalReadiness: SyncIO[Unit]
): ClientCall.Listener[Response] =
new ClientCall.Listener[Response] {
override def onMessage(message: Response): Unit =
Expand Down Expand Up @@ -110,6 +109,8 @@ private[client] object Fs2UnaryCallHandler {
}
}
}.unsafeRunSync()

override def onReady(): Unit = signalReadiness.unsafeRunSync()
}

def unary[F[_], Request, Response](
Expand All @@ -119,7 +120,7 @@ private[client] object Fs2UnaryCallHandler {
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
ReceiveState.init(cb, options.errorAdapter).map { state =>
call.start(mkListener[Response](state), headers)
call.start(mkListener[Response](state, SyncIO.unit), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
// sends more than one responses, we can catch it and fail it in the listener.
call.request(2)
Expand All @@ -133,18 +134,15 @@ private[client] object Fs2UnaryCallHandler {
call: ClientCall[Request, Response],
options: ClientOptions,
messages: Stream[F, Request],
output: StreamOutput[F, Request],
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
ReceiveState.init(cb, options.errorAdapter).flatMap { state =>
call.start(mkListener[Response](state), headers)
call.start(mkListener[Response](state, output.onReady), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
// sends more than one responses, we can catch it and fail it in the listener.
call.request(2)
messages
.map(call.sendMessage)
.compile
.drain
.guaranteeCase {
output.writeStream(messages).compile.drain.guaranteeCase {
case Outcome.Succeeded(_) => F.delay(call.halfClose())
case Outcome.Errored(e) => F.delay(call.cancel(e.getMessage, e))
case Outcome.Canceled() => onCancel(call)
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCal
def closeStream(status: Status, trailers: Metadata)(implicit F: Sync[F]): F[Unit] =
F.delay(call.close(status, trailers))

def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] =
def sendSingleMessage(message: Response)(implicit F: Sync[F]): F[Unit] =
F.delay(call.sendMessage(message))

def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ package server

import cats.effect._
import cats.effect.std.Dispatcher
import cats.syntax.all._
import fs2.grpc.server.internal.Fs2UnaryServerCallHandler
import fs2.grpc.shared.StreamOutput
import io.grpc._

class Fs2ServerCallHandler[F[_]: Async] private (
Expand All @@ -47,7 +49,7 @@ class Fs2ServerCallHandler[F[_]: Async] private (
implementation: (Stream[F, Request], Metadata) => F[Response]
): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options))
val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, SyncIO.unit, dispatcher, options))
listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers))
listener
}
Expand All @@ -57,8 +59,10 @@ class Fs2ServerCallHandler[F[_]: Async] private (
implementation: (Stream[F, Request], Metadata) => Stream[F, Response]
): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options))
listener.unsafeStreamResponse(new Metadata(), implementation(_, headers))
val (listener, streamOutput) = dispatcher.unsafeRunSync(StreamOutput.server(call, dispatcher).flatMap { output =>
Fs2StreamServerCallListener[F](call, output.onReady, dispatcher, options).map((_, output))
})
listener.unsafeStreamResponse(streamOutput, new Metadata(), implementation(_, headers))
listener
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ package fs2
package grpc
package server

import cats.syntax.all._
import cats.effect._
import cats.effect.std.Dispatcher
import cats.syntax.all._
import fs2.grpc.shared.StreamOutput
import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException}

private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
Expand All @@ -49,10 +50,10 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
}

private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessage
call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendSingleMessage

private def handleStreamResponse(headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessage).compile.drain
private def handleStreamResponse(headers: Metadata, sendResponse: Stream[F, Unit])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> sendResponse.compile.drain

private def unsafeRun(f: F[Unit])(implicit F: Async[F]): Unit = {
val bracketed = F.guaranteeCase(f) {
Expand All @@ -70,8 +71,8 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
): Unit =
unsafeRun(handleUnaryResponse(headers, implementation(source)))

def unsafeStreamResponse(headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit
def unsafeStreamResponse(streamOutput: StreamOutput[F, Response], headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit
F: Async[F]
): Unit =
unsafeRun(handleStreamResponse(headers, implementation(source)))
unsafeRun(handleStreamResponse(headers, streamOutput.writeStream(implementation(source))))
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ package server
import cats.Functor
import cats.syntax.all._
import cats.effect.kernel.Deferred
import cats.effect.Async
import cats.effect.{Async, SyncIO}
import cats.effect.std.{Dispatcher, Queue}
import io.grpc.ServerCall

class Fs2StreamServerCallListener[F[_], Request, Response] private (
requestQ: Queue[F, Option[Request]],
signalReadiness: SyncIO[Unit],
val isCancelled: Deferred[F, Unit],
val call: Fs2ServerCall[F, Request, Response],
val dispatcher: Dispatcher[F]
Expand All @@ -47,6 +48,8 @@ class Fs2StreamServerCallListener[F[_], Request, Response] private (
dispatcher.unsafeRunSync(requestQ.offer(message.some))
}

override def onReady(): Unit = signalReadiness.unsafeRunSync()

override def onHalfClose(): Unit =
dispatcher.unsafeRunSync(requestQ.offer(none))

Expand All @@ -60,13 +63,14 @@ object Fs2StreamServerCallListener {

private[server] def apply[Request, Response](
call: ServerCall[Request, Response],
signalReadiness: SyncIO[Unit],
dispatcher: Dispatcher[F],
options: ServerOptions
)(implicit F: Async[F]): F[Fs2StreamServerCallListener[F, Request, Response]] = for {
inputQ <- Queue.unbounded[F, Option[Request]]
isCancelled <- Deferred[F, Unit]
serverCall <- Fs2ServerCall[F, Request, Response](call, options)
} yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, isCancelled, serverCall, dispatcher)
} yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, signalReadiness, isCancelled, serverCall, dispatcher)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ private[server] object Fs2ServerCall {
}

private[server] final class Fs2ServerCall[Request, Response](
call: ServerCall[Request, Response]
call: ServerCall[Request, Response],
) {

import Fs2ServerCall.Cancel

def stream[F[_]](response: Stream[F, Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] =
def stream[F[_]](sendStream: Stream[F, Response] => Stream[F, Unit], response: Stream[F, Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] =
run(
response.pull.peek1
.flatMap {
case Some((_, stream)) =>
Pull.suspend {
call.sendHeaders(new Metadata())
stream.map(call.sendMessage).pull.echo
sendStream(stream).pull.echo
}
case None => Pull.done
}
Expand Down
Loading