diff --git a/include/trx/trx.tpp b/include/trx/trx.tpp index e2bd7a9..3a9dd8c 100644 --- a/include/trx/trx.tpp +++ b/include/trx/trx.tpp @@ -359,7 +359,7 @@ TrxFile
::_create_trx_from_pointer(json header, std::tuple shape = std::make_tuple(static_cast(trx->header["NB_VERTICES"].int_value()), 3); trx->streamlines->mmap_pos = - trx::_create_memmap(filename, shape, "r+", ext.substr(1, ext.size() - 1), mem_adress); + trx::_create_memmap(filename, shape, "r+", ext, mem_adress); trx::detail::remap(trx->streamlines->_data, trx->streamlines->mmap_pos.data(), shape); } diff --git a/src/detail/dtype_helpers.cpp b/src/detail/dtype_helpers.cpp index 8bf056c..2d07b09 100644 --- a/src/detail/dtype_helpers.cpp +++ b/src/detail/dtype_helpers.cpp @@ -27,7 +27,9 @@ int _sizeof_dtype(const std::string &dtype) { return sizeof(float); if (dtype == "float64") return sizeof(double); - return sizeof(std::uint16_t); // default to 16-bit float size + if (dtype == "float16") + return sizeof(std::uint16_t); + throw TrxDTypeError("Unrecognized dtype: " + dtype); } std::string _get_dtype(const std::string &dtype) { diff --git a/tests/test_trx_mmap.cpp b/tests/test_trx_mmap.cpp index 09e8224..b689d57 100644 --- a/tests/test_trx_mmap.cpp +++ b/tests/test_trx_mmap.cpp @@ -526,7 +526,8 @@ TEST(TrxFileMemmap, __sizeof_dtype_values) { EXPECT_EQ(trx::detail::_sizeof_dtype("int64"), sizeof(int64_t)); EXPECT_EQ(trx::detail::_sizeof_dtype("float32"), sizeof(float)); EXPECT_EQ(trx::detail::_sizeof_dtype("float64"), sizeof(double)); - EXPECT_EQ(trx::detail::_sizeof_dtype("unknown"), sizeof(uint16_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("float16"), sizeof(uint16_t)); + EXPECT_THROW(trx::detail::_sizeof_dtype("unknown"), trx::TrxDTypeError); } // asserts dtype code mapping.