diff --git a/pyroaring/abstract_bitmap.pxi b/pyroaring/abstract_bitmap.pxi index 9608703..39b9e4e 100644 --- a/pyroaring/abstract_bitmap.pxi +++ b/pyroaring/abstract_bitmap.pxi @@ -12,11 +12,16 @@ try: except NameError: # python 3 pass -cdef croaring.roaring_bitmap_t *deserialize_ptr(bytes buff): + +cdef croaring.roaring_bitmap_t *deserialize_ptr(const unsigned char[:] buff): cdef croaring.roaring_bitmap_t *ptr cdef const char *reason_failure = NULL + + cdef char* buffer_ptr = &buff[0] + buff_size = len(buff) - ptr = croaring.roaring_bitmap_portable_deserialize_safe(buff, buff_size) + ptr = croaring.roaring_bitmap_portable_deserialize_safe(buffer_ptr, buff_size) + if ptr == NULL: raise ValueError("Could not deserialize bitmap") # Validate the bitmap @@ -26,11 +31,14 @@ cdef croaring.roaring_bitmap_t *deserialize_ptr(bytes buff): raise ValueError(f"Invalid bitmap after deserialization: {reason_failure.decode('utf-8')}") return ptr -cdef croaring.roaring64_bitmap_t *deserialize64_ptr(bytes buff): +cdef croaring.roaring64_bitmap_t *deserialize64_ptr(const unsigned char[:] buff): cdef croaring.roaring64_bitmap_t *ptr cdef const char *reason_failure = NULL + + cdef char* buffer_ptr = &buff[0] + buff_size = len(buff) - ptr = croaring.roaring64_bitmap_portable_deserialize_safe(buff, buff_size) + ptr = croaring.roaring64_bitmap_portable_deserialize_safe(buffer_ptr, buff_size) if ptr == NULL: raise ValueError("Could not deserialize bitmap") # Validate the bitmap @@ -760,7 +768,7 @@ cdef class AbstractBitMap: @classmethod - def deserialize(cls, bytes buff): + def deserialize(cls, const unsigned char[:] buff): """ Generate a bitmap from the given serialization. See AbstractBitMap.serialize for the reverse operation. @@ -1221,7 +1229,7 @@ cdef class AbstractBitMap64: @classmethod - def deserialize(cls, bytes buff): + def deserialize(cls, const unsigned char[:] buff): """ Generate a bitmap from the given serialization. See AbstractBitMap64.serialize for the reverse operation. diff --git a/test.py b/test.py index e57369d..be51da7 100755 --- a/test.py +++ b/test.py @@ -143,21 +143,37 @@ def bitmap_sample(bitmap: AbstractBitMap, size: int) -> list[int]: return [bitmap[i] for i in indices] def assert_is_not(self, bitmap1: AbstractBitMap, bitmap2: AbstractBitMap) -> None: + add1 = remove1 = add2 = remove2 = -1 if isinstance(bitmap1, BitMap): if bitmap1: - bitmap1.remove(bitmap1[0]) + remove1 = bitmap1[0] + bitmap1.remove(remove1) else: - bitmap1.add(27) + add1 = 27 + bitmap1.add(add1) elif isinstance(bitmap2, BitMap): if bitmap2: - bitmap2.remove(bitmap1[0]) + remove2 = bitmap2[0] + bitmap2.remove(remove2) else: - bitmap2.add(27) + add2 = 27 + bitmap2.add(add2) else: # The two are non-mutable, cannot do anything... return if bitmap1 == bitmap2: pytest.fail( 'The two bitmaps are identical (modifying one also modifies the other).') + # Restore the bitmaps to their original point + else: + if add1 >= 0: + bitmap1.remove(add1) + if remove1 >= 0: + bitmap1.add(remove1) + if add2 >= 0: + bitmap2.remove(add2) + if remove2 >= 0: + bitmap2.add(remove2) + class TestBasic(Util): @@ -874,6 +890,34 @@ def test_serialization( assert isinstance(new_bm, cls2) self.assert_is_not(old_bm, new_bm) + @given(bitmap_cls, bitmap_cls, hyp_many_collections) + def test_deserialization_from_memoryview( + self, + cls1: type[EitherBitMap], + cls2: type[EitherBitMap], + values: list[HypCollection] + ) -> None: + old_bms = [cls1(vals) for vals in values] + + # Create a memoryview with all of the items concatenated into a single bytes + # object. + serialized = [bm.serialize() for bm in old_bms] + sizes = [len(ser) for ser in serialized] + starts = [0] + for s in sizes: + starts.append(s + starts[-1]) + + combined = b''.join(serialized) + mutable_combined = bytearray(combined) + + for source in (combined, mutable_combined): + with memoryview(source) as mv: + new_bms = [cls2.deserialize(mv[start: start + size])for start, size in zip(starts, sizes)] + for old_bm, new_bm in zip(old_bms, new_bms): + assert old_bm == new_bm + assert isinstance(new_bm, cls2) + self.assert_is_not(old_bm, new_bm) + @given(bitmap_cls, hyp_collection, st.integers(min_value=2, max_value=pickle.HIGHEST_PROTOCOL)) def test_pickle_protocol( self,