Skip to content

Commit c96a405

Browse files
committed
refactor: jvm platform input stream wrapper and testcases
1 parent 3b19f78 commit c96a405

File tree

4 files changed

+75
-16
lines changed

4 files changed

+75
-16
lines changed

src/commonMain/kotlin/space/iseki/bencoding/BencodingSerializationException.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@ package space.iseki.bencoding
22

33
import kotlinx.serialization.SerializationException
44

5-
class BencodingSerializationException(
5+
open class BencodingSerializationException(
66
override val message: String = "",
77
override val cause: Throwable? = null,
88
) : SerializationException()
9+
10+
class BencodingDecodeException(
11+
val reason: String,
12+
val position: Long,
13+
cause: Throwable? = null,
14+
) : BencodingSerializationException("decode failed at $position, $reason", cause)

src/commonMain/kotlin/space/iseki/bencoding/IO.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ internal enum class Symbol {
1010
}
1111

1212
internal interface I {
13-
val pos: Int
13+
val pos: Long
1414
fun lookahead(): Symbol
1515
fun readText(): ByteArray
1616
fun readNumber(): Long
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package space.iseki.bencoding
2+
3+
import kotlin.test.Test
4+
import kotlin.test.assertTrue
5+
6+
class BencodingDecodeExceptionTest{
7+
8+
@Test
9+
fun test(){
10+
val th = checkNotNull(runCatching { throw BencodingDecodeException("test", 10) }.exceptionOrNull())
11+
assertTrue { th is BencodingDecodeException && th.position == 10L && th.reason == "test" }
12+
}
13+
}

src/jvmMain/kotlin/space/iseki/bencoding/InputStreamInput.kt

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@ package space.iseki.bencoding
22

33
import kotlinx.serialization.KSerializer
44
import kotlinx.serialization.serializer
5+
import java.io.EOFException
56
import java.io.InputStream
67

78
private const val UNINITIALIZED = -2
89
private const val EOF = -1
910

10-
internal class InputStreamI(private val inputStream: InputStream) : I {
11-
override var pos: Int = 0
12-
private set
11+
12+
internal class InputStreamI(inputStream: InputStream) : I {
13+
private val inputStream = CounteredInputStream(inputStream)
14+
override val pos: Long
15+
get() = inputStream.pos
1316

1417
override fun lookahead(): Symbol = when (la()) {
1518
'l'.code -> Symbol.List
@@ -21,16 +24,28 @@ internal class InputStreamI(private val inputStream: InputStream) : I {
2124
else -> unrecognizedInput()
2225
}
2326

24-
override fun readText(): ByteArray = readLength().let { n -> inputStream.readNBytes(n).also { pos += n } }
27+
override fun readText(): ByteArray {
28+
val len = readLength()
29+
if (len == 0) return ByteArray(0)
30+
val buffer = ByteArray(len)
31+
var p = 0
32+
while (p < buffer.size) {
33+
p = inputStream.read(buffer, p, buffer.size - p)
34+
if (p == -1) {
35+
unexpectedEOF("read $len bytes as text")
36+
}
37+
}
38+
return buffer
39+
}
2540

2641
override fun readNumber(): Long {
2742
var n = 0L
2843
var factor = 1
29-
if (read() != 'i'.code) unrecognizedInput("readNumber")
44+
if (read() != 'i'.code) unrecognizedInput("read number")
3045
while (true) {
3146
when (val i = read()) {
3247
'-'.code -> {
33-
if (n != 0L || factor != 1) unrecognizedInput("readNumber")
48+
if (n != 0L || factor != 1) unrecognizedInput("read number")
3449
factor = -1
3550
}
3651

@@ -39,7 +54,8 @@ internal class InputStreamI(private val inputStream: InputStream) : I {
3954
}
4055

4156
'e'.code -> break
42-
else -> unrecognizedInput("readNumber")
57+
EOF -> unexpectedEOF("read number")
58+
else -> unrecognizedInput("read number")
4359
}
4460
}
4561
return n
@@ -49,20 +65,28 @@ internal class InputStreamI(private val inputStream: InputStream) : I {
4965
when (lookahead()) {
5066
Symbol.EOF -> return
5167
Symbol.Dict, Symbol.List, Symbol.End -> read()
52-
Symbol.Text -> readLength().let { n -> inputStream.skipNBytes(n.toLong()); pos += n }
68+
Symbol.Text -> {
69+
val len = readLength()
70+
try {
71+
inputStream.skipNBytes(len.toLong())
72+
} catch (ex: EOFException) {
73+
unexpectedEOF("skip $len bytes")
74+
}
75+
}
76+
5377
Symbol.Integer -> readNumber()
5478
}
5579
}
5680

5781
private var _la = UNINITIALIZED
5882

5983
private fun la() = when (_la) {
60-
UNINITIALIZED -> inputStream.read().also { _la = it;pos++ }
84+
UNINITIALIZED -> inputStream.read().also { _la = it }
6185
else -> _la
6286
}
6387

6488
private fun read() = when (_la) {
65-
UNINITIALIZED -> inputStream.read().also { pos++ }
89+
UNINITIALIZED -> inputStream.read()
6690
else -> _la.also { _la = UNINITIALIZED }
6791
}
6892

@@ -72,21 +96,37 @@ internal class InputStreamI(private val inputStream: InputStream) : I {
7296
when (val i = read()) {
7397
in '0'.code..'9'.code -> l = l * 10 + (i - '0'.code)
7498
':'.code -> break
75-
else -> unrecognizedInput("readLength")
99+
EOF -> unexpectedEOF("read length")
100+
else -> unrecognizedInput("read length")
76101
}
77102
if (l < 0) error("length overflow")
78103
}
104+
79105
return l
80106
}
81107

108+
private fun unexpectedEOF(during: String = ""): Nothing = when {
109+
during.isEmpty() -> "unexpected EOF"
110+
else -> "unexpected EOF during $during"
111+
}.let { error(it) }
112+
82113
private fun unrecognizedInput(during: String = ""): Nothing = when {
83114
during.isEmpty() -> "unrecognized input"
84115
else -> "unrecognized input, during $during"
85-
}.let { throw BencodingSerializationException(it) }
116+
}.let { throw BencodingDecodeException(it, pos) }
86117

87-
@Suppress("SameParameterValue")
88-
private fun error(msg: String): Nothing = throw BencodingSerializationException(msg)
118+
private fun error(reason: String, cause: Throwable? = null): Nothing =
119+
throw BencodingDecodeException(reason, pos, cause)
89120

121+
private class CounteredInputStream(private val inputStream: InputStream) : InputStream() {
122+
var pos = 0L
123+
private set
124+
125+
override fun read(): Int = inputStream.read().also { if (it > 0) pos += it }
126+
override fun skip(n: Long): Long = inputStream.skip(n).also { if (it > 0) pos += it }
127+
override fun read(b: ByteArray, off: Int, len: Int): Int =
128+
inputStream.read(b, off, len).also { if (it > 0) pos += it }
129+
}
90130
}
91131

92132
inline fun <reified T> InputStream.decodeInBencoding() = decodeInBencoding(serializer<T>())

0 commit comments

Comments
 (0)