Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.springframework.web.reactive.function.client.WebClient.RequestHeaders
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import reactor.util.context.Context
import reactor.util.retry.Retry
import kotlin.coroutines.CoroutineContext

/**
Expand Down Expand Up @@ -226,7 +227,7 @@ inline fun <reified T : Any> WebClient.ResponseSpec.toEntityFlux(): Mono<Respons
toEntityFlux(object : ParameterizedTypeReference<T>() {})

/**
* Extension for [WebClient.ResponseSpec.toEntity] providing a `toEntity<Foo>()` variant
* Extension for [WebClient.ResponseSpec.toEntity] providing a `awaitEntity<Foo>()` variant
* leveraging Kotlin reified type parameters and allows [kotlin.coroutines.CoroutineContext]
* propagation to the [CoExchangeFilterFunction]. This extension is not subject to type erasure
* and retains actual generic type arguments.
Expand All @@ -240,6 +241,22 @@ suspend inline fun <reified T : Any> WebClient.ResponseSpec.awaitEntity(): Respo
}
}

/**
* Extension for [WebClient.ResponseSpec.toEntity] providing a `awaitEntityWithRetry<Foo>(Retry)` variant
* leveraging Kotlin reified type parameters and allows [kotlin.coroutines.CoroutineContext]
* propagation to the [CoExchangeFilterFunction]. This extension is not subject to type erasure
* and retains actual generic type arguments.
*
* @param retrySpec the [Retry] strategy passed to the [Mono.retryWhen]
* @param T the type of the body
*/
suspend inline fun <reified T : Any> WebClient.ResponseSpec.awaitEntityWithRetry(retrySpec: Retry): ResponseEntity<T> {
val context = currentCoroutineContext().minusKey(Job.Key)
return withContext(context.toReactorContext()) {
toEntity<T>().retryWhen(retrySpec).awaitSingle()
}
}

private val contextPropagationPresent = ClassUtils.isPresent("io.micrometer.context.ContextSnapshotFactory",
WebClient::class.java.classLoader)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ import org.springframework.web.reactive.function.client.CoExchangeFilterFunction
import reactor.core.publisher.Flux
import reactor.core.publisher.Hooks
import reactor.core.publisher.Mono
import reactor.util.retry.Retry
import java.time.Duration
import java.util.concurrent.CompletableFuture
import java.util.concurrent.atomic.AtomicInteger
import java.util.function.Function
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
Expand Down Expand Up @@ -247,6 +249,42 @@ class WebClientExtensionsTests {
}
}

@Test
fun `ResponseSpec#awaitEntityWithRetry with coroutine context propagation`() {
val exchangeFunction = mockk<ExchangeFunction>()
val mockResponse = mockk<ClientResponse>()
val mockClientHeaders = mockk<ClientResponse.Headers>()
val foo = mockk<Foo>()
val slot = slot<ClientRequest>()
val atomicInteger = AtomicInteger(0)
every { exchangeFunction.exchange(capture(slot)) } answers {
if (atomicInteger.getAndIncrement() < 2) {
Mono.error(Exception())
} else {
Mono.just(mockResponse)
}
}
every { mockResponse.statusCode() } returns HttpStatus.OK
every { mockResponse.headers() } returns mockClientHeaders
every { mockClientHeaders.asHttpHeaders() } returns HttpHeaders()
every { mockResponse.bodyToMono(object : ParameterizedTypeReference<Foo>() {}) } returns Mono.just(foo)
runBlocking(FooContextElement(foo)) {
val responseEntity = WebClient.builder()
.exchangeFunction(exchangeFunction)
.filter(object : CoExchangeFilterFunction() {
override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse {
assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo)
return next.exchange(request)
}
})
.build().get().uri("/path").retrieve().awaitEntityWithRetry<Foo>(Retry.max(2))
val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext
assertThat(atomicInteger.get()).isEqualTo(3)
assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo)
assertThat(responseEntity.body).isEqualTo(foo)
}
}

@Test
fun `ResponseSpec#awaitEntity with coroutine context propagation to multiple CoExchangeFilterFunctions`() {
val exchangeFunction = mockk<ExchangeFunction>()
Expand Down