Skip to content

Commit 009b99c

Browse files
authored
feat: add glacier request body checksum (#379)
1 parent 396859c commit 009b99c

File tree

9 files changed

+318
-11
lines changed

9 files changed

+318
-11
lines changed

codegen/protocol-tests/build.gradle.kts

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ fun generateSmithyBuild(tests: List<ProtocolTest>): String {
106106
"kotlin-codegen": {
107107
"service": "${test.serviceShapeId}",
108108
"package": {
109-
"name": "aws.sdk.kotlin.protocoltest.${test.packageName}",
109+
"name": "aws.sdk.kotlin.services.${test.packageName}",
110110
"version": "1.0"
111111
},
112112
"build": {
@@ -145,12 +145,18 @@ open class ProtocolTestTask : DefaultTask() {
145145
@get:Input
146146
var plugin: String = ""
147147

148+
/**
149+
* The build directory for the task
150+
*/
151+
val generatedBuildDir: File
152+
@OutputDirectory
153+
get() = project.buildDir.resolve("smithyprojections/${project.name}/$protocol/$plugin")
154+
148155
@TaskAction
149156
fun runTests() {
150157
require(protocol.isNotEmpty()) { "protocol name must be specified" }
151158
require(plugin.isNotEmpty()) { "plugin name must be specified" }
152159

153-
val generatedBuildDir = project.file("${project.buildDir}/smithyprojections/${project.name}/$protocol/$plugin")
154160
println("[$protocol] buildDir: $generatedBuildDir")
155161
if (!generatedBuildDir.exists()) {
156162
throw GradleException("$generatedBuildDir does not exist")
@@ -169,14 +175,22 @@ open class ProtocolTestTask : DefaultTask() {
169175
}
170176
}
171177

172-
173-
174178
enabledProtocols.forEach {
175-
tasks.register<ProtocolTestTask>("testProtocol-${it.projectionName}") {
179+
val protocolName = it.projectionName
180+
181+
val protocolTestTask = tasks.register<ProtocolTestTask>("testProtocol-$protocolName") {
176182
dependsOn(tasks["generateSdk"])
177183
group = "Verification"
178-
protocol = it.projectionName
184+
protocol = protocolName
179185
plugin = "kotlin-codegen"
186+
}.get()
187+
188+
// FIXME This is a hack to work around how protocol tests aren't in the actual service model and thus codegen
189+
// separately from service customizations.
190+
tasks.create<Copy>("copyStaticFiles-$protocolName") {
191+
from(rootProject.projectDir.resolve("services/$protocolName/common/src"))
192+
into(protocolTestTask.generatedBuildDir.resolve("src/main/kotlin/"))
193+
tasks["generateSdk"].finalizedBy(this)
180194
}
181195
}
182196

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.codegen.customization.glacier
7+
8+
import aws.sdk.kotlin.codegen.sdkId
9+
import software.amazon.smithy.kotlin.codegen.KotlinSettings
10+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
11+
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
12+
import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
13+
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
14+
import software.amazon.smithy.kotlin.codegen.model.expectShape
15+
import software.amazon.smithy.kotlin.codegen.model.isStreaming
16+
import software.amazon.smithy.kotlin.codegen.rendering.protocol.HttpFeatureMiddleware
17+
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
18+
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware
19+
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
20+
import software.amazon.smithy.model.Model
21+
import software.amazon.smithy.model.shapes.OperationShape
22+
import software.amazon.smithy.model.shapes.ServiceShape
23+
import software.amazon.smithy.model.shapes.StructureShape
24+
25+
public class GlacierBodyChecksum : KotlinIntegration {
26+
override fun enabledForService(model: Model, settings: KotlinSettings): Boolean =
27+
model.expectShape<ServiceShape>(settings.service).sdkId.equals("Glacier", ignoreCase = true)
28+
29+
override fun customizeMiddleware(
30+
ctx: ProtocolGenerator.GenerationContext,
31+
resolved: List<ProtocolMiddleware>,
32+
): List<ProtocolMiddleware> = resolved + glacierBodyChecksumMiddleware
33+
34+
private val glacierBodyChecksumMiddleware = object : HttpFeatureMiddleware() {
35+
override val order: Byte = 127 // Must come after AwsSignatureVersion4
36+
override val name: String = "GlacierBodyChecksum"
37+
38+
override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean {
39+
val input = op.input.getOrNull()?.let { ctx.model.expectShape<StructureShape>(it) }
40+
return input?.members()?.any { it.isStreaming || ctx.model.expectShape(it.target).isStreaming } == true
41+
}
42+
43+
override fun renderConfigure(writer: KotlinWriter) {
44+
writer.addImport(RuntimeTypes.Utils.Sha256)
45+
writer.addImport(glacierSymbol("GlacierBodyChecksum"))
46+
writer.addImport(glacierSymbol("TreeHasherImpl"))
47+
48+
writer.write("val chunkSizeBytes = 1024 * 1024 // 1MB")
49+
writer.write("treeHasher = TreeHasherImpl(chunkSizeBytes) { Sha256() }")
50+
}
51+
52+
private fun glacierSymbol(name: String) = buildSymbol {
53+
this.name = name
54+
namespace = "aws.sdk.kotlin.services.glacier.internal"
55+
}
56+
}
57+
}

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
7878
// FIXME - document type not fully supported yet, see https://github.com/awslabs/smithy-kotlin/issues/123
7979
"PutAndGetInlineDocumentsInput",
8080

81-
// Glacier customizations
82-
"GlacierChecksums", // smithy-kotlin#164
83-
"GlacierMultipartChecksums", // smithy-kotlin#164
84-
8581
// aws-sdk-kotlin#390
8682
"RestJsonHttpWithHeaderMemberNoModeledBody",
8783
"RestJsonHttpWithNoModeledBody",

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/middleware/AwsSignatureVersion4.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import software.amazon.smithy.model.traits.OptionalAuthTrait
2626
*/
2727
open class AwsSignatureVersion4(private val signingServiceName: String) : HttpFeatureMiddleware() {
2828
override val name: String = AwsRuntimeTypes.Signing.AwsSigV4SigningMiddleware.name
29-
override val order: Byte = 127
29+
override val order: Byte = 126 // Must come before GlacierBodyChecksum
3030

3131
init {
3232
require(signingServiceName.isNotEmpty()) { "signingServiceName must be specified" }

codegen/smithy-aws-kotlin-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ aws.sdk.kotlin.codegen.customization.glacier.GlacierAddVersionHeader
1414
aws.sdk.kotlin.codegen.customization.glacier.GlacierAccountIdDefault
1515
aws.sdk.kotlin.codegen.customization.polly.PollyPresigner
1616
aws.sdk.kotlin.codegen.customization.BoxServices
17+
aws.sdk.kotlin.codegen.customization.glacier.GlacierBodyChecksum

services/build.gradle.kts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ subprojects {
3838
sourceSets.getByName("test") {
3939
kotlin.srcDir("common/test")
4040
kotlin.srcDir("generated-src/test")
41+
42+
dependencies {
43+
implementation(kotlin("test-junit5"))
44+
implementation(project(":aws-runtime:testing"))
45+
}
4146
}
4247
}
4348

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.services.glacier.internal
7+
8+
import aws.sdk.kotlin.services.glacier.model.GlacierException
9+
import aws.smithy.kotlin.runtime.client.operationName
10+
import aws.smithy.kotlin.runtime.http.Feature
11+
import aws.smithy.kotlin.runtime.http.FeatureKey
12+
import aws.smithy.kotlin.runtime.http.HttpBody
13+
import aws.smithy.kotlin.runtime.http.HttpClientFeatureFactory
14+
import aws.smithy.kotlin.runtime.http.operation.SdkHttpOperation
15+
import aws.smithy.kotlin.runtime.http.request.headers
16+
import aws.smithy.kotlin.runtime.util.Sha256
17+
import aws.smithy.kotlin.runtime.util.encodeToHex
18+
19+
private const val defaultChunkSizeBytes = 1024 * 1024 // 1MB
20+
21+
internal class GlacierBodyChecksum(config: Config) : Feature {
22+
private val treeHasher = config.treeHasher
23+
24+
internal class Config {
25+
internal var treeHasher: TreeHasher = TreeHasherImpl(defaultChunkSizeBytes) { Sha256() }
26+
}
27+
28+
internal companion object Feature : HttpClientFeatureFactory<Config, GlacierBodyChecksum> {
29+
override val key: FeatureKey<GlacierBodyChecksum> = FeatureKey("GlacierBodyChecksum")
30+
31+
override fun create(block: Config.() -> Unit): GlacierBodyChecksum {
32+
val config = Config().apply(block)
33+
return GlacierBodyChecksum(config)
34+
}
35+
}
36+
37+
override fun <I, O> install(operation: SdkHttpOperation<I, O>) {
38+
operation.execution.finalize.intercept { req, next ->
39+
val body = req.subject.body
40+
if (body is HttpBody.Streaming && !body.isReplayable) {
41+
val opName = req.context.operationName ?: "This operation"
42+
throw GlacierException("$opName requires a byte array or replayable stream")
43+
}
44+
val hashes = treeHasher.calculateHashes(body)
45+
req.subject.headers {
46+
set("X-Amz-Content-Sha256", hashes.fullHash.encodeToHex())
47+
set("X-Amz-Sha256-Tree-Hash", hashes.treeHash.encodeToHex())
48+
}
49+
50+
next.call(req)
51+
}
52+
}
53+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.services.glacier.internal
7+
8+
import aws.smithy.kotlin.runtime.http.HttpBody
9+
import aws.smithy.kotlin.runtime.util.HashFunction
10+
import kotlinx.coroutines.flow.*
11+
import kotlin.math.min
12+
13+
internal typealias HashSupplier = () -> HashFunction
14+
15+
/**
16+
* The result of a [TreeHasher] calculation.
17+
* @param fullHash A full hash of the entire byte array (taken at once)
18+
* @param treeHash A composite hash of the byte array, taken in chunks.
19+
*/
20+
internal class Hashes(public val fullHash: ByteArray, public val treeHash: ByteArray)
21+
22+
/**
23+
* A hash calculator that returns [Hashes] derived using a tree. See
24+
* [Computing Checksums](https://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html) in the Glacier
25+
* service guide for more details.
26+
*/
27+
internal interface TreeHasher {
28+
/**
29+
* Perform the hash calculation.
30+
* @param body The [HttpBody] over which to calculate hashes.
31+
* @return A [Hashes] containing the results of the calculation.
32+
*/
33+
suspend fun calculateHashes(body: HttpBody): Hashes
34+
}
35+
36+
/**
37+
* The default implementation of a [TreeHasher].
38+
*/
39+
internal class TreeHasherImpl(private val chunkSizeBytes: Int, private val hashSupplier: HashSupplier) : TreeHasher {
40+
override suspend fun calculateHashes(body: HttpBody): Hashes {
41+
val full = hashSupplier()
42+
val hashTree = ArrayDeque<ByteArray>()
43+
44+
body.chunks().collect { chunk ->
45+
full.update(chunk)
46+
hashTree.addLast(chunk.hash())
47+
}
48+
49+
if (hashTree.isEmpty()) {
50+
// Edge case for empty bodies
51+
hashTree.add(byteArrayOf().hash())
52+
}
53+
54+
while (hashTree.size > 1) {
55+
val nextRow = mutableListOf<ByteArray>()
56+
57+
while (hashTree.isNotEmpty()) {
58+
if (hashTree.size == 1) {
59+
nextRow.add(hashTree.removeFirst())
60+
} else {
61+
val hash = hashSupplier()
62+
hashTree.removeFirst().let(hash::update)
63+
hashTree.removeFirst().let(hash::update)
64+
nextRow.add(hash.digest())
65+
}
66+
}
67+
68+
hashTree.addAll(nextRow)
69+
}
70+
71+
return Hashes(full.digest(), hashTree.first())
72+
}
73+
74+
private fun ByteArray.hash(): ByteArray = hashSupplier().apply { update(this@hash) }.digest()
75+
76+
private suspend fun HttpBody.chunks(): Flow<ByteArray> = when (this) {
77+
is HttpBody.Empty -> flowOf()
78+
79+
is HttpBody.Bytes -> {
80+
val size = bytes().size
81+
val chunkStarts = 0 until size step chunkSizeBytes
82+
val chunkRanges = chunkStarts.map { it until min(it + chunkSizeBytes, size) }
83+
chunkRanges.asFlow().map(bytes()::sliceArray)
84+
}
85+
86+
is HttpBody.Streaming -> flow {
87+
val channel = readFrom()
88+
while (!channel.isClosedForRead) {
89+
emit(channel.readRemaining(chunkSizeBytes))
90+
}
91+
reset()
92+
}
93+
}
94+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.services.glacier.internal
7+
8+
import aws.sdk.kotlin.runtime.testing.runSuspendTest
9+
import aws.smithy.kotlin.runtime.content.ByteStream
10+
import aws.smithy.kotlin.runtime.http.content.ByteArrayContent
11+
import aws.smithy.kotlin.runtime.http.toHttpBody
12+
import aws.smithy.kotlin.runtime.io.SdkByteChannel
13+
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
14+
import aws.smithy.kotlin.runtime.util.HashFunction
15+
import aws.smithy.kotlin.runtime.util.Sha256
16+
import aws.smithy.kotlin.runtime.util.encodeToHex
17+
import kotlinx.coroutines.async
18+
import kotlinx.coroutines.withTimeout
19+
import kotlin.test.Test
20+
import kotlin.test.assertContentEquals
21+
import kotlin.test.assertEquals
22+
import kotlin.test.fail
23+
24+
private const val megabyte = 1024 * 1024
25+
26+
class TreeHasherTest {
27+
@Test
28+
fun testCalculateHashes() = runSuspendTest {
29+
val chunk1 = byteArrayOf(1, 2, 3)
30+
val chunk2 = byteArrayOf(4, 5, 6)
31+
val payload = chunk1 + chunk2
32+
33+
val fullHash = byteArrayOf(7, 9, 11) // Each element added once (thus, ∑(c[n] + 1))
34+
val treeHash = byteArrayOf(9, 11, 13) // Elements added twice (thus, ∑(c[n] + 2))
35+
val chunkSize = 3
36+
val hasher = TreeHasherImpl(chunkSize) { RollingSumHashFunction(chunkSize) }
37+
38+
val body = ByteArrayContent(payload)
39+
val hashes = hasher.calculateHashes(body)
40+
41+
assertContentEquals(fullHash, hashes.fullHash)
42+
assertContentEquals(treeHash, hashes.treeHash)
43+
}
44+
45+
@Test
46+
fun integrationTestCalculateHashes() = runSuspendTest {
47+
val byteStream = object : ByteStream.ReplayableStream() {
48+
override fun newReader(): SdkByteReadChannel {
49+
val byteChannel = SdkByteChannel()
50+
val payloadBytes = "abcdefghijklmnopqrstuvwxyz".encodeToByteArray() // 26 bytes
51+
async {
52+
withTimeout(10_000) { // For sanity, bail out after 10s
53+
(0 until megabyte).forEach { // This will yield len(payloadBytes) megabytes of content
54+
byteChannel.writeFully(payloadBytes)
55+
}
56+
}
57+
byteChannel.close()
58+
}
59+
return byteChannel
60+
}
61+
}
62+
63+
val hasher = TreeHasherImpl(megabyte) { Sha256() }
64+
val hashes = hasher.calculateHashes(byteStream.toHttpBody())
65+
66+
assertEquals("74df7872289a84fa31b6ae4cfdbac34ef911cfe9357e842c600a060da6a899ae", hashes.fullHash.encodeToHex())
67+
assertEquals("a1c6d421d75f727ce97e6998ab79e6a1cc08ee9502f541f9c5748b462c4dc83f", hashes.treeHash.encodeToHex())
68+
}
69+
}
70+
71+
/**
72+
* Calculates a rolling sum for a hash. In this algorithm:
73+
* * hash size = chunk size
74+
* * every [update] call adds the positional input bytes to a rolling hash plus 1 (differentiates full from tree hashes)
75+
*/
76+
class RollingSumHashFunction(private val chunkSize: Int) : HashFunction {
77+
private val rollingHash = ByteArray(chunkSize)
78+
79+
override fun digest(): ByteArray = rollingHash
80+
override fun reset() = fail("reset should not have been called")
81+
override fun update(input: ByteArray) {
82+
assertEquals(chunkSize, input.size, "Chunk size must be exactly $chunkSize")
83+
for (i in input.indices) {
84+
rollingHash[i] = (rollingHash[i] + input[i] + 1).toByte()
85+
}
86+
}
87+
}

0 commit comments

Comments
 (0)