11/*
2- * Copyright 2002-2023 the original author or authors.
2+ * Copyright 2002-2024 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
@@ -68,12 +68,13 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
6868
6969 private final Mono <GraphQlSession > graphQlSessionMono ;
7070
71- private final long keepalive ;
71+ @ Nullable
72+ private final Duration keepAlive ;
7273
7374
7475 WebSocketGraphQlTransport (
7576 URI url , @ Nullable HttpHeaders headers , WebSocketClient client , CodecConfigurer codecConfigurer ,
76- WebSocketGraphQlClientInterceptor interceptor , long keepalive ) {
77+ WebSocketGraphQlClientInterceptor interceptor , @ Nullable Duration keepAlive ) {
7778
7879 Assert .notNull (url , "URI is required" );
7980 Assert .notNull (client , "WebSocketClient is required" );
@@ -83,9 +84,9 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
8384 this .url = url ;
8485 this .headers .putAll ((headers != null ) ? headers : HttpHeaders .EMPTY );
8586 this .webSocketClient = client ;
86- this .keepalive = keepalive ;
87+ this .keepAlive = keepAlive ;
8788
88- this .graphQlSessionHandler = new GraphQlSessionHandler (codecConfigurer , interceptor , keepalive );
89+ this .graphQlSessionHandler = new GraphQlSessionHandler (codecConfigurer , interceptor , keepAlive );
8990
9091 this .graphQlSessionMono = initGraphQlSession (this .url , this .headers , client , this .graphQlSessionHandler )
9192 .cacheInvalidateWhen (GraphQlSession ::notifyWhenClosed );
@@ -166,8 +167,9 @@ public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
166167 return this .graphQlSessionMono .flatMapMany ((session ) -> session .executeSubscription (request ));
167168 }
168169
169- public long getKeepAlive () {
170- return keepalive ;
170+ @ Nullable
171+ Duration getKeepAlive () {
172+ return this .keepAlive ;
171173 }
172174
173175
@@ -191,15 +193,18 @@ private static class GraphQlSessionHandler implements WebSocketHandler {
191193
192194 private final AtomicBoolean stopped = new AtomicBoolean ();
193195
194- private final long keepalive ;
196+ @ Nullable
197+ private final Duration keepAlive ;
198+
195199
200+ GraphQlSessionHandler (
201+ CodecConfigurer codecConfigurer , WebSocketGraphQlClientInterceptor interceptor ,
202+ @ Nullable Duration keepAlive ) {
196203
197- GraphQlSessionHandler (CodecConfigurer codecConfigurer , WebSocketGraphQlClientInterceptor interceptor ,
198- long keepalive ) {
199204 this .codecDelegate = new CodecDelegate (codecConfigurer );
200205 this .interceptor = interceptor ;
201206 this .graphQlSessionSink = Sinks .unsafe ().one ();
202- this .keepalive = keepalive ;
207+ this .keepAlive = keepAlive ;
203208 }
204209
205210
@@ -257,7 +262,7 @@ public Mono<Void> handle(WebSocketSession session) {
257262 session .send (connectionInitMono .concatWith (graphQlSession .getRequestFlux ())
258263 .map ((message ) -> this .codecDelegate .encode (session , message )));
259264
260- Flux <Void > receiveCompletion = session .receive ()
265+ Mono <Void > receiveCompletion = session .receive ()
261266 .flatMap ((webSocketMessage ) -> {
262267 if (sessionNotInitialized ()) {
263268 try {
@@ -303,20 +308,22 @@ public Mono<Void> handle(WebSocketSession session) {
303308 }
304309 }
305310 return Mono .empty ();
306- });
307-
308- if (keepalive > 0 ) {
309- Duration keepAliveDuration = Duration .ofSeconds (keepalive );
310- receiveCompletion = receiveCompletion
311- .mergeWith (Flux .interval (keepAliveDuration , keepAliveDuration )
312- .flatMap (i -> {
313- graphQlSession .sendPing (null );
314- return Mono .empty ();
315- })
316- );
311+ })
312+ .mergeWith ((this .keepAlive != null ) ?
313+ Flux .interval (this .keepAlive , this .keepAlive )
314+ .filter ((aLong ) -> graphQlSession .checkSentOrReceivedMessagesAndClear ())
315+ .doOnNext ((aLong ) -> graphQlSession .sendPing ())
316+ .then () :
317+ Flux .empty ())
318+ .then ();
319+
320+ if (this .keepAlive != null ) {
321+ Flux .interval (this .keepAlive , this .keepAlive )
322+ .filter ((aLong ) -> graphQlSession .checkSentOrReceivedMessagesAndClear ())
323+ .doOnNext ((aLong ) -> graphQlSession .sendPing ())
324+ .subscribe ();
317325 }
318326
319-
320327 return Mono .zip (sendCompletion , receiveCompletion .then ()).then ();
321328 }
322329
@@ -413,6 +420,8 @@ private static class GraphQlSession {
413420
414421 private final Map <String , RequestState > requestStateMap = new ConcurrentHashMap <>();
415422
423+ private boolean hasReceivedMessages ;
424+
416425
417426 GraphQlSession (WebSocketSession webSocketSession ) {
418427 this .connection = DisposableConnection .from (webSocketSession );
@@ -483,11 +492,16 @@ void sendPong(@Nullable Map<String, Object> payload) {
483492 this .requestSink .sendRequest (message );
484493 }
485494
486- public void sendPing (@ Nullable Map < String , Object > payload ) {
487- GraphQlWebSocketMessage message = GraphQlWebSocketMessage .ping (payload );
495+ void sendPing () {
496+ GraphQlWebSocketMessage message = GraphQlWebSocketMessage .ping (null );
488497 this .requestSink .sendRequest (message );
489498 }
490499
500+ boolean checkSentOrReceivedMessagesAndClear () {
501+ boolean received = this .hasReceivedMessages ;
502+ this .hasReceivedMessages = false ;
503+ return (this .requestSink .checkSentMessagesAndClear () || received );
504+ }
491505
492506 // Inbound messages
493507
@@ -504,6 +518,8 @@ void handleNext(GraphQlWebSocketMessage message) {
504518 return ;
505519 }
506520
521+ this .hasReceivedMessages = true ;
522+
507523 if (requestState instanceof SingleResponseRequestState ) {
508524 this .requestStateMap .remove (id );
509525 }
@@ -631,6 +647,8 @@ private static final class RequestSink {
631647 @ Nullable
632648 private FluxSink <GraphQlWebSocketMessage > requestSink ;
633649
650+ private boolean hasSentMessages ;
651+
634652 private final Flux <GraphQlWebSocketMessage > requestFlux = Flux .create ((sink ) -> {
635653 Assert .state (this .requestSink == null , "Expected single subscriber only for outbound messages" );
636654 this .requestSink = sink ;
@@ -642,9 +660,16 @@ Flux<GraphQlWebSocketMessage> getRequestFlux() {
642660
643661 void sendRequest (GraphQlWebSocketMessage message ) {
644662 Assert .state (this .requestSink != null , "Unexpected request before Flux is subscribed to" );
663+ this .hasSentMessages = true ;
645664 this .requestSink .next (message );
646665 }
647666
667+ boolean checkSentMessagesAndClear () {
668+ boolean result = this .hasSentMessages ;
669+ this .hasSentMessages = false ;
670+ return result ;
671+ }
672+
648673 }
649674
650675
0 commit comments