diff --git a/zstd-kmp/build.gradle.kts b/zstd-kmp/build.gradle.kts index 169ed2f..97664ba 100644 --- a/zstd-kmp/build.gradle.kts +++ b/zstd-kmp/build.gradle.kts @@ -120,6 +120,8 @@ android { testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles("src/androidMain/consumer-rules.pro") + ndk { abiFilters += listOf("x86", "x86_64", "armeabi-v7a", "arm64-v8a") } diff --git a/zstd-kmp/native/ZstdKmp.cpp b/zstd-kmp/native/ZstdKmp.cpp index b8804ed..6329a25 100644 --- a/zstd-kmp/native/ZstdKmp.cpp +++ b/zstd-kmp/native/ZstdKmp.cpp @@ -16,28 +16,6 @@ #include #include -/** - * Support for operating on JVM objects from native code. - * - * Pass a pointer to this to all JNI functions that operate on JVM objects. - */ -class JniZstd { -public: - JniZstd(JNIEnv *env, jclass zstdCompressorClass, jclass zstdDecompressorClass); - - jfieldID zstdCompressorOutputBytesProcessed; - jfieldID zstdCompressorInputBytesProcessed; - jfieldID zstdDecompressorOutputBytesProcessed; - jfieldID zstdDecompressorInputBytesProcessed; -}; - -JniZstd::JniZstd(JNIEnv *env, jclass zstdCompressorClass, jclass zstdDecompressorClass) - : zstdCompressorOutputBytesProcessed(env->GetFieldID(zstdCompressorClass, "outputBytesProcessed", "I")), - zstdCompressorInputBytesProcessed(env->GetFieldID(zstdCompressorClass, "inputBytesProcessed", "I")), - zstdDecompressorOutputBytesProcessed(env->GetFieldID(zstdDecompressorClass, "outputBytesProcessed", "I")), - zstdDecompressorInputBytesProcessed(env->GetFieldID(zstdDecompressorClass, "inputBytesProcessed", "I")) { -} - extern "C" JNIEXPORT jstring JNICALL Java_com_squareup_zstd_JniZstdKt_jniGetErrorName(JNIEnv* env, jobject type, jlong code) { auto codeSizeT = static_cast(code); @@ -46,13 +24,6 @@ Java_com_squareup_zstd_JniZstdKt_jniGetErrorName(JNIEnv* env, jobject type, jlon return env->NewStringUTF(errorString); } -extern "C" JNIEXPORT jlong JNICALL -Java_com_squareup_zstd_JniZstdKt_createJniZstd(JNIEnv* env, jclass type) { - auto zstdCompressorClass = env->FindClass("com/squareup/zstd/ZstdCompressor"); - auto zstdDecompressorClass = env->FindClass("com/squareup/zstd/ZstdDecompressor"); - auto jniZstd = new JniZstd(env, zstdCompressorClass, zstdDecompressorClass); - return reinterpret_cast(jniZstd); -} extern "C" JNIEXPORT jlong JNICALL Java_com_squareup_zstd_JniZstdKt_createZstdCompressor(JNIEnv* env, jclass type) { @@ -66,9 +37,8 @@ Java_com_squareup_zstd_JniZstdCompressor_setParameter(JNIEnv* env, jobject type, return ZSTD_CCtx_setParameter(cctx, static_cast(param), value); } -extern "C" JNIEXPORT jlong JNICALL -Java_com_squareup_zstd_JniZstdCompressor_compressStream2(JNIEnv* env, jobject type, jlong jniZstdPointer, jlong cctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart, jint mode) { - auto jniZstd = reinterpret_cast(jniZstdPointer); +extern "C" JNIEXPORT jlongArray JNICALL +Java_com_squareup_zstd_JniZstdCompressor_compressStream2(JNIEnv* env, jobject type, jlong cctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart, jint mode) { auto cctx = reinterpret_cast(cctxPointer); auto inputByteArrayElements = env->GetByteArrayElements(inputByteArray, NULL); @@ -85,13 +55,16 @@ Java_com_squareup_zstd_JniZstdCompressor_compressStream2(JNIEnv* env, jobject ty result = -ZSTD_error_GENERIC; } - env->SetIntField(type, jniZstd->zstdCompressorOutputBytesProcessed, static_cast(zstdOutput.pos) - outputStart); - env->SetIntField(type, jniZstd->zstdCompressorInputBytesProcessed, static_cast(zstdInput.pos) - inputStart); - env->ReleaseByteArrayElements(inputByteArray, inputByteArrayElements, JNI_ABORT); env->ReleaseByteArrayElements(outputByteArray, outputByteArrayElements, 0); - return result; + jlong results[3]; + results[0] = static_cast(result); + results[1] = static_cast(zstdInput.pos) - inputStart; + results[2] = static_cast(zstdOutput.pos) - outputStart; + jlongArray resultArray = env->NewLongArray(3); + env->SetLongArrayRegion(resultArray, 0, 3, results); + return resultArray; } extern "C" JNIEXPORT void JNICALL @@ -106,9 +79,8 @@ Java_com_squareup_zstd_JniZstdKt_createZstdDecompressor(JNIEnv* env, jclass type return reinterpret_cast(dctx); } -extern "C" JNIEXPORT jlong JNICALL -Java_com_squareup_zstd_JniZstdDecompressor_decompressStream(JNIEnv* env, jobject type, jlong jniZstdPointer, jlong dctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart) { - auto jniZstd = reinterpret_cast(jniZstdPointer); +extern "C" JNIEXPORT jlongArray JNICALL +Java_com_squareup_zstd_JniZstdDecompressor_decompressStream(JNIEnv* env, jobject type, jlong dctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart) { auto dctx = reinterpret_cast(dctxPointer); auto inputByteArrayElements = env->GetByteArrayElements(inputByteArray, NULL); @@ -124,13 +96,16 @@ Java_com_squareup_zstd_JniZstdDecompressor_decompressStream(JNIEnv* env, jobject result = -ZSTD_error_GENERIC; } - env->SetIntField(type, jniZstd->zstdDecompressorOutputBytesProcessed, static_cast(zstdOutput.pos) - outputStart); - env->SetIntField(type, jniZstd->zstdDecompressorInputBytesProcessed, static_cast(zstdInput.pos) - inputStart); - env->ReleaseByteArrayElements(inputByteArray, inputByteArrayElements, JNI_ABORT); env->ReleaseByteArrayElements(outputByteArray, outputByteArrayElements, 0); - return result; + jlong results[3]; + results[0] = static_cast(result); + results[1] = static_cast(zstdInput.pos) - inputStart; + results[2] = static_cast(zstdOutput.pos) - outputStart; + jlongArray resultArray = env->NewLongArray(3); + env->SetLongArrayRegion(resultArray, 0, 3, results); + return resultArray; } extern "C" JNIEXPORT void JNICALL diff --git a/zstd-kmp/src/androidMain/consumer-rules.pro b/zstd-kmp/src/androidMain/consumer-rules.pro new file mode 100644 index 0000000..103fe73 --- /dev/null +++ b/zstd-kmp/src/androidMain/consumer-rules.pro @@ -0,0 +1,7 @@ +# Keep classes with native methods so R8 doesn't rename them. +# JNI function names in the native library are derived from the fully-qualified +# class and method names, so renaming would break the JNI linkage. +-keepclasseswithmembers class com.squareup.zstd.** { + native ; +} + diff --git a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstd.kt b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstd.kt index 04a9220..77ee565 100644 --- a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstd.kt +++ b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstd.kt @@ -15,12 +15,25 @@ */ package com.squareup.zstd -internal val jniZstdPointer: Long = run { - loadNativeLibrary() - createJniZstd() -} +internal val jniLibraryLoaded: Unit by lazy { loadNativeLibrary() } + +/** + * Result of a native stream operation. The native function returns a [LongArray] with three + * elements; this class provides named access to each. + */ +internal class StreamResult(private val values: LongArray) { + /** The zstd result code. */ + val result: Long + get() = values[0] -@JvmName("createJniZstd") internal external fun createJniZstd(): Long + /** The number of input bytes consumed. */ + val inputBytesProcessed: Int + get() = values[1].toInt() + + /** The number of output bytes produced. */ + val outputBytesProcessed: Int + get() = values[2].toInt() +} @JvmName("jniGetErrorName") internal external fun jniGetErrorName(code: Long): String? diff --git a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdCompressor.kt b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdCompressor.kt index dac31a5..26143e6 100644 --- a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdCompressor.kt +++ b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdCompressor.kt @@ -17,10 +17,12 @@ package com.squareup.zstd internal class JniZstdCompressor : ZstdCompressor() { @JvmField - var cctxPointer = + var cctxPointer = run { + jniLibraryLoaded createZstdCompressor().also { if (it == 0L) throw OutOfMemoryError("createZstdCompressor failed") } + } override fun setParameter(param: Int, value: Int): Long = setParameter(cctxPointer, param, value) @@ -32,18 +34,24 @@ internal class JniZstdCompressor : ZstdCompressor() { inputEnd: Int, inputStart: Int, mode: Int, - ): Long = - compressStream2( - jniPointer = jniZstdPointer, - cctxPointer = cctxPointer, - outputByteArray = outputByteArray, - outputEnd = outputEnd, - outputStart = outputStart, - inputByteArray = inputByteArray, - inputEnd = inputEnd, - inputStart = inputStart, - mode = mode, - ) + ): Long { + val streamResult = + StreamResult( + compressStream2( + cctxPointer = cctxPointer, + outputByteArray = outputByteArray, + outputEnd = outputEnd, + outputStart = outputStart, + inputByteArray = inputByteArray, + inputEnd = inputEnd, + inputStart = inputStart, + mode = mode, + ) + ) + inputBytesProcessed = streamResult.inputBytesProcessed + outputBytesProcessed = streamResult.outputBytesProcessed + return streamResult.result + } override fun close() { val cctxPointerToClose = cctxPointer @@ -56,7 +64,6 @@ internal class JniZstdCompressor : ZstdCompressor() { private external fun setParameter(cctxPointer: Long, param: Int, value: Int): Long private external fun compressStream2( - jniPointer: Long, cctxPointer: Long, outputByteArray: ByteArray, outputEnd: Int, @@ -65,7 +72,7 @@ internal class JniZstdCompressor : ZstdCompressor() { inputEnd: Int, inputStart: Int, mode: Int, - ): Long + ): LongArray private external fun close(cctxPointer: Long) } diff --git a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdDecompressor.kt b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdDecompressor.kt index 0a7bd11..1b58ed5 100644 --- a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdDecompressor.kt +++ b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdDecompressor.kt @@ -17,10 +17,12 @@ package com.squareup.zstd internal class JniZstdDecompressor : ZstdDecompressor() { @JvmField - var dctxPointer = + var dctxPointer = run { + jniLibraryLoaded createZstdDecompressor().also { if (it == 0L) throw OutOfMemoryError("createZstdDecompressor failed") } + } override fun decompressStream( outputByteArray: ByteArray, @@ -29,17 +31,23 @@ internal class JniZstdDecompressor : ZstdDecompressor() { inputByteArray: ByteArray, inputEnd: Int, inputStart: Int, - ): Long = - decompressStream( - jniPointer = jniZstdPointer, - dctxPointer = dctxPointer, - outputByteArray = outputByteArray, - outputEnd = outputEnd, - outputStart = outputStart, - inputByteArray = inputByteArray, - inputEnd = inputEnd, - inputStart = inputStart, - ) + ): Long { + val streamResult = + StreamResult( + decompressStream( + dctxPointer = dctxPointer, + outputByteArray = outputByteArray, + outputEnd = outputEnd, + outputStart = outputStart, + inputByteArray = inputByteArray, + inputEnd = inputEnd, + inputStart = inputStart, + ) + ) + inputBytesProcessed = streamResult.inputBytesProcessed + outputBytesProcessed = streamResult.outputBytesProcessed + return streamResult.result + } override fun close() { val cctxPointerToClose = dctxPointer @@ -50,7 +58,6 @@ internal class JniZstdDecompressor : ZstdDecompressor() { } private external fun decompressStream( - jniPointer: Long, dctxPointer: Long, outputByteArray: ByteArray, outputEnd: Int, @@ -58,7 +65,7 @@ internal class JniZstdDecompressor : ZstdDecompressor() { inputByteArray: ByteArray, inputEnd: Int, inputStart: Int, - ): Long + ): LongArray private external fun close(cctxPointer: Long) } diff --git a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/Zstd.jni.kt b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/Zstd.jni.kt index f2aed42..5d4a2d7 100644 --- a/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/Zstd.jni.kt +++ b/zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/Zstd.jni.kt @@ -18,7 +18,10 @@ package com.squareup.zstd -actual fun getErrorName(code: Long): String? = jniGetErrorName(code) +actual fun getErrorName(code: Long): String? { + jniLibraryLoaded + return jniGetErrorName(code) +} actual fun zstdCompressor(): ZstdCompressor = JniZstdCompressor()