diff --git a/src/erc7730/common/binary.py b/src/erc7730/common/binary.py index 0477ed2a..a570be02 100644 --- a/src/erc7730/common/binary.py +++ b/src/erc7730/common/binary.py @@ -13,43 +13,53 @@ def from_hex(value: str) -> bytes: return bytes.fromhex(value.removeprefix("0x")) -def tlv(tag: int | IntEnum, *value: bytes | str | None) -> bytes: +def tlv(tag: int | IntEnum, value: bytes | str | None = None) -> bytes: """ Encode a value in TLV format (Tag-Length-Value) - If value is not encoded, it will be encoded as ASCII. - If input string is not ASCII, and UnicodeEncodeError is raised. + Tag and length are DER encoded. If tag value or length exceed 255, an OverflowError is raised. - If encoded value is longer than 255 bytes, an OverflowError is raised. + If value is not encoded, it will be encoded as ASCII. + If input string is not ASCII, a UnicodeEncodeError is raised. @param tag: the tag (can be an enum) - @param value: the value (can be already encoded, or a string) + @param value: the value (can be already encoded, a string or None) @return: encoded TLV """ - values_encoded = bytearray() - for v in value: - if v is not None: - values_encoded.extend(v.encode("ascii", errors="strict") if isinstance(v, str) else v) - return ( - (tag.value if isinstance(tag, IntEnum) else tag).to_bytes(1, "big") - + len(values_encoded).to_bytes(1, "big") - + values_encoded - ) + return der_encode_int(tag.value if isinstance(tag, IntEnum) else tag) + length_value(value) -def length_value(value: bytes | str | None) -> bytes: + +def length_value( + value: bytes | str | None, +) -> bytes: """ - Prepend the length of the value encoded on 1 byte to the value itself. + Prepend the length (DER encoded) of the value encoded to the value itself. + If length exceeds 255 bytes, an OverflowError is raised. If value is not encoded, it will be encoded as ASCII. - If input string is not ASCII, and UnicodeEncodeError is raised. - - If encoded value is longer than 255 bytes, an OverflowError is raised. + If input string is not ASCII, a UnicodeEncodeError is raised. @param value: the value (can be already encoded, or a string) @return: encoded TLV """ if value is None: return (0).to_bytes(1, "big") - value_encoded = value.encode("ascii", errors="strict") if isinstance(value, str) else value - return len(value_encoded).to_bytes(1, "big") + value_encoded + match value: + case bytes(): + value_encoded = value + case str(): + value_encoded = value.encode("ascii", errors="strict") + return der_encode_int(len(value_encoded)) + value_encoded + + +def der_encode_int(value: int) -> bytes: + """ + Encode an integer in DER format. + If value exceeds 255, an OverflowError is raised. + + @param value: the integer to encode + @return: DER encoded byte array + """ + value_bytes = value.to_bytes(1, "big") # raises OverflowError if value >= 256 + return (0x81).to_bytes(1, "big") + value_bytes if value >= 0x80 else value_bytes diff --git a/tests/common/test_binary.py b/tests/common/test_binary.py new file mode 100644 index 00000000..35843e17 --- /dev/null +++ b/tests/common/test_binary.py @@ -0,0 +1,37 @@ +from enum import IntEnum + +import pytest + +from erc7730.common.binary import tlv + + +class _Tag(IntEnum): + FIELD = 5 + + +@pytest.mark.parametrize( + "tag, value, expected", + [ + (1, None, b"\x01\x00"), + (1, b"\xab", b"\x01\x01\xab"), + (1, "hi", b"\x01\x02hi"), + (_Tag.FIELD, b"\xff", b"\x05\x01\xff"), + (0x80, None, b"\x81\x80\x00"), + (1, b"\x00" * 128, b"\x01\x81\x80" + b"\x00" * 128), + ], +) +def test_tlv(tag: int | IntEnum, value: bytes | str | None, expected: bytes) -> None: + assert tlv(tag, value) == expected + + +@pytest.mark.parametrize( + "tag, value, exc", + [ + (256, None, OverflowError), + (1, b"\x00" * 256, OverflowError), + (1, "là-haut", UnicodeEncodeError), + ], +) +def test_tlv_errors(tag: int, value: bytes | str | None, exc: type[Exception]) -> None: + with pytest.raises(exc): + tlv(tag, value)