Skip to content

Commit 29c1875

Browse files
authored
Fix 4bit tensor unpacking (#118)
4bit tensor unpacking to numpy array was buggy before this fix. I updated the logic to make sure we correctly handle the bytes when converting to numpy. Added unit tests for all numeric dtypes. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent c7879f7 commit 29c1875

File tree

5 files changed

+154
-124
lines changed

5 files changed

+154
-124
lines changed

src/onnx_ir/_core.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -657,15 +657,13 @@ def _load(self):
657657
self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
658658
return
659659
# Map the whole file into the memory
660-
# TODO(justinchuby): Verify if this would exhaust the memory address space
661660
with open(self.path, "rb") as f:
662661
self.raw = mmap.mmap(
663662
f.fileno(),
664663
0,
665664
access=mmap.ACCESS_READ,
666665
)
667-
# Handle the byte order correctly by always using little endian
668-
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
666+
669667
if self.dtype in {
670668
_enums.DataType.INT4,
671669
_enums.DataType.UINT4,
@@ -675,16 +673,18 @@ def _load(self):
675673
dt = np.dtype(np.uint8).newbyteorder("<")
676674
count = self.size // 2 + self.size % 2
677675
else:
676+
# Handle the byte order correctly by always using little endian
677+
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
678678
count = self.size
679+
679680
self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
680681
shape = self.shape.numpy()
681-
if self.dtype == _enums.DataType.INT4:
682-
# Unpack the int4 arrays
683-
self._array = _type_casting.unpack_int4(self._array, shape)
684-
elif self.dtype == _enums.DataType.UINT4:
685-
self._array = _type_casting.unpack_uint4(self._array, shape)
686-
elif self.dtype == _enums.DataType.FLOAT4E2M1:
687-
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
682+
683+
if self.dtype.bitwidth == 4:
684+
# Unpack the 4bit arrays
685+
self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
686+
self.dtype.numpy()
687+
)
688688
else:
689689
self._array = self._array.reshape(shape)
690690

@@ -1071,15 +1071,7 @@ def numpy(self) -> np.ndarray:
10711071
"""
10721072
array = self.numpy_packed()
10731073
# ONNX IR returns the unpacked arrays
1074-
if self.dtype == _enums.DataType.INT4:
1075-
return _type_casting.unpack_int4(array, self.shape.numpy())
1076-
if self.dtype == _enums.DataType.UINT4:
1077-
return _type_casting.unpack_uint4(array, self.shape.numpy())
1078-
if self.dtype == _enums.DataType.FLOAT4E2M1:
1079-
return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
1080-
raise TypeError(
1081-
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
1082-
)
1074+
return _type_casting.unpack_4bitx2(array, self.shape.numpy()).view(self.dtype.numpy())
10831075

10841076
def numpy_packed(self) -> npt.NDArray[np.uint8]:
10851077
"""Return the tensor as a packed array."""

src/onnx_ir/_core_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2082,7 +2082,7 @@ def test_initialize_with_torch_tensor(self, _: str, dtype: ir.DataType):
20822082
)
20832083
np.testing.assert_array_equal(
20842084
tensor.numpy(),
2085-
_type_casting._unpack_uint4_as_uint8(
2085+
_type_casting.unpack_4bitx2(
20862086
packed_data.numpy(force=True).view(np.uint8), dims=[2, 4]
20872087
).view(dtype.numpy()),
20882088
)

src/onnx_ir/_type_casting.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# Copyright (c) ONNX Project Contributors
22
# SPDX-License-Identifier: Apache-2.0
33
"""Numpy utilities for non-native type operation."""
4-
# TODO(justinchuby): Upstream the logic to onnx
54

65
from __future__ import annotations
76

87
import typing
98
from collections.abc import Sequence
109

11-
import ml_dtypes
1210
import numpy as np
1311

1412
if typing.TYPE_CHECKING:
@@ -28,9 +26,7 @@ def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
2826
return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
2927

3028

31-
def _unpack_uint4_as_uint8(
32-
data: npt.NDArray[np.uint8], dims: Sequence[int]
33-
) -> npt.NDArray[np.uint8]:
29+
def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
3430
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
3531
3632
Args:
@@ -52,56 +48,3 @@ def _unpack_uint4_as_uint8(
5248
result = result[:-1]
5349
result.resize(dims, refcheck=False)
5450
return result
55-
56-
57-
def unpack_uint4(
58-
data: npt.NDArray[np.uint8], dims: Sequence[int]
59-
) -> npt.NDArray[ml_dtypes.uint4]:
60-
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
61-
62-
Args:
63-
data: A numpy array.
64-
dims: The dimensions are used to reshape the unpacked buffer.
65-
66-
Returns:
67-
A numpy array of int8/uint8 reshaped to dims.
68-
"""
69-
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.uint4)
70-
71-
72-
def _extend_int4_sign_bits(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]:
73-
"""Extend 4-bit signed integer to 8-bit signed integer."""
74-
return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8)
75-
76-
77-
def unpack_int4(
78-
data: npt.NDArray[np.uint8], dims: Sequence[int]
79-
) -> npt.NDArray[ml_dtypes.int4]:
80-
"""Convert a packed (signed) int4 array to unpacked int4 array represented as int8.
81-
82-
The sign bit is extended to the most significant bit of the int8.
83-
84-
Args:
85-
data: A numpy array.
86-
dims: The dimensions are used to reshape the unpacked buffer.
87-
88-
Returns:
89-
A numpy array of int8 reshaped to dims.
90-
"""
91-
unpacked = _unpack_uint4_as_uint8(data, dims)
92-
return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
93-
94-
95-
def unpack_float4e2m1(
96-
data: npt.NDArray[np.uint8], dims: Sequence[int]
97-
) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
98-
"""Convert a packed float4e2m1 array to unpacked float4e2m1 array.
99-
100-
Args:
101-
data: A numpy array.
102-
dims: The dimensions are used to reshape the unpacked buffer.
103-
104-
Returns:
105-
A numpy array of float32 reshaped to dims.
106-
"""
107-
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)

src/onnx_ir/serde.py

Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474

7575
if typing.TYPE_CHECKING:
7676
import google.protobuf.internal.containers as proto_containers
77-
import numpy.typing as npt
7877

7978
logger = logging.getLogger(__name__)
8079

@@ -117,13 +116,6 @@ def _little_endian_dtype(dtype) -> np.dtype:
117116
return np.dtype(dtype).newbyteorder("<")
118117

119118

120-
def _unflatten_complex(
121-
array: npt.NDArray[np.float32 | np.float64],
122-
) -> npt.NDArray[np.complex64 | np.complex128]:
123-
"""Convert the real representation of a complex dtype to the complex dtype."""
124-
return array[::2] + 1j * array[1::2]
125-
126-
127119
@typing.overload
128120
def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
129121
@typing.overload
@@ -391,54 +383,88 @@ def numpy(self) -> np.ndarray:
391383
"Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
392384
)
393385

386+
shape = self._proto.dims
387+
394388
if self._proto.HasField("raw_data"):
395-
array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<"))
396-
# Cannot return now, because we may need to unpack 4bit tensors
397-
elif dtype == _enums.DataType.STRING:
398-
return np.array(self._proto.string_data).reshape(self._proto.dims)
399-
elif self._proto.int32_data:
400-
array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
401-
if dtype in {_enums.DataType.FLOAT16, _enums.DataType.BFLOAT16}:
402-
# Reinterpret the int32 as float16 or bfloat16
403-
array = array.astype(np.uint16).view(dtype.numpy())
404-
elif dtype in {
389+
if dtype.bitwidth == 4:
390+
return _type_casting.unpack_4bitx2(
391+
np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
392+
).view(dtype.numpy())
393+
return np.frombuffer(
394+
self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
395+
).reshape(shape)
396+
if dtype == _enums.DataType.STRING:
397+
return np.array(self._proto.string_data).reshape(shape)
398+
if self._proto.int32_data:
399+
assert dtype in {
400+
_enums.DataType.BFLOAT16,
401+
_enums.DataType.BOOL,
402+
_enums.DataType.FLOAT16,
403+
_enums.DataType.FLOAT4E2M1,
405404
_enums.DataType.FLOAT8E4M3FN,
406405
_enums.DataType.FLOAT8E4M3FNUZ,
407406
_enums.DataType.FLOAT8E5M2,
408407
_enums.DataType.FLOAT8E5M2FNUZ,
409-
}:
410-
array = array.astype(np.uint8).view(dtype.numpy())
411-
elif self._proto.int64_data:
412-
array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64))
413-
elif self._proto.uint64_data:
408+
_enums.DataType.INT16,
409+
_enums.DataType.INT32,
410+
_enums.DataType.INT4,
411+
_enums.DataType.INT8,
412+
_enums.DataType.UINT16,
413+
_enums.DataType.UINT4,
414+
_enums.DataType.UINT8,
415+
}, f"Unsupported dtype {dtype} for int32_data"
416+
array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
417+
if dtype.bitwidth == 32:
418+
return array.reshape(shape)
419+
if dtype.bitwidth == 16:
420+
# Reinterpret the int32 as float16 or bfloat16
421+
return array.astype(np.uint16).view(dtype.numpy()).reshape(shape)
422+
if dtype.bitwidth == 8:
423+
return array.astype(np.uint8).view(dtype.numpy()).reshape(shape)
424+
if dtype.bitwidth == 4:
425+
return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
426+
dtype.numpy()
427+
)
428+
raise ValueError(
429+
f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
430+
)
431+
if self._proto.int64_data:
432+
assert dtype in {
433+
_enums.DataType.INT64,
434+
}, f"Unsupported dtype {dtype} for int64_data"
435+
return np.array(
436+
self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
437+
).reshape(shape)
438+
if self._proto.uint64_data:
439+
assert dtype in {
440+
_enums.DataType.UINT64,
441+
_enums.DataType.UINT32,
442+
}, f"Unsupported dtype {dtype} for uint64_data"
414443
array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
415-
elif self._proto.float_data:
444+
if dtype == _enums.DataType.UINT32:
445+
return array.astype(np.uint32).reshape(shape)
446+
return array.reshape(shape)
447+
if self._proto.float_data:
448+
assert dtype in {
449+
_enums.DataType.FLOAT,
450+
_enums.DataType.COMPLEX64,
451+
}, f"Unsupported dtype {dtype} for float_data"
416452
array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
417453
if dtype == _enums.DataType.COMPLEX64:
418-
array = _unflatten_complex(array)
419-
elif self._proto.double_data:
454+
return array.view(np.complex64).reshape(shape)
455+
return array.reshape(shape)
456+
if self._proto.double_data:
457+
assert dtype in {
458+
_enums.DataType.DOUBLE,
459+
_enums.DataType.COMPLEX128,
460+
}, f"Unsupported dtype {dtype} for double_data"
420461
array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
421462
if dtype == _enums.DataType.COMPLEX128:
422-
array = _unflatten_complex(array)
423-
else:
424-
# Empty tensor
425-
if not self._proto.dims:
426-
# When dims not precent and there is no data, we return an empty array
427-
return np.array([], dtype=dtype.numpy())
428-
else:
429-
# Otherwise we return a size 0 array with the correct shape
430-
return np.zeros(self._proto.dims, dtype=dtype.numpy())
431-
432-
if dtype == _enums.DataType.INT4:
433-
return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
434-
elif dtype == _enums.DataType.UINT4:
435-
return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
436-
elif dtype == _enums.DataType.FLOAT4E2M1:
437-
return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
438-
else:
439-
# Otherwise convert to the correct dtype and reshape
440-
# Note we cannot use view() here because the storage dtype may not be the same size as the target
441-
return array.astype(dtype.numpy()).reshape(self._proto.dims)
463+
return array.view(np.complex128).reshape(shape)
464+
return array.reshape(shape)
465+
466+
# Empty tensor. We return a size 0 array with the correct shape
467+
return np.zeros(shape, dtype=dtype.numpy())
442468

443469
def tobytes(self) -> bytes:
444470
"""Return the tensor as a byte string conformed to the ONNX specification, in little endian.

src/onnx_ir/serde_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) ONNX Project Contributors
22
# SPDX-License-Identifier: Apache-2.0
3+
import itertools
34
import unittest
45

56
import google.protobuf.text_format
@@ -346,6 +347,74 @@ def test_tensor_proto_tensor_empty_tensor(self):
346347
# Test dlpack
347348
np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy())
348349

350+
@parameterized.parameterized.expand(
351+
[
352+
(name, dtype, array)
353+
for (name, dtype), array in itertools.product(
354+
[
355+
("FLOAT", ir.DataType.FLOAT),
356+
("UINT8", ir.DataType.UINT8),
357+
("INT8", ir.DataType.INT8),
358+
("UINT16", ir.DataType.UINT16),
359+
("INT16", ir.DataType.INT16),
360+
("INT32", ir.DataType.INT32),
361+
("INT64", ir.DataType.INT64),
362+
("BOOL", ir.DataType.BOOL),
363+
("FLOAT16", ir.DataType.FLOAT16),
364+
("DOUBLE", ir.DataType.DOUBLE),
365+
("UINT32", ir.DataType.UINT32),
366+
("UINT64", ir.DataType.UINT64),
367+
("COMPLEX64", ir.DataType.COMPLEX64),
368+
("COMPLEX128", ir.DataType.COMPLEX128),
369+
("BFLOAT16", ir.DataType.BFLOAT16),
370+
("FLOAT8E4M3FN", ir.DataType.FLOAT8E4M3FN),
371+
("FLOAT8E4M3FNUZ", ir.DataType.FLOAT8E4M3FNUZ),
372+
("FLOAT8E5M2", ir.DataType.FLOAT8E5M2),
373+
("FLOAT8E5M2FNUZ", ir.DataType.FLOAT8E5M2FNUZ),
374+
("UINT4", ir.DataType.UINT4),
375+
("INT4", ir.DataType.INT4),
376+
("FLOAT4E2M1", ir.DataType.FLOAT4E2M1),
377+
],
378+
[
379+
np.array(
380+
[
381+
[-1000, -6, -1, -0.0, +0.0],
382+
[0.1, 0.25, 1, float("inf"), -float("inf")],
383+
[float("NaN"), -float("NaN"), 1000, 6.0, 0.001],
384+
],
385+
),
386+
np.array(42),
387+
np.array([]),
388+
np.array([[[], [], []]]),
389+
],
390+
)
391+
]
392+
)
393+
def test_round_trip_numpy_conversion_from_raw_data(
394+
self, _: str, onnx_dtype: ir.DataType, original_array: np.ndarray
395+
):
396+
original_array = original_array.astype(onnx_dtype.numpy())
397+
ir_tensor = ir.Tensor(original_array, name="test_tensor")
398+
proto = serde.to_proto(ir_tensor)
399+
if original_array.size > 0:
400+
self.assertGreater(len(proto.raw_data), 0)
401+
# tensor_proto_tensor from raw_data
402+
tensor_proto_tensor = serde.from_proto(proto)
403+
roundtrip_array = tensor_proto_tensor.numpy()
404+
if onnx_dtype in {
405+
ir.DataType.FLOAT8E5M2FNUZ,
406+
ir.DataType.FLOAT8E5M2,
407+
ir.DataType.FLOAT8E4M3FN,
408+
ir.DataType.BFLOAT16,
409+
}:
410+
# There is a bug in ml_dtypes that causes equality checks to fail for these dtypes
411+
# See https://github.com/jax-ml/ml_dtypes/issues/301
412+
self.assertEqual(roundtrip_array.shape, original_array.shape)
413+
self.assertEqual(roundtrip_array.dtype, original_array.dtype)
414+
self.assertEqual(roundtrip_array.tobytes(), original_array.tobytes())
415+
else:
416+
np.testing.assert_equal(roundtrip_array, original_array, strict=True)
417+
349418

350419
class DeserializeGraphTest(unittest.TestCase):
351420
def test_deserialize_graph_handles_unsorted_graph(self):

0 commit comments

Comments
 (0)