2020import java .util .ArrayList ;
2121import java .util .Arrays ;
2222import java .util .Collections ;
23+ import java .util .HashSet ;
2324import java .util .LinkedHashMap ;
2425import java .util .List ;
2526import java .util .Map ;
27+ import java .util .Set ;
2628import java .util .concurrent .CompletionException ;
2729
2830import com .apollographql .federation .graphqljava ._Entity ;
3840import org .springframework .graphql .data .method .annotation .support .HandlerDataFetcherExceptionResolver ;
3941import org .springframework .graphql .execution .ErrorType ;
4042import org .springframework .lang .Nullable ;
43+ import org .springframework .util .Assert ;
4144
4245/**
4346 * DataFetcher that handles the "_entities" query by invoking
@@ -65,6 +68,7 @@ final class EntitiesDataFetcher implements DataFetcher<Mono<DataFetcherResult<Li
6568 public Mono <DataFetcherResult <List <Object >>> get (DataFetchingEnvironment environment ) {
6669 List <Map <String , Object >> representations = environment .getArgument (_Entity .argumentName );
6770
71+ Set <String > batched = new HashSet <>();
6872 List <Mono <Object >> monoList = new ArrayList <>();
6973 for (int index = 0 ; index < representations .size (); index ++) {
7074 Map <String , Object > map = representations .get (index );
@@ -79,15 +83,27 @@ public Mono<DataFetcherResult<List<Object>>> get(DataFetchingEnvironment environ
7983 monoList .add (resolveException (ex , environment , null , index ));
8084 continue ;
8185 }
82- monoList .add (invokeResolver (environment , handlerMethod , map , index ));
86+
87+ if (!handlerMethod .isBatchHandlerMethod ()) {
88+ monoList .add (invokeEntityMethod (environment , handlerMethod , map , index ));
89+ }
90+ else if (batched .contains (typename )) {
91+ // zip needs a value, this will be replaced by batch results
92+ monoList .add (Mono .just (Collections .emptyMap ()));
93+ }
94+ else {
95+ EntityBatchDelegate batchDelegate = new EntityBatchDelegate (environment , handlerMethod , typename );
96+ monoList .add (batchDelegate .invokeEntityBatchMethod ());
97+ batched .add (typename );
98+ }
8399 }
84100 return Mono .zip (monoList , Arrays ::asList ).map (EntitiesDataFetcher ::toDataFetcherResult );
85101 }
86102
87- private Mono <Object > invokeResolver (
103+ private Mono <Object > invokeEntityMethod (
88104 DataFetchingEnvironment env , EntityHandlerMethod handlerMethod , Map <String , Object > map , int index ) {
89105
90- return handlerMethod .getEntity (env , map , index )
106+ return handlerMethod .getEntity (env , map )
91107 .switchIfEmpty (Mono .error (new RepresentationNotResolvedException (map , handlerMethod )))
92108 .onErrorResume ((ex ) -> resolveException (ex , env , handlerMethod , index ));
93109 }
@@ -96,7 +112,7 @@ private Mono<Object> resolveException(
96112 Throwable ex , DataFetchingEnvironment env , @ Nullable EntityHandlerMethod handlerMethod , int index ) {
97113
98114 Throwable theEx = (ex instanceof CompletionException ) ? ex .getCause () : ex ;
99- DataFetchingEnvironment theEnv = new EntityDataFetchingEnvironment (env , index );
115+ DataFetchingEnvironment theEnv = new IndexedDataFetchingEnvironment (env , index );
100116 Object handler = (handlerMethod != null ) ? handlerMethod .getBean () : null ;
101117
102118 return this .exceptionResolver .resolveException (theEx , theEnv , handler )
@@ -120,6 +136,9 @@ private static DataFetcherResult<List<Object>> toDataFetcherResult(List<Object>
120136 List <GraphQLError > errors = new ArrayList <>();
121137 for (int i = 0 ; i < entities .size (); i ++) {
122138 Object entity = entities .get (i );
139+ if (entity instanceof EntityBatchDelegate delegate ) {
140+ delegate .processResults (entities , errors );
141+ }
123142 if (entity instanceof ErrorContainer errorContainer ) {
124143 errors .addAll (errorContainer .errors ());
125144 entities .set (i , null );
@@ -129,11 +148,80 @@ private static DataFetcherResult<List<Object>> toDataFetcherResult(List<Object>
129148 }
130149
131150
132- private static class EntityDataFetchingEnvironment extends DelegatingDataFetchingEnvironment {
151+ private class EntityBatchDelegate {
152+
153+ private final DataFetchingEnvironment environment ;
154+
155+ private final EntityHandlerMethod handlerMethod ;
156+
157+ private final List <Map <String , Object >> representations = new ArrayList <>();
158+
159+ private final List <Integer > indexes = new ArrayList <>();
160+
161+ @ Nullable
162+ private List <?> resultList ;
163+
164+ EntityBatchDelegate (DataFetchingEnvironment env , EntityHandlerMethod handlerMethod , String typeName ) {
165+ this .environment = env ;
166+ this .handlerMethod = handlerMethod ;
167+ List <Map <String , Object >> maps = env .getArgument (_Entity .argumentName );
168+ for (int i = 0 ; i < maps .size (); i ++) {
169+ Map <String , Object > map = maps .get (i );
170+ if (typeName .equals (map .get ("__typename" ))) {
171+ this .representations .add (map );
172+ this .indexes .add (i );
173+ }
174+ }
175+ }
176+
177+ Mono <Object > invokeEntityBatchMethod () {
178+ return this .handlerMethod .getEntities (this .environment , this .representations )
179+ .mapNotNull ((result ) -> (((List <?>) result ).isEmpty ()) ? null : result )
180+ .switchIfEmpty (Mono .defer (this ::handleEmptyResult ))
181+ .onErrorResume (this ::handleErrorResult )
182+ .map ((result ) -> {
183+ this .resultList = (List <?>) result ;
184+ return this ;
185+ });
186+ }
187+
188+ Mono <Object > handleEmptyResult () {
189+ List <Mono <Object >> exceptions = new ArrayList <>(this .indexes .size ());
190+ for (int i = 0 ; i < this .indexes .size (); i ++) {
191+ Map <String , Object > map = this .representations .get (i );
192+ Exception ex = new RepresentationNotResolvedException (map , this .handlerMethod );
193+ exceptions .add (resolveException (ex , this .environment , this .handlerMethod , this .indexes .get (i )));
194+ }
195+ return Mono .zip (exceptions , Arrays ::asList );
196+ }
197+
198+ Mono <List <Object >> handleErrorResult (Throwable ex ) {
199+ List <Mono <Object >> list = new ArrayList <>();
200+ for (Integer index : this .indexes ) {
201+ list .add (resolveException (ex , this .environment , this .handlerMethod , index ));
202+ }
203+ return Mono .zip (list , Arrays ::asList );
204+ }
205+
206+ void processResults (List <Object > entities , List <GraphQLError > errors ) {
207+ Assert .state (this .resultList != null , "Expected resultList" );
208+ for (int i = 0 ; i < this .resultList .size (); i ++) {
209+ Object entity = this .resultList .get (i );
210+ if (entity instanceof ErrorContainer errorContainer ) {
211+ errors .addAll (errorContainer .errors ());
212+ entity = null ;
213+ }
214+ entities .set (this .indexes .get (i ), entity );
215+ }
216+ }
217+ }
218+
219+
220+ private static class IndexedDataFetchingEnvironment extends DelegatingDataFetchingEnvironment {
133221
134222 private final ExecutionStepInfo executionStepInfo ;
135223
136- EntityDataFetchingEnvironment (DataFetchingEnvironment env , int index ) {
224+ IndexedDataFetchingEnvironment (DataFetchingEnvironment env , int index ) {
137225 super (env );
138226 this .executionStepInfo = ExecutionStepInfo .newExecutionStepInfo (env .getExecutionStepInfo ())
139227 .path (env .getExecutionStepInfo ().getPath ().segment (index ))
0 commit comments