Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions zstd-kmp/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ android {

testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"

consumerProguardFiles("src/androidMain/consumer-rules.pro")

ndk {
abiFilters += listOf("x86", "x86_64", "armeabi-v7a", "arm64-v8a")
}
Expand Down
61 changes: 18 additions & 43 deletions zstd-kmp/native/ZstdKmp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,6 @@
#include <jni.h>
#include <zstd.h>

/**
* Support for operating on JVM objects from native code.
*
* Pass a pointer to this to all JNI functions that operate on JVM objects.
*/
class JniZstd {
public:
JniZstd(JNIEnv *env, jclass zstdCompressorClass, jclass zstdDecompressorClass);

jfieldID zstdCompressorOutputBytesProcessed;
jfieldID zstdCompressorInputBytesProcessed;
jfieldID zstdDecompressorOutputBytesProcessed;
jfieldID zstdDecompressorInputBytesProcessed;
};

JniZstd::JniZstd(JNIEnv *env, jclass zstdCompressorClass, jclass zstdDecompressorClass)
: zstdCompressorOutputBytesProcessed(env->GetFieldID(zstdCompressorClass, "outputBytesProcessed", "I")),
zstdCompressorInputBytesProcessed(env->GetFieldID(zstdCompressorClass, "inputBytesProcessed", "I")),
zstdDecompressorOutputBytesProcessed(env->GetFieldID(zstdDecompressorClass, "outputBytesProcessed", "I")),
zstdDecompressorInputBytesProcessed(env->GetFieldID(zstdDecompressorClass, "inputBytesProcessed", "I")) {
}

extern "C" JNIEXPORT jstring JNICALL
Java_com_squareup_zstd_JniZstdKt_jniGetErrorName(JNIEnv* env, jobject type, jlong code) {
auto codeSizeT = static_cast<size_t>(code);
Expand All @@ -46,13 +24,6 @@ Java_com_squareup_zstd_JniZstdKt_jniGetErrorName(JNIEnv* env, jobject type, jlon
return env->NewStringUTF(errorString);
}

extern "C" JNIEXPORT jlong JNICALL
Java_com_squareup_zstd_JniZstdKt_createJniZstd(JNIEnv* env, jclass type) {
auto zstdCompressorClass = env->FindClass("com/squareup/zstd/ZstdCompressor");
auto zstdDecompressorClass = env->FindClass("com/squareup/zstd/ZstdDecompressor");
auto jniZstd = new JniZstd(env, zstdCompressorClass, zstdDecompressorClass);
return reinterpret_cast<jlong>(jniZstd);
}

extern "C" JNIEXPORT jlong JNICALL
Java_com_squareup_zstd_JniZstdKt_createZstdCompressor(JNIEnv* env, jclass type) {
Expand All @@ -66,9 +37,8 @@ Java_com_squareup_zstd_JniZstdCompressor_setParameter(JNIEnv* env, jobject type,
return ZSTD_CCtx_setParameter(cctx, static_cast<ZSTD_cParameter>(param), value);
}

extern "C" JNIEXPORT jlong JNICALL
Java_com_squareup_zstd_JniZstdCompressor_compressStream2(JNIEnv* env, jobject type, jlong jniZstdPointer, jlong cctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart, jint mode) {
auto jniZstd = reinterpret_cast<JniZstd*>(jniZstdPointer);
extern "C" JNIEXPORT jlongArray JNICALL
Java_com_squareup_zstd_JniZstdCompressor_compressStream2(JNIEnv* env, jobject type, jlong cctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart, jint mode) {
auto cctx = reinterpret_cast<ZSTD_CCtx*>(cctxPointer);

auto inputByteArrayElements = env->GetByteArrayElements(inputByteArray, NULL);
Expand All @@ -85,13 +55,16 @@ Java_com_squareup_zstd_JniZstdCompressor_compressStream2(JNIEnv* env, jobject ty
result = -ZSTD_error_GENERIC;
}

env->SetIntField(type, jniZstd->zstdCompressorOutputBytesProcessed, static_cast<jint>(zstdOutput.pos) - outputStart);
env->SetIntField(type, jniZstd->zstdCompressorInputBytesProcessed, static_cast<jint>(zstdInput.pos) - inputStart);

env->ReleaseByteArrayElements(inputByteArray, inputByteArrayElements, JNI_ABORT);
env->ReleaseByteArrayElements(outputByteArray, outputByteArrayElements, 0);

return result;
jlong results[3];
results[0] = static_cast<jlong>(result);
results[1] = static_cast<jlong>(zstdInput.pos) - inputStart;
results[2] = static_cast<jlong>(zstdOutput.pos) - outputStart;
jlongArray resultArray = env->NewLongArray(3);
env->SetLongArrayRegion(resultArray, 0, 3, results);
return resultArray;
}

extern "C" JNIEXPORT void JNICALL
Expand All @@ -106,9 +79,8 @@ Java_com_squareup_zstd_JniZstdKt_createZstdDecompressor(JNIEnv* env, jclass type
return reinterpret_cast<jlong>(dctx);
}

extern "C" JNIEXPORT jlong JNICALL
Java_com_squareup_zstd_JniZstdDecompressor_decompressStream(JNIEnv* env, jobject type, jlong jniZstdPointer, jlong dctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart) {
auto jniZstd = reinterpret_cast<JniZstd*>(jniZstdPointer);
extern "C" JNIEXPORT jlongArray JNICALL
Java_com_squareup_zstd_JniZstdDecompressor_decompressStream(JNIEnv* env, jobject type, jlong dctxPointer, jbyteArray outputByteArray, jint outputEnd, jint outputStart, jbyteArray inputByteArray, jint inputEnd, jint inputStart) {
auto dctx = reinterpret_cast<ZSTD_DCtx*>(dctxPointer);

auto inputByteArrayElements = env->GetByteArrayElements(inputByteArray, NULL);
Expand All @@ -124,13 +96,16 @@ Java_com_squareup_zstd_JniZstdDecompressor_decompressStream(JNIEnv* env, jobject
result = -ZSTD_error_GENERIC;
}

env->SetIntField(type, jniZstd->zstdDecompressorOutputBytesProcessed, static_cast<jint>(zstdOutput.pos) - outputStart);
env->SetIntField(type, jniZstd->zstdDecompressorInputBytesProcessed, static_cast<jint>(zstdInput.pos) - inputStart);

env->ReleaseByteArrayElements(inputByteArray, inputByteArrayElements, JNI_ABORT);
env->ReleaseByteArrayElements(outputByteArray, outputByteArrayElements, 0);

return result;
jlong results[3];
results[0] = static_cast<jlong>(result);
results[1] = static_cast<jlong>(zstdInput.pos) - inputStart;
results[2] = static_cast<jlong>(zstdOutput.pos) - outputStart;
jlongArray resultArray = env->NewLongArray(3);
env->SetLongArrayRegion(resultArray, 0, 3, results);
return resultArray;
}

extern "C" JNIEXPORT void JNICALL
Expand Down
7 changes: 7 additions & 0 deletions zstd-kmp/src/androidMain/consumer-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Keep classes with native methods so R8 doesn't rename them.
# JNI function names in the native library are derived from the fully-qualified
# class and method names, so renaming would break the JNI linkage.
-keepclasseswithmembers class com.squareup.zstd.** {
native <methods>;
}

23 changes: 18 additions & 5 deletions zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstd.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,25 @@
*/
package com.squareup.zstd

internal val jniZstdPointer: Long = run {
loadNativeLibrary()
createJniZstd()
}
internal val jniLibraryLoaded: Unit by lazy { loadNativeLibrary() }

/**
* Result of a native stream operation. The native function returns a [LongArray] with three
* elements; this class provides named access to each.
*/
internal class StreamResult(private val values: LongArray) {
/** The zstd result code. */
val result: Long
get() = values[0]

@JvmName("createJniZstd") internal external fun createJniZstd(): Long
/** The number of input bytes consumed. */
val inputBytesProcessed: Int
get() = values[1].toInt()

/** The number of output bytes produced. */
val outputBytesProcessed: Int
get() = values[2].toInt()
}

@JvmName("jniGetErrorName") internal external fun jniGetErrorName(code: Long): String?

Expand Down
37 changes: 22 additions & 15 deletions zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/JniZstdCompressor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package com.squareup.zstd

internal class JniZstdCompressor : ZstdCompressor() {
@JvmField
var cctxPointer =
var cctxPointer = run {
jniLibraryLoaded
createZstdCompressor().also {
if (it == 0L) throw OutOfMemoryError("createZstdCompressor failed")
}
}

override fun setParameter(param: Int, value: Int): Long = setParameter(cctxPointer, param, value)

Expand All @@ -32,18 +34,24 @@ internal class JniZstdCompressor : ZstdCompressor() {
inputEnd: Int,
inputStart: Int,
mode: Int,
): Long =
compressStream2(
jniPointer = jniZstdPointer,
cctxPointer = cctxPointer,
outputByteArray = outputByteArray,
outputEnd = outputEnd,
outputStart = outputStart,
inputByteArray = inputByteArray,
inputEnd = inputEnd,
inputStart = inputStart,
mode = mode,
)
): Long {
val streamResult =
StreamResult(
compressStream2(
cctxPointer = cctxPointer,
outputByteArray = outputByteArray,
outputEnd = outputEnd,
outputStart = outputStart,
inputByteArray = inputByteArray,
inputEnd = inputEnd,
inputStart = inputStart,
mode = mode,
)
)
inputBytesProcessed = streamResult.inputBytesProcessed
outputBytesProcessed = streamResult.outputBytesProcessed
return streamResult.result
}

override fun close() {
val cctxPointerToClose = cctxPointer
Expand All @@ -56,7 +64,6 @@ internal class JniZstdCompressor : ZstdCompressor() {
private external fun setParameter(cctxPointer: Long, param: Int, value: Int): Long

private external fun compressStream2(
jniPointer: Long,
cctxPointer: Long,
outputByteArray: ByteArray,
outputEnd: Int,
Expand All @@ -65,7 +72,7 @@ internal class JniZstdCompressor : ZstdCompressor() {
inputEnd: Int,
inputStart: Int,
mode: Int,
): Long
): LongArray

private external fun close(cctxPointer: Long)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package com.squareup.zstd

internal class JniZstdDecompressor : ZstdDecompressor() {
@JvmField
var dctxPointer =
var dctxPointer = run {
jniLibraryLoaded
createZstdDecompressor().also {
if (it == 0L) throw OutOfMemoryError("createZstdDecompressor failed")
}
}

override fun decompressStream(
outputByteArray: ByteArray,
Expand All @@ -29,17 +31,23 @@ internal class JniZstdDecompressor : ZstdDecompressor() {
inputByteArray: ByteArray,
inputEnd: Int,
inputStart: Int,
): Long =
decompressStream(
jniPointer = jniZstdPointer,
dctxPointer = dctxPointer,
outputByteArray = outputByteArray,
outputEnd = outputEnd,
outputStart = outputStart,
inputByteArray = inputByteArray,
inputEnd = inputEnd,
inputStart = inputStart,
)
): Long {
val streamResult =
StreamResult(
decompressStream(
dctxPointer = dctxPointer,
outputByteArray = outputByteArray,
outputEnd = outputEnd,
outputStart = outputStart,
inputByteArray = inputByteArray,
inputEnd = inputEnd,
inputStart = inputStart,
)
)
inputBytesProcessed = streamResult.inputBytesProcessed
outputBytesProcessed = streamResult.outputBytesProcessed
return streamResult.result
}

override fun close() {
val cctxPointerToClose = dctxPointer
Expand All @@ -50,15 +58,14 @@ internal class JniZstdDecompressor : ZstdDecompressor() {
}

private external fun decompressStream(
jniPointer: Long,
dctxPointer: Long,
outputByteArray: ByteArray,
outputEnd: Int,
outputStart: Int,
inputByteArray: ByteArray,
inputEnd: Int,
inputStart: Int,
): Long
): LongArray

private external fun close(cctxPointer: Long)
}
5 changes: 4 additions & 1 deletion zstd-kmp/src/jniMain/kotlin/com/squareup/zstd/Zstd.jni.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

package com.squareup.zstd

actual fun getErrorName(code: Long): String? = jniGetErrorName(code)
actual fun getErrorName(code: Long): String? {
jniLibraryLoaded
return jniGetErrorName(code)
}

actual fun zstdCompressor(): ZstdCompressor = JniZstdCompressor()

Expand Down
Loading